# Conditional GAN (cGAN) — Bug-Fix Labs - 10 Bugs to Fix

In [None]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Data  (MNIST 28×28 → 32×32)
# ---------------------------
transform = transforms.Compose([
    transforms.Resize(32),
    # transforms.ToTensor()                 # BUG (1)
    transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) # BUG KILLED: gan training is more stable with inputs normalized to [-1, 1]
])
ds = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
loader = DataLoader(ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# ---------------------------
# Hyperparams
# ---------------------------
z_dim = 128
emb_dim = 50
num_classes = 10
# g_lr  = 2e-3   # BUG
g_lr  = 2e-4   # BUG BURNED: generator learning rate was too high compared to discriminator, leading to instability
# d_lr  = 2e-5   # BUG
d_lr  = 2e-4   # BUG BURNED: discriminator learning rate was too low, making it unable to keep up with the generator

# ---------------------------
# Models
# ---------------------------
class CGAN_G(nn.Module):
    def __init__(self, z=128, emb_dim=50, num_classes=10, ch=64):
        super().__init__()

        self.z, self.emb_dim, self.num_classes, self.ch = z, emb_dim, num_classes, ch

        self.emb = nn.Embedding(self.num_classes, self.emb_dim)
        self.proj = nn.Linear(self.z + self.emb_dim, self.ch*8*4*4)

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(self.ch*8),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.ch*8, self.ch*4, 4, 2, 1),

            nn.BatchNorm2d(self.ch*4),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.ch*4, self.ch*2, 4, 2, 1),

            nn.BatchNorm2d(self.ch*2),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.ch*2, self.ch, 4, 2, 1),

            nn.BatchNorm2d(self.ch),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.ch, 1, 3, 1, 1),
            nn.Tanh() # BUG MURDERED : Hello tanh
        )

    def forward(self, z, labels):

        # we want to turn the classes into embedding vectors
        # we concatenate along feature dimension (dim=1) so each sample's noise + label embedding are combined
        emb = self.emb(labels)
        x = torch.cat([z, emb], 1)
        x = self.proj(x) # project combined vector into higher-dim space for CNN input
        x = x.view(-1, self.ch*8, 4, 4) # reshape into 4x4 feature maps (start of image)
        x = self.conv_blocks(x)
        return x

class CGAN_D(nn.Module):
    def __init__(self, emb_dim=50, num_classes=10, ch=64):
        super().__init__()

        self.emb_dim, self.num_classes, self.ch = emb_dim, num_classes, ch

        self.emb = nn.Embedding(self.num_classes, self.emb_dim)
        self.proj = nn.Linear(self.emb_dim, 32*32) # BUG BURNED: embedding output was not projected to image size, making concatenation incorrect

        self.conv_blocks = nn.Sequential(
            nn.Conv2d(2, self.ch, 3, 2, 1), # BUG BURNED: changed in_channels from 1 to 2 because of the concatenated embedding layer
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ch, self.ch*2, 3, 2, 1),
            nn.BatchNorm2d(self.ch*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ch*2, self.ch*4, 3, 2, 1),
            nn.BatchNorm2d(self.ch*4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ch*4, self.ch*8, 3, 2, 1),
            nn.BatchNorm2d(self.ch*8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Flatten(),
            nn.Linear(self.ch*8*2*2, 1) # BUG BURNED: changed input features from ch*8*4*4 to ch*8*2*2 because of pooling
        )

    def forward(self, x, labels):
        emb = self.emb(labels)
        emb = self.proj(emb).view(-1, 1, 32, 32) # BUG BURNED: embedding output was not reshaped to image size, making concatenation incorrect
        x = torch.cat([x, emb], 1)
        x = self.conv_blocks(x)
        return x

# ---------------------------
# Training
# ---------------------------
g = CGAN_G(z=z_dim, emb_dim=emb_dim, num_classes=num_classes).to(device)
d = CGAN_D(emb_dim=emb_dim, num_classes=num_classes).to(device)

g_opt = torch.optim.Adam(g.parameters(), lr=g_lr, betas=(0.5, 0.999))
d_opt = torch.optim.Adam(d.parameters(), lr=d_lr, betas=(0.5, 0.999))

criterion = nn.BCEWithLogitsLoss()

# BUG BURNED: loss was not initialized
g_losses = []
d_losses = []

for epoch in range(5):
    for i, (imgs, labels) in enumerate(loader):
        imgs = imgs.to(device)
        labels = labels.to(device)

        # Train Discriminator
        d_opt.zero_grad()

        real_preds = d(imgs, labels)
        real_loss = criterion(real_preds, torch.ones_like(real_preds))

        z = torch.randn(imgs.size(0), z_dim).to(device)
        fake_imgs = g(z, labels)
        fake_preds = d(fake_imgs.detach(), labels) # BUG BURNED: fake images were not detached, leading to backprop through generator
        fake_loss = criterion(fake_preds, torch.zeros_like(fake_preds))

        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_opt.step()

        # Train Generator
        g_opt.zero_grad()

        fake_preds = d(fake_imgs, labels)
        g_loss = criterion(fake_preds, torch.ones_like(fake_preds))

        g_loss.backward()
        g_opt.step()

        # BUG BURNED: losses were not appended
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/100], Step [{i+1}/{len(loader)}], D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}')