In [None]:
### Install requirements
from os.path import isfile

repository   = "https://github.com/lmingari/olot-course.git"
requirements = "requirements-section3-2.txt"

if not isfile(requirements):
    !git clone {repository}
    %cd olot-course
    !pip install -r {requirements}

# 3.2 Autoencoder (I): Training a CNN-based Autoencoder
***

## General configuration

In [None]:
### Configuration ###
config = {
    'BATCH_SIZE':    16,
    'LATENT_DIM':    2,
    'LEARNING_RATE': 1E-3,
    'NUM_EPOCHS':    150,
    'RANDOM_SEED':   43,
    'FNAME_MODEL':   'autoencoder.pt',
    }

## Main functions

#### Importing modules

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchsummary import summary

import xarray as xr
import matplotlib.pyplot as plt

#### Training function

In [None]:
def train_epoch(model, loader, criterion, optimizer):
    """
    Performs one complete training epoch over the dataset.
    
    Args:
        model: The neural network model to be trained
                
        loader: DataLoader that provides batches of training data.
            Each iteration yields a batch of inputs
        
        criterion: Loss function used to compute the training loss (e.g., nn.MSELoss).
            Takes model predictions and targets as input
        
        optimizer (torch.optim.Optimizer): Optimization algorithm used to update 
            model parameters (e.g., Adam, SGD)
    
    Returns:
        Returns training metrics (average loss for the epoch).
    """
    
    # Set training mode
    model.train()
    
    total_loss = 0.0

    # Mini-batch loop
    for batch in loader:
        # Model prediction
        prediction = model(batch)
        # Compute loss
        loss = criterion(prediction,batch)

        # Update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update metrics
        current_batch_size = batch.size(0)
        total_loss += loss.item()*current_batch_size
    return total_loss / len(loader.dataset)

#### Evaluation function

In [None]:
def evaluate_epoch(model, loader, criterion):
    # Set inference mode
    model.eval()
    
    total_loss = 0.0
    
    with torch.no_grad():
        for batch in loader:
            # Model prediction
            prediction = model(batch)
            # Compute loss
            loss = criterion(prediction,batch)

            # Update metrics
            current_batch_size = batch.size(0)
            total_loss += loss.item()*current_batch_size
    return total_loss / len(loader.dataset)

## 1. Loading raw data and normalization

In [None]:
from os.path import isfile

## Get a FALL3D ensemble run output
fname = "data/tephra_col_mass.ens.nc"
if not isfile(fname):
    !wget -P ./data https://saco.csic.es/s/wFpKYG5bHfTwKbi/download/tephra_col_mass.ens.nc

In [None]:
ds = xr.open_dataset(fname)
da = ds["tephra_col_mass"]

In [None]:
## Percentile-based scaling
## Compute the 98th percentile for normalization
print(f"Maximum value: {da.max().item()}")
print(f"98th percentile: {da.quantile(0.98).item()}")

In [None]:
da.plot.hist(bins=30)

## 2. Create a custom Dataset and splitting

In [None]:
from helper import EnsembleDataset, MinMaxScale

## Re-scale between 0 and 20
## so 98% of the data is between 0 and 1
min_value = 0
max_value = 20
transform = MinMaxScale(min_value, max_value)

## Create a Dataset object for the full dataset (training + validation)
dataset = EnsembleDataset(da, transform)

## Random split with in training and validation datasets
n_total = len(dataset)
n_train = int(0.8 * n_total)   # 80% train
n_val   = n_total - n_train    # 20% val

torch.manual_seed(config['RANDOM_SEED'])
train_dataset, val_dataset = random_split(dataset, [n_train, n_val])

## 3. Create a DataLoader

In [None]:
train_loader = DataLoader(train_dataset, batch_size=config['BATCH_SIZE'], shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=config['BATCH_SIZE'], shuffle=False)

In [None]:
## Load data by mini-batches with dimensions:
## (nbatch,nchannels,nlat,nlon)
for batch in train_loader:
    print("Batch dimensions: (nbatch,nchannels,nlat,nlon)")
    print(batch.shape)
    break

## 4. Define a model

In [None]:
from helper import Autoencoder
model = Autoencoder(config['LATENT_DIM'])
summary(model, (1,101,121))

## 5. Loss function
Creates a criterion that measures the mean squared error:

In [None]:
criterion = nn.MSELoss(reduction="mean")

## 6. Optimizer

In [None]:
optimizer = optim.Adam(model.parameters(), lr=config['LEARNING_RATE'])

## Training loop

In [None]:
# Evaluation metrics for every  epoch
train_losses = []
val_losses   = []

for epoch in range(config['NUM_EPOCHS']):
    train_loss = train_epoch(model, train_loader, criterion, optimizer)
    val_loss   = evaluate_epoch(model, val_loader, criterion)
    # Store current losses
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    if epoch%10 == 0 or epoch == config['NUM_EPOCHS']-1:
        print(f"Epoch {epoch+1:02d} -> Train loss {train_loss:.4f} | Validation loss: {val_loss:.4f}")
print("Done!")

In [None]:
plt.plot(train_losses, label = 'Training loss')
plt.plot(val_losses, label = 'Validation loss')
plt.xlabel('Epoch')
plt.ylabel('Averaged Loss')
plt.legend()

## State reconstruction using the validation dataset

In [None]:
model.eval()

yp_list = [] # List of reconstructions
xb_list = [] # List of original inputs

## Iterate over the validation dataset
with torch.no_grad():
    for xb in val_loader:
        prediction = model(xb)
        yp = prediction.squeeze(1)
        xb = xb.squeeze(1)
        yp_list.append(transform.invert(yp))
        xb_list.append(transform.invert(xb))
    reconstructions = torch.cat(yp_list, dim=0)
    inputs = torch.cat(xb_list, dim=0)

In [None]:
## Plotting reconstructions for validation dataset 
plot_conf = {
    'cmap': 'RdYlBu_r',
    'vmin': 0, 
    'vmax': 30,
}

n=min(12,n_val)
fig, axs = plt.subplots(nrows = n, ncols = 2, figsize=(6,38))

for i in range(n):
    cs1=axs[i,0].pcolormesh(da.lon,da.lat,inputs[i], **plot_conf)
    cs2=axs[i,1].pcolormesh(da.lon,da.lat,reconstructions[i], **plot_conf)
                        
for ax in axs.flat:
    ax.set_xticks([])
    ax.set_yticks([])

axs[0,0].set_title('Original model output')
axs[0,1].set_title('Reconstructed model output')
    
cbar = fig.colorbar(cs2, 
             ax=axs, 
             orientation='horizontal',
             fraction=0.05,
             pad=0.02, 
             aspect=30
            )
cbar.set_label('Column mass [g/m2]')

## Latent Space Visualization

In [None]:
fig, ax = plt.subplots()

ax.set(title = 'Latent space (z1,z2)', ylabel = 'z2', xlabel = 'z1')

with torch.no_grad():
    for batch in train_loader:
        z = model.encode(batch)
        ax.scatter(z[:,0],z[:,1], color='red')
    for batch in val_loader:
        z = model.encode(batch)
        ax.scatter(z[:,0],z[:,1], color='blue')

## Save trained model

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),  # Trained Model parameters
    'LATENT_DIM': config['LATENT_DIM'],      # Dimension of the latent space (=2)
    'MINVAL': min_value,                     # Min value used for normalization (=0)
    'MAXVAL': max_value,                     # Max value used for normalization (=20)
    }, config['FNAME_MODEL'])