<a href="https://colab.research.google.com/github/bhushan1729/Machine-Learning-Algorithm/blob/main/WGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import os

In [2]:
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

z_dim = 100
batch_size = 128
epochs = 20
lr = 5e-5
n_critic = 1
clip_value = 0.01
img_size = 28
channels = 1
img_shape = (channels, img_size, img_size)

os.makedirs("wgan_dcgan", exist_ok=True)

In [3]:
# Data Loader
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataloader = DataLoader(
    datasets.MNIST(
        ".",
        train=True,
        download=True,
        transform=transform
    ),
    batch_size=batch_size,
    shuffle=True
)

In [17]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.init_size = img_size // 4  # 7x7
        self.fc = nn.Linear(z_dim, 128 * self.init_size ** 2)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 7 -> 14
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1),    # 14 -> 28 (fixed)
            nn.Tanh()
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(z.size(0), 128, self.init_size, self.init_size)
        return self.conv_blocks(x)


In [18]:
# Critic
"""class Critic(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
        nn.Conv2d(1, 64, 4, stride=2, padding=1), # 28 to 14
        nn.LeakyReLU(0.2),
        nn.Conv2d(64, 128, 4, stride=2, padding=1), # 14 to 7
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),
        nn.Flatten(),
        nn.Linear(128 * 7 * 7, 1) # Added out_features=1
    )

  def forward(self, x):
    return self.model(x)"""

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.flatten = nn.Flatten()
        # figure out linear input size dynamically
        test_input = torch.zeros(1, 1, 28, 28)  # adjust if your images differ
        out_dim = self.conv(test_input).view(1, -1).size(1)
        self.fc = nn.Linear(out_dim, 1)

    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        return self.fc(x)


In [19]:
# Initialize Models and Optimizers
generator = Generator().to(device)
critic = Critic().to(device)

optimizer_G = optim.RMSprop(generator.parameters(), lr=lr)
optimizer_D = optim.RMSprop(critic.parameters(), lr=lr)

In [20]:
# Training Loop
for epoch in range(1, epochs + 1):
  for i, (real_imgs, _) in enumerate(dataloader):

    real_imgs = real_imgs.to(device)
    b_size = real_imgs.size(0)

    # Train Critic
    for _ in range(n_critic):
      z = torch.randn(b_size, z_dim).to(device)
      fake_imgs = generator(z).detach()

      loss_c = -torch.mean(critic(real_imgs)) + torch.mean(critic(fake_imgs))
      critic.zero_grad()
      loss_c.backward()
      optimizer_D.step()

      # Weight clipping for Lipschitz constraint
      for p in critic.parameters():
        p.data.clamp_(-clip_value, clip_value)


    # Train Generator
    z = torch.randn(b_size, z_dim).to(device)
    gen_imgs = generator(z)
    loss_g = -torch.mean(critic(gen_imgs))

    optimizer_G.zero_grad()
    loss_g.backward()
    optimizer_G.step()

    if i % 100 == 0:
      print(f"Epoch [{epoch}/{epochs}] [Batch [{i}/{len(dataloader)}] "
      f"[Loss Critic: {loss_c.item():.4f}] [Loss G: {loss_g.item():.4f}]")

  # Save sample every epoch
  generator.eval()
  with torch.no_grad():
    z = torch.randn(64, z_dim).to(device)
    sample = generator(z)
    sample = sample*0.5 + 0.5 # Denormalize
    save_image(sample, f'wgan_dcgan/epoch_{epoch}.png', nrow = 8)
  generator.train()


Epoch [1/20] [Batch [0/469] [Loss Critic: -0.1524] [Loss G: 0.0155]
Epoch [1/20] [Batch [100/469] [Loss Critic: -0.0550] [Loss G: -0.0024]
Epoch [1/20] [Batch [200/469] [Loss Critic: -0.0376] [Loss G: -0.0274]
Epoch [1/20] [Batch [300/469] [Loss Critic: -0.0373] [Loss G: -0.0242]
Epoch [1/20] [Batch [400/469] [Loss Critic: -0.0359] [Loss G: -0.0270]
Epoch [2/20] [Batch [0/469] [Loss Critic: -0.0341] [Loss G: -0.0255]
Epoch [2/20] [Batch [100/469] [Loss Critic: -0.0348] [Loss G: -0.0263]
Epoch [2/20] [Batch [200/469] [Loss Critic: -0.0465] [Loss G: -0.0124]
Epoch [2/20] [Batch [300/469] [Loss Critic: -0.0662] [Loss G: 0.0035]
Epoch [2/20] [Batch [400/469] [Loss Critic: -0.0730] [Loss G: 0.0212]
Epoch [3/20] [Batch [0/469] [Loss Critic: -0.0902] [Loss G: 0.0441]
Epoch [3/20] [Batch [100/469] [Loss Critic: -0.0915] [Loss G: 0.0561]
Epoch [3/20] [Batch [200/469] [Loss Critic: -0.0966] [Loss G: 0.0343]
Epoch [3/20] [Batch [300/469] [Loss Critic: -0.1034] [Loss G: 0.0486]
Epoch [3/20] [Batch

In [21]:
# Inspect one batch
real_imgs, _ = next(iter(dataloader))
print("Dataset batch shape:", real_imgs.shape)


Dataset batch shape: torch.Size([128, 1, 28, 28])


In [22]:
z = torch.randn(1, z_dim).to(device)
gen_img = generator(z)
print("Generator output shape:", gen_img.shape)

crit_out = critic.conv(gen_img)
print("Critic conv output shape:", crit_out.shape)
print("Flattened size:", crit_out.view(1, -1).size(1))


Generator output shape: torch.Size([1, 1, 28, 28])
Critic conv output shape: torch.Size([1, 128, 7, 7])
Flattened size: 6272
