## V2 of the training notebook for GANs

In [None]:
# imports

import os
import pathlib
import torch
import torch.nn as nn
import numpy as np
from torchinfo import summary

from ml_benchmark_spategan.config import config
from ml_benchmark_spategan.dataloader import dataloader
from ml_benchmark_spategan.model.spagan2d import Generator
from ml_benchmark_spategan.visualization.plot_train import plot_losses

from IPython.display import clear_output

### Load configuration and initialize run directory


In [None]:
# find project base directory
project_base = pathlib.Path(os.getcwd()).parent
# load configuration
cf = config.set_up_run(project_base)

### Build dataloaders

In [None]:
# dataloader_train, test_dataloader, cf = dataloader.build_dataloaders(cf)
dataloader_train, test_dataloader = dataloader.build_dummy_dataloaders()
# update cf in run directory
cf.save()
# describe shapes of data
print("Training data shapes:")
x_shape, y_shape = dataloader_train.dataset._get_shapes()
print(f"  x: {x_shape}")
print(f"  y: {y_shape}")

In [None]:
model = Generator(cf.model)
summary(model, input_size=(1, 15, 16, 16))

In [None]:
loss_function = nn.MSELoss()
learning_rate = 0.00002
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.0, 0.999),weight_decay=0.0001)
num_epochs = cf.training.epochs

In [None]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
# Initialize a list to store the loss per epoch
loss_train = []
loss_test = []

# Create GradScaler for automatic mixed precision
scaler = torch.amp.GradScaler('cuda')

# Iterate over the epochs
for epoch in range(num_epochs):
    # Define a variable to accumulate the cost function value per epoch
    epoch_loss = []
    
    # Iterate over the batches
    for idx, (batch_x_data, batch_y_data) in enumerate(dataloader_train):
        batch_size = batch_x_data.shape[0]

        # Move data to the device
        batch_x_data = batch_x_data.to(device)
        batch_y_data = batch_y_data.to(device)
        
        # Zero the gradients since by default they accumulate
        optimizer.zero_grad()

        # Use autocast for mixed precision
        with torch.amp.autocast('cuda'):
            # Compute the forward pass
            outputs = model(batch_x_data)
            outputs = outputs
            # Compute the cost function
            loss_batch = loss_function(outputs, batch_y_data)
        
        epoch_loss.append(loss_batch.item())

        # Compute the gradients (backward pass) with scaled loss
        scaler.scale(loss_batch).backward()

        # Update the model weights with gradient scaling
        scaler.step(optimizer)
        scaler.update()
    # Calculate the loss for the epoch
    epoch_loss = np.mean(epoch_loss)
    loss_train.append(epoch_loss)


    test_loss = []
    with torch.no_grad():
        for batch_x, batch_y_data in test_dataloader:
            batch_size = batch_x.shape[0]

            # Move data to the device
            batch_x = batch_x.to(device)
            batch_y_data = batch_y_data.to(device)
            with torch.amp.autocast('cuda'):
                outputs = model(batch_x)
                test_loss.append(loss_function(outputs, batch_y_data).item())
    test_epoch_loss = np.mean(test_loss)
    loss_test.append(test_epoch_loss)
    # Print every 10 epochs
    if (epoch + 1) % 10 == 0 or epoch == 0:
        clear_output(wait=True)
        print(f'Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.6f} - Test Loss: {test_epoch_loss:.6f}')
        plot_losses(loss_train, loss_test, cf)

# At the end of training, save the model weights
model_name = 'model.pt'
torch.save(model.state_dict(), f'{cf.logging.run_dir}/{model_name}')