<h1><strong>Weather Prediction Experiment</strong></h1>
<p>In this experiment, we will use the first time step to predict the entire trajectory.</p>

In [None]:
import torch
import os
import h5py
import preprocess
import numpy as np
import matplotlib.pyplot as plt
from utils import SequenceDataset, plot_sample
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch import nn, optim
from pathlib import Path
# Import the necessary classes
from score import ScoreUNet
from score import VPSDE
from score import GaussianScore
import importlib
import score
importlib.reload(score)

<h2>Load Model</h2>

In [None]:
checkpoint_path = "slurm/checkpoints/attention_config_spatial_T2m_U10m_2000_2014/attention_config_spatial_T2m_U10m_2000_2014_310.pth"

In [None]:
import importlib
import score
importlib.reload(score)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

PATH_DATA = Path('./data/processed')
# Load mask
with h5py.File(PATH_DATA / "mask.h5", "r") as f:
    mask = torch.tensor(f["dataset"][:], dtype=torch.float32, device=device).unsqueeze(0)
    mask_cpu = mask.detach().clone().cpu()
if torch.isnan(mask).any():
    raise ValueError("Mask contains NaN values!")
window = 12
# Load dataset to get dimensions
testset = SequenceDataset(PATH_DATA / "test.h5", window=window, flatten=True)
channels, y_dim, x_dim = testset[0][0].shape
print(f"Channels : {channels}")

TRAIN_CONFIG = {
    "epochs": 10000,
    "batch_size": 5,
    "learning_rate": 2e-4,
    "weight_decay": 1e-4,
    "scheduler": "cosine",
    "embedding": 64,
    "activation": "SiLU",
    "eta": 5e-3,
}
MODEL_CONFIG = { 'hidden_channels' : [64, 128,128,256],
'attention_levels' : [2],
'hidden_blocks' : [2,3,3,3],
'spatial' : 2,
'channels' : channels,
'context' : 4,
'embedding' : 64 }

In [None]:
batch_size = TRAIN_CONFIG['batch_size']
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=1, persistent_workers=True)
# Initialize ScoreUNet and VPSDE
score_unet = ScoreUNet(**MODEL_CONFIG).to(device)
vpsde = VPSDE(score_unet, shape=(channels, y_dim, x_dim), eta = TRAIN_CONFIG["eta"]).to(device)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
vpsde.load_state_dict(checkpoint['model_state_dict'])
print(f"Model restored from {checkpoint_path}, trained until epoch {checkpoint['epoch']}")

<h2>Define First Time Step Operator A(x)</h2>

In [None]:
def first_timestep(batch, mask):
    """Extract only the first timestep from the batch and replicate it across all timesteps."""
    if batch.ndim == 4:
        batch = batch.unsqueeze(0)
    S, B, C, H, W = batch.shape
    num_variables = 2  # T2m and U10m
    window = C // num_variables
    batch_reshaped = batch.view(S, B, window, num_variables, H, W)
    first_step = batch_reshaped[:, :, 0, :, :, :]
    expanded = first_step.unsqueeze(2).expand(-1, -1, window, -1, -1, -1)
    return expanded.reshape(S, B, C, H, W) * mask

def A(x):
    return first_timestep(x, mask)

In [None]:

batch, dic = next(iter(testloader))
batch = batch.to(device)
first_step_data = A(batch)

new_tensor = torch.stack((batch.cpu(), first_step_data.squeeze(0).cpu()), dim=1).flatten(0,1).cpu()
path_unnorm = PATH_DATA/ "train.h5"
info  = {'var_index': ['T2m', 'U10m'], 'channels' : 2, 'window' : 12}
fig = plot_sample(new_tensor, info, mask_cpu, samples=4, step=3, unnormalize=True, path_unnorm=path_unnorm)
plt.suptitle("Original Data vs First Timestep Only", fontsize=16)
plt.tight_layout()

<h2>Prediction Experiment</h2>

In [None]:
x_star, c_star = next(iter(testloader))
x_star = x_star.to(device)
c_star = c_star['context'].to(device)


y_star = torch.normal(A(x_star), 1e-2)*mask  

info  = {'var_index': ['T2m', 'U10m'], 'channels' : 2, 'window' : 12}
comparison = torch.stack((x_star.cpu(), y_star.squeeze(0).cpu()), dim=1).flatten(0,1).cpu()
fig = plot_sample(comparison, info, mask_cpu, samples=2, step=2, unnormalize=True, path_unnorm=path_unnorm)
plt.suptitle("Ground Truth vs First Timestep Input", fontsize=16)
plt.tight_layout()

In [None]:
import score
importlib.reload(score)
x_star, c_star = next(iter(testloader))
x_star, c_star = x_star[0].unsqueeze(0), c_star['context'][0].unsqueeze(0)
print(x_star.shape, c_star.shape)
x_star = x_star.to(device)
c_star = c_star.to(device)
y_star = torch.normal(A(x_star), 1e-2)*mask  
print(y_star.shape)
sde = VPSDE(score.DPSGaussianScore(y_star, mask, A=A, sde=vpsde, zeta=15.0), shape=x_star.shape).to(device)



In [None]:

num_samples = 3
x_preds = sde.sample(mask, shape=(num_samples,), c=c_star, steps=512, corrections=8, tau=0.5).cpu()
all_tensors = [x_star.detach().cpu(), y_star.squeeze(0).detach().cpu()] + [x_preds[i] for i in range(num_samples)]
new_tensor = torch.stack(all_tensors, dim=1).flatten(0,1).cpu()

path_unnorm = PATH_DATA/ "train.h5"
info  = {'var_index': ['T2m', 'U10m'], 'channels' : 2, 'window' : 12}
fig = plot_sample(new_tensor, info, mask_cpu, samples=5, step=3, unnormalize=True, path_unnorm=path_unnorm)
plt.suptitle("Ground Truth vs First Timestep Input vs Prediction", fontsize=16)
plt.tight_layout()

In [None]:
import metrics
importlib.reload(metrics)
preds = torch.stack([x_preds[i] for i in range(num_samples)]).squeeze(1)
gt = x_star.detach().cpu()
gt = gt.repeat(3,1,1,1)
print(preds.shape, gt.shape, mask_cpu.shape)
rmse_per_var_time, overall_rmse = metrics.calculate_rmse(preds, gt, mask_cpu)
metrics_results = metrics.calculate_metrics(preds, gt, mask_cpu)
print(f"Overall RMSE: {overall_rmse:.3f}")
print(f"RMSE per variable and time step: {rmse_per_var_time}")
print(f"Metrics: {metrics_results}")
metrics.plot_metric_comparison(metrics_results)

torch.Size([3, 24, 64, 64]) torch.Size([3, 24, 64, 64]) torch.Size([1, 64, 64])
torch.Size([3, 12, 2, 64, 64])
torch.Size([3, 12, 2, 64, 64])
Overall RMSE: 2.354
RMSE per variable and time step: tensor([[0.4675, 0.8796, 0.6094, 1.0897, 1.9756, 2.6577, 2.0941, 1.7199, 1.7108,
         1.3281, 1.0398, 1.1285],
        [1.5686, 3.3839, 9.9891, 5.7674, 5.1942, 1.4577, 9.7733, 2.4632, 2.6650,
         3.3576, 5.9067, 6.3296]])
Metrics: {'rmse': {'per_var_time': tensor([[0.4675, 0.8796, 0.6094, 1.0897, 1.9756, 2.6577, 2.0941, 1.7199, 1.7108,
         1.3281, 1.0398, 1.1285],
        [1.5686, 3.3839, 9.9891, 5.7674, 5.1942, 1.4577, 9.7733, 2.4632, 2.6650,
         3.3576, 5.9067, 6.3296]]), 'overall': tensor(2.3537)}, 'mae': {'per_var_time': tensor([[ 0.6297,  1.2771,  0.8202,  1.7304,  3.3362,  4.5289,  3.5767,  2.9183,
          2.8928,  2.2204,  1.7019,  1.8055],
        [ 2.2078,  5.2255, 16.9514,  9.0744,  8.6560,  1.8788, 16.4512,  3.9273,
          4.0796,  5.5121, 10.0151, 10.5642]]),