# ðŸ§± DCGAN - Bricks Data

This notebook is an **unofficial PyTorch implementation** of the excellent [Keras example](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/04_gan/01_dcgan/dcgan.ipynb) for Deep Convolutional GAN, originally created by David Foster as part of the companion code for the excellent book [Generative Deep Learning, 2nd Edition](https://www.oreilly.com/library/view/generative-deep-learning/9781098134174/).

_The original code is available [here](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition) and is licensed under the Apache License 2.0._
_This implementation is distributed under the Apache License 2.0. See the LICENSE file for details._

In this notebook, we'll walk through the steps required to train your own DCGAN on the bricks dataset using PyTorch

In [None]:
%load_ext autoreload
%autoreload 2

import os

# Get the working directory and the current notebook directory
working_dir = os.getcwd()
exp_dir = os.path.join(working_dir, "notebooks/04_gan/01_dcgan/")

In [None]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset
from torch import nn
import torch.nn.functional as F
from torchinfo import summary
from torch import optim
import torch
from torch.utils.tensorboard import SummaryWriter
from matplotlib import pyplot as plt

from notebooks.utils import display

import math

In [None]:
IMAGE_SIZE = 64
CHANNELS = 1
BATCH_SIZE = 128
Z_DIM = 100
EPOCHS = 300
LOAD_MODEL = False
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.999
LEARNING_RATE = 0.0002
NOISE_PARAM = 0.1

## 1. Prepare the data <a name="prepare"></a>

In [None]:
data_dir = working_dir + "/data"
dataset_dir = data_dir + "/lego-brick-images"

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.Grayscale(1),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)
])

full_data = datasets.ImageFolder(dataset_dir, transform=transform)


In [None]:
print("full dataset size = ", len(full_data))
print("full dataset labels = ", full_data.classes)

In [None]:
# we want to use the data in the dataset folder only
required_class = "dataset"
req_class_idx = full_data.class_to_idx[required_class]
req_idxs = [i for i, (_, label) in enumerate(full_data) if label==req_class_idx]
print("size of required indces = ", len(req_idxs))


In [None]:
train_data = Subset(full_data, req_idxs)
print("size of training dataset = ", len(train_data))

In [None]:
# Create a dataset
train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

In [None]:
train_iter = iter(train_data_loader)
sample_images, _ = next(train_iter)

In [None]:
display(sample_images)

## 2. Build the GAN <a name="build"></a>

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels, image_size):
        super().__init__()
        self.channels = channels
        self.image_size = image_size

        self.dropout = nn.Dropout(0.3)
        
        # layer group 1
        p = self._get_padding_size(input_w=self.image_size, stride=2, kernal_size=4)

        self.conv_1 = nn.Conv2d(in_channels=self.channels, out_channels=64, 
                                kernel_size=4, stride=2, padding=p, bias=False)
        
        # layer group 2 
        p = self._get_padding_size(input_w=self.image_size/2, stride=2, kernal_size=4)

        self.conv_2 = nn.Conv2d(in_channels=64, out_channels=128, 
                               kernel_size=4, stride=2, padding=p, bias=False)
        
        self.bn_2 = nn.BatchNorm2d(num_features= 128, momentum=0.9)


        # layer group 3
        p = self._get_padding_size(input_w=self.image_size/4, stride=2, kernal_size=4)

        self.conv_3 = nn.Conv2d(in_channels=128, out_channels=256, 
                               kernel_size=4, stride=2, padding=p, bias=False)
        self.bn_3 = nn.BatchNorm2d(num_features=256, momentum=0.9)

        # layer group 4
        p = self._get_padding_size(input_w=self.image_size/8, stride=2, kernal_size=4)
        self.conv_4 = nn.Conv2d(in_channels=256, out_channels=512, 
                               kernel_size=4, stride=2, padding=p, bias=False)
        self.bn_4 = nn.BatchNorm2d(num_features=512, momentum=0.9)

        self.conv_5 = nn.Conv2d(in_channels=512, out_channels=1, 
                               kernel_size=4, stride=1, padding=0, bias=False)


    
    @staticmethod
    def _get_padding_size(input_w, stride, kernal_size):
        p = ((input_w /2) - 1) * stride
        p = (p - input_w) + kernal_size
        p = math.ceil(p/2)

        return p
    
    def forward(self, x):
        B = x.shape[0]
        x = self.conv_1(x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout(x)

        x = self.conv_2(x)
        x = self.bn_2(x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout(x)

        x = self.conv_3(x)
        x = self.bn_3(x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout(x)

        x = self.conv_4(x)
        x = self.bn_4(x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout(x)

        x = self.conv_5(x)
        x = F.sigmoid(x)

        x = x.view((B,1))

        return x


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
discriminator = Discriminator(CHANNELS, IMAGE_SIZE)

print(discriminator.state_dict)

In [None]:
summary(discriminator, (1,CHANNELS, IMAGE_SIZE, IMAGE_SIZE))

In [None]:
class Generator(nn.Module):
    def __init__(self, num_dim, channels):
        super().__init__()

        self.num_dim = num_dim
        self.channels = channels

        self.conv_trans_1 = nn.ConvTranspose2d(in_channels=self.num_dim, out_channels=512, 
                                               kernel_size=4, stride=1, padding=0, bias=False, output_padding=0)
        self.bn_1 = nn.BatchNorm2d(num_features=512, momentum=0.9)

        p = self._get_padding_size(input_w=4, stride=2, kernal_size=4)
        self.conv_trans_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, 
                                               kernel_size=4, stride=2, padding=p, output_padding=0, bias=False)
        
        self.bn_2 = nn.BatchNorm2d(num_features=256, momentum=0.9)

        p = self._get_padding_size(input_w=4*2, stride=2, kernal_size=4)
        self.conv_trans_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, 
                                               kernel_size=4, stride=2, padding=p, output_padding=0, bias=False)
        self.bn_3 = nn.BatchNorm2d(num_features=128, momentum=0.9)

        p = self._get_padding_size(input_w=4*4, stride=2, kernal_size=4) 
        self.conv_trans_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, 
                                               kernel_size=4, stride=2, padding=p, output_padding=0, bias=False)
        self.bn_4 = nn.BatchNorm2d(num_features=64, momentum=0.9)

        p = self._get_padding_size(input_w=4*8, stride=2, kernal_size=4)
        self.conv_trans_5 = nn.ConvTranspose2d(in_channels=64, out_channels=self.channels, 
                                               kernel_size=4, stride=2, padding=p, output_padding=0, bias=False)

    
    @staticmethod
    def _get_padding_size(input_w, stride, kernal_size):
        p = ((input_w - 1) * stride) / 2
        p = p - input_w
        p = p + (kernal_size / 2)
        p = p + 1/2
        return math.floor(p)
    
    def forward(self, x):
        B = x.shape[0]
        x = x.view((B, self.num_dim, 1, 1))
        x = self.conv_trans_1(x)
        x = self.bn_1(x)
        x = F.leaky_relu(x, 0.2)

        x = self.conv_trans_2(x)
        x = self.bn_2(x)
        x = F.leaky_relu(x, 0.2)

        x = self.conv_trans_3(x)
        x = self.bn_3(x)
        x = F.leaky_relu(x, 0.2)

        x = self.conv_trans_4(x)
        x = self.bn_4(x)
        x = F.leaky_relu(x, 0.2)

        x = self.conv_trans_5(x)
        x = F.tanh(x)

        return (x)


In [None]:
generator = Generator(Z_DIM, CHANNELS)
print(generator.state_dict)

In [None]:
summary(generator, (1, Z_DIM))

DCGAN class

In [None]:
class DCGAN (nn.Module):
    def __init__(self, num_dim, channels, image_size, log_dir="./log/"):
        super().__init__()
        self.num_dim = num_dim
        self.channels = channels
        self.image_size = image_size
        
        self.generator = Generator(num_dim, channels)
        self.discriminator = Discriminator(channels, image_size)

        self.writer_train = SummaryWriter(log_dir + "/train")
    
    # this function will not be used, we just implment it to be able to use
    # the torchinfo summary function
    def forward(self, x):
        x = self.generator(x)
        x = self.discriminator(x)
        return x

    def train_step(self, real_images):
        # set the dicremenator and generator to training mode
        self.generator.train()
        self.discriminator.train()

        # zero the grads
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

        # generate fake images
        B = real_images.shape[0]
        input_noise = torch.randn((B, self.num_dim)).to(self.device)

        real_images = real_images.to(device)

        fake_images = self.generator(input_noise)

        # train discremeantor
        fake_pred = self.discriminator(fake_images)
        real_pred = self.discriminator(real_images)

        # prepare labels
        fake_lables = torch.zeros_like(fake_pred) + (self.label_noise * torch.rand_like(fake_pred))
        # fake_lables = torch.clamp(fake_lables, min=0.0)
        real_lables = torch.ones_like(real_pred) - (self.label_noise * torch.rand_like(real_pred))
        # real_lables = torch.clamp(real_lables, max=1.0)

        # fake_lables = torch.zeros_like(fake_pred) 
        # real_lables = torch.ones_like(real_pred) 

        #calculate dicremenator loss 
        d_loss_fake = self.d_loss_fn(fake_pred, fake_lables)
        d_loss_real = self.d_loss_fn(real_pred, real_lables)
        d_loss = (d_loss_fake + d_loss_real) / 2

        # calculate gradiants
        d_loss.backward()
        # update waits
        self.d_optimizer.step()

        # train generator
        fake_images = self.generator(input_noise)
        # train discremeantor
        fake_pred = self.discriminator(fake_images)

        g_loss = self.g_loss_fn(fake_pred, real_lables)
        #cal gradients
        g_loss.backward()
        #update waits
        self.g_optimizer.step()

        loss_dict = {"d_loss_fake":d_loss_fake.item(), "d_loss_real":d_loss_real.item(), 
                     "d_loss": d_loss.item(), "g_loss":g_loss.item()}
        
        return loss_dict


    def fit(self, training_dataloader, epochs, g_optimizer, d_optimizer, 
            d_loss_fn, g_loss_fn, device, labels_noise=0.1, callbacks=None):
        
        self.g_optimizer = g_optimizer
        self.d_optimizer = d_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn
        self.label_noise = labels_noise
        self.device = device


        for i in range(1, epochs+1):

            losses = {"d_fake_loss_acc": 0,
                      "d_real_loss_acc": 0,
                      "d_loss_acc": 0,
                      "g_loss_acc": 0}

            # loop over all data in the training set
            for images, _ in training_dataloader:

                # run training_step
                loss_dict = self.train_step(images)
                losses["d_fake_loss_acc"] += loss_dict["d_loss_fake"]
                losses["d_real_loss_acc"] += loss_dict["d_loss_real"]
                losses["d_loss_acc"] += loss_dict["d_loss"]
                losses["g_loss_acc"] += loss_dict["g_loss"]
            

            losses["d_fake_loss_acc"] /= len(train_data)
            losses["d_real_loss_acc"] /= len(train_data)
            losses["d_loss_acc"] /= len(train_data)
            losses["g_loss_acc"] /= len(train_data)
            
            # print epoch progress
            print(
                f"Epoch {i}/{epochs}: Training: d_fake_loss: {losses['d_fake_loss_acc'] :.4f} "
                f" d_real_loss: {losses['d_real_loss_acc']:.4f} "
                f" d_loss: {losses['d_loss_acc']:.4f}"
                f" g_loss: {losses['g_loss_acc']:.4f}"
            )
            # log loss to tensorboard
            self.writer_train.add_scalar("d_fake_loss", losses["d_fake_loss_acc"], global_step=i)
            self.writer_train.add_scalar("d_real_loss", losses["d_real_loss_acc"], global_step=i)
            self.writer_train.add_scalar("d_loss",losses["d_loss_acc"], global_step=i)
            self.writer_train.add_scalar("g_loss", losses["g_loss_acc"], global_step=i)
            
            # run call back functions
            if callbacks is not None:
                logs = {"device":self.device,
                        "generator":self.generator,
                        "model_state_dict": self.state_dict(),
                        "loss": losses
                }

                for callback in callbacks:
                    callback.on_epoch_end(i, logs=logs)


create the required callbacks

In [None]:
class Callback:
    def on_epoch_end(self, epoch, logs=None):
        pass

In [None]:
class GenerateImages(Callback):
    def __init__(self, num_images, latent_dim, save_dir="./gen_examples"):
        super().__init__()
        self.num_images = num_images
        self.latent_dim = latent_dim
        self.save_dir = save_dir
    def on_epoch_end(self, epoch, logs=None):
        device = logs["device"]
        generator = logs["generator"]

        input_noise = torch.randn((self.num_images, self.latent_dim)).to(device)

        with torch.no_grad():
            generator.eval()
            # scale back to 0 to 255
            gen_imgs = generator(input_noise).detach() * 127.5 + 127.5
            display(gen_imgs, save_to=self.save_dir+f"/epoch_{epoch}.png")
        
        return


In [None]:
class SaveCheckpoint(Callback):
    def __init__(self, save_dir, save_every=10):
        super().__init__()
        self.save_dir = save_dir
        self.save_every = save_every
    def on_epoch_end(self, epoch, logs=None):
        
        if (epoch % self.save_every) == 0:
            checkpoint = {"epoch":epoch,
                        "model_state_dict":logs["model_state_dict"],
                        "loss":logs["loss"]
                        }
            checkpoint_file = self.save_dir + f"/checkpoint_{epoch}.pth"

            torch.save(checkpoint, checkpoint_file)

Create the DCGAN object and train it

In [None]:
log_dir =  exp_dir + "/log"
os.makedirs(log_dir, exist_ok=True)

sample_dir =  exp_dir + "/sample_gen"
os.makedirs(sample_dir, exist_ok=True)

checkpoint_dir =  exp_dir + "/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
dcgan = DCGAN(Z_DIM, CHANNELS, IMAGE_SIZE, log_dir).to(device)
print(dcgan.state_dict)

In [None]:
summary(dcgan, (1, Z_DIM))

In [None]:
g_loss_function = nn.BCELoss()
d_loss_function = nn.BCELoss()

g_optimizer = optim.Adam(dcgan.generator.parameters(), lr=LEARNING_RATE, betas=[ADAM_BETA_1, ADAM_BETA_2])
d_optimizer = optim.Adam(dcgan.discriminator.parameters(), lr=LEARNING_RATE, betas=[ADAM_BETA_1, ADAM_BETA_2])

In [None]:
callbacks = [GenerateImages(10, Z_DIM, save_dir=sample_dir),
             SaveCheckpoint(save_dir=checkpoint_dir, save_every=30)]

In [None]:
# check if we have checkpoint to load
if LOAD_MODEL:
    checkpoint_file = checkpoint_dir + "/checkpoint_270.pth"
    checkpoint = torch.load(checkpoint_file)
    dcgan.load_state_dict(checkpoint["model_state_dict"])

In [None]:
dcgan.fit(train_data_loader, epochs=EPOCHS, g_optimizer=g_optimizer, d_optimizer=d_optimizer,
          d_loss_fn=d_loss_function, g_loss_fn=g_loss_function, device=device, callbacks=callbacks,
          labels_noise=NOISE_PARAM)

## 3. Generate new images <a name="decode"></a>

In [None]:
# Sample some points in the latent space, from the standard normal distribution
grid_width, grid_height = (10, 3)
z_sample = torch.randn((grid_width * grid_height, Z_DIM))

In [None]:
with torch.no_grad():
    dcgan.eval()
    reconstructions = dcgan.generator(z_sample.to(device))
    
reconstructions_np = reconstructions.to("cpu").permute(0, 2, 3, 1).numpy()

In [None]:
# Draw a plot of decoded images
fig = plt.figure(figsize=(18, 5))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

# Output the grid of faces
for i in range(grid_width * grid_height):
    ax = fig.add_subplot(grid_height, grid_width, i + 1)
    ax.axis("off")
    ax.imshow(reconstructions_np[i, :, :], cmap="Greys")

In [None]:
def compare_images(img1, img2):
    return torch.mean(torch.abs(img1 - img2))

In [None]:
all_data = []
for images, _ in train_data_loader:
        all_data.extend(images)



In [None]:
r, c = 3, 5
fig, axs = plt.subplots(r, c, figsize=(10, 6))
fig.suptitle("Generated images", fontsize=20)

noise = torch.randn((r * c, Z_DIM))

with torch.no_grad():
    dcgan.eval()
    gen_imgs = dcgan.generator(noise.to(device))


gen_imgs_np = gen_imgs.to("cpu").permute(0, 2, 3, 1).numpy()
cnt = 0
for i in range(r):
    for j in range(c):
        axs[i, j].imshow(gen_imgs_np[cnt], cmap="gray_r")
        axs[i, j].axis("off")
        cnt += 1

plt.show()

In [None]:
fig, axs = plt.subplots(r, c, figsize=(10, 6))
fig.suptitle("Closest images in the training set", fontsize=20)

cnt = 0
for i in range(r):
    for j in range(c):
        c_diff = 99999
        c_img = None
        for k_idx, k in enumerate(all_data):
            diff = compare_images(gen_imgs[cnt].to("cpu"), k.to("cpu"))
            if diff < c_diff:
                c_img = k.permute(1, 2, 0).numpy()
                c_diff = diff
        axs[i, j].imshow(c_img, cmap="gray_r")
        axs[i, j].axis("off")
        cnt += 1

plt.show()