In [None]:
import os
from PIL import Image
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [None]:
# HYPERPARAMETERS
LEARNING_RATE = 0.001
NUM_EPOCHS = 100

Part 1: Create Grumpy Cat Dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dir, transform = None):
        self.total_imgs: list[torch.Tensor] = []
        for file in os.listdir(dir):
            self.total_imgs.append( Image.open(dir + '/' + file) )
        self.transform = transform
    
    def __len__(self) -> int:
        return len(self.total_imgs)
    
    def __getitem__(self, idx) -> Image.Image:
        return self.transform(self.total_imgs[idx])
dataset = CustomDataset("grumpifyCat")
dataloader = DataLoader(dataset)

Part 2: Data Augmentation

In [None]:
def get_transform(mode: str) -> nn.Module:
    if mode == "simple":
        return transforms.Compose([
            transforms.Resize((64, 64), transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    elif mode == "deluxe":
        return transforms.Compose([
            transforms.Resize((64, 64), transforms.InterpolationMode.BICUBIC),
            transforms.RandomGrayscale(0.2),
            transforms.RandomRotation(180),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    raise NotImplementedError

Part 3: Implement the Discriminator of the DCGAN

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # size_out = (size_in + 2 * padding - kernel) / stride + 1
        self.layers = nn.Sequential(
            # INPUT 3x64x64
            nn.Conv2d(  3,  32, kernel_size=4, stride=2, padding=1), # 32x32x32
            nn.LazyBatchNorm2d(),
            nn.ReLU(),
            nn.Conv2d( 32,  64, kernel_size=4, stride=2, padding=1), # 64x16x16
            nn.LazyBatchNorm2d(),
            nn.ReLU(),
            nn.Conv2d( 64, 128, kernel_size=4, stride=2, padding=1), # 128x8x8
            nn.LazyBatchNorm2d(),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 256x4x4
            nn.LazyBatchNorm2d(),
            nn.ReLU(),
            nn.Conv2d(256, 1, kernel_size=6, stride=2, padding=1), # 1x1x1
        )
    
    def forward(self, z) -> torch.Tensor:
        return self.layers(z).squeeze()

Part 4: Implement the Generator of the DCGAN

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_size=100):
        super(Generator, self).__init__()
        # size_out = (size_in - 1) * stride - 2 * padding + kernel
        self.layers = nn.Sequential(
            # INPUT nsx1x1
            nn.ConvTranspose2d(noise_size, 256, kernel_size=6, stride=2, padding=1),    # 256x4x4
            nn.LazyBatchNorm2d(),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),           # 128x8x8
            nn.LazyBatchNorm2d(),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),            # 64x16x16
            nn.LazyBatchNorm2d(),
            nn.ReLU(),
            nn.ConvTranspose2d( 64, 32, kernel_size=4, stride=2, padding=1),            # 32x32x32
            nn.LazyBatchNorm2d(),
            nn.ReLU(),
            nn.ConvTranspose2d( 32,  3, kernel_size=4, stride=2, padding=1),            # 3x64x64
            nn.Tanh(),
        )
    
    def forward(self, z) -> torch.Tensor:
        return self.layers(z)

Part 5: Training Loop

In [None]:
def training_loop(dataloader: DataLoader):
    # Create generators and discriminators
    generator = Generator()
    discriminator = Discriminator()

    # Create optimizers for the generators and discriminators
    g_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE)
    d_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE)

    for epoch in range(NUM_EPOCHS):
        iteration = 0
        for data in dataloader:
            # Load real images
            real_images = data[0]
            
            # 
            # TRAIN DISCRIMINATOR
            # 
            
            d_optimizer.zero_grad()

            # 1. Compute the discriminator loss on real images
            D_real_pred = discriminator(real_images)
            D_real_loss = torch.full(D_real_pred.shape, 1) - D_real_pred

            # 2. Sample noise
            noise = sample_noise(opts.noise_size)

            # 3. Generate fake images from the noise
            fake_images = generator(noise)

            # 4. Compute the discriminator loss on the fake images
            D_fake_pred = discriminator(fake_images)
            D_fake_loss = torch.full(D_fake_pred.shape, -1) - D_fake_pred

            # 5. Compute total loss
            D_total_loss = D_real_loss + D_fake_loss
            if iteration % 2 == 0:
                D_total_loss.backward()
                d_optimizier.step()

            # 
            # TRAIN GENERATOR
            # 

            g_optimizer.zero_grad()

            # 1. Sample noise
            noise = sample_noise(opts.noise_size)

            # 2. Generate fake images from the noise
            fake_images = generator(noise)

            # 3. Compute the generator loss
            D_gen_pred = discriminator(fake_images)
            D_gen_loss = torch.full(D_gen_pred.shape, -1) - D_gen_pred
            G_loss = -D_gen_loss

            G_loss.backward()
            g_optimizer.step()

            # Print each loss every 200 iterations
            if iteration % 200 == 0:
                print(f"Loss - Discriminator: {D_total_loss:<8.2f} Generator: {G_loss:<8.2f}")

training_loop(dataloader)