# A Minimal GAN on MNIST — From Scratch in PyTorch

This notebook trains a **basic GAN** on MNIST with the discriminator acting as a **binary classifier**:
- **Generator** \(G_\theta(z)\) maps noise \(z\sim\mathcal N(0,I)\) to images.
- **Discriminator** \(D_\psi(x)\in(0,1)\) predicts the probability that an image is **real**.
- **Losses (non‑saturating)**  
  \(\displaystyle \mathcal L_D = \mathrm{BCE}(D(x_{\text{real}}),1)+\mathrm{BCE}(D(G(z)),0)\)  
  \(\displaystyle \mathcal L_G = \mathrm{BCE}(D(G(z)),1)\)

## 1) Setup

In [None]:
import os, math, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets, utils as vutils
import matplotlib.pyplot as plt

torch.manual_seed(0); np.random.seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

## 2) Hyperparameters

In [None]:
img_size   = 28
img_dim    = 1 * img_size * img_size
z_dim      = 100
g_widths   = [256, 512, 1024]
d_widths   = [512, 256]

batch_size = 128
epochs     = 30
lr         = 2e-4
betas      = (0.5, 0.999)
label_smooth_real = 0.9

sample_every = 1
fixed_z = torch.randn(64, z_dim, device=device)
out_dir = "gan_mnist_outputs"
os.makedirs(out_dir, exist_ok=True)

## 3) Data loading (MNIST)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

real_batch = next(iter(train_loader))[0][:64]
grid = vutils.make_grid(real_batch, nrow=8, normalize=True, value_range=(-1,1))
plt.figure(figsize=(6,6)); plt.axis("off"); plt.title("Real MNIST samples (normalized)")
plt.imshow(np.transpose(grid.cpu().numpy(), (1,2,0)))
plt.show()

## 4) Models — simple MLP GAN

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=100, widths=[256,512,1024], out_dim=784):
        super().__init__()
        layers = []
        in_f = z_dim
        for w in widths:
            layers += [nn.Linear(in_f, w), nn.BatchNorm1d(w), nn.ReLU(True)]
            in_f = w
        layers += [nn.Linear(in_f, out_dim), nn.Tanh()]
        self.net = nn.Sequential(*layers)
    def forward(self, z):
        x = self.net(z).view(-1, 1, 28, 28)
        return x

class Discriminator(nn.Module):
    def __init__(self, in_dim=784, widths=[512,256]):
        super().__init__()
        layers = []
        in_f = in_dim
        for w in widths:
            layers += [nn.Linear(in_f, w), nn.LeakyReLU(0.2, inplace=True)]
            in_f = w
        layers += [nn.Linear(in_f, 1)]
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.net(x).squeeze(-1)

G = Generator(z_dim, g_widths, img_dim).to(device)
D = Discriminator(img_dim, d_widths).to(device)

print(G); print(D)
print("Params G:", sum(p.numel() for p in G.parameters() if p.requires_grad))
print("Params D:", sum(p.numel() for p in D.parameters() if p.requires_grad))

## 5) Losses & optimizers (non‑saturating GAN)

In [None]:
bce = nn.BCEWithLogitsLoss()
optG = optim.Adam(G.parameters(), lr=lr, betas=betas)
optD = optim.Adam(D.parameters(), lr=lr, betas=betas)

def d_loss(real_logits, fake_logits, smooth=0.0):
    targets_real = torch.full_like(real_logits, 1.0 - smooth)
    targets_fake = torch.zeros_like(fake_logits)
    return bce(real_logits, targets_real) + bce(fake_logits, targets_fake)

def g_loss(fake_logits):
    targets = torch.ones_like(fake_logits)
    return bce(fake_logits, targets)

## 6) Training loop

In [None]:
log = {"D": [], "G": []}; step = 0

for ep in range(1, epochs+1):
    t0 = time.time()
    for real, _ in train_loader:
        real = real.to(device, non_blocking=True)
        B = real.size(0)

        # Update D
        z = torch.randn(B, z_dim, device=device)
        fake = G(z).detach()
        logits_real = D(real)
        logits_fake = D(fake)
        loss_D = d_loss(logits_real, logits_fake, smooth=1.0 - label_smooth_real)
        optD.zero_grad(); loss_D.backward(); optD.step()

        # Update G
        z = torch.randn(B, z_dim, device=device)
        fake = G(z)
        logits_fake = D(fake)
        loss_G = g_loss(logits_fake)
        optG.zero_grad(); loss_G.backward(); optG.step()

        log["D"].append(loss_D.item()); log["G"].append(loss_G.item())
        step += 1

    # Sampling
    if ep % sample_every == 0:
        with torch.no_grad():
            fake = G(fixed_z).cpu()
            grid = vutils.make_grid(fake, nrow=8, normalize=True, value_range=(-1,1))
        plt.figure(figsize=(6,6)); plt.axis("off")
        plt.title(f"Generated samples @ epoch {ep}")
        plt.imshow(np.transpose(grid.numpy(), (1,2,0))); plt.show()
        vutils.save_image(fake, os.path.join(out_dir, f"samples_epoch_{ep:03d}.png"),
                          nrow=8, normalize=True, value_range=(-1,1))
    print(f"[{ep:02d}/{epochs}] D={np.mean(log['D'][-len(train_loader):]):.3f} "
          f"G={np.mean(log['G'][-len(train_loader):]):.3f}  ({time.time()-t0:.1f}s)")

## 7) Loss curves

In [None]:
plt.figure(figsize=(6,3))
plt.plot(log["D"], label="D loss")
plt.plot(log["G"], label="G loss")
plt.xlabel("training step"); plt.ylabel("loss")
plt.legend(); plt.tight_layout(); plt.show()

## 8) Final sampling

In [None]:
with torch.no_grad():
    z = torch.randn(64, z_dim, device=device)
    fake = G(z).cpu()
grid = vutils.make_grid(fake, nrow=8, normalize=True, value_range=(-1,1))
plt.figure(figsize=(6,6)); plt.axis("off")
plt.title("Final generated samples")
plt.imshow(np.transpose(grid.numpy(), (1,2,0))); plt.show()

## 9) (Optional) Minimax generator loss

To try the original **minimax** generator loss (more prone to saturation), replace `g_loss` with:

```python
def g_loss_minimax(fake_logits):
    # minimize E[ log(1 - D(G(z))) ] using BCE with target 0
    targets = torch.zeros_like(fake_logits)
    return bce(fake_logits, targets)
```