# DCGAN on Oxford 102 Flower

In this notebook we train a Deep Convolutional GAN (DCGAN) on Oxford 102 Flower dataset. This datasets consists of 102 flower classes, each containing between 40 and 256 instances. We leverage data augmentation to increase the size of the dataset and reduce overfitting.

## 1. Background



Historically, DCGAN represents significant progress in developing more scalable GAN architectures. The authors adopt three changes to the standard CNN architecture and show that these allow for stable training on higher resolution images.

1. Replace deterministic pooling functions (e.g. max pooling) with learned downsampling/upsampling layers (e.g. strided convolutions).

2. Remove dense layers after convolutional stacks.

3. Use batch normalization in order to stabalize training, reduce the influence of poor initialization and improve gradient flow. In the context of GANs, batch normalization has shown to be effective in reducing mode collapse. To prevent perpetual oscilation between the generator and discriminator however, the authors do not apply batch normalization to either the generator output or discriminator input.

![](https://i.ibb.co/0jc930m/Capture.png)

The authors specifically use the following architecture for their generator.

![](https://i.ibb.co/YPhrm08/Capture.png)

## 2. Implementation

### 2.1. Setup

In [1]:
#@markdown

from typing import TypeAlias, Any, Callable, Tuple, List
from dataclasses import dataclass

import os
import sys
import tempfile
import time
import math
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torch.optim import SGD, Adam

from torchvision.transforms import (
    Compose,
    Normalize,
    Resize,
    RandomCrop,
    RandomRotation,
    RandomHorizontalFlip,
    RandomVerticalFlip,
    ColorJitter,
    ToTensor,
    PILToTensor,
    ToPILImage,

)

from torchvision.datasets import Flowers102
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import pandas as pd
import seaborn as sns

In [2]:
#@markdown

EPOCHS = 300
BATCH_SIZE = 128

IMAGE_WIDTH  = 64
IMAGE_HEIGHT = 64
IMAGE_NORMALIZATION_MEAN = 0.5
IMAGE_NORMALIZATION_STANDARD_DEVIATION = 0.5

# FABRIC_RANKS = 8
# FABRIC_ACCELERATOR = 'gpu'
# FABRIC_STRATEGY = 'ddp'

LATENT_DIMENSION = 100

In [3]:
#@markdown

transform = Compose([
    ToTensor(),
    Resize((IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=0),
    Normalize(
        mean=IMAGE_NORMALIZATION_MEAN,
        std=IMAGE_NORMALIZATION_STANDARD_DEVIATION,
    ),
])

train_dataset = Flowers102(
    root='.',
    split='train',
    transform=transform,
    download=True,
)

validation_dataset = Flowers102(
    root='.',
    split='test',
    transform=transform,
    download=True
)

dataset = ConcatDataset((train_dataset, validation_dataset))

Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to flowers-102/102flowers.tgz


100%|██████████| 344862509/344862509 [00:02<00:00, 158644186.50it/s]


Extracting flowers-102/102flowers.tgz to flowers-102
Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/imagelabels.mat to flowers-102/imagelabels.mat


100%|██████████| 502/502 [00:00<00:00, 1471377.08it/s]


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/setid.mat to flowers-102/setid.mat


100%|██████████| 14989/14989 [00:00<00:00, 19689452.76it/s]


In [4]:
#@markdown

# def setup(rank, world_size):
#     os.environ['MASTER_ADDR'] = 'localhost'
#     os.environ['MASTER_PORT'] = '12355'

#     # initialize the process group
#     dist.init_process_group("gloo", rank=rank, world_size=world_size)

# def cleanup():
#     dist.destroy_process_group()

# setup(0, 1)

In [5]:
#@markdown

# fabric = L.Fabric(
#     accelerator=FABRIC_ACCELERATOR,
#     devices=FABRIC_RANKS,
#     strategy=FABRIC_STRATEGY,
# )

# fabric.launch()

In [6]:
#@markdown

# Using for instance `number_of_crops = 4`, we can obtain a dataset with many
# more images - from 7000 to 100,000.

def renormalize(x: torch.Tensor) -> torch.Tensor:
    return (x - x.min()) / (x.max() - x.min())

def get_augments(x: torch.Tensor, number_of_crops: int = 4) -> List[torch.Tensor]:

    augments = []

    random_vertical_flip = RandomVerticalFlip(p=1)
    random_horizontal_flip = RandomHorizontalFlip(p=1)
    random_crop =  RandomCrop(size=(IMAGE_WIDTH // 2, IMAGE_HEIGHT // 2))
    # random_color_jitter = ColorJitter(hue=.04)
    resize = Resize(size=(IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=0)

    for _ in range(number_of_crops):
        crop = random_crop(x)
        crop = resize(crop)

        augments.append(crop)
        augments.append(random_horizontal_flip(crop))
        augments.append(random_vertical_flip(crop))

    augments.append(random_horizontal_flip(x))
    augments.append(random_vertical_flip(x))

    random.shuffle(augments)

    return augments


print('Creating augmented dataset (wait 1 minute)...')

augmented_dataset = []

for x, _ in dataset:
    augmented_dataset += get_augments(x.cuda())

Creating augmented dataset (wait 1 minute)...


### 2.2. Models

In [7]:
#@markdown

generator = nn.Sequential(

    # Projection from (-1, LATENT_DIMENSION, 1, 1) to (-1, 1024, 4, 4).

    nn.ConvTranspose2d(in_channels=LATENT_DIMENSION, out_channels=1024, kernel_size=4,  stride=2, padding=0),
    nn.ReLU(),
    nn.BatchNorm2d(1024),

    # Upsampling stack.

    nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(512),

    nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),

    nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(128),

    nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(3),

    nn.Tanh(),
).cuda()


discriminator = nn.Sequential(

    # Downsampling stack.

    nn.Conv2d(in_channels=3, out_channels=128, kernel_size=4, stride=2, padding=1),
    nn.LeakyReLU(0.2),

    nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2),

    nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2),

    nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(1024),
    nn.LeakyReLU(0.2),

    # Projection from (-1, 1024, 4, 4) to (-1, 1, 1, 1).

    nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0),
    nn.Flatten(),

    nn.Sigmoid(),
).cuda()

### 2.3. Training

In [8]:
#@markdown

!rm -rf samples checkpoints
!mkdir samples checkpoints

In [None]:
#@markdown

#dataloader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE)


def get_fake_batch(batch_size: int) -> torch.Tensor:

    z = torch.randn(batch_size, LATENT_DIMENSION).view(-1, LATENT_DIMENSION, 1, 1).cuda()
    x = generator(z)

    return x


number_of_batches = len(augmented_dataset) // BATCH_SIZE
criterion = nn.BCELoss()
generator_optimizer = Adam(generator.parameters(), lr=1e-4)
discriminator_optimizer = Adam(discriminator.parameters(), lr=1e-4)

random.shuffle(augmented_dataset)

discriminator_losses = []
generator_losses = []

for epoch in range(EPOCHS):
    for batch_index in range(0, number_of_batches):
    # for batch_index, (real_batch, _) in enumerate(dataloader):

        # Train the discriminator.

        discriminator_optimizer.zero_grad()

        # if real_batch.size(0) != BATCH_SIZE:
        #     continue

        # real_batch = real_batch.cuda()

        real_batch_slice = slice(batch_index * BATCH_SIZE, (batch_index + 1) * BATCH_SIZE)
        real_batch = torch.cat(augmented_dataset[real_batch_slice]).view(BATCH_SIZE, 3, IMAGE_HEIGHT, IMAGE_WIDTH).cuda()
        real_labels = torch.ones((BATCH_SIZE, 1)).cuda()
        fake_batch = get_fake_batch(BATCH_SIZE).detach()
        fake_labels = torch.zeros((BATCH_SIZE, 1)).cuda()

        real_loss = criterion(discriminator(real_batch), real_labels)
        fake_loss = criterion(discriminator(fake_batch), fake_labels)
        discriminator_loss = real_loss + fake_loss
        discriminator_loss.backward()

        discriminator_optimizer.step()

        # Train the generator.

        generator_optimizer.zero_grad()

        fake_batch = get_fake_batch(BATCH_SIZE)
        real_labels = torch.ones((BATCH_SIZE, 1)).cuda()
        generator_loss = criterion(discriminator(fake_batch), real_labels)
        generator_loss.backward()

        generator_optimizer.step()

        # Record losses.

        discriminator_losses.append(discriminator_loss.item())
        generator_losses.append(generator_loss.item())

        if (batch_index + 1) % 5 == 0:
            mean_discriminator_loss = sum(discriminator_losses) / len(discriminator_losses)
            mean_generator_loss = sum(generator_losses) / len(generator_losses)
            discriminator_losses.clear()
            generator_losses.clear()

            print(f'[INFO] Epoch {epoch}/{EPOCHS}, Batch {batch_index}/{number_of_batches} - generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}')

    print('[INFO] Saving samples...')

    sample_images = renormalize(fake_batch[: 64].detach()).cpu().view(64, 3, IMAGE_HEIGHT, IMAGE_WIDTH)
    save_image(sample_images.data, os.path.join('samples', f'epoch-{epoch}-batch-{batch_index}.png'))

    print('[INFO] Saving checkpoints...')

    torch.save(discriminator, f'checkpoints/epoch-{epoch}-discriminator.pt')
    torch.save(generator, f'checkpoints/epoch-{epoch}-generator.pt')

[INFO] Epoch 0/300, Batch 4/784 - generator loss: 2.1348724603652953, discriminator loss: 0.8206498503684998
[INFO] Epoch 0/300, Batch 9/784 - generator loss: 3.823407793045044, discriminator loss: 0.43898946046829224
[INFO] Epoch 0/300, Batch 14/784 - generator loss: 6.043527984619141, discriminator loss: 0.09605499356985092
[INFO] Epoch 0/300, Batch 19/784 - generator loss: 8.511597061157227, discriminator loss: 0.09670139849185944
[INFO] Epoch 0/300, Batch 24/784 - generator loss: 11.256410217285156, discriminator loss: 0.10804294347763062
[INFO] Epoch 0/300, Batch 29/784 - generator loss: 13.341007614135743, discriminator loss: 0.08272217959165573
[INFO] Epoch 0/300, Batch 34/784 - generator loss: 13.845421028137206, discriminator loss: 0.026516332104802132
[INFO] Epoch 0/300, Batch 39/784 - generator loss: 12.294143104553223, discriminator loss: 0.031355221197009085
[INFO] Epoch 0/300, Batch 44/784 - generator loss: 12.533626365661622, discriminator loss: 0.01875983886420727
[INFO