<strong> Import all libraires that will be used </strong>

In [None]:
import os
import time
import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision.utils as vutils
from torch.utils.data import DataLoader
from torchvision import transforms
from torchmetrics.image.fid import FrechetInceptionDistance

from tqdm.notebook import tqdm

from models import AttentionUNetDiscriminator, AttentionUNetGenerator
from utils import ImageDataset

<strong> Initializing the paintings dataset </strong>

In [None]:
# Building the dataset
image_dir = '../data/anime/images_2'
image_size = 64
f = os.listdir(image_dir)[0]
print(f)
print(os.path.isfile(os.path.join(image_dir, f)))

# Transformations to normalize the data before dataloader
transform = transforms.Compose([
    transforms.Resize((image_size,image_size)), # Standardizing the size of the images
    transforms.ToTensor(), # Transforming to tensor
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) # Normalizing
])


# Initializing the dataset
training_dataset = ImageDataset(image_dir, transform, limit=10000)
print(f"Dataset contains {len(training_dataset)} images")

<strong> Sampling an element from the dataset and plotting it </strong>

In [None]:
# Sampling randomly an element from the dataset
n = len(training_dataset)
integer = random.randint(0,n)

# Sampled image
image = training_dataset[integer].numpy()*0.5 + 0.5 # De-normalizing the image

# Plot the image
plt.figure(figsize=(3, 2))
plt.imshow(np.transpose(image, (1, 2, 0)))  
plt.axis('off')  # Hide axes
plt.title('Sample Painting')
plt.show()

<strong> Intializing the Dataloader </strong>

In [None]:
batch_size = 64

<strong> Initializing the parameters of the model </strong>

In [None]:
# Model's parameters
latent_dim = 100
channels_out = 3
channels_in = 3

# Intializing the models
G = AttentionUNetGenerator(latent_dim, channels_out)
D = AttentionUNetDiscriminator(3)

<strong> Sampling a vector to plot the fake image generated by the Generator </strong>

In [None]:
# Generate random noise
latent_dim = 100  # Latent space dimension

# Generate random noise
noise = torch.randn(16, latent_dim, image_size, image_size)  # Noise input for generator
# Generate images
fake_images = G(noise)  # Output shape: [16, 3, 128, 128]

# De-normalize and reshape the first generated image
image_generated = fake_images[0].detach().cpu().numpy()  # Select first image in batch
image_generated = image_generated * 0.5 + 0.5  # De-normalize to [0, 1]

# Plot the image
plt.imshow(np.transpose(image_generated, (1, 2, 0)))  # Convert to (H, W, C)
plt.axis('off')  # Hide axes
plt.title('Sample Generated Image')
plt.show()

<strong> Let build the training loop</strong>

In [None]:
import torch.nn.functional as F
from torchvision.models import resnet18, ResNet18_Weights

def diversity_loss(fake_images, feature_extractor):
    # Extract features using a pretrained model
    features = feature_extractor(fake_images)  # Shape: (batch_size, feature_dim)

    # Normalize features to unit vectors
    features = F.normalize(features, p=2, dim=1)

    # Compute pairwise cosine similarity
    similarity_matrix = torch.matmul(features, features.T)  # Shape: (batch_size, batch_size)

    # Remove diagonal elements (self-similarity)
    batch_size = similarity_matrix.size(0)
    mask = torch.eye(batch_size, device=similarity_matrix.device).bool()
    diversity_penalty = similarity_matrix.masked_fill(mask, 0).mean()

    return diversity_penalty

In [None]:
def gradient_penalty(D, real, fake,device='cuda'):
    # Compute random weight for interpolation
    alpha = torch.rand(real.size(0), 1, 1, 1).to(device)

    # Interpolate real and fake images
    interpolated = (alpha * real + (1 - alpha) * fake).requires_grad_()

    # Compute output of the critic
    prob_interpolated = D(interpolated)

    # Compute gradients of the critic with respect to the interpolated images
    gradients = torch.autograd.grad(outputs=prob_interpolated, inputs=interpolated,
                                    grad_outputs=torch.ones_like(prob_interpolated),
                                    create_graph=True, retain_graph=True)[0]

    # Compute gradient penalty
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(p=2, dim=1) - 1) ** 2).mean()

    return gradient_penalty


In [None]:

torch.autograd.set_detect_anomaly(True)

loss_options = ["w","bce"]
loss_option = loss_options[0]

batch_size = 128
dataloader = DataLoader(training_dataset, batch_size=batch_size)

image_size = 64
experiment = 3

output_dir = f"generated_samples/wgan/experiment_{experiment}"
os.makedirs(output_dir, exist_ok=True)

# Parameters
input_channels = 3
channels_out = input_channels
n_classes = 2
k = 5  # number of critic updates per generator update
latent_dim = 128 # Dimension of latent space
epochs = 1000 # Number of epochs
lambda_gp = 10  # Gradient penalty weight
lambda_div = 5  # Diversity loss weight

d_lr = 1e-4 if loss_option == "w" else 2e-4  # Base learning rate
g_lr = 2e-4 if loss_option == "w" else 2e-4  # Base learning rate


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# feature extractor to compute diversity loss
feature_extractor = resnet18(weights=ResNet18_Weights.DEFAULT)
feature_extractor.eval().to(device)

# Initialize models (Use Attention U-Net GAN models)
G_attention = AttentionUNetGenerator(latent_dim, channels_out).to(device)  # Latent space maps to input

if loss_option == "w":
    D_attention = AttentionUNetDiscriminator(input_channels, features=64, use_sigmoid=False).to(device)  # Wasserstein discriminator
elif loss_option == "bce":
    D_attention = AttentionUNetDiscriminator(input_channels, features=64).to(device)  # PatchGAN discriminator

# Loss function and optimizers
if loss_option == "bce":
    criterion = torch.nn.BCELoss()  # Binary Cross-Entropy for GANs

optimizer_g = torch.optim.Adam(G_attention.parameters(), lr=g_lr, betas=(0.0, 0.9) if loss_option == "w" else (0.5, 0.999))
optimizer_d = torch.optim.Adam(D_attention.parameters(), lr=d_lr, betas=(0.0, 0.9) if loss_option == "w" else (0.5, 0.999))

# FID Evaluation setup
epoch_fid = 100  # Evaluate FID every `epoch_eval` epochs
epoch_sampling = 10 # Save generated samples every `epoch_sampling` epochs


fid = FrechetInceptionDistance(feature=192, reset_real_features=False, normalize=True)
n_samples = 1000

# Ensure your dataset of real paintings is available
N = len(training_dataset)  # Assuming `paintings` is a tensor of real images

# Batch update the FID score
indices = random.sample(range(N), n_samples)
real_images_eval = torch.stack([training_dataset[idx] for idx in indices])
for i in range(0, n_samples, batch_size):
    real_images_chunk = real_images_eval[i:i+batch_size].to('cpu')
    fid.update(real_images_chunk, real=True)
del real_images_eval

# Track losses and FID values
FID_values = []
D_loss = []
G_loss = []

# Training loop
for epoch in tqdm(range(1, epochs+1)):

    start_time = time.time()
    
    generator_count = 0
    generator_loss = 0
    
    critic_count = 0
    critic_loss = 0
    
    for batch in tqdm(dataloader):
        ################################################################
        # 1) CRITIC (DISCRIMINATOR) UPDATES 
        ################################################################
        
        # Get next batch of real images
        real_images = batch.to(device)
        curr_batch_size = real_images.size(0)

        # Possibly create BCE labels if using BCE
        if loss_option == "bce":
            
            labels_real = torch.full((curr_batch_size, 1), 0.9, device=device)
            labels_fake = torch.full((curr_batch_size, 1), 0.1, device=device)

        # Zero grad for D
        optimizer_d.zero_grad()

        # Optionally add small noise to real images
        # real_noisy = real_images + torch.randn_like(real_images) * 0.01

        # ----- 1(a): Discriminator loss on real
        if loss_option == "w":
            output_real = D_attention(real_images).view(-1)
            loss_d_real = -output_real.mean()
        else:  # BCE
            output_real = D_attention(real_images).view(-1, 1)
            loss_d_real = criterion(output_real, labels_real)

        # ----- 1(b): Discriminator loss on fake
        noise = torch.randn(curr_batch_size, latent_dim, image_size, image_size, device=device)
        fake_images = G_attention(noise).detach()  # detach so G is not updated here

        # fake_noisy = fake_images + torch.randn_like(fake_images) * 0.01

        if loss_option == "w":
            output_fake = D_attention(fake_images).view(-1)
            loss_d_fake = output_fake.mean()
        else:  # BCE
            output_fake = D_attention(fake_images).view(-1, 1)
            loss_d_fake = criterion(output_fake, labels_fake)

        # ----- 1(c): Gradient Penalty
        gp = gradient_penalty(D_attention, real_images, fake_images, device=device) 
        # gradient_penalty(...) is assumed to return scalar

        # Combine D losses
        loss_d = loss_d_real + loss_d_fake + lambda_gp * gp
        
        # Backprop and step
        loss_d.backward()
        optimizer_d.step()
        
        # Track critic loss
        critic_loss += loss_d.item()
        critic_count += 1


        ################################################################
        # 2) GENERATOR UPDATE (every k steps of critic)
        ################################################################
        
        if critic_count % k == 0: # Only update G every k steps
            # Freeze D's parameters
            for p in D_attention.parameters():
                p.requires_grad = False

            # We'll generate a new batch of noise/images
            noise = torch.randn(batch_size, latent_dim, image_size, image_size, device=device)
            fake_images = G_attention(noise)
            
            # Zero grad for G
            optimizer_g.zero_grad()

            if loss_option == "w":
                output_fake_for_gen = D_attention(fake_images).view(-1)
                loss_g = -output_fake_for_gen.mean()
            else:  # BCE
                labels_real = torch.full((batch_size, 1), 0.9, device=device)
                output_fake_for_gen = D_attention(fake_images).view(-1, 1)
                loss_g = criterion(output_fake_for_gen, labels_real)

            # Diversity penalty (optional)
            diversity_pen = diversity_loss(fake_images, feature_extractor)
            loss_g += lambda_div * diversity_pen

            # Backprop and step
            loss_g.backward()
            optimizer_g.step()

            # Re-enable D grads for next iteration
            for p in D_attention.parameters():
                p.requires_grad = True
                
            # Track generator loss
            generator_loss += loss_g.item()
            generator_count += 1


    # Evaluate FID every `epoch_eval` epochs
    if epoch % epoch_fid == 0:
        G_attention.eval()  # Set generator to eval mode for FID computation
        with torch.no_grad():
        # Generate evaluation images
            noise_eval = torch.randn(n_samples, latent_dim, image_size, image_size, device=device)
            for i in range(0, n_samples, batch_size):
                fake_images_chunk = G_attention(noise_eval[i:i+batch_size]).to('cpu')
                fid.update(fake_images_chunk, real=False)

            fid_value = fid.compute().item()
            FID_values.append(fid_value)

            print(f'FID after epoch {epoch}: {fid_value}')

        torch.cuda.empty_cache()
        fid.reset()
        G_attention.train()
        
    if epoch % epoch_sampling == 0:
        G_attention.eval()  # Set generator to eval mode for FID computation
        
        # Generate evaluation images
        with torch.no_grad():
            noise_eval = torch.randn(64, latent_dim, image_size, image_size, device=device)
            fake_images_eval = G_attention(noise_eval).to(torch.device('cpu'))
        
        # Save a grid of generated samples
        grid = vutils.make_grid(fake_images_eval, normalize=True, scale_each=True)
        sample_path = os.path.join(output_dir, f"epoch_{epoch:03d}.png")
        vutils.save_image(grid, sample_path)
        print(f"Sample images saved to {sample_path}")
        
        del fake_images_eval
        torch.cuda.empty_cache()
        G_attention.train()
        
    # Track losses
    
    generator_loss = generator_loss/generator_count
    critic_loss = critic_loss/critic_count
    
    D_loss.append(critic_loss)
    G_loss.append(generator_loss)


    end_time = time.time()
    
    print(f'Epoch [{epoch }/{epochs}] | Loss D: {critic_loss} | Loss G: {generator_loss} | Wasserstein Distance: { - critic_loss - generator_loss}')
    print(f'generator_count: {generator_count}, critic_count: {critic_count}')
    print(f'Epoch duration: {end_time - start_time:.2f}s')
    


In [None]:
model_dir = f"weights/wgan/experiment_{experiment}"
os.makedirs(model_dir, exist_ok=True)

torch.save(D_attention.state_dict(), f"{model_dir}/d_att.pth")
torch.save(G_attention.state_dict(), f"{model_dir}/g_att.pth")

In [None]:
experiment = 3
latent_dim = 128
channels_out = 3
G_unet_test = AttentionUNetGenerator(latent_dim, channels_out)
G_unet_test.load_state_dict(torch.load(f"{model_dir}/g_att.pth"))

<strong> Generating some examples using the trained generator </strong>

In [None]:
image_size = 64
noise_eval = torch.randn(64, latent_dim, image_size, image_size, device=torch.device('cpu'))
G_unet_test.eval()  # Set generator to eval mode for sampling
fake_images_eval = G_unet_test(noise_eval).to(torch.device('cpu'))

# Save a grid of generated samples
grid = vutils.make_grid(fake_images_eval, normalize=True, scale_each=True)
np_grid = grid.permute(1, 2, 0).numpy()

# Display the images
plt.figure(figsize=(8, 8))
plt.imshow(np_grid)
plt.axis('off')  # Turn off axis
plt.show()

<strong> Plotting Losses <strong>

In [None]:
# Assuming G_loss and D_loss are defined
ng = len(G_loss)
nd = len(D_loss)
time_steps_g = [i for i in range(ng)]
time_steps_d = [i for i in range(nd)]


plt.figure(figsize=(10, 6))

# Plot Generator Loss
plt.plot(time_steps_g, G_loss, label='Generator Loss', color='darkorange', linestyle='-', linewidth=2)

# Plot Discriminator Loss
plt.plot(time_steps_d, D_loss, label='Discriminator Loss', color='royalblue', linestyle='--', linewidth=2)

# Add grid, labels, and title
plt.grid(True, linestyle='--', alpha=0.6)
plt.xlabel('Time Steps', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Loss Evolution of Generator and Discriminator', fontsize=14)

# Adding legend
plt.legend(loc='upper right', fontsize=11)

# Show plot
plt.show()

<strong> Plotting FID values <strong>

In [None]:
n = len(FID_values)
time_steps = [i for i in range(n)]

plt.figure(figsize=(10, 6))

# Plot FID values
plt.plot(time_steps, FID_values, label='Generator FID', color='darkorange', linestyle='-', linewidth=2)


# Add grid, labels, and title
plt.grid(True, linestyle='--', alpha=0.6)
plt.xlabel('Time Steps', fontsize=12)
plt.ylabel('FID', fontsize=12)
plt.title('FID Evolution through training', fontsize=14)

In [None]:
# Save G_loss and D_loss to text files
plots_dir = f'plots/wgan/experiment_{experiment}'
os.makedirs(plots_dir, exist_ok=True)

np.savetxt(f'{plots_dir}/g_loss.txt', np.array(G_loss))
np.savetxt(f'{plots_dir}/d_loss.txt', np.array(D_loss))

# Save FID values to text file
np.savetxt(f'{plots_dir}/fid_values.txt', np.array(FID_values))