# WGAN-GP — Bug-Fix Labs - 10 Bugs to Fix

In [None]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Data  (CIFAR-10 32×32)
# ---------------------------
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])  # OK for Tanh
])
ds = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform)
loader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# ---------------------------
# Hyperparams
# ---------------------------
z_dim = 128
g_lr  = 2e-4
d_lr  = 2e-4
n_critic = 5                      # BUG : 1 itr is not utilizing the advs of critic --> needs to be modified to 1 -> 5 n_critic
lambda_gp = 10.0

# ---------------------------
# Models
# ---------------------------
class Critic(nn.Module):
    def __init__(self, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, ch,   4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch, ch*2, 4, 2, 1), nn.InstanceNorm2d(ch*2, affine=True), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*2, ch*4, 4, 2, 1), nn.InstanceNorm2d(ch*4, affine=True), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*4, 1, 4, 1, 0),
            #nn.Sigmoid()                 # BUG : sigmoid should not exist
        )
    def forward(self, x): return self.net(x).view(x.size(0))

class Gen(nn.Module):
    def __init__(self, z=128, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z,   ch*4, 4, 1, 0, bias=False), nn.BatchNorm2d(ch*4), nn.ReLU(True),
            nn.ConvTranspose2d(ch*4, ch*2, 4, 2, 1, bias=False), nn.BatchNorm2d(ch*2), nn.ReLU(True),
            nn.ConvTranspose2d(ch*2, ch,   4, 2, 1, bias=False), nn.BatchNorm2d(ch),   nn.ReLU(True),
            nn.ConvTranspose2d(ch,   3,    4, 2, 1, bias=False), nn.Tanh()
        )
    def forward(self, z): return self.net(z.view(z.size(0), z.size(1), 1, 1))

D = Critic().to(device)
G = Gen(z_dim).to(device)

# ---------------------------
# Loss placeholders & Optimizers (WRONG for WGAN… intentionally)
# ---------------------------
# bce = nn.BCEWithLogitsLoss()      # BUG: no need for bce since we are dealing with continuous critic
optG = torch.optim.Adam(G.parameters(), lr=g_lr, betas=(0.0, 0.9))   # BUG: Wgans gp recommends diffrent betas(0,0.9) for more stability
optD = torch.optim.Adam(D.parameters(), lr=d_lr, betas=(0.0, 0.9))   # BUG: Wgans gp recommends diffrent betas(0,0.9) for more stability

# ---------------------------
# Broken gradient penalty
# ---------------------------
def gradient_penalty(Dnet, real, fake):
    b = real.size(0)
    eps = torch.rand(b,1,1,1, device=real.device)             # BUG: eps should be uniformly distributed instead of normally destributed
    x_hat = eps*real + (1-eps)*fake
    x_hat.requires_grad_(True) # Fix: Enable gradient calculation for x_hat
    # BUG:
    d_hat = Dnet(x_hat)
    grads = torch.autograd.grad(
        outputs=d_hat.sum(),
        inputs=x_hat,
        create_graph=True,     # <-- critical so GP updates the critic
        retain_graph=True,     # safe for multiple uses in the loop
        only_inputs=True
    )[0]    # WRONG: we are trying to compute the grad of x_hat, not the real samples
    grad_norm = grads.view(b, -1).norm(2,dim=1)
    gp = lambda_gp * ((grad_norm - 1.0)**2).mean()            # WRONG: for the wgan gp, the algo requires a L2 norm of gradiants instead of rawsum of the components
    return gp

# ---------------------------
# Training loop (intentionally wrong)
# ---------------------------
for step, (real, _) in enumerate(loader):
    real = real.to(device)
    b = real.size(0)

    # -- Critic updates --
    for _ in range(n_critic):                    # BUG: fixed above with n_critic
        z = torch.randn(b, z_dim, device=device)
        fake = G(z).detach()                              # BUG

        # WRONG
        d_real = D(real)
        d_fake = D(fake)
        # BUG

        lossD = (d_fake.mean() - d_real.mean()) + gradient_penalty(D, real, fake) # bce is not appropriete with wgan gp

        optD.zero_grad()
        lossD.backward()
        optD.step()                              # BUG
        # (also missing)

        # BUG
        for p in D.parameters():
            p.data.clamp_(-0.01, 0.01)

    # -- Generator update --
    z = torch.randn(b, z_dim, device=device)
    fake = G(z)
    # WRONG
    lossG = -D(fake).mean()   # BUG
    lossG.backward()
    optG.step()                              # BUG

    if step > 10:   # keep the broken demo short
        break

print("Your task: fix all bugs until the WGAN-GP training runs stably.")