In [2]:
from torch.utils.data import Dataset, DataLoader
import torchvision
import numpy as np
from torchvision import transforms
from torch import nn
import torch
from torch import optim
import tensorboard
import torchvision
from torch.utils.tensorboard import SummaryWriter
import torchvision.datasets as datasets

torch.cuda.empty_cache()

In [3]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, fd):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, fd, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(fd, fd * 2, 4, 2, 1),
            nn.BatchNorm2d(fd * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(fd * 2, fd * 4, 4, 2, 1),
            nn.BatchNorm2d(fd * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(fd * 4, fd * 8, 4, 2, 1),
            nn.BatchNorm2d(fd * 8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(fd * 8, 1, 4, 2, 0),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, fg):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(channels_noise, fg * 16, 4, 1, 0),
            nn.BatchNorm2d(fg * 16),
            nn.ReLU(),  
            nn.ConvTranspose2d(fg * 16, fg * 8, 4, 2, 1),
            nn.BatchNorm2d(fg * 8),
            nn.ReLU(),  
            nn.ConvTranspose2d(fg * 8, fg * 4, 4, 2, 1),
            nn.BatchNorm2d(fg * 4),
            nn.ReLU(),  
            nn.ConvTranspose2d(fg * 4, fg * 2, 4, 2, 1),
            nn.BatchNorm2d(fg * 2),
            nn.ReLU(),  
            nn.ConvTranspose2d(fg * 2, channels_img, 4, 2, 1),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.net(x)


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if isinstance(m,nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)


In [4]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # critic scores
    mixed_scores = critic(interpolated_images)

    # gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# Hyperparameters 
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 5e-5
bs = 64
img_size = 64
no_channels = 3
z_size = 100
epochs = 5
fc = 64
fg = 64
crit_iter = 5
lamdaa = 10

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(img_size),
        transforms.Normalize(
            [0.5 for _ in range(no_channels)], [0.5 for _ in range(no_channels)]
        ),
    ]
)


dataset = datasets.ImageFolder(root='Bitmoji-Faces', transform=transform)

loader = DataLoader(dataset, batch_size=bs, shuffle=True)


gen = Generator(z_size, no_channels, fg).to(device)
critic = Discriminator(no_channels, fc).to(device)
initialize_weights(gen)
initialize_weights(critic)


opt_gen = optim.RMSprop(gen.parameters(), lr=lr)
opt_critic = optim.RMSprop(critic.parameters(), lr=lr)

# for tensorboard plotting
fixed_noise = torch.randn(32, z_size, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
critic.train()

for epoch in range(epochs):
    for batch_idx, (data, _) in enumerate(loader):
        data = data.to(device)
        cur_batch_size = data.shape[0]

        # max E[critic(real)] - E[critic(fake)]
        for _ in range(crit_iter):
            noise = torch.randn(cur_batch_size, z_size, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(data).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic,data,fake,device=device)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + lamdaa * gp 
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()



        # max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # print to tensorboard
        if batch_idx % 100 == 0 and batch_idx > 0:
            gen.eval()
            critic.eval()
            print(
                f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                img_grid_real = torchvision.utils.make_grid(
                    data[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1
            gen.train()
            critic.train()

Epoch [0/5] Batch 100/2035                   Loss D: -0.5362, loss G: -0.0119
Epoch [0/5] Batch 200/2035                   Loss D: 0.5708, loss G: -0.0386
Epoch [0/5] Batch 300/2035                   Loss D: 0.2384, loss G: -0.2373
Epoch [0/5] Batch 400/2035                   Loss D: 0.6278, loss G: -0.3796
Epoch [0/5] Batch 500/2035                   Loss D: -0.1947, loss G: -0.2303
Epoch [0/5] Batch 600/2035                   Loss D: -0.2974, loss G: -0.1531
Epoch [0/5] Batch 700/2035                   Loss D: -0.5889, loss G: -0.1627
Epoch [0/5] Batch 800/2035                   Loss D: 0.2336, loss G: -0.1702
Epoch [0/5] Batch 900/2035                   Loss D: -0.4959, loss G: -0.4040
Epoch [0/5] Batch 1000/2035                   Loss D: -0.6435, loss G: -0.1689
Epoch [0/5] Batch 1100/2035                   Loss D: 0.1258, loss G: -0.8791
Epoch [0/5] Batch 1200/2035                   Loss D: -0.6311, loss G: -0.2651
Epoch [0/5] Batch 1300/2035                   Loss D: -0.4515, los