# GAN

 - Implement a DCGAN (https://arxiv.org/abs/1511.06434)
 - Train the model to generate CIFAR-like images
    - Use Tensorboard
    - Use it to generate CIFAR-like images

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [2]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_disc):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # Input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_disc, kernel_size=4, stride=2, padding=1
            ), # 32 x 32
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self.d_block(features_disc, features_disc * 2, 4, 2, 1), # 16 x 16
            self.d_block(features_disc * 2, features_disc * 4, 4, 2, 1), # 8 x 8
            self.d_block(features_disc * 4, features_disc * 8, 4, 2, 1), # 4 x 4
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_disc * 8, 1, kernel_size = 4, stride = 2, padding = 0), # 1 x 1
            nn.Sigmoid(),
        )
        
    def d_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
        
    def forward(self, x):
        return self.disc(x)
    
    
    
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_gen):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self.g_block(z_dim, features_gen * 16, 4, 1, 0), # N x f_g * 16 x 4 x 4
            self.g_block(features_gen * 16, features_gen * 8, 4, 2, 1), # 8 x 8
            self.g_block(features_gen * 8, features_gen * 4, 4, 2, 1), # 16 x 16
            self.g_block(features_gen * 4, features_gen * 2, 4, 2, 1), # 32 x 32
            nn.ConvTranspose2d(
                features_gen * 2, channels_img, kernel_size = 4, stride = 2, padding = 1,
            ),
            nn.Tanh(), # [-1, 1]
        )
        
    def g_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.gen(x)
    

In [3]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [4]:
def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    initialize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1)
    gen = Generator(z_dim, in_channels, 8)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W)
    print("Success")

In [5]:
test()

Success


In [6]:
# Hyperparameters etc.
device = torch.device("cuda")
learning_rate = 2e-4 # could also use two lrs, one for gen and one for disc
batch_sz = 128
img_sz = 64
channels_img = 1 # MNIST channels_img = 1
z_dim = 100
num_epochs = 5
features_disc = 64
features_gen = 64

transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(img_sz),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            [0.5 for _ in range(channels_img)], [0.5 for _ in range(channels_img)]),
        
    ]
)

dataset = datasets.MNIST(root="../data", train=True, transform=transforms, download=False)
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
loader = DataLoader(dataset, batch_size=batch_sz, shuffle=True)
gen = Generator(z_dim, channels_img, features_gen).to(device)
disc = Discriminator(channels_img, features_disc).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=learning_rate, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=learning_rate, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, z_dim, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
# writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
# writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        noise = torch.randn((batch_sz, z_dim, 1, 1)).to(device)
        fake = gen(noise)
        
        ### Train Discriminator max log(D(x)) + log(1 - D(G(Z)))
        disc_real = disc(real).reshape(-1) # N
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward(retain_graph=True)
        opt_disc.step()
        
        ### Train Generator min log(1 - D(G(z))) <--> max log(D(G(z)))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        # Print losses to tensorboard
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch} / {num_epochs}] Batch {batch_idx} / {len(loader)} / Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}")
            
            with torch.no_grad():
                fake = gen(fixed_noise)
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
                
                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
                
            step += 1

Epoch [0 / 5] Batch 0 / 469 / Loss D: 0.6964, loss G: 0.7818
Epoch [0 / 5] Batch 100 / 469 / Loss D: 0.0156, loss G: 4.0726
Epoch [0 / 5] Batch 200 / 469 / Loss D: 1.0042, loss G: 0.2800
Epoch [0 / 5] Batch 300 / 469 / Loss D: 0.3033, loss G: 1.0671
Epoch [0 / 5] Batch 400 / 469 / Loss D: 0.5324, loss G: 0.4630
Epoch [1 / 5] Batch 0 / 469 / Loss D: 0.6772, loss G: 0.7211
Epoch [1 / 5] Batch 100 / 469 / Loss D: 0.6229, loss G: 1.0814
Epoch [1 / 5] Batch 200 / 469 / Loss D: 0.6265, loss G: 0.8070
Epoch [1 / 5] Batch 300 / 469 / Loss D: 0.6181, loss G: 0.9548
Epoch [1 / 5] Batch 400 / 469 / Loss D: 0.6320, loss G: 0.9531
Epoch [2 / 5] Batch 0 / 469 / Loss D: 0.5991, loss G: 0.9246
Epoch [2 / 5] Batch 100 / 469 / Loss D: 0.6538, loss G: 0.7472
Epoch [2 / 5] Batch 200 / 469 / Loss D: 0.6187, loss G: 0.7318
Epoch [2 / 5] Batch 300 / 469 / Loss D: 0.5905, loss G: 0.9050
Epoch [2 / 5] Batch 400 / 469 / Loss D: 0.6212, loss G: 0.7227
Epoch [3 / 5] Batch 0 / 469 / Loss D: 0.5954, loss G: 1.0293
