## V2 of the training notebook for GANs

In [None]:
# imports

import os
import pathlib
import torch
import torch.nn as nn
import numpy as np
from torchinfo import summary

from ml_benchmark_spategan.config import config
from ml_benchmark_spategan.dataloader import dataloader
from ml_benchmark_spategan.model.spagan2d import Generator, Discriminator, train_gan_step
from ml_benchmark_spategan.visualization.plot_train import plot_adversarial_losses, plot_predictions

from IPython.display import clear_output

### Load configuration and initialize run directory


In [None]:
# find project base directory
project_base = pathlib.Path(os.getcwd()).parent
# load configuration
cf = config.set_up_run(project_base)

### Build dataloaders

In [None]:
# dataloader_train, test_dataloader, cf = dataloader.build_dataloaders(cf)
dataloader_train, test_dataloader = dataloader.build_dummy_dataloaders()
# update cf in run directory
cf.save()
# describe shapes of data
print("Training data shapes:")
x_shape, y_shape = dataloader_train.dataset._get_shapes()
print(f"  x: {x_shape}")
print(f"  y: {y_shape}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize models
generator = Generator(cf.model).to(device)

# Print model summaries
print("Generator architecture:")
summary(generator, input_size=(1, 15, 16, 16))


In [None]:

discriminator = Discriminator(cf).to(device)
print("\nDiscriminator architecture:")
# Note: Discriminator takes (high_res_target, low_res_input)
summary(discriminator, input_size=[(1, 1, 128, 128), (1, 15, 16, 16)])

In [None]:
# Setup training
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Move models to device
generator = generator.to(device)
discriminator = discriminator.to(device)

# Loss function
criterion = nn.BCEWithLogitsLoss()

# Optimizers
gen_opt = torch.optim.AdamW(
    generator.parameters(), 
    lr=cf.training.learning_rate, 
    betas=(0.0, 0.999),
    weight_decay=0.0001
)

disc_opt = torch.optim.AdamW(
    discriminator.parameters(), 
    lr=cf.training.learning_rate, 
    betas=(0.0, 0.999),
    weight_decay=0.0001
)

# For mixed precision training
scaler = torch.amp.GradScaler('cuda')

In [None]:
# GAN Training loop
loss_gen_train = []
loss_disc_train = []
loss_gen_test = []

print(f"Starting GAN training for {cf.training.epochs} epochs...")

# Get a fixed batch for visualization
val_iter = iter(test_dataloader)
x_vis, y_vis = next(val_iter)

for epoch in range(cf.training.epochs):
    # Training phase
    epoch_gen_losses = []
    epoch_disc_losses = []
    
    for batch_idx, (x_batch, y_batch) in enumerate(dataloader_train):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        
        # Reshape y_batch to 2D for discriminator
        y_batch_2d = y_batch.view(-1, 1, 128, 128)
        
        # Use the train_gan_step function from spagan2d
        gen_loss, disc_loss = train_gan_step(
            config=cf,
            input_image=x_batch,
            target=y_batch_2d,
            step=epoch * len(dataloader_train) + batch_idx,
            discriminator=discriminator,
            generator=generator,
            gen_opt=gen_opt,
            disc_opt=disc_opt,
            scaler=scaler,
            criterion=criterion
        )
        
        epoch_gen_losses.append(gen_loss)
        epoch_disc_losses.append(disc_loss)
    
    # Calculate average training losses
    train_gen_loss = np.mean(epoch_gen_losses)
    train_disc_loss = np.mean(epoch_disc_losses)
    loss_gen_train.append(train_gen_loss)
    loss_disc_train.append(train_disc_loss)
    
    # Validation phase
    generator.eval()
    test_losses = []
    
    with torch.no_grad():
        for x_batch, y_batch in test_dataloader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            
            with torch.amp.autocast('cuda'):
                y_pred = generator(x_batch)
                loss = nn.L1Loss()(y_pred, y_batch)
            
            test_losses.append(loss.item())
    
    test_loss = np.mean(test_losses)
    loss_gen_test.append(test_loss)
    
    # Print progress and plot
    if (epoch + 1) % cf.logging.save_frequency == 0 or epoch == 0:
        clear_output(wait=True)
        print(f'Epoch {epoch+1}/{cf.training.epochs}')
        print(f'  Generator Loss:     {train_gen_loss:.6f}')
        print(f'  Discriminator Loss: {train_disc_loss:.6f}')
        print(f'  Test Loss (L1):     {test_loss:.6f}')
        
        # Plot losses
        plot_adversarial_losses(loss_gen_train, loss_disc_train, loss_gen_test, cf)
        
        
    # Save checkpoint
    if (epoch + 1) % cf.logging.checkpoint_frequency == 0:
        torch.save({
            'epoch': epoch + 1,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'gen_optimizer_state_dict': gen_opt.state_dict(),
            'disc_optimizer_state_dict': disc_opt.state_dict(),
            'train_gen_loss': train_gen_loss,
            'train_disc_loss': train_disc_loss,
            'test_loss': test_loss,
        }, f'{cf.logging.run_dir}/checkpoint_epoch_{epoch+1}.pt')
        print(f'  Checkpoint saved')

    if (epoch + 1) % cf.logging.map_frequency == 0:
        # Visualize predictions
        plot_predictions(generator, x_vis.to(device), y_vis.to(device), cf, epoch + 1, num_samples=3)

# Save final models
torch.save({
    'epoch': cf.training.epochs,
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'gen_optimizer_state_dict': gen_opt.state_dict(),
    'disc_optimizer_state_dict': disc_opt.state_dict(),
}, f'{cf.logging.run_dir}/final_models.pt')

print(f'\nGAN training complete! Models saved to {cf.logging.run_dir}')