<a href="https://colab.research.google.com/github/keshav-b/ML-DL-stuff/blob/master/GANs/GANs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [71]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

import tqdm.notebook as tq

# Simple GAN

In [70]:
transforms = transforms.Compose([
                                   transforms.ToTensor(), 
                                   transforms.Normalize((0.5,), (0.5,))
                                   ])

In [6]:
batch_size = 32
dataset = datasets.FashionMNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

In [8]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)

In [9]:
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
num_epochs = 5

In [75]:
gen = Generator(z_dim, image_dim).to(device)
disc = Discriminator(image_dim).to(device)

TypeError: ignored

In [11]:
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

In [12]:
for epoch in range(num_epochs):
    for x, _ in tq.tqdm(loader):
        x = x.view(-1, 784).to(device)
        batch_size = x.shape[0]

        # DISC TRAIN: obj = max log(D(x)) + log(1-D(G(z)))
        #             obj = max log(D_x) + log(1-D_G_z)

        D_x = disc(x).view(-1)

        z = torch.randn(batch_size, z_dim).to(device) # noise
        G_z = gen(z)
        D_G_z = disc(G_z.detach()).view(-1) # flatten | detach: re-use G_z for GEN

        lossD = criterion(D_x, torch.ones_like(D_x)) + criterion(D_G_z, torch.zeros_like(D_G_z))
        lossD  = lossD / 2   # ????

        disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        # GEN TRAIN: obj = min  log(1-D(G(z))) = max log(D(G(z)))
        #             obj = min  log(1-D_G_z_) = max log(1-D_G_z_)

        D_G_z_ = disc(G_z).view(-1) # flatten

        lossG = criterion(D_G_z_, torch.ones_like(D_G_z_))

        gen.zero_grad()
        lossG.backward()
        opt_gen.step()


    print(f"Epoch [{epoch}/{num_epochs}] \t LossD: {lossD:.4f} \t LossG: {lossG:.4f}")

HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Epoch [0/5] 	 LossD: 0.2038 	 LossG: 1.7877


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Epoch [1/5] 	 LossD: 0.2965 	 LossG: 2.1221


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Epoch [2/5] 	 LossD: 0.3500 	 LossG: 1.9642


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Epoch [3/5] 	 LossD: 0.5073 	 LossG: 1.3444


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Epoch [4/5] 	 LossD: 0.4392 	 LossG: 1.4552


# DCGAN


In [18]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, d_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(in_channels, d_features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self._conv_block(d_features, d_features*2, 4, 2, 1),
            self._conv_block(d_features*2, d_features*4, 4, 2, 1),
            self._conv_block(d_features*4, d_features*8, 4, 2, 1),
            nn.Conv2d(d_features*8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid()
        )

    def _conv_block(self, in_channels, out_channels, kernel_size, stride, padding):
      return nn.Sequential(
          nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
      )

    def forward(self, x):
        return self.disc(x)

In [76]:
class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, g_features):
        super().__init__()
        self.gen = nn.Sequential(
            self._conv_block(z_dim, g_features*16, 4, 1, 0),
            self._conv_block(g_features*16, g_features*8, 4, 2, 1),
            self._conv_block(g_features*8, g_features*4, 4, 2, 1),
            self._conv_block(g_features*4, g_features*2, 4, 2, 1),
            nn.ConvTranspose2d(g_features*2, in_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def _conv_block(self, in_channels, out_channels, kernel_size, stride, padding):
      return nn.Sequential(
          nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
          nn.BatchNorm2d(out_channels),
          nn.ReLU()
      )

    def forward(self, x):
        return self.gen(x)

In [77]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

In [78]:
def test():
  N, in_channels, H, W = 8, 1, 64, 64
  z_dim = 100
  x = torch.randn((N, in_channels, H, W))
  disc = Discriminator(in_channels, d_features=8)
  initialize_weights(disc)
  assert disc(x).shape == (N, 1, 1, 1), "Wrong dim for D"

  gen = Generator(z_dim, in_channels, g_features=8)
  initialize_weights(gen)
  z = torch.randn((N, z_dim, 1, 1))
  assert gen(z).shape == (N, in_channels, H, W), "Wrong dim for G"

In [79]:
test()

In [80]:
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 2e-4
z_dim = 100
image_dim = 64
num_epochs = 5
d_features = 64
g_features = 64

In [72]:
transforms = transforms.Compose([
                                  transforms.Resize(image_dim),
                                  transforms.ToTensor(), 
                                  transforms.Normalize((0.5,), (0.5,))
                                   ])

In [81]:
batch_size = 128
dataset = datasets.FashionMNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [87]:
gen = Generator(z_dim=z_dim, in_channels=1, g_features=g_features).to(device)
disc = Discriminator(in_channels=1, d_features=d_features).to(device)

initialize_weights(disc)
initialize_weights(gen)

In [89]:
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()

In [92]:
for epoch in range(num_epochs):
    for x, _ in tq.tqdm(loader):
        x = x.to(device)

        z = torch.randn(batch_size, z_dim, 1, 1).to(device) # noise
        # DISC TRAIN: obj = max log(D(x)) + log(1-D(G(z)))
        #             obj = max log(D_x) + log(1-D_G_z)

        D_x = disc(x).view(-1)

  
        G_z = gen(z)
        D_G_z = disc(G_z.detach()).view(-1) # flatten | detach: re-use G_z for GEN

        lossD = criterion(D_x, torch.ones_like(D_x)) + criterion(D_G_z, torch.zeros_like(D_G_z))
        lossD  = lossD / 2   # ????

        disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        # GEN TRAIN: obj = min  log(1-D(G(z))) = max log(D(G(z)))
        #             obj = min  log(1-D_G_z_) = max log(1-D_G_z_)

        D_G_z_ = disc(G_z).view(-1) # flatten

        lossG = criterion(D_G_z_, torch.ones_like(D_G_z_))

        gen.zero_grad()
        lossG.backward()
        opt_gen.step()


    print(f"Epoch [{epoch}/{num_epochs}] \t LossD: {lossD:.4f} \t LossG: {lossG:.4f}")

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch [0/5] 	 LossD: 0.6442 	 LossG: 0.8332


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch [1/5] 	 LossD: 0.6956 	 LossG: 0.7654


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch [2/5] 	 LossD: 0.7048 	 LossG: 0.7607


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch [3/5] 	 LossD: 0.6399 	 LossG: 0.7673


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch [4/5] 	 LossD: 0.7016 	 LossG: 0.4932
