#### spatio-temporal downscaling cGAN training script following
https://doi.org/10.48550/arXiv.2411.16098 \
https://agupubs.onlinelibrary.wiley.com/doi/10.1029/2023EA002906 \
the models, training and validation routines deviate slightly from \
the papers since this notebooks aims to make changes to a specific downscaling problem easy.

In [None]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
from models import Generator, Discriminator
from types import SimpleNamespace
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
config_dict = {
    'batch_size': 2,
    'num_epochs': 100,
    'save_step': 1000, #save models after every save_step
    'eval_step': 50, #validation after every eval_step
    'lam': 1, # scale l1loss
    'ensemble_size': 2, # number of ensembles member during training
    'Discriminator_filter': 16, # 128 in paper, must be divisible by 4
    'Generator_filter': 16, # 128 in paper, can be adjusted freely
}
config = SimpleNamespace(**config_dict)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

Load some example data to check if model runs

In [None]:
import xarray as xr
import numpy as np
from einops import rearrange

In [None]:
ds_test_y = xr.open_dataset('data/y_test.nc').load()
ds_test_x = xr.open_dataset('data/x_test.nc').load()

In [None]:
# adjust example data, to match model input and target
x_test = ds_test_x.cp.values + ds_test_x.lsp.values # total precipitation
x_test = x_test[:, 7:-7, 7:-7]
x_test = rearrange(x_test, 't h w -> 1 1 t h w')
y_test = rearrange(ds_test_y.rainfall_amount.values, 't h w -> 1 1 t h w')

In [None]:
x_test = torch.tensor(x_test).to(device).float()
y_test = torch.tensor(y_test).to(device).float()

In [None]:
assert not torch.isnan(x_test).any(), "x_test contains NaN values"
assert not torch.isnan(y_test).any(), "y_test contains NaN values"

In [None]:
x_test.shape, y_test.shape

In [None]:
# from continuous time dimension to batches
# change to a  dataset that already has the samples with the right dimensions
x_chunks = torch.split(x_test, 8, dim=2)
x_batched = torch.cat(x_chunks, dim=0)

y_chunks = torch.split(y_test, 8*6, dim=2)
y_batched = torch.cat(y_chunks, dim=0)

In [None]:
# this model gets samples of t=8, w,h = 14 and downscales it to t=48 and w,h = 168 
# samples and model architecture can be adjusted to fit different data 
# Generator upsampling can be adjusted and discriminator downsampling can be adjusted in models.py

In [None]:
# apply normalization to data, if necessary... spateGAN uses norm. data in the discriminator --> implment in train function

In [None]:
x_batched.shape, y_batched.shape

In [None]:
dm_train = torch.utils.data.TensorDataset(x_batched, y_batched)
# dataloader can also be adjusted to use multiprocessing and loading
dm_train = torch.utils.data.DataLoader(dm_train, batch_size=config.batch_size, shuffle=True,)

In [None]:

def generator_optimizer(model):
    return torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.0, 0.999),weight_decay=0.0001)

def discriminator_optimizer(model):
    return torch.optim.AdamW(model.parameters(), lr=2e-4, betas=(0.5, 0.999),weight_decay=0.0001)

In [None]:
def train_step(config, input_image, target, discriminator, generator, gen_opt, disc_opt, scaler, criterion):
    
    generator.train()
    discriminator.train()

    gen_opt.zero_grad(set_to_none=True)

    ##################################
    ########Generator:############
    ##################################
    
    # mixed precission
    with torch.autocast(device_type="cuda", dtype=torch.float16):
   
        ## generate multiple ensemble prediction-
        gen_ensemble = torch.cat([generator(input_image) for _ in range(config.ensemble_size)], dim=1)
        
        pred = gen_ensemble[:,0:1] # single prediction

        # calculate ensemble mean
        ensemble_mean = torch.sum(gen_ensemble, dim=1, keepdim=True) / config.ensemble_size

         # Classify all fake batch with D
        disc_fake_output = discriminator(input_image, pred)
        gen_gan_loss = criterion(disc_fake_output, torch.ones_like(disc_fake_output))
        
        # can also be changed to crps loss, using the indv. ensemble members
        l1loss = nn.L1Loss()(ensemble_mean, target)

        loss = (l1loss * config.lam  + gen_gan_loss)


    scaler.scale(loss).backward()
    
    # Gradient Norm Clipping
    #nn.utils.clip_grad_norm_(generator.parameters(), max_norm=2.0, norm_type=2)
    
    scaler.step(gen_opt)
    scaler.update()
        
    ##################################
    ########Discriminator:############
    ##################################
    
    
    # spateGAN uses log normed data for the discriminator, spateGAN-ERA5 does not apply a normalization
    # for this min max from the trainings data should be calcualted and used in the config:
    
    # input_image = torch.log(input_image + 1e-6)
    # target = torch.log(target + 1e-6)
    # input_image = (input_image - config.min) / (config.max() - config.min())
    # target = (target - config.min()) / (config.max() - config.min()) 
    
    disc_opt.zero_grad(set_to_none=True)
    

    pred = pred.detach()
    
    # pred = torch.log(pred + 1e-6)
    # pred = (pred - config.min) / (config.max() - config.min())
    
    with torch.autocast(device_type="cuda", dtype=torch.float16):

        # discriminator prediction
        disc_real_output = discriminator(input_image, target)
        disc_real = criterion(disc_real_output, torch.ones_like(disc_real_output))


        # Classify all fake batch with D
        disc_fake_output = discriminator(input_image, pred) # 
        # Calculate D's loss on the all-fake batch
        disc_fake = criterion(disc_fake_output, torch.zeros_like(disc_fake_output))

   
    scaler.scale(disc_fake + disc_real).backward()
    
    # Gradient Norm Clipping
    #nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=2.0, norm_type=2)
    scaler.step(disc_opt)
    scaler.update()
    

In [None]:
# This is a really simple validation function to plot some predictions and check if the model is learning
# Input and targets should be from a separate dataset, not the training data
# Scores that i tracked but are not included: Discriminator loss, Generator loss, L1 loss, GAN loss (mainly to check model stability)
# FSS, RAPSD, CRPS, RMSE etc... FSS and RAPSD are suitable to select the best performing model state.

def validation(config, generator,discriminator, input_image, target, step, plot_folder):
    ensemble_size = 1
    batch_split = 1

    generator.eval()
    outputs = []

    for _ in range(ensemble_size):
        batch_outputs = []
        for i in range(0, input_image.size(0), batch_split):
            batch = input_image[i:i+batch_split]
            with torch.no_grad():
                out = generator(batch)
            batch_outputs.append(out)
        outputs.append(torch.cat(batch_outputs, dim=0))

    # cat along dimension 1 for ensemble
    ensemble_output = torch.cat(outputs, dim=1) 
    
    
    target = target.cpu().detach().numpy()
    ensemble_output = ensemble_output.cpu().detach().numpy()
    input_image = input_image.cpu().detach().numpy()
    
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(input_image[0, 0, 0, :, :], cmap='turbo')
    ax[0].set_title('input')
    ax[1].imshow(target[0, 0, 0, :, :], cmap='turbo')
    ax[1].set_title('target')
    ax[2].imshow(ensemble_output[0, 0, 0, :, :], cmap='turbo')
    ax[2].set_title('output')
    plt.show()
   
  

In [None]:
def fit(config, dm_train):
      
    generator = Generator(filter_size=config.Generator_filter).to(device)

    # wrap to dataparallel
    if torch.cuda.device_count() > 1:
        generator = nn.DataParallel(generator)

    # load optimizer
    gen_opt = generator_optimizer(generator)


    discriminator = Discriminator(filter_size=config.Discriminator_filter).to(device)

    # wrap to dataparallel
    if torch.cuda.device_count() > 1:
        discriminator = nn.DataParallel(discriminator)

    # load optimizer
    disc_opt = discriminator_optimizer(discriminator)

  
    scaler = torch.amp.GradScaler()
    criterion = torch.nn.BCEWithLogitsLoss()


    
    ############################################################
    ################MODEL TRAINING##############################
    ############################################################
    
    for epoch in tqdm(range(config.num_epochs)):
        
        for step, (input_image, target) in enumerate(dm_train):
                
            train_step(config, input_image, target, discriminator, generator, gen_opt, disc_opt, scaler, criterion)
            
            full_step = epoch * len(dm_train) + step
            
        
            if full_step % (config.eval_step) == 0:
                
                ### change to valiudation dataset e.g.:
                
                # input_val, target_val = next(iter(dm_val))
                validation(config, generator,discriminator, input_image, target, step, '')
                
            ######## save model ########
            
            if full_step != 0 and full_step % config.save_step == 0:
                
                if isinstance(discriminator, torch.nn.DataParallel):
                    state_dict = discriminator.module.state_dict()
                else:
                    state_dict = discriminator.state_dict()
                    
                checkpoint_disc = {
                    'state_dict': state_dict,
                    'optimizer_state_dict': disc_opt.state_dict(),
                    'training_step': full_step,
                }
                torch.save(checkpoint_disc, 
                            'model_save/discriminator_{}.pt'.format(full_step),
                            )   
                print("saved model, discrimnator model:")
                
                if isinstance(generator, torch.nn.DataParallel):
                    state_dict = generator.module.state_dict()
                else:
                    state_dict = generator.state_dict()
        
                checkpoint_gen = {
                    'state_dict': state_dict,
                    'optimizer_state_dict': gen_opt.state_dict(),
                    'training_step': full_step,
                }
                torch.save(checkpoint_gen, 
                            'model_save/generator_{}.pt'.format(full_step),
                            )     
                print("saved model, generator model:")

In [None]:
fit(config, dm_train)

In [None]:
# currently missing but important: select best model state based on validation scores.
# GANs do not converge like optimizing a regression loss. So an option to deal with it, is to let it run, track if it is stable (e.g. disc loss)
# and select the best model from this long run based on different scores.
# depending on training time validation can be implmented in the validation function, or can be calculated after training using the saved model states.
# For precipitation FSS, and RAPSD are valueable scores to judge the model performance. But of course there are many more scores that can be used.