# This Cat Doesn't Exist
## Generative Adversarial Networks Version

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import IPython.display as ipd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import time
import glob
from PIL import Image

# Dataset Definitions

In [None]:
###### RUN THIS CODE ONLY IF YOU ARE ON GOOGLE COLAB! ########
###### OTHERWISE, YOU SHOULD DELETE THIS CELL!!       ########
## !wget https://github.com/ursinus-cs477-f2023/HW7_GAN/archive/refs/heads/main.zip
## !unzip main.zip
## !mv HW7_GAN-main/cats .
##############################################################

In [None]:
def plot_samples(XGen):
    """
    Plot a set of examples from the dataset/generator on a square grid
    
    Parameters
    ----------
    XGen: torch.tensor(n_examples, 3, dim, dim)
        A batch of examples to plot
    """
    k = int(np.sqrt(XGen.shape[0]))
    for i in range(XGen.shape[0]):
        plt.subplot(k, k, i+1)
        Xi = XGen[i, :, :, :].detach().cpu().numpy()
        Xi = np.moveaxis(Xi, 0, 2)
        Xi[Xi < 0] = 0
        Xi[Xi > 1] = 1
        plt.imshow(Xi)
        plt.axis("off")

class CatData(Dataset):
    def __init__(self, foldername, imgres=64, train=True):
        self.images = glob.glob("{}/*".format(foldername))
        self.preprocess = transforms.Compose([
            transforms.Resize(imgres),
            transforms.ToTensor()
        ])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = Image.open(self.images[idx])
        img = self.preprocess(img)
        if torch.max(img) > 1:
            img /= 255
        return img
    
traindata = CatData("cats")

samples = DataLoader(traindata, batch_size=16, shuffle=True)
plot_samples(next(iter(samples)))

# Model Definitions

In [None]:
class Discriminator(nn.Module):
    def __init__(self, depth=4, dim_latent=64, dim_img=64, in_channels=3, start_channels=64):
        """
        depth: int
            How many convolutional layers there are in the encoder/decoder
        dim_latent: int
            Dimension of the flattened latent space
        dim_digit: int
            Width/height of input image
        in_channels: int
            Number of channels of input image
        start_channels: int
            Number of channels out of the first convolutional layer
        """
        super().__init__()
        self.depth = depth
        self.dim_latent = dim_latent
        self.dim_img = dim_img
        self.in_channels = in_channels
        self.start_channels = start_channels
        
        layers = []
        channels = start_channels
        for _ in range(depth):
            layers.append(nn.Conv2d(in_channels, channels, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.BatchNorm2d(channels))
            layers.append(nn.LeakyReLU(0.2))
            in_channels = channels
            channels *= 2
        
        self.conv_layers = nn.Sequential(*layers)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(in_channels * (dim_img // (2 ** depth)) ** 2, 1)
    
    def forward(self, X):
        X = self.conv_layers(X)
        X = self.flatten(X)
        X = self.linear(X)
        return X
        

class Generator(nn.Module):
    def __init__(self, depth=4, dim_latent=64, dim_img=64, in_channels=3, end_channels=64):
        """
        depth: int
            How many convolutional layers there are in the encoder/decoder
        dim_latent: int
            Dimension of the latent space
        dim_digit: int
            Width/height of input image
        in_channels: int
            Number of channels of input image
        end_channels: int
            Number of channels out of the second to last convolutional layer
        """
        super().__init__()
        imgres = dim_img//(2*depth) 
        in_channels = end_channels*(2*(depth-1)) 
        shape_latent = (in_channels, imgres, imgres)
        
        self.linear = nn.Sequential(
        nn.Linear(dim_latent, np.prod(shape_latent)),
        nn.Unflatten(1, shape_latent),
        nn.LeakyReLU(0.2)
        )

        layers = []
        channels = in_channels
        for _ in range(depth):
            layers.append(nn.ConvTranspose2d(channels, channels//2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.BatchNorm2d(channels//2))
            if channels > 3:
                layers.append(nn.LeakyReLU(0.2))
            else:
                layers.append(nn.Sigmoid())
            channels = channels // 2

        self.conv_layers = nn.Sequential(*layers)

    def forward(self, z):
        z = self.linear(z)
        z = self.conv_layers(z)
        return z

    def sample(self, n_examples, device):
        """
        Sample from the latent space and generate the images

        Parameters
        ----------
        n_examples: int
            Number of examples to generate
        device: string
            Device for model/tensors

        Returns
        -------
        torch.tensor(n_examples, 3, dim_img, dim_img)
            A batch of generated examples
        """
        z = torch.randn(n_examples, dim_latent).to(device)
        generated_images = self.forward(z)
        return generated_images
    

## TODO: Test your models here with dummy data before you proceed to training

In [None]:
discriminator = Discriminator(depth=4, dim_latent=64, dim_img=64, in_channels=3, start_channels=64)
X = torch.zeros(16, 3, 64, 64)
discriminator(X)

In [None]:
generator = Generator(depth=4, dim_latent=64, dim_img=64, in_channels=3, end_channels=64)
z = torch.randn(16, 64)
XEst = generator(z)
plot_samples(XEst)

# Training

In [None]:
device = 'cpu'
depth = 2
dim_latent = 32
channels = 32
lr = 3e-4

n_epochs = 40
batch_size = 16

## TODO: Fill this in

# Generate fake images
noise = torch.randn(batch_size, dim_latent, 1, 1)
fake_images = generator(noise)

# Pass fake images through the discriminator
fake_outputs = discriminator(fake_images)

# Compute the generator loss
generator_loss = F.binary_cross_entropy_with_logits(fake_outputs, torch.ones_like(fake_outputs))

# Compute gradients and take a step in the generator optimizer
generator_loss.backward()
generator_optimizer.step()

# Zero out the discriminator's gradients
discriminator_optimizer.zero_grad()

# Pass real images through the discriminator
real_outputs = discriminator(real_images)

# Compute the discriminator loss for real images
real_loss = F.binary_cross_entropy_with_logits(real_outputs, torch.ones_like(real_outputs))

# Pass fake images through the discriminator
fake_outputs = discriminator(fake_images.detach())

# Compute the discriminator loss for fake images
fake_loss = F.binary_cross_entropy_with_logits(fake_outputs, torch.zeros_like(fake_outputs))

# Compute the total discriminator loss
discriminator_loss = real_loss + fake_loss

# Compute gradients and take a step in the discriminator optimizer
discriminator_loss.backward()
discriminator_optimizer.step()

# Output generated examples at the end of each epoch
plot_samples(generator.sample(16, device))