In [None]:
import os
import imageio
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms
from torchvision.utils import save_image

# Ensure the output directories exist
output_dir = 'output'
checkpoint_dir = 'checkpoint'
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)


# Note
# ld & lg = 0.003 are very blur
# ld & lg = 0.0003 are blur

# Parameters
epochs = 300
batch_size = 64
learning_rate_d = 0.0004

learning_rate_g = 0.004
b1 = 0.5
b2 = 0.999
latent_dim = 100
img_size = 64  # Adjust based on your dataset
channels = 3  # Adjust based on your dataset

img_shape = (channels, img_size, img_size)

# Check CUDA's presence
cuda_is_present = True if torch.cuda.is_available() else False

class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        def layer_block(input_size, output_size, normalize=True):
            layers = [nn.Linear(input_size, output_size)]
            if normalize:
                layers.append(nn.BatchNorm1d(output_size, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *layer_block(latent_dim, 128, normalize=False),
            *layer_block(128, 256),
            *layer_block(256, 512),
            *layer_block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        verdict = self.model(img_flat)
        return verdict

# Initialize models and loss function
generator = Generator()
discriminator = Discriminator()
adversarial_loss = torch.nn.BCELoss()

if cuda_is_present:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Custom dataset for cartoon images
class CartoonDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.file_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.jpg')]
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img = Image.open(self.file_paths[idx])
        if self.transform:
            img = self.transform(img)
        return img

# Define transformations
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load cartoon dataset
dataset_path = '/Users/chutszkan/Desktop/Files/HKU/COMP7502/Test/datasets/simpsons_dataset_archive/simpsons_dataset/homer_simpson/'
data_loader = DataLoader(
    CartoonDataset(folder_path=dataset_path, transform=transform),
    batch_size=batch_size, shuffle=True
)

Tensor = torch.cuda.FloatTensor if cuda_is_present else torch.FloatTensor

optimizer_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate_g, betas=(b1, b2))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate_d, betas=(b1, b2))

losses = []
images_for_gif = []
for epoch in range(1, epochs + 1):
    for i, images in enumerate(data_loader):
        real_images = Variable(images.type(Tensor))
        real_output = Variable(Tensor(images.size(0), 1).fill_(1.0), requires_grad=False)
        fake_output = Variable(Tensor(images.size(0), 1).fill_(0.0), requires_grad=False)

        # Training Generator
        optimizer_generator.zero_grad()
        z = Variable(Tensor(np.random.normal(0, 1, (images.shape[0], latent_dim))))
        generated_images = generator(z)
        generator_loss = adversarial_loss(discriminator(generated_images), real_output)
        generator_loss.backward()
        optimizer_generator.step()

        # Training Discriminator
        optimizer_discriminator.zero_grad()
        discriminator_loss_real = adversarial_loss(discriminator(real_images), real_output)
        discriminator_loss_fake = adversarial_loss(discriminator(generated_images.detach()), fake_output)
        discriminator_loss = (discriminator_loss_real + discriminator_loss_fake) / 2
        discriminator_loss.backward()
        optimizer_discriminator.step()

        print(f"[Epoch {epoch:=4d}/{epochs}] [Batch {i:=4d}/{len(data_loader)}] ---> "
              f"[D Loss: {discriminator_loss.item():.6f}] [G Loss: {generator_loss.item():.6f}]")

    losses.append((generator_loss.item(), discriminator_loss.item()))
    if epoch % 10 == 0:
        image_filename = f'{output_dir}/images/epoch_{epoch}.png'
        os.makedirs(f'{output_dir}/images', exist_ok=True)
        save_image(generated_images.data[:25], image_filename, nrow=5, normalize=True)
        images_for_gif.append(imageio.imread(image_filename))
        # Save model checkpoints
        checkpoint_path = os.path.join(checkpoint_dir, f'generator_epoch_{epoch}.pth')
        torch.save(generator.state_dict(), checkpoint_path)
        checkpoint_path = os.path.join(checkpoint_dir, f'discriminator_epoch_{epoch}.pth')
        torch.save(discriminator.state_dict(), checkpoint_path)

# Visualizing the losses at every epoch
losses = np.array(losses)
plt.plot(losses.T[0], label='Generator')
plt.plot(losses.T[1], label='Discriminator')
plt.title("Training Losses")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f'{output_dir}/loss_plot.png')

# Creating a gif of generated images at every epoch
imageio.mimwrite(f'{output_dir}/generated_images.gif', images_for_gif, fps=len(images_for_gif)/5)

# Load the final trained models for generating images
generator.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'generator_epoch_{epochs}.pth')))
generator.eval()

# Generate some images
num_images = 10
z = Variable(Tensor(np.random.normal(0, 1, (num_images, latent_dim))))
generated_images = generator(z)

# Save generated images
generated_image_filename = f'{output_dir}/final_generated_images.png'
save_image(generated_images.data, generated_image_filename, nrow=num_images, normalize=True)
print(f'Generated images saved to {generated_image_filename}')
