In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as T
import os

from DatasetCH import UpscaleDataset
from models import *
import Network

# Make dirs
mdir="./Model_unet/Test_1"
rdir="./Results_unet/Test_1"
os.makedirs(mdir, exist_ok=True)
os.makedirs(rdir, exist_ok=True)

# Define the tensorboard writer
writer = SummaryWriter("./Runs_unet/Test_1") # was runs_unet

In [2]:
import sys
sys.path.append('/home/mpyrina/Notebooks/ANEMOI/ClimateDiffuse/src/')
from DatasetCH import *
from TrainDiffusion import *
from TrainUnet import *

In [3]:
from Evaluation import *

### unet only

In [None]:
# define the datasets
ifs_dir = '/s2s/mpyrina/ECMWF_MCH/Europe_eval/s2s_hind_2022/all/'
obs_dir = '/net/cfc/s2s_nobackup/mpyrina/TABSD_ifs_like/'

# Run training for small number of epochs 
num_epochs = 12
## Select hyperparameters of training
batch_size = 8
learning_rate = 1e-5 # try also 1e-5
accum = 8

dataset_train = UpscaleDataset(coarse_data_dir = ifs_dir, highres_data_dir = obs_dir,
year_start=2002, year_end=2012, month=815,  
constant_variables=None, constant_variables_filename=None)

dataset_test = UpscaleDataset(coarse_data_dir = ifs_dir, highres_data_dir = obs_dir,
year_start=2012, year_end=2015, month=815,  
constant_variables=None, constant_variables_filename=None)

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=4)

# Define device
device =  'cuda' if torch.cuda.is_available() else 'cpu'

# define the ml model : 1, 1, : 1 input var, one output
unet_model = UNet((256, 128), 1, 1, label_dim=0, use_diffuse=False)
unet_model.to(device)


In [None]:
#
scaler = torch.cuda.amp.GradScaler()

# define the optimiser
optimiser = torch.optim.AdamW(unet_model.parameters(), lr=learning_rate)

loss_fn = torch.nn.MSELoss()

# train the model
losses = []
for step in range(num_epochs):
    # model_save
    model_save_path = f"{mdir}/unet_model_epoch_{step}.pt"
    # fig_save
    fig_save_path = f"{rdir}/{step}.png"
    # best modes
    mbest = f"{mdir}/best_unet_model_epoch_{step}.pt"

    epoch_loss = train_step(
        unet_model, loss_fn, dataloader_train, optimiser,
        scaler, step, accum, writer, device=device)
    losses.append(epoch_loss)

    # Save the model weights
    model_save_path = f"{mdir}/unet_model_epoch_{step}.pt"
    torch.save(unet_model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

    (fig, ax), (base_error, pred_error), predicted_numpy_array = sample_model(
        unet_model, dataloader_test, device=device)
    plt.show()
    fig.savefig(fig_save_path, dpi=300)
    plt.close(fig)


    writer.add_scalar("Error/base", base_error, step)
    writer.add_scalar("Error/pred", pred_error, step)

    # save the model
    if losses[-1] == min(losses):
        torch.save(unet_model.state_dict(), mbest)


In [None]:
# Evaluate the model on the test set
avg_pred_error, avg_base_error = evaluate_model(unet_model, loss_fn, dataloader_test, device=device)

print(f"\nFinal Evaluation on Test Set:")
print(f"Average Prediction Error (MAE): {avg_pred_error:.4f}")
print(f"Average Baseline Error (Coarse vs. Ground Truth MAE): {avg_base_error:.4f}")