In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split

import numpy as np
import matplotlib.pyplot as plt
import os

from vgg import VGG11
from gan import Generator, Discriminator

In [None]:
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)

latent_size = 100
image_channels = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Update the transform pipeline to include the Resize operation
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Scale pixel values from [0, 1] to [-1, 1]
])


# Load and split the MNIST dataset
train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
gan_data_size = int(0.8 * len(train_data))
real_data_size = len(train_data) - gan_data_size
gan_data, real_data = random_split(train_data, [gan_data_size, real_data_size])

# Create data loaders
gan_loader = DataLoader(gan_data, batch_size=256, shuffle=True)
real_loader = DataLoader(real_data, batch_size=256, shuffle=True)

In [None]:
def write_log(epoch, num_epochs, stats, log_dir="logs"):
    # Unpack the dictionary
    g_loss = stats['g_loss']
    d_loss = stats['d_loss']
    synthetic_loss = stats['synthetic_loss']
    real_loss = stats['real_loss']
    g_losses = stats['g_losses']
    d_losses = stats['d_losses']
    synthetic_losses = stats['synthetic_losses']
    real_losses = stats['real_losses']

    # Append the losses to the corresponding lists
    g_losses.append(g_loss)
    d_losses.append(d_loss)
    synthetic_losses.append(synthetic_loss)
    real_losses.append(real_loss)

    log_message = f"Epoch: {epoch+1}/{num_epochs}\n"
    log_message += f"Generator loss: {g_loss:.4f}, Discriminator loss: {d_loss:.4f}\n"
    log_message += f"Synthetic CNN loss: {synthetic_loss:.4f}, Real CNN loss: {real_loss:.4f}\n"
    print(log_message)
    
    with open(os.path.join(log_dir, "training_logs.txt"), "a") as log_file:
        log_file.write(log_message)

    if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
        # Save the loss curves
        fig, axes = plt.subplots(2, 2, figsize=(10, 10))

        axes[0, 0].plot(g_losses)
        axes[0, 0].set_title("Generator Loss")
        axes[0, 0].set_xlabel("Epoch")
        axes[0, 0].set_ylabel("Loss")

        axes[0, 1].plot(d_losses)
        axes[0, 1].set_title("Discriminator Loss")
        axes[0, 1].set_xlabel("Epoch")
        axes[0, 1].set_ylabel("Loss")

        axes[1, 0].plot(synthetic_losses)
        axes[1, 0].set_title("Synthetic CNN Loss")
        axes[1, 0].set_xlabel("Epoch")
        axes[1, 0].set_ylabel("Loss")

        axes[1, 1].plot(real_losses)
        axes[1, 1].set_title("Real CNN Loss")
        axes[1, 1].set_xlabel("Epoch")
        axes[1, 1].set_ylabel("Loss")

        plt.tight_layout()
        plt.savefig(os.path.join(log_dir, f"loss_curves_epoch_{epoch+1}.png"), dpi=300)
        plt.close(fig)

In [None]:
def train_gan_cnn_with_distance_reg(generator, discriminator, cnn_synthetic, cnn_real, 
                                    gan_loader, real_loader, 
                                    g_optimizer, d_optimizer, cnn_synthetic_optimizer, cnn_real_optimizer, 
                                    criterion_CNN, criterion_DIS,
                                    distance_reg_weight, num_epochs):
    
    stats = {
    'g_losses': [], 'd_losses': [], 'synthetic_losses': [], 'real_losses': []
    }   
    for epoch in range(num_epochs):
        generator = generator.to(device)
        discriminator = discriminator.to(device)
        cnn_synthetic = cnn_synthetic.to(device)
        cnn_real = cnn_real.to(device)

        cnn_synthetic.train()
        cnn_real.train()
        generator.train()
        discriminator.train()


        # Iterate through the GAN and real data loaders in parallel
        for (gan_data, _), (real_data, real_labels) in zip(gan_loader, real_loader):
            gan_data = gan_data.to(device)
            real_data = real_data.to(device)
            real_labels = real_labels.to(device)

            # Train the GAN (generator and discriminator)
            real_labels_gan = torch.ones(real_data.size(0), 1).to(device)
            fake_labels_gan = torch.zeros(real_data.size(0), 1).to(device)

            # Train the discriminator on real data
            real_outputs_gan = discriminator(real_data)
            real_loss_gan = criterion_DIS(real_outputs_gan, real_labels_gan)
            
            # Train the discriminator on generated data
            noise = torch.randn(real_data.size(0), latent_size).to(device)
            fake_data = generator(noise)
            fake_outputs_gan = discriminator(fake_data.detach())
            fake_loss_gan = criterion_DIS(fake_outputs_gan, fake_labels_gan)

            d_loss = real_loss_gan + fake_loss_gan
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            # Train the generator
            outputs_gan = discriminator(fake_data)
            g_loss = criterion_DIS(outputs_gan, real_labels_gan)
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()
           

            # Generate synthetic samples using the trained GAN
            synthetic_data = generator(noise)

            # Train cnn_synthetic on synthetic samples
            synthetic_outputs = cnn_synthetic(synthetic_data)
            synthetic_loss = criterion_CNN(synthetic_outputs, real_labels)  # Use real labels as targets
            cnn_synthetic_optimizer.zero_grad()
            synthetic_loss.backward()
            cnn_synthetic_optimizer.step()

            # Train cnn_real on real samples
            real_outputs = cnn_real(real_data)
            real_loss = criterion_CNN(real_outputs, real_labels)
            cnn_real_optimizer.zero_grad()
            real_loss.backward()
            cnn_real_optimizer.step()

            # Add distance metric between the weights of cnn_synthetic and cnn_real as a regularization term
            distance_metric = 0
            for p_synthetic, p_real in zip(cnn_synthetic.parameters(), cnn_real.parameters()):
                if p_synthetic.dim() > 1:
                    distance_metric += torch.nn.functional.cosine_similarity(p_synthetic.view(1, -1), p_real.view(1, -1)).mean()
                else:
                    # print(p_synthetic.dim())
                    pass

            # Apply distance regularization on cnn_synthetic
            cnn_synthetic_optimizer.zero_grad()
            (-distance_reg_weight * distance_metric).backward(retain_graph=True)
            cnn_synthetic_optimizer.step()

            (distance_reg_weight * distance_metric).backward()
            cnn_real_optimizer.step()

        stats.update({
        'g_loss': g_loss.item(),
        'd_loss': d_loss.item(),
        'synthetic_loss': synthetic_loss.item(),
        'real_loss': real_loss.item()
        })
        
        write_log(epoch, num_epochs, stats, log_dir)

        # Save a generated image
        fake_img = fake_data[0].detach().cpu().numpy().squeeze()
        real_img = real_data[0].detach().cpu().numpy().squeeze()

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

        # Display the real image
        ax1.imshow(real_img, cmap='gray')
        ax1.set_title("Real Image")
        ax1.axis("off")

        # Display the fake image
        ax2.imshow(fake_img, cmap='gray')
        ax2.set_title("Fake Image")
        ax2.axis("off")

        plt.savefig(os.path.join(log_dir, f"generated_image_epoch_{epoch+1}.png"), dpi=300)
        plt.close()

In [None]:
generator = Generator(latent_size, image_channels)
discriminator = Discriminator(image_channels)
cnn_real = VGG11()
cnn_synthetic = VGG11()

g_optimizer = optim.Adam(generator.parameters(), lr=0.0001)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.001)
cnn_real_optimizer = optim.Adam(cnn_real.parameters(), lr=1e-4)
cnn_synthetic_optimizer = optim.Adam(cnn_synthetic.parameters(), lr=1e-4)

criterion_CNN = nn.CrossEntropyLoss()
criterion_DIS = nn.BCELoss()

In [None]:
train_gan_cnn_with_distance_reg(generator = generator, discriminator = discriminator, cnn_synthetic = cnn_synthetic, cnn_real = cnn_real, 
                                    gan_loader = gan_loader, real_loader = real_loader, 
                                    g_optimizer = g_optimizer, d_optimizer = d_optimizer,  cnn_synthetic_optimizer = cnn_synthetic_optimizer, cnn_real_optimizer = cnn_real_optimizer, 
                                    criterion_CNN = criterion_CNN, criterion_DIS = criterion_DIS,
                                    distance_reg_weight = 1, num_epochs = 20)