# ðŸ§± DCGAN - Bricks Data

In this notebook, we'll walk through the steps required to train your own DCGAN on the bricks dataset

In [1]:
%load_ext autoreload
%autoreload 2

import os, sys, glob, math
from PIL import Image
from tqdm import tqdm, trange

from dotenv import load_dotenv
load_dotenv()
python_path = os.getenv('PYTHONPATH')
data_path = os.getenv('DATA_PATH')
if python_path:
    for path in python_path.split(os.pathsep):
        if path not in sys.path:
            sys.path.append(path)


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchmetrics import MeanMetric
from torchmetrics.classification import BinaryAccuracy


from notebooks.pt_utils import *

## 0. Parameters <a name="parameters"></a>

In [2]:
IMAGE_SIZE = 64
CHANNELS = 1
BATCH_SIZE = 128
Z_DIM = 100
EPOCHS = 300
LOAD_MODEL = False
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.999
LEARNING_RATE = 0.0002
NOISE_PARAM = 0.1
NUM_WORKERS = 24

## 1. Prepare the data <a name="prepare"></a>

In [3]:
class LegoBricksDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform

        self.image_files = glob.glob(os.path.join(self.root, '*.png'))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        image = Image.open(self.image_files[index])

        if self.transform:
            image = self.transform(image)

        return image

In [4]:
train_dataset = LegoBricksDataset(
    root = os.path.join(data_path, 'lego-brick-images/dataset'), 
    transform=transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5),
    ]),
)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, pin_memory_device='cuda')

In [None]:
train_sample = next(iter(train_loader))
display(train_sample)

## 2. Build the GAN <a name="build"></a>

In [6]:
def generate_sample_images(epoch, generator, device, latent_dim=Z_DIM, num_images=10):
    random_latent_vectors = torch.randn(size=(num_images, latent_dim)).to(device)
    images = generator(random_latent_vectors)
    display(images.cpu(), save_to=f'./output/generated_img_{epoch:03d}.png')

In [7]:
class DCGAN(nn.Module):
    def __init__(self, img_size, latent_dim):

        super().__init__()
        
        c, w, h = img_size
        self.latent_dim = latent_dim

        
        class Discriminator(nn.Module):
            def __init__(self):
                super().__init__()

                self.seq = nn.Sequential(
                    
                    nn.Conv2d(c, 64, 4, stride=2, padding=1, bias=False),
                    nn.LeakyReLU(0.2),
                    nn.Dropout2d(0.3),

                    nn.Conv2d(64, 128, 4, stride=2, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=128, momentum=0.9),
                    nn.LeakyReLU(0.2),
                    nn.Dropout2d(0.3),
                    
                    nn.Conv2d(128, 256, 4, stride=2, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=256, momentum=0.9),
                    nn.LeakyReLU(0.2),
                    nn.Dropout2d(0.3),
                    
                    nn.Conv2d(256, 512, 4, stride=2, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=512, momentum=0.9),
                    nn.LeakyReLU(0.2),
                    nn.Dropout2d(0.3),
                    
                    nn.Conv2d(512, 1, 4, stride=1, padding=0, bias=False),
                    nn.Sigmoid(),

                    nn.Flatten(),
                )

            def forward(self, x):
                return self.seq(x)
            
        
        class Generator(nn.Module):
            def __init__(self):
                super().__init__()

                self.seq = nn.Sequential(
                    
                    nn.ConvTranspose2d(latent_dim, 512, 4, stride=1, padding=0, bias=False),
                    nn.BatchNorm2d(num_features=512, momentum=0.9),
                    nn.LeakyReLU(0.2),
                    
                    nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=256, momentum=0.9),
                    nn.LeakyReLU(0.2),
                    
                    nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=128, momentum=0.9),
                    nn.LeakyReLU(0.2),
                    
                    nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=64, momentum=0.9),
                    nn.LeakyReLU(0.2),

                    nn.ConvTranspose2d(64, c, 4, stride=2, padding=1, bias=False),
                    nn.Tanh(),
                )

            def forward(self, x):
                x = x.reshape(*x.size(), 1, 1)
                return self.seq(x)
            
        
        self.generator = Generator()
        self.discriminator = Discriminator()

    
    def forward(self, x):
        return None
    
    def fit(self, loader, epochs, loss_fn, g_optimizer, d_optimizer, device):

        g_loss_metric = MeanMetric().to(device)
        d_loss_metric = MeanMetric().to(device)
        g_acc_metric = BinaryAccuracy().to(device)
        d_acc_metric = BinaryAccuracy().to(device)
        d_acc_real_metric  = BinaryAccuracy().to(device)
        d_acc_fake_metric  = BinaryAccuracy().to(device)


        for epoch in range(epochs):

            # Train

            with trange(math.ceil(len(loader.dataset)/BATCH_SIZE), desc=f'Epoch {epoch + 1}', unit='batch', leave=True) as pbar:
                
                for batch, real_images in enumerate(loader):

                    self.generator.train()
                    self.discriminator.train()

                    random_latent_vectors = torch.randn(size=(len(real_images), self.latent_dim)).to(device)

                    real_images = real_images.to(device)
                    fake_images = self.generator(random_latent_vectors)


                    d_optimizer.zero_grad()

                    real_preds = self.discriminator(real_images)
                    real_labels = torch.ones_like(real_preds)
                    real_labels_noisy = real_labels - NOISE_PARAM * torch.rand_like(real_preds)
                    d_real_loss = loss_fn(real_preds, real_labels_noisy)
                    d_real_loss.backward()

                    fake_preds = self.discriminator(fake_images.detach())
                    fake_labels = torch.zeros_like(fake_preds)
                    fake_labels_noisy = fake_labels + NOISE_PARAM * torch.rand_like(fake_labels)
                    d_fake_loss = loss_fn(fake_preds, fake_labels_noisy)
                    d_fake_loss.backward()

                    d_optimizer.step()


                    g_optimizer.zero_grad()
                    
                    fake_preds = self.discriminator(fake_images)
                    g_loss = loss_fn(fake_preds, real_labels)
                    g_loss.backward()

                    g_optimizer.step()


                    g_loss_metric.update(g_loss)
                    g_acc_metric.update(fake_preds, real_labels)

                    d_loss_metric.update((d_real_loss + d_fake_loss) / 2.0)
                    d_acc_metric.update(real_preds, real_labels)
                    d_acc_metric.update(fake_preds, fake_labels)
                    d_acc_real_metric.update(real_preds, real_labels)
                    d_acc_fake_metric.update(fake_preds, fake_labels)

                    d_loss      = d_loss_metric.compute()
                    d_acc       = d_acc_metric.compute()
                    d_acc_real  = d_acc_real_metric.compute()
                    d_acc_fake  = d_acc_fake_metric.compute()
                    g_loss      = g_loss_metric.compute()
                    g_acc       = g_acc_metric.compute()
                
                    postfix_str = f'd_loss: {d_loss:0.4f}, d_acc: {d_acc:0.4f}, d_acc_real: {d_acc_real:0.4f}, d_acc_fake: {d_acc_fake:0.4f}, g_loss: {g_loss:0.4f}, g_acc: {g_acc:0.4f}'
                    pbar.set_postfix_str(postfix_str)
                    pbar.update()
            

            self.generator.eval()
            with torch.no_grad():
                generate_sample_images(
                    epoch=epoch,
                    generator=self.generator, 
                    device=device, 
                    latent_dim=self.latent_dim, 
                    num_images=10
                )
        



## 3. Train the GAN <a name="train"></a>

In [8]:
# Create a DCGAN
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dcgan = DCGAN(img_size=(CHANNELS, IMAGE_SIZE, IMAGE_SIZE), latent_dim=Z_DIM).to(device)

In [9]:
if LOAD_MODEL:
    dcgan.load_weights("./checkpoint/checkpoint.ckpt")

In [10]:
loss_fn = nn.BCELoss()
g_optimizer = torch.optim.Adam(params=dcgan.generator.parameters(), lr=LEARNING_RATE, betas=[ADAM_BETA_1, ADAM_BETA_2])
d_optimizer = torch.optim.Adam(params=dcgan.discriminator.parameters(), lr=LEARNING_RATE, betas=[ADAM_BETA_1, ADAM_BETA_2])

In [None]:
dcgan.fit(
    train_loader, 
    epochs=EPOCHS, 
    loss_fn=loss_fn, 
    g_optimizer=g_optimizer, 
    d_optimizer=d_optimizer, 
    device=device
)

In [17]:
# torch.save(dcgan.state_dict(), './models/dcgan')

In [None]:
dcgan.load_state_dict(torch.load('./models/dcgan'))

## 3. Generate new images <a name="decode"></a>

In [12]:
def showable(img_tensor):
    trns = transforms.Compose([
        transforms.Normalize(-1, 2),
        transforms.ToPILImage()
    ])
    return trns(img_tensor)

In [13]:
# Sample some points in the latent space, from the standard normal distribution
grid_width, grid_height = (10, 3)
z_sample = torch.randn(size=(grid_width * grid_height, Z_DIM))

In [14]:
# Decode the sampled points
dcgan.generator.eval()
with torch.no_grad():
    reconstructions = dcgan.generator(z_sample.to(device))

In [None]:
# Draw a plot of decoded images
fig = plt.figure(figsize=(18, 5))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

# Output the grid of faces
for i in range(grid_width * grid_height):
    ax = fig.add_subplot(grid_height, grid_width, i + 1)
    ax.axis("off")
    ax.imshow(showable(reconstructions[i]), cmap="Greys")

In [16]:
def compare_images(img1, img2):
    return torch.mean(torch.abs(img1 - img2))

In [None]:
r, c = 3, 5
fig, axs = plt.subplots(r, c, figsize=(10, 6))
fig.suptitle("Generated images", fontsize=20)

noise = torch.randn(size=(r * c, Z_DIM))
dcgan.generator.eval()
with torch.no_grad():
    gen_imgs = dcgan.generator(noise.to(device)).cpu()

cnt = 0
for i in range(r):
    for j in range(c):
        axs[i, j].imshow(showable(gen_imgs[cnt]), cmap="gray_r")
        axs[i, j].axis("off")
        cnt += 1

plt.show()

In [None]:
fig, axs = plt.subplots(r, c, figsize=(10, 6))
fig.suptitle("Closest images in the training set", fontsize=20)

cnt = 0
for i in range(r):
    for j in range(c):
        c_diff = 99999
        c_img = None
        for k_idx, ks in enumerate(train_loader):
            for k in ks:
                diff = compare_images(gen_imgs[cnt], k)
                if diff < c_diff:
                    c_img = torch.clone(k)
                    c_diff = diff
        axs[i, j].imshow(showable(c_img), cmap="gray_r")
        axs[i, j].axis("off")
        cnt += 1

plt.show()