In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import imageio
import matplotlib.pyplot as plt

# Ensure output directories exist
output_dir = 'output'
checkpoint_dir = 'checkpoint'
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

# Parameters
epochs = 1000
batch_size = 64
learning_rate = 0.001
b1 = 0.9
b2 = 0.999
latent_dim = 100
img_size = 64
channels = 3
num_classes = 2  # Number of classes (0: Bart, 1: Lisa)

img_shape = (channels, img_size, img_size)

# Check CUDA's presence
cuda_is_present = True if torch.cuda.is_available() else False

# Define custom dataset class
class CartoonDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        # Define class labels for Bart and Lisa
        class_labels = {'bart_simpson': 0, 'lisa_simpson': 1}
        for sub_dir in class_labels:
            sub_dir_path = os.path.join(root_dir, sub_dir)
            if os.path.isdir(sub_dir_path):  # Ensure it's a directory
                for img_name in os.listdir(sub_dir_path):
                    if img_name.endswith('.jpg') or img_name.endswith('.png'):  # Filter valid image files
                        self.image_paths.append(os.path.join(sub_dir_path, img_name))
                        self.labels.append(class_labels[sub_dir])
                
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# Transformations
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),  # Ensure all images are resized to img_size x img_size
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load dataset
dataset = CartoonDataset('./datasets/simpsons_dataset_archive/simpsons_dataset/', transform=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, latent_dim)

        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim + latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z, labels):
        c = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
        z = z.view(z.size(0), latent_dim, 1, 1)
        z = torch.cat((z, c), 1)
        img = self.model(z)
        return img

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, img_size * img_size)

        self.model = nn.Sequential(
            nn.Conv2d(channels + 1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        c = self.label_emb(labels).view(-1, 1, img_size, img_size)
        img = torch.cat((img, c), 1)
        validity = self.model(img)
        return validity.view(validity.size(0), -1)

# Initialize models and loss function
generator = Generator()
discriminator = Discriminator()
adversarial_loss = torch.nn.BCELoss()

if cuda_is_present:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Optimizers
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(b1, b2))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(b1, b2))

# Training
losses = []
images_for_gif = []
for epoch in range(1, epochs + 1):
    for i, (images, labels) in enumerate(data_loader):
        real_images = Variable(images.type(torch.cuda.FloatTensor if cuda_is_present else torch.FloatTensor))
        labels = Variable(labels.type(torch.cuda.LongTensor if cuda_is_present else torch.LongTensor))
        real_output = Variable(torch.ones(images.size(0), 1).type(torch.cuda.FloatTensor if cuda_is_present else torch.FloatTensor))
        fake_output = Variable(torch.zeros(images.size(0), 1).type(torch.cuda.FloatTensor if cuda_is_present else torch.FloatTensor))

        # Training Generator
        optimizer_generator.zero_grad()
        z = Variable(torch.randn(images.size(0), latent_dim).type(torch.cuda.FloatTensor if cuda_is_present else torch.FloatTensor))
        gen_labels = Variable(torch.randint(0, num_classes, (images.size(0),)).type(torch.cuda.LongTensor if cuda_is_present else torch.LongTensor))
        generated_images = generator(z, gen_labels)
        generator_loss = adversarial_loss(discriminator(generated_images, gen_labels), real_output)
        generator_loss.backward()
        optimizer_generator.step()

        # Training Discriminator
        optimizer_discriminator.zero_grad()
        discriminator_loss_real = adversarial_loss(discriminator(real_images, labels), real_output)
        discriminator_loss_fake = adversarial_loss(discriminator(generated_images.detach(), gen_labels), fake_output)
        discriminator_loss = (discriminator_loss_real + discriminator_loss_fake) / 2
        discriminator_loss.backward()
        optimizer_discriminator.step()

        print(f"[Epoch {epoch:4d}/{epochs}] [Batch {i:4d}/{len(data_loader)}] ---> "
              f"[D Loss: {discriminator_loss.item():.6f}] [G Loss: {generator_loss.item():.6f}]")

    losses.append((generator_loss.item(), discriminator_loss.item()))
    if epoch % 10 == 0:
        image_filename = f'{output_dir}/images/epoch_{epoch}.png'
        os.makedirs(f'{output_dir}/images', exist_ok=True)
        save_image(generated_images.data[:25], image_filename, nrow=5, normalize=True)
        images_for_gif.append(imageio.imread(image_filename))
        # Save model checkpoints
        checkpoint_path = os.path.join(checkpoint_dir, f'generator_epoch_{epoch}.pth')
        torch.save(generator.state_dict(), checkpoint_path)
        checkpoint_path = os.path.join(checkpoint_dir, f'discriminator_epoch_{epoch}.pth')
        torch.save(discriminator.state_dict(), checkpoint_path)

# Visualizing the losses at every epoch
losses = np.array(losses)
plt.plot(losses.T[0], label='Generator')
plt.plot(losses.T[1], label='Discriminator')
plt.title("Training Losses")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f'{output_dir}/loss_plot.png')

# Creating a gif of generated images at every epoch
imageio.mimwrite(f'{output_dir}/generated_images.gif', images_for_gif, fps=len(images_for_gif)/5)
