In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
from torchvision import transforms


In [None]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d, num_classes, img_size):
    super().__init__()
    self.img_size = img_size
    self.disc = nn.Sequential(
        nn.Conv2d(channels_img+1, features_d, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d*2, 4, 2, 1),
        self._block(features_d*2, features_d*4, 4, 2, 1),
        self._block(features_d*4, features_d*8, 4, 2, 1),
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),

    )
    self.embed = nn.Embedding(num_classes, img_size*img_size)

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
      nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        bias= False,
      ),
      nn.InstanceNorm2d(out_channels, affine=True),
      nn.LeakyReLU(0.2)
    )

  def forward(self, x, labels):
    embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
    x = torch.cat([x, embedding], dim=1)
    return self.disc(x)

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g, num_classes, img_size, embed_size):
    super().__init__()
    self.img_size = img_size
    self.net = nn.Sequential(
        self._block(z_dim + embed_size, features_g*16, 4, 1, 0),
        self._block(features_g*16, features_g*8, 4, 2, 1),
        self._block(features_g*8, features_g*4, 4, 2, 1),
        self._block(features_g*4, features_g*2, 4, 2, 1),
        nn.ConvTranspose2d(
            features_g*2, channels_img, kernel_size=4, stride=2, padding=1
        ),
        nn.Tanh(),
    )
    self.embed = nn.Embedding(num_classes, embed_size)

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False,
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )
  def forward(self, x, labels):
    embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
    x = torch.cat([x, embedding], dim=1)
    return self.net(x)

In [None]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
  N, in_channels, H, W = 8, 3, 64, 64
  z_dim = 100
  x = torch.randn((N, in_channels, H, W))
  disc = Discriminator(in_channels, 8)
  initialize_weights(disc)
  assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
  gen = Generator(z_dim, in_channels, 8)

  z = torch.randn((N, z_dim, 1, 1))
  assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
  print("Success, tests passed!")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
NUM_CLASSES = 10
GEN_EMBEDDING = 100
Z_DIM = 100
NOISE_DIM = Z_DIM
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
LAMDA_GP = 10

transform = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

# If you train on MNIST, remember to set channels_img to 1
dataset = datasets.MNIST(
    root="dataset/", train=True, transform=transform, download=True
)

# comment mnist above and uncomment below if train on CelebA
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMAGE_SIZE, GEN_EMBEDDING).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_DISC,  NUM_CLASSES, IMAGE_SIZE).to(device)
initialize_weights(gen)
initialize_weights(critic)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas= (0.0,0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas= (0.0,0.9))


fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0



In [None]:
gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, labels) in enumerate(dataloader):
        real = real.to(device)
        cur_batch_size = real.shape[0]
        labels = labels.to(device)
        real.requires_grad_(True)

        # Train Critic: max E[critic(real)] - E[critic(fake)] + lambda * ||grad critic(interpolated)||
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
            fake = gen(noise, labels).detach()  # Detach fake from generator graph
            critic_real = critic(real, labels).reshape(-1)
            critic_fake = critic(fake, labels).reshape(-1)
            gp = gradient_penalty(critic, labels, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMDA_GP*gp
            )
            critic.zero_grad()
            loss_critic.backward()
            opt_critic.step()
        noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1).to(device)
        fake_for_gen = gen(noise, labels)
        output = critic(fake_for_gen, labels).reshape(-1)
        loss_gen = -torch.mean(output)


        output = critic(fake, labels).reshape(-1)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()


        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                # Generate labels with a batch size of 32 for fixed_noise
                fixed_labels = torch.randint(0, NUM_CLASSES, (32,), device=device)
                fake = gen(fixed_noise, fixed_labels)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

Epoch [0/5] Batch 0/938                   Loss D: -159.9119, loss G: 152.4591


In [None]:
def gradient_penalty(critic, labels, real ,fake,device="cpu"):
  BATCH_SIZE, C, H, W = real.shape
  epsilon = torch.rand((BATCH_SIZE,1,1,1)).repeat(1,C,H,W).to(device)
  interpolated_images = real*epsilon + fake*(1-epsilon)
  interpolated_images.requires_grad_(True) # Explicitly set requires_grad to True

  mixed_scores = critic(interpolated_images, labels)
  gradient = torch.autograd.grad(
      inputs = interpolated_images,
      outputs = mixed_scores,
      grad_outputs = torch.ones_like(mixed_scores),
      create_graph = True,
      retain_graph = True,
  )[0]

  gradient = gradient.view(gradient.shape[0],-1)
  gradient_norm = gradient.norm(2,dim=1)
  gradient_penalty = torch.mean((gradient_norm - 1)**2)
  return gradient_penalty