In [5]:
import os
from PIL import Image
from torch import nn
from torch.utils.data import Dataset
from torchvision import transforms

Part 1: Create Grumpy Cat Dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dir, transform = None):
        self.total_imgs: list[Image.Image] = []
        for file in os.listdir(dir):
            self.total_imgs.append( Image.open(file) )
        self.transform = transform
    
    def __len__(self) -> int:
        return len(self.total_imgs)
    
    def __getitem__(self, idx) -> Image.Image:
        return self.total_imgs[idx]

Part 2: Data Augmentation

In [6]:
def get_transform(mode: str) -> nn.Module:
    if mode == "simple":
        return transforms.Compose([
            transforms.Resize((64, 64), transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    elif mode == "deluxe":
        return transforms.Compose([
            transforms.Resize((64, 64), transforms.InterpolationMode.BICUBIC),
            transforms.RandomGrayscale(0.2),
            transforms.RandomRotation(180),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    raise NotImplementedError

Part 3: Implement the Discriminator of the DCGAN

In [None]:
class Discriminator(nn.Module):
    def __init__(self, conv_dim=64):
        super(Discriminator, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(  3,  32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d( 32,  64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d( 64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
        )