In [None]:
import os
import io
import contextlib
import json

from VAE.VAE import VAE
from VAE.GAN import GAN
from tools import generate_config, mount_drive
from tools.dataset import imagenette_loaders

In [None]:
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary

In [None]:
if is_colab := True:
    mount_drive()

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

In [None]:
config = generate_config(
    exp_number = 1,
    name = 'basic_implementation',
    output_path = 'drive/MyDrive/latent_diffusion/2025_01_13/models/VAE', 
    dataset = 'imagenette2-160', 
    description = 'Initial implementation of VAE with GAN training', 
    image_resolution = 160, 
    latent_resolution = 40, 
    latent_channel = 3, 
    batch_size = 64, 
    learning_rate = 1e-3, 
    num_epochs = 20
)

In [None]:
if not os.path.exists(f'{config['output_path']}/experiment_{config['experiment_number']}'):
    os.makedirs(f'{config['output_path']}/experiment_{config['experiment_number']}')

filename = f'{config['output_path']}/experiment_{config['experiment_number']}.json'

In [None]:
train_loader, test_loader = imagenette_loaders(config=config)

In [None]:
def kl_divergence(mu, logvar):
    # Calculate KL divergence
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    return kl.mean()

def reconstruction_loss(recon_x, x):
    assert recon_x.shape == x.shape, f"Shape mismatch: recon_x shape {recon_x.shape}, x shape {x.shape}"

    recon_loss = F.mse_loss(recon_x, x, reduction='mean')
    return recon_loss

In [None]:
def train_discriminator(model_gan, optimizer_discriminator, data):
    model_gan.model_vae.eval()
    model_gan.model_discriminator.train()

    optimizer_discriminator.zero_grad()

    with torch.no_grad():
        recon_x, _, _ = model_gan.model_vae(data)

    real_pred = model_gan.discriminator(data)
    fake_pred = model_gan.discriminator(recon_x)

    real_loss = F.binary_cross_entropy(real_pred, torch.ones_like(real_pred))
    fake_loss = F.binary_cross_entropy(fake_pred, torch.zeros_like(fake_pred))

    loss = real_loss + fake_loss

    loss.backward()
    optimizer_discriminator.step()

    return loss

In [None]:
def train_vae(model_gan, optimizer_vae, data):

    model_gan.model_vae.train()
    model_gan.model_discriminator.eval()

    optimizer_vae.zero_grad()

    # Freeze discriminator within the context
    for param in model_gan.model_discriminator.parameters():
        param.requires_grad = False

    recon_data, mu, logvar, y = model_gan(data)

    # Compute loss components
    kl = kl_divergence(mu, logvar)
    recon_loss = reconstruction_loss(recon_data, data)
    adversarial_loss = F.binary_cross_entropy_with_logits(y, torch.ones_like(y))  # Label smoothing

    # Combine losses with appropriate weighting
    loss = recon_loss + adversarial_loss + 1e-6 * kl

    # Backward pass and optimization
    loss.backward()
    optimizer_vae.step()

    for param in model_gan.model_discriminator.parameters():
        param.requires_grad = True

    return loss

In [None]:
def train(model_gan, epoch, optimizer_vae, optimizer_discriminator, data_loader):
    total_loss = 0
    for batch_idx, data in enumerate(data_loader):
        data = data.to(device)
        
        discriminator_loss = train_discriminator(model_gan, optimizer_discriminator, data)
        vae_loss = train_vae(model_gan, optimizer_vae, data)

        total_loss += vae_loss + discriminator_loss

    avg_loss = total_loss / len(data_loader.dataset)

    print(f'====> Epoch: {epoch}')
    print(f'Average loss: {avg_loss:.5f}')
    return avg_loss

In [None]:
if is_train := True:
    model_gan = GAN(config)

    optimizer_vae = optim.Adam(model_gan.model_vae.parameters(), lr=config['learning_rate'])
    optimizer_discriminator = optim.Adam(model_gan.model_discriminator.parameters(), lr=config['learning_rate'])

    model_gan.to(device)
    print('==================================')
    print(f'Training model {config["name"]}')
    print('==================================')

    summary_str = io.StringIO()
    with contextlib.redirect_stdout(summary_str):
        summary(model_gan.model_vae, input_size = (3, config['image_resolution'], config['image_resolution']))
    config['summary'] = model_gan.model_vae.getvalue()

    for epoch in range(1, config['num_epochs'] + 1):
        config['loss'] = train(model_gan, epoch, optimizer_vae, optimizer_discriminator, train_loader)

    # Open the file in write mode and use json.dump() to save the data
    with open(filename, 'w') as file:
        json.dump(config, file, indent=4)

    torch.save(model_gan.model_vae, f'{config['output_path']}/{config['name']}/{config['name']}.pth')

# Test

In [None]:
import torch

import matplotlib.pyplot as plt

# Load the VAE model
model_path = f'{config["output_path"]}/{config["name"]}/{config["name"]}.pth'
model = torch.load(model_path)

# Set the model to evaluation mode
model.eval()

# Get a batch of test data
data_iter = iter(test_loader)
data = next(data_iter)

# Move the data to the device
data = data.to(device)

# Pass the data through the VAE model
with torch.no_grad():
    recon_data, mu, logvar, _ = model(data)

# Determine the number of images to show
num_images = 3  # Change this value to configure the number of images to show

# Plot the original, latent space, and reconstruction for each image
fig, axes = plt.subplots(nrows=num_images, ncols=3, figsize=(12, 4*num_images))

for i in range(num_images):
    # Plot the original image
    axes[i, 0].imshow(data[i].permute(1, 2, 0).cpu())
    axes[i, 0].set_title('Original')
    axes[i, 0].axis('off')

    # Plot the latent space
    axes[i, 1].imshow(mu[i].cpu().view(1, -1), cmap='hot')
    axes[i, 1].set_title('Latent Space')
    axes[i, 1].axis('off')

    # Plot the reconstruction
    axes[i, 2].imshow(recon_data[i].permute(1, 2, 0).cpu())
    axes[i, 2].set_title('Reconstruction')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()