# Imports & Setup

In [None]:
# Standard libraries
import os
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid

# TensorBoard
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter

# Scikit-learn
from sklearn.model_selection import train_test_split

# Kaggle dataset helper (if needed)
import kagglehub  # Make sure this is actually used in your code

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Download latest version
dataset_path = kagglehub.dataset_download("ebrahimelgazar/pixel-art")

print("Path to dataset files:", dataset_path)
csv_path = os.path.join(dataset_path, "labels.csv")
sprites_path = os.path.join(dataset_path, "sprites.npy")
sprite_label_path = os.path.join(dataset_path, "sprites_labels.npy")



# Generator Model

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100, num_classes=5, embed_dim=10):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.label_embedding = nn.Embedding(num_classes, embed_dim)
        input_dim = latent_dim + embed_dim


        self.model = nn.Sequential(
            # Dense -> 4x4x256
            nn.Linear(input_dim, 4 * 4 * 256, bias=False),
            nn.BatchNorm1d(4 * 4 * 256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Reshape to (batch_size, 256, 4, 4)
            View((-1, 256, 4, 4)),
            
            # Upsampling 1: 4x4 -> 8x8
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Upsampling 2: 8x8 -> 16x16
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Output layer: 16x16x3
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=1, padding=2, bias=False),
            nn.Tanh()
        )
        
    def forward(self, noise, labels):
        label_embed = self.label_embedding(labels)  # shape: (batch_size, embed_dim)
        x = torch.cat([noise, label_embed], dim=1)  # concatenate along features
        return self.model(x)


# Helper module to reshape tensors
class View(nn.Module):
    def __init__(self, shape):
        super(View, self).__init__()
        self.shape = shape  # shape: tuple (-1, C, H, W)

    def forward(self, x):
        return x.view(self.shape)


# Example usage
latent_dim = 100
num_classes = 5
batch_size = 8
gen = Generator(latent_dim, num_classes)
z = torch.randn(batch_size, latent_dim)
labels = torch.randint(0, num_classes, (batch_size,))  # random class labels
fake_imgs = gen(z, labels)
print(fake_imgs.shape)  # should be [8, 3, 16, 16]

# Discriminator Model

In [None]:

class Discriminator(nn.Module):
    def __init__(self, num_classes=5, embed_dim=10):
        super(Discriminator, self).__init__()
        self.num_classes = num_classes
        
        # Label embedding
        self.label_embedding = nn.Embedding(num_classes, embed_dim)
        
        # Convolutional layers
        self.model = nn.Sequential(
            # Input: (3 + embed_dim) x 16 x 16 -> 64 x 8 x 8
            nn.Conv2d(3 + embed_dim, 64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            # 8x8x64 -> 4x4x128
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            # 4x4x128 -> 2x2x256
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Flatten(),
            nn.Linear(256 * 2 * 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, labels):
        """
        x: images (batch_size, 3, 16, 16)
        labels: class indices (batch_size,)
        """
        batch_size = x.size(0)
        
        # Embed labels and expand to match spatial size
        label_embed = self.label_embedding(labels)  # (batch_size, embed_dim)
        label_embed = label_embed.unsqueeze(2).unsqueeze(3)  # (batch, embed_dim, 1, 1)
        label_embed = label_embed.expand(batch_size, label_embed.size(1), 16, 16)
        
        # Concatenate image and label embedding
        x = torch.cat([x, label_embed], dim=1)  # (batch, 3 + embed_dim, 16, 16)
        
        return self.model(x)

# Example usage
num_classes = 5
batch_size = 8
disc = Discriminator(num_classes)
x = torch.randn(batch_size, 3, 16, 16)
labels = torch.randint(0, num_classes, (batch_size,))
output = disc(x, labels)
print(output.shape)  # [8, 1]


# Loss Functions

In [None]:
# Binary Cross Entropy Loss
bce_loss = nn.BCELoss()

def discriminator_loss(real_output, fake_output):
    """
    Discriminator loss
    real_output: discriminator predictions on real images
    fake_output: discriminator predictions on fake images
    """
    # Labels
    real_labels = torch.ones_like(real_output)
    fake_labels = torch.zeros_like(fake_output)
    
    real_loss = bce_loss(real_output, real_labels)
    fake_loss = bce_loss(fake_output, fake_labels)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    """
    Generator loss
    fake_output: discriminator predictions on generated images
    """
    labels = torch.ones_like(fake_output)  # Generator wants discriminator to predict 1
    loss = bce_loss(fake_output, labels)
    return loss


# Dataset Class

In [None]:
class PixelArtDataset(Dataset):
    def __init__(self, images_path, labels_path=None):
        # Load images
        self.images = np.load(images_path)
        labels = np.load(labels_path)
        class_labels = np.argmax(labels, axis=1) 
        self.labels = class_labels 
        # Validate images
        if self.images.shape[1:] != (16, 16, 3):
            raise ValueError(f"Images must be 16x16x3, but are {self.images.shape}")
        if self.images.dtype != np.uint8:
            raise TypeError(f"Images must be uint8, but are {self.images.dtype}")

        # Normalize images to [-1, 1] and convert to float32
        self.images = (self.images.astype(np.float32) - 127.5) / 127.5

        # Convert to PyTorch tensors and change shape to (C, H, W)
        self.images = torch.from_numpy(self.images).permute(0, 3, 1, 2)
        
        print(f"Dataset loaded: {len(self.images)} images")
        print(f"Image shape: {self.images.shape}")
        print(f"Image range: [{self.images.min():.3f}, {self.images.max():.3f}]")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.labels is not None:
            return self.images[idx], self.labels[idx]
        return self.images[idx], 0  # return dummy label for compatibility

# Usage example
batch_size = 32
dataset = PixelArtDataset(sprites_path, sprite_label_path)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
class PixelArtGAN:
    def __init__(self, generator, discriminator, generator_optimizer, discriminator_optimizer, latent_dim=100, num_classes=5, writer=None):
        self.generator = generator.to(DEVICE)
        self.discriminator = discriminator.to(DEVICE)
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.discriminator_optimizer = discriminator_optimizer
        self.generator_optimizer = generator_optimizer
        self.writer = writer
        
        # Fixed noise for monitoring
        self.seed = torch.randn(num_classes, latent_dim, device=DEVICE)

        # Fixed labels: one-hot encoding for each class
        self.seed_labels = torch.arange(num_classes, device=DEVICE)


    def train_step(self, real_images, labels):
        batch_size = real_images.size(0)
        real_images = real_images.to(DEVICE)
        labels = labels.to(DEVICE)
        
        
        # Generate noise
        noise = torch.randn(batch_size, self.latent_dim, device=DEVICE)
        
        # Train Generator
        self.generator_optimizer.zero_grad()
        generated_images = self.generator(noise, labels)
        fake_output = self.discriminator(generated_images, labels)
        g_loss = bce_loss(fake_output, torch.ones_like(fake_output))
        g_loss.backward()
        self.generator_optimizer.step()
        
        # Train Discriminator
        self.discriminator_optimizer.zero_grad()
        real_output = self.discriminator(real_images, labels)
        fake_output_detached = self.discriminator(generated_images.detach(), labels)
        real_loss = bce_loss(real_output, torch.ones_like(real_output))
        fake_loss = bce_loss(fake_output_detached, torch.zeros_like(fake_output_detached))
        d_loss = real_loss + fake_loss
        d_loss.backward()
        self.discriminator_optimizer.step()
        
        return g_loss.item(), d_loss.item()
    
    def generate_and_show_images(self, epoch):
        """Generate images and log them to TensorBoard (one per class)"""
        self.generator.eval()
        with torch.no_grad():
            preds = self.generator(self.seed, self.seed_labels).cpu()  # generated images

        # Rescale from [-1,1] -> [0,1]
        preds = (preds + 1) / 2
        preds = torch.clamp(preds, 0, 1)

        if self.writer is not None:
            # Make a grid of images
            grid = torchvision.utils.make_grid(preds, nrow=self.num_classes)
            self.writer.add_image("Generated Images", grid, epoch)
        
        self.generator.train()

In [None]:
def save_checkpoint(generator, discriminator, g_optimizer, d_optimizer, epoch, path):
    checkpoint = {
        "generator": generator.state_dict(),
        "discriminator": discriminator.state_dict(),
        "g_optimizer": g_optimizer.state_dict(),
        "d_optimizer": d_optimizer.state_dict(),
        "epoch": epoch
    }
    torch.save(checkpoint, path)
    print(f"✅ Checkpoint saved at {path}")


# Training Function

In [None]:
def train_gan(
    gan,
    dataloader,
    epochs,
    save_interval=10,
    checkpoint_dir="checkpoints",
    writer=None
):
    """
    Train GAN model
    
    Args:
        gan: PixelArtGAN instance
        dataloader: torch DataLoader with (images, labels)
        epochs: int, number of epochs
        generator_optimizer: torch.optim optimizer for generator
        discriminator_optimizer: torch.optim optimizer for discriminator
        save_interval: int, save checkpoint every N epochs
        checkpoint_dir: str, path to save checkpoints
        writer: TensorBoard SummaryWriter (optional)
    """
    
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    gen_losses, disc_losses = [], []
    print("🚀 Starting training...")

    for epoch in range(epochs):
        epoch_g_loss, epoch_d_loss = [], []

        loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for images, labels in loop:
            g_loss, d_loss = gan.train_step(images, labels)
            epoch_g_loss.append(g_loss)
            epoch_d_loss.append(d_loss)
            loop.set_postfix({'g_loss': g_loss, 'd_loss': d_loss})

        avg_g_loss = sum(epoch_g_loss) / len(epoch_g_loss)
        avg_d_loss = sum(epoch_d_loss) / len(epoch_d_loss)
        gen_losses.append(avg_g_loss)
        disc_losses.append(avg_d_loss)

        # Log to TensorBoard
        if writer is not None:
            writer.add_scalar("Generator Loss", avg_g_loss, epoch+1)
            writer.add_scalar("Discriminator Loss", avg_d_loss, epoch+1)
            gan.generate_and_show_images(epoch+1)
            

        # Print and generate images
        print(f"Epoch {epoch+1}/{epochs} - Gen Loss: {avg_g_loss:.4f}, Disc Loss: {avg_d_loss:.4f}")
        

        # Save checkpoints periodically
        if (epoch + 1) % save_interval == 0:
            save_path = os.path.join(checkpoint_dir, f"epoch_{epoch+1}.pth")
            save_checkpoint(
                gan.generator,
                gan.discriminator,
                gan.generator_optimizer,
                gan.discriminator_optimizer,
                epoch+1,
                save_path
            )

    return gen_losses, disc_losses


In [None]:
# GAN hyperparameters
LATENT_DIM = 100
LEARNING_RATE = 0.0002
BETA_1 = 0.5
EPOCHS = 200
BATCH_SIZE = 32


In [None]:
# TensorBoard writer
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
run_name = f"pixel_art_gan_{current_time}"

# Set the writer with a custom event name
writer = SummaryWriter(log_dir=f"runs/{run_name}")

# Optimizers
g_optimizer = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
d_optimizer = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

# Initialize GAN wrapper
gan = PixelArtGAN(
    generator=gen,
    discriminator=disc,
    generator_optimizer=g_optimizer,
    discriminator_optimizer=d_optimizer,
    latent_dim=LATENT_DIM,
    num_classes=5,
    writer=writer,
)

# Train GAN
gen_losses, disc_losses = train_gan(
    gan=gan,
    dataloader=dataloader,
    epochs=EPOCHS,
    save_interval=10,
    checkpoint_dir=f"checkpoints/GAN_{current_time}",
    writer=writer,
)

writer.close()