In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as T
import os

from DatasetCH import UpscaleDataset
import Network

In [2]:
# Make dirs
mdir="/s2s_nobackup/mpyrina/Downscaling_output/Model_dif/Test_2"
rdir="/s2s_nobackup/mpyrina/Downscaling_output/Results_dif/Test_2"
os.makedirs(mdir, exist_ok=True)
os.makedirs(rdir, exist_ok=True)

# Define the tensorboard writer
writer = SummaryWriter(mdir) # was runs_unet

In [3]:
import sys
sys.path.append('/home/mpyrina/Notebooks/Diffusion_Downscaling/src_norm')
from DatasetCH import *
from TrainDiffusion import *
#from TrainUnet import *

In [4]:
import Evaluation

### TRAIN DIFFUSION

In [5]:
batch_size = 16
learning_rate = 1e-5
num_epochs = 1
accum = 4

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# a tensor of shape [B, C, H, W] mean that c=8, image resol=(H, W) 

network = Network.EDMPrecond(
        img_resolution=(256, 128),
        in_channels=2,
        out_channels=1,
        label_dim=1
    ).to(device)

# define the datasets
ifs_dir = '/s2s/mpyrina/ECMWF_MCH/Europe_eval/s2s_hind_2022/all/'
obs_dir = '/net/cfc/s2s_nobackup/mpyrina/TABSD_ifs_like/'
mask_dir = '/net/cfc/s2s_nobackup/mpyrina/TABSD_ifs_like/TabsD_mask_static.nc'

dataset_train = UpscaleDataset(coarse_data_dir = ifs_dir, highres_data_dir = obs_dir,
year_start=2002, year_end=2012, month=815,  
constant_variables=None, constant_variables_filename=None, mask_path=mask_dir)

dataset_test = UpscaleDataset(coarse_data_dir = ifs_dir, highres_data_dir = obs_dir,
year_start=2012, year_end=2015, month=815,  
constant_variables=None, constant_variables_filename=None, mask_path=mask_dir)

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=4)

Test - new upscale
Loaded coarse data shape: (460, 11, 16, 32)
Loaded high-resolution data shape: (460, 128, 256)
Final coarse shape: torch.Size([5060, 1, 16, 32])
Final fine shape: torch.Size([5060, 1, 128, 256])
Input shape (should be [N, 1, H, W]): torch.Size([5060, 1, 128, 256])
Loading static mask...
Test - new upscale
Loaded coarse data shape: (138, 11, 16, 32)
Loaded high-resolution data shape: (138, 128, 256)
Final coarse shape: torch.Size([1518, 1, 16, 32])
Final fine shape: torch.Size([1518, 1, 128, 256])
Input shape (should be [N, 1, H, W]): torch.Size([1518, 1, 128, 256])
Loading static mask...


In [6]:
# Train
scaler = torch.cuda.amp.GradScaler()
optimiser = torch.optim.AdamW(network.parameters(), lr=learning_rate)

loss_fn = EDMLoss()
losses = []


for step in range(num_epochs):
    # model_save
    model_save_path = f"{mdir}/dif_model_epoch_{step}.pt"
    # fig_save
    fig_save_path = f"{rdir}/dif_model_{step}.png"
    # best modes
    mbest = f"{mdir}/best_dif_model_epoch_{step}.pt"

    epoch_loss = training_step(network, loss_fn, optimiser,
                                   dataloader_train, scaler, step,
                                   accum, writer, device=device)
    losses.append(epoch_loss)
    
    # Save the model weights
    torch.save(network.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")
    
    if losses[-1] == min(losses):
        torch.save(network.state_dict(), mbest)
        
    # Plot and save
    (fig, ax), (base_error, pred_error), predicted_numpy_array = sample_model_dif(network, dataloader_test, device=device)
    plt.show()
    fig.savefig(fig_save_path, dpi=300)
    plt.close(fig)
    writer.add_scalar("Error/base", base_error, step)
    writer.add_scalar("Error/pred", pred_error, step)


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():


Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   1%|          | 2/317 [01:10<3:37:53, 41.50s/it, Loss: 1.1231]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   1%|          | 3/317 [02:18<4:40:28, 53.59s/it, Loss: 1.0688]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   1%|▏         | 4/317 [03:28<5:14:01, 60.20s/it, Loss: 1.0551]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   2%|▏         | 5/317 [04:36<5:27:01, 62.89s/it, Loss: 1.0469]

Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 128, 256])
Input shape to conv0: torch.Size([16, 128, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 64, 128])
Input shape to conv0: torch.Size([16, 256, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 32, 64])
Input shape to conv0: torch.Size([16, 512, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 1024, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 1024, 16, 32]), skips[-1].shape: torch.Size([16, 1024, 16, 32])
Input shape to conv0: torch.Size([16, 2048, 16, 32])
x.shape: torch.Size([16, 102

Train :: Epoch: 0:   2%|▏         | 5/317 [05:27<5:40:14, 65.43s/it, Loss: 1.0469]


KeyboardInterrupt: 

In [None]:
(fig, ax), (base_error, pred_error), predicted_numpy_array = sample_model_dif(network, dataloader_test, device=device)
plt.show()

In [12]:
model_save_path = f"./Model_dif/dif_model_epoch_{step}.pt"
torch.save(network.state_dict(), model_save_path)
(f"Model saved to {model_save_path}")

'Model saved to ./Model_dif/dif_model_epoch_0.pt'

In [17]:
plt.show()

### unet only

In [None]:
# define the datasets
ifs_dir = '/s2s/mpyrina/ECMWF_MCH/Europe_eval/s2s_hind_2022/all/'
obs_dir = '/net/cfc/s2s_nobackup/mpyrina/TABSD_ifs_like/'

# Run training for small number of epochs 
num_epochs = 1
## Select hyperparameters of training
batch_size = 8
learning_rate = 1e-5
accum = 8

dataset_train = UpscaleDataset(coarse_data_dir = ifs_dir, highres_data_dir = obs_dir,
year_start=2005, year_end=2008, month=815,  
constant_variables=None, constant_variables_filename=None)

dataset_test = UpscaleDataset(coarse_data_dir = ifs_dir, highres_data_dir = obs_dir,
year_start=2009, year_end=2010, month=815,  
constant_variables=None, constant_variables_filename=None)

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=4)

# Define device
device =  'cuda' if torch.cuda.is_available() else 'cpu'

# define the ml model : 1, 1, : 1 input var, one output
unet_model = UNet((256, 128), 1, 1, label_dim=0, use_diffuse=True)
unet_model.to(device)


In [None]:
#
scaler = torch.cuda.amp.GradScaler()

# define the optimiser
optimiser = torch.optim.AdamW(unet_model.parameters(), lr=learning_rate)

# Define the tensorboard writer
writer = SummaryWriter("./runs_unet")

loss_fn = torch.nn.MSELoss()

# train the model
losses = []
for step in range(num_epochs):
    epoch_loss = train_step(
        unet_model, loss_fn, dataloader_train, optimiser,
        scaler, step, accum, writer, device=device)
    losses.append(epoch_loss)

    # Save the model weights
    model_save_path = f"./Model/dif_model_epoch_{step}.pt"
    torch.save(unet_model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

    (fig, ax), (base_error, pred_error), predicted_numpy_array = sample_model(
        unet_model, dataloader_test, device=device)
    plt.show()
    fig.savefig(f"./results_unet/{step}.png", dpi=300)
    plt.close(fig)


    writer.add_scalar("Error/base", base_error, step)
    writer.add_scalar("Error/pred", pred_error, step)

    # save the model
    if losses[-1] == min(losses):
        torch.save(unet_model.state_dict(), f"./Model/Models_dif/best_unet_model_epoch_{step}.pt")
