In [None]:
from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive

Mounted at /gdrive
/gdrive


In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn as nn
import torchvision
from tqdm import tqdm
import matplotlib.pyplot as plt

import sys
sys.path.append("/gdrive/MyDrive/work_space/GAN_stuff")

from utils.mnist_dataset import MNISTDataset

Constants

In [None]:
DATA_DIR = "/gdrive/MyDrive/work_space/GAN_stuff/MNIST"

IMAGE_SIZE = 32
N_CHANNELS = 1

NN classes

In [None]:
# NN classes

class Generator(nn.Module):
    def __init__(self, n_latent_dims):
        super(Generator, self).__init__()
        self.init_size = IMAGE_SIZE // 4
        self.linear_1 = nn.Linear(n_latent_dims, 128 * self.init_size ** 2)


        self.gen = nn.Sequential(
          nn.BatchNorm2d(128),
          nn.Upsample(scale_factor=2), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
          nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
          nn.Conv2d(64, N_CHANNELS, kernel_size=3, stride=1, padding=1),
          nn.Tanh()
        )

    def forward(self, z):
        z = self.linear_1(z)
        z = z.view(z.shape[0], 128, self.init_size, self.init_size)
        img = self.gen(z)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.feature_size = IMAGE_SIZE // 2**4

        self.conv_layers = nn.Sequential(
            nn.Conv2d(N_CHANNELS, 16, kernel_size=3, stride=2, padding=1), nn.LeakyReLU(0.2), 
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(32), nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)
        )
        
        self.classifier = nn.Sequential(
                                        nn.Linear(128 * self.feature_size**2, 1),
                                        nn.Sigmoid()
                                       )

    def forward(self, img):
        features = self.conv_layers(img)
        decision = self.classifier(features.view(features.shape[0], -1))
        return decision


Utility functions

In [None]:
def get_noise_vector(z_size):
	return torch.rand(z_size)

Training config

In [None]:
device = 'cuda'

n_latent_dims = 100

n_epochs = 100
batch_size = 64

Initialization

In [None]:
train_dataset = MNISTDataset(DATA_DIR, subset_name='train')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

generator = Generator(n_latent_dims).to(device)
discriminator = Discriminator().to(device)

gen_optimizer = Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discrim_optimizer = Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss(reduction='mean')

Training loop

In [None]:
for epoch in range(1, n_epochs):

  print("Epoch:", epoch)
  epoch_discrim_loss = 0
  epoch_gen_loss = 0

  for batch in train_loader:

    # Discriminator phase --
    real_images = batch['image'].to(device).float()
    ones = torch.ones(real_images.shape[0], 1, device=device)

    noise_batch = get_noise_vector(z_size=(batch_size, n_latent_dims)).to(device)
    with torch.no_grad():
      fake_images = generator(noise_batch)
    zeros = torch.zeros(fake_images.shape[0], 1, device=device)

    discrim_optimizer.zero_grad()
    reals_loss = criterion(discriminator(real_images), ones)
    fakes_loss = criterion(discriminator(fake_images), zeros)
    discrim_loss = (reals_loss + fakes_loss) / 2
    discrim_loss.backward(retain_graph=True)
    discrim_optimizer.step()


    # Generator phase --
    gen_optimizer.zero_grad()
    noise_batch = get_noise_vector(z_size=(batch_size, n_latent_dims)).to(device)
    fake_images = generator(noise_batch)
    discrim_pred = discriminator(fake_images)
    ones = torch.ones(batch_size, 1, device=device)

    gen_loss = criterion(discrim_pred, ones)
    gen_loss.backward()
    gen_optimizer.step()

    # Accumulate losses
    epoch_discrim_loss += discrim_loss.item()
    epoch_gen_loss += gen_loss.item()


  epoch_discrim_loss /= len(train_loader)
  epoch_gen_loss /= len(train_loader)
  print("D loss:", epoch_discrim_loss, "| G loss:", epoch_gen_loss)

  if epoch % 1 == 0:
    grid = torchvision.utils.make_grid(fake_images.cpu().detach())
    plt.imshow(grid.permute(1,2,0))
    plt.show()



Output hidden; open in https://colab.research.google.com to view.

tensor([[0.4217, 0.3376, 0.6501],
        [0.0981, 0.6784, 0.5426],
        [0.7805, 0.9909, 0.1721]])