# Lab 2: Debug a Broken Vanilla GAN (find 12+ issues)

In [None]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Add normalization and resizing to the transform
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), transforms.Resize((32, 32))])


loader = DataLoader(
    torchvision.datasets.MNIST('./data', True, download=True, transform=transform),
    batch_size=256,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)
z_dim = 100
# Adjust learning rates
g_lr = 2e-4
d_lr = 2e-4

class D(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1), #32*32 -> 16*16
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(32, 64, 4, 2, 1), # 16*16 -> 8*8
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 4, 2, 1), # 8*8 -> 4*4
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 1, 4, 4, 0), # 4*4 -> 1*1
            #nn.Sigmoid() # Using BCEWithLogitsLoss, no sigmoid needed here
        )

    def forward(self, x):
        return self.net(x).view(x.size(0), 1)

class G(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(100, 128, 4, 1, 0), # latent 1x1 to 4x4
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), # 4x4 to 8x8
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), # 8x8 to 16x16
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),  # 16x16 to 32x32
            nn.Tanh() # Output should be in tanh range
        )

    def forward(self, z):
        z = z.view(z.size(0), z_dim, 1, 1)
        return self.net(z)

Dnet = D().to(device)
Gnet = G().to(device)
crit = nn.BCEWithLogitsLoss()
opt_d = torch.optim.Adam(Dnet.parameters(), lr=d_lr, betas=(0.9, 0.999))
opt_g = torch.optim.Adam(Gnet.parameters(), lr=g_lr, betas=(0.9, 0.999))

In [None]:
from tqdm import tqdm
for epoch in range(10):

  for real, _ in tqdm(loader):
      real = real.to(device)
      b = real.size(0)

      opt_d.zero_grad()
      z = torch.randn(b, z_dim, device=device)
      fake = Gnet(z) # Removed redundant view here

      loss_d = crit(Dnet(fake), torch.zeros(b, 1, device=device)) + \
              crit(Dnet(real), torch.ones(b, 1, device=device))
      loss_d.backward()
      opt_d.step()

      opt_g.zero_grad()
      z = torch.randn(b, z_dim, device=device)
      fake = Gnet(z) # Removed redundant view here
      loss_g = crit(Dnet(fake), torch.ones(b, 1, device=device)) # Non-saturating loss
      loss_g.backward()
      opt_g.step()
  print(f'Epoch {epoch+1}/{10}, Loss D: {loss_d.item()}, Loss G: {loss_g.item()}')

print('Now fix all the issues.')

In [None]:

import matplotlib.pyplot as plt
import torchvision

# Function to plot a batch of images
def show_images(real_images, fake_images, n=8):
    """
    real_images, fake_images: tensors (batch_size, 1, H, W)
    n: number of images per row
    """
    # Undo normalization (if using Normalize((0.5,), (0.5,)))
    real_images = real_images * 0.5 + 0.5
    fake_images = fake_images * 0.5 + 0.5

    # Convert to CPU and numpy for plotting
    real_images = real_images.detach().cpu()
    fake_images = fake_images.detach().cpu()

    # Plot real images
    plt.figure(figsize=(n*2, 4))
    for i in range(n):
        plt.subplot(2, n, i+1)
        plt.imshow(real_images[i].squeeze(), cmap='gray')
        plt.axis('off')
        if i == 0:
            plt.ylabel("Real", fontsize=12)

    # Plot fake images
    for i in range(n):
        plt.subplot(2, n, n+i+1)
        plt.imshow(fake_images[i].squeeze(), cmap='gray')
        plt.axis('off')
        if i == 0:
            plt.ylabel("Fake", fontsize=12)

    plt.tight_layout()
    plt.show()


# -----------------------------
# Example usage after one batch
# -----------------------------
real, _ = next(iter(loader))
real = real.to(device)
b = real.size(0)

z = torch.randn(b, z_dim, device=device)
fake = Gnet(z.view(b, z_dim, 1, 1))

show_images(real, fake, n=8)  # shows first 8 real and fake images