# My experiments with GAN

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import multiprocessing

from PIL import Image
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

In [10]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x64
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            # N x features_d x 32 x 32
            nn.LeakyReLU(0.2),
            self._block(features_d,   features_d*2, kernel_size=4, stride=2, padding=1),
            # N x features_d*2 x 16 x 16
            self._block(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1),
            # N x features_d*4 x 8 x 8
            self._block(features_d*4, features_d*8, kernel_size=4, stride=2, padding=1),
            # N x features_d*8 x 4 x 4
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=1, padding=0),
            # N x 1 x 1 x 1
            nn.Flatten(),
            # output: N x 1
            # nn.Sigmoid(),  # not needed for torch.nn.BCEWithLogitsLoss()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            # bias=False for BatchNorm
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            # do not normalize across the batches, normalize only across the layer (instance)
            nn.InstanceNorm2d(out_channels, affine=True),  # LayerNorm <--> InstanceNorm
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # input: N x z_dim x 1 x 1
            self._block(z_dim,         features_g*16, kernel_size=4, stride=1, padding=0),
            # N x features_g*16 x 4 x 4
            self._block(features_g*16, features_g*8,  kernel_size=4, stride=2, padding=1),
            # N x features_g*8 x 8 x 8
            self._block(features_g*8,  features_g*4,  kernel_size=4, stride=2, padding=1),
            # N x features_g*4 x 16 x 16
            self._block(features_g*4,  features_g*2,  kernel_size=4, stride=2, padding=1),
            # N x features_g*2 x 32 x 32
            nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1),
            # output: N x channels_img x 64 x 64
            nn.Tanh(),  # between (-1, 1)
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            # do not normalize across the batches, normalize only across the layer (instance)
            nn.InstanceNorm2d(out_channels, affine=True),  # LayerNorm <--> InstanceNorm
            # nn.ReLU(),  # like in DCGAN paper
            nn.LeakyReLU(0.2),
        )
    
    def forward(self, x):
        return self.gen(x)


def init_weights(model):
    ''' Initialize weights of the model
        with mean of 0.0 and standard deviation of 0.02 '''
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)


def test():
    N, in_channels, height, width = 8, 3, 64, 64
    z_dim = 100
    features_d = features_g = 64
    
    x = torch.randn((N, in_channels, height, width))
    disc = Discriminator(in_channels, features_d)
    init_weights(disc)
    assert disc(x).shape == (N, 1)
    
    gen = Generator(z_dim, in_channels, features_g)
    init_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, height, width)
    
    print('Test is OK')


test()

Test is OK


In [11]:
# hyperparameters
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 1  # MNIST dataset
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_D = FEATURES_G = 64
LOG_DIR = 'logs'
CHECKPOINT_DISC = 'disc.pth.tar'
CHECKPOINT_GEN = 'gen.pth.tar'
NUM_WORKERS = multiprocessing.cpu_count()


transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)],
                         [0.5 for _ in range(CHANNELS_IMG)]),
])


def save_checkpoint(model, optimizer, step, filename):
    print("=> Saving checkpoint")
    checkpoint = {'state_dict': model.state_dict(),
                  'optimizer': optimizer.state_dict(),
                  'step': step,}
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('=> Loading checkpoint')
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    # if we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    step = checkpoint["step"]
    return step

In [4]:
# get MNIST dataset
data_dir = '/content/dataset/MNIST/raw/'
if os.path.exists(data_dir):
    !rm -rf $data_dir

!mkdir $data_dir
!wget --directory-prefix=$data_dir https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/t10k-images-idx3-ubyte.gz
!wget --directory-prefix=$data_dir https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/t10k-labels-idx1-ubyte.gz
!wget --directory-prefix=$data_dir https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/train-images-idx3-ubyte.gz
!wget --directory-prefix=$data_dir https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/train-labels-idx1-ubyte.gz

dataset = datasets.MNIST(root='dataset/', train=True, transform=transform, download=True)

mkdir: cannot create directory ‘/content/dataset/MNIST/raw/’: No such file or directory
--2021-03-29 10:58:37--  https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/t10k-images-idx3-ubyte.gz
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/master/mnist/data/t10k-images-idx3-ubyte.gz [following]
--2021-03-29 10:58:37--  https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/master/mnist/data/t10k-images-idx3-ubyte.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1648877 (1.6M) [application/octet-stream]
Saving to: ‘/content/dataset/MNIST/raw/t10k-images-id

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [15]:
def main():
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,
                        num_workers=NUM_WORKERS, pin_memory=True)
    # initialize generator and discriminator
    disc = Discriminator(CHANNELS_IMG, FEATURES_D).to(DEVICE)
    gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_G).to(DEVICE)
    init_weights(disc)
    init_weights(gen)

    optim_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    optim_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    scaler_disc = torch.cuda.amp.GradScaler()
    scaler_gen = torch.cuda.amp.GradScaler()
    criterion = nn.BCEWithLogitsLoss()

    fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(DEVICE)

    # set writer for TensorBoard
    writer_real = SummaryWriter(os.path.join(LOG_DIR, 'real'))
    writer_fake = SummaryWriter(os.path.join(LOG_DIR, 'fake'))
    step = 0

    # load models
    if os.path.exists(CHECKPOINT_DISC) and os.path.exists(CHECKPOINT_GEN):
        step = load_checkpoint(CHECKPOINT_DISC, disc, optim_disc, LEARNING_RATE)
        step = load_checkpoint(CHECKPOINT_GEN, gen, optim_gen, LEARNING_RATE)

    disc.train()
    gen.train()
    
    for epoch in range(1, NUM_EPOCHS+1):
        step = train(disc, gen, optim_disc, optim_gen, criterion, loader,
                     scaler_disc, scaler_gen, writer_real, writer_fake,
                     epoch, step, fixed_noise)

        # save models
        save_checkpoint(disc, optim_disc, step, CHECKPOINT_DISC)
        save_checkpoint(gen , optim_gen, step, CHECKPOINT_GEN)


def train(disc, gen, optim_disc, optim_gen, criterion, loader,
          scaler_disc, scaler_gen, writer_real, writer_fake,
          epoch, step, fixed_noise):
    loop = tqdm(loader, leave=False)
    loop.set_description(f'Epoch [{epoch}/{NUM_EPOCHS}]')

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]

        # train Discriminator: max log(D(real)) + log(1-D(G(noise)))
        noise = torch.randn((cur_batch_size, Z_DIM, 1, 1)).to(DEVICE)
        with torch.cuda.amp.autocast():
            fake = gen(noise)
            disc_real = disc(real)
            disc_fake = disc(fake.detach())
            loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
            loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
            loss_disc = (loss_disc_real + loss_disc_fake) / 2.0

        optim_disc.zero_grad()
        scaler_disc.scale(loss_disc).backward()
        scaler_disc.step(optim_disc)
        scaler_disc.update()

        # train Generator: min log(1-D(G(noise))) or max log(D(G(noise)))
        with torch.cuda.amp.autocast():
            disc_fake = disc(fake)
            loss_gen = criterion(disc_fake, torch.ones_like(disc_fake))

        optim_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(optim_gen)
        scaler_gen.update()

        if batch_idx % 100 == 0:
            with torch.no_grad():
                fake = gen(fixed_noise)
            # take out up to 32 examples
            img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
            img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
            writer_fake.add_image('Fake', img_grid_fake, global_step=step)
            writer_real.add_image('Real', img_grid_real, global_step=step)
            step += 1

            loop.set_postfix(loss_d=f'{loss_disc:.4f}', loss_g=f'{loss_gen:.4f}')

    return step

In [None]:
# Run TensorBoard

# Delete previous logs dir
if os.path.exists(LOG_DIR):
    !rm -rf $LOG_DIR

# To fix the error, because PyTorch and TensorFlow are installed both:
# AttributeError: module 'tensorflow._api.v2.io.gfile' has no attribute 'get_filesystem'
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

# Load the TensorBoard notebook extension
%load_ext tensorboard

# Start TensorBoard before training to monitor it in progress
%tensorboard --logdir $LOG_DIR

# Reload TensorBoard
%reload_ext tensorboard

In [17]:
main()

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


=> Saving checkpoint
=> Saving checkpoint
