### Packages + Configuration

In [1]:
import gc
import json
import math
import runpy
from pathlib import Path

import numpy as np

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

years = [2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010]
variable  = "2m_temperature"

seed = 0

epochs = 400
lr = 1e-4
batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pin_memory = (device.type == "cuda")

dataset_dir = Path.cwd() / "input_labels"

### Load data and check if normalized

In [2]:
def to_nchw(a: np.ndarray) -> np.ndarray:
    a = np.asarray(a)
    if a.ndim == 3:
        a = a[:, None, :, :]
    elif a.ndim == 4:
        if a.shape[1] in (1, 3):
            pass
        elif a.shape[-1] in (1, 3):
            a = a.transpose(0, 3, 1, 2)
    return a.astype(np.float32)

files = []
for year in years:
    files.append(dataset_dir / f"{variable}" / f"{year}_{variable}.npz")

N = 0
C = H = W = None
for file in files:
    with np.load(file) as data:
        X0 = to_nchw(data["X"])
        if C is None:
            _, C, H, W = X0.shape
        N += X0.shape[0]
        del X0
gc.collect()

X = np.empty((N, C, H, W), dtype=np.float32)
Y = np.empty((N, C, H, W), dtype=np.float32)

offset = 0
for file in files:
    with np.load(file) as data:
        X0 = to_nchw(data["X"])
        Y0 = to_nchw(data["Y"])
    n0 = X0.shape[0]
    X[offset:offset+n0] = X0
    Y[offset:offset+n0] = Y0
    offset += n0
    del X0, Y0
gc.collect()

def check_0_1(name: str, a: np.ndarray, tol: float = 1e-6):
    amin = float(np.nanmin(a))
    amax = float(np.nanmax(a))
    ok = (amin >= -tol) and (amax <= 1.0 + tol)
    if ok:
        print(f"The data is properly normalized: min = {amin:.6f}, max = {amax:.6f}")
    if not ok:
        raise ValueError(
            print(f"The normalization step is missing or inconsistent: min = {amin:.6f}, max = {amax:.6f}")
        )

check_0_1("X", X)
check_0_1("Y", Y)

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

n = len(X)
train_n = int(0.8 * n)
val_n = int(0.1 * n)

rng = np.random.default_rng(seed)
perm = rng.permutation(n)

train_idx = perm[:train_n]
val_idx   = perm[train_n:train_n + val_n]
test_idx  = perm[train_n + val_n:]

The data is properly normalized: min = 0.010022, max = 0.988940
The data is properly normalized: min = 0.000000, max = 1.000000


### Prepare data

In [3]:
class FullDataset(Dataset):
    def __init__(self, X, Y, indices):
        self.X = X
        self.Y = Y
        self.indices = np.asarray(indices)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        idx = self.indices[i]
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

train_ds = FullDataset(X, Y, train_idx)
val_ds   = FullDataset(X, Y, val_idx)
test_ds  = FullDataset(X, Y, test_idx)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=0, pin_memory=pin_memory)
val_loader   = DataLoader(val_ds,   batch_size=1,          shuffle=False, num_workers=0, pin_memory=pin_memory)
test_loader  = DataLoader(test_ds,  batch_size=1,          shuffle=False, num_workers=0, pin_memory=pin_memory)

### Model + Evaluator

In [4]:
# SmCL (SoftMax Contraining Layer)
class SmCL(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, z_hr: torch.Tensor, x_hr: torch.Tensor) -> torch.Tensor:
        B, C, H, W = z_hr.shape
        if (H % 2) != 0 or (W % 2) != 0:
            raise ValueError(f"H,W must be divisible by 2. Got {H}x{W}, scale = {2}")

        Hs, Ws = H // 2, W // 2

        x_lr = F.avg_pool2d(x_hr, kernel_size=2, stride=2)

        z = z_hr.view(B, C, Hs, 2, Ws, 2)
        z_flat = z.permute(0, 1, 2, 4, 3, 5).contiguous().view(B, C, Hs, Ws, 2 * 2)

        w_flat = torch.softmax(z_flat, dim=-1)
        w = w_flat.view(B, C, Hs, Ws, 2, 2).permute(0, 1, 2, 4, 3, 5).contiguous()

        x_lr_b = x_lr.view(B, C, Hs, 1, Ws, 1)
        y = w * (x_lr_b * (2 * 2))

        return y.view(B, C, H, W)

# simple SRCNN + SmCL tail
class SRCNN_SmCL(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(32, 1, kernel_size=5, padding=2)
        self.smcl = SmCL()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = F.relu(self.conv1(x))
        z = F.relu(self.conv2(z))
        z = self.conv3(z)
        return self.smcl(z, x)

# metric for evaluating "on the run"
@torch.no_grad()
def eval_loss(loader, model, device):
    model.eval()
    total_loss = 0.0
    n = 0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        pred = model(x)
        total_loss += F.l1_loss(pred, y, reduction="sum").item()
        n += y.numel()

    return total_loss / max(n, 1)

# Metric for final test set evaluation
@torch.no_grad()
def eval_mae_rmse(loader, model, device, baseline: bool = False):
    model.eval()
    abs_sum = 0.0
    sq_sum = 0.0
    n = 0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        pred = x if baseline else model(x)

        diff = pred - y
        abs_sum += diff.abs().sum().item()
        sq_sum  += (diff * diff).sum().item()
        n += y.numel()

    mae = abs_sum / max(n, 1)
    rmse = math.sqrt(sq_sum / max(n, 1))
    return mae, rmse

# define model
model = SRCNN_SmCL().to(device)

### Training Loop

In [5]:
# Create a path where we will store current run data
run_dir = Path("runs") / f"SRCNN_SmCL_{variable}" 
run_dir.mkdir(parents=True, exist_ok=True)

# Path for report.json what will store important information about the run
report_path = run_dir / "report.json"

# Fill in the training configuration (these don't depend on post-training data)
report = {
    "config": {
        "model": "SRCNN_SmCL",
        "years": years,
        "variable": variable,

        "seed": seed,
        
        "epochs": epochs,
        "lr": lr,
        "batch_size": batch_size,
        "device": str(device),
    }
}

#=========================================================================================#

def train():
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    # Store intermediate losses for plotting later
    train_hist = []
    val_hist = []
    epoch_hist = []

    # Here we will store the best checkpoint
    best_val = float("inf")
    ckpt_path_best = run_dir / "best_model.pt"

    #=========================================================================================#
    
    # Calculate the baseline metrics and add to report
    baseline_mae, baseline_rmse = eval_mae_rmse(test_loader, model, device, baseline=True)
    
    print(f"[BASELINE ERROR | IDENTITY MODEL] test_MAE={baseline_mae:.6f} | test_RMSE={baseline_rmse:.6f}")

    report["baseline_error"] = {
        "mae": baseline_mae,
        "rmse": baseline_rmse,
    }

    #=========================================================================================#
    
    # Start iterating across epochs
    for epoch in range(1, epochs + 1):
        model.train()
        total_train_loss = 0.0
        n = 0

        for x, y in train_loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            pred = model(x)
            loss = F.l1_loss(pred, y)

            loss.backward()
            opt.step()

            loss = F.l1_loss(pred, y, reduction="sum")
            total_train_loss += loss.item()
            n += y.numel()

        train_loss = total_train_loss / max(n, 1)
        val_loss = eval_loss(val_loader, model, device)

        print(f"epoch {epoch:03d} | train_loss = {train_loss:.6f} | val_loss = {val_loss:.6f}")

        epoch_hist.append(epoch)
        train_hist.append(train_loss)
        val_hist.append(val_loss)

        if val_loss < best_val:
            best_val = val_loss
            torch.save(
                {"state_dict": model.state_dict(), "best_val": best_val, "epoch": epoch},
                ckpt_path_best,
            )

    #=========================================================================================#
    
    # Plot the saved train and val losses saved during training
    fig, ax = plt.subplots(figsize=(7, 4), constrained_layout=True)
    ax.plot(epoch_hist, train_hist, label="train_loss")
    ax.plot(epoch_hist, val_hist, label="val_loss")
    ax.set_xlabel("epoch")
    ax.set_ylabel("loss")
    ax.set_title("Loss curve")
    ax.grid(True, alpha=0.3)
    ax.legend()
    fig.savefig(run_dir / f"SRCNN_SmCL_{variable}_loss_curve.png", dpi=150)
    plt.close(fig)

    #=========================================================================================#
    
    # Load the best checkpoint and evaluate it on the test set
    if ckpt_path_best.exists():
        ckpt = torch.load(ckpt_path_best, map_location=device)
        model.load_state_dict(ckpt["state_dict"])

    # Calculate the trained metrics and add to report
    test_mae, test_rmse = eval_mae_rmse(test_loader, model, device, baseline=False)
    
    print(f"[TRAINED ERROR | SRCNN_SmCL] test_MAE={test_mae:.6f} | test_RMSE={test_rmse:.6f}")

    report["trained_error"] = {
        "mae": test_mae,
        "rmse": test_rmse,
    }

    #=========================================================================================#
    
    # Calculate how much the model performace improved
    report["error_reduction"] = {
        "mae": 100.0 * (baseline_mae - test_mae) / baseline_mae if baseline_mae != 0 else float("nan"),
        "rmse": 100.0 * (baseline_rmse - test_rmse) / baseline_rmse if baseline_rmse != 0 else float("nan"),
    }

if __name__ == "__main__":
    train()

[BASELINE ERROR | IDENTITY MODEL] test_MAE=0.005972 | test_RMSE=0.010435
epoch 001 | train_loss = 0.005304 | val_loss = 0.004764
epoch 002 | train_loss = 0.004556 | val_loss = 0.004375
epoch 003 | train_loss = 0.004283 | val_loss = 0.004205
epoch 004 | train_loss = 0.004165 | val_loss = 0.004127
epoch 005 | train_loss = 0.004108 | val_loss = 0.004086
epoch 006 | train_loss = 0.004076 | val_loss = 0.004061
epoch 007 | train_loss = 0.004055 | val_loss = 0.004043
epoch 008 | train_loss = 0.004038 | val_loss = 0.004029
epoch 009 | train_loss = 0.004026 | val_loss = 0.004017
epoch 010 | train_loss = 0.004015 | val_loss = 0.004007
epoch 011 | train_loss = 0.004006 | val_loss = 0.003999
epoch 012 | train_loss = 0.003998 | val_loss = 0.003991
epoch 013 | train_loss = 0.003990 | val_loss = 0.003985
epoch 014 | train_loss = 0.003983 | val_loss = 0.003977
epoch 015 | train_loss = 0.003977 | val_loss = 0.003973
epoch 016 | train_loss = 0.003972 | val_loss = 0.003967
epoch 017 | train_loss = 0.0039

### Check for physical violations

In [6]:
# Load ranges for denormalising
RANGES_PATH = Path.cwd() / "ranges.py"
ns = runpy.run_path(str(RANGES_PATH))
RANGES = ns["RANGES"]

vmin = float(RANGES[variable][0])
vmax = float(RANGES[variable][1])

# Which variables have to be non-negative in physical units?
NONNEG_VARS = {
    "2m_temperature",
    "mean_surface_downward_long_wave_radiation_flux",
    "mean_surface_downward_short_wave_radiation_flux",
    "specific_humidity",
    "surface_pressure",
    "total_precipitation",
}
need_nonneg = (variable in NONNEG_VARS)

# Helper for calculating reduction percentages
def calculate_reduction(b, t):
    if b is None or t is None:
        return None
    b = float(b)
    t = float(t)
    if not math.isfinite(b) or not math.isfinite(t):
        return None
    if b == 0.0:
        return 0.0 if t == 0.0 else None
    return 100.0 * (b - t) / b

#=========================================================================================#

# Calculate the baseline violations and add to report
negative_pixels_total = 0
total_pixels = 0

mse_sum = 0.0
mse_count = 0

with torch.no_grad():
    for x, _ in test_loader:
        x = x.to(device, non_blocking=True)

        pred = x

        pred_phys = pred * (vmax - vmin) + vmin
        x_phys    = x    * (vmax - vmin) + vmin

        if need_nonneg:
            negative_pixels_total += (pred_phys < 0).sum().item()
            total_pixels += pred_phys.numel()

        pred_lr = F.avg_pool2d(pred_phys, kernel_size=2, stride=2)
        x_lr    = F.avg_pool2d(x_phys,    kernel_size=2, stride=2)

        diff = pred_lr - x_lr
        mse_sum += (diff * diff).sum().item()
        mse_count += diff.numel()

baseline_mean_consistency_rmse = float(math.sqrt(mse_sum / max(mse_count, 1)))
baseline_negatives_per_million = 1000000*(float(negative_pixels_total/max(total_pixels, 1))) if need_nonneg else None
    
bpm_str = f"{baseline_negatives_per_million:.6f}" if baseline_negatives_per_million is not None else "NA"
print(f"[BASELINE VIOLATION | IDENTITY MODEL] mean_consistency_rmse = {baseline_mean_consistency_rmse:.6f} | negatives_per_million = {bpm_str}")
    
report["baseline_violation"] = {
    "mean_consistency_rmse": baseline_mean_consistency_rmse,
    "negatives_per_million": baseline_negatives_per_million,
}

#=========================================================================================#

# Load the best checkpoint and evaluate violations of the test set
ckpt_path_best = run_dir / "best_model.pt"
ckpt = torch.load(ckpt_path_best, map_location=device)
model.load_state_dict(ckpt["state_dict"])
model.to(device)
model.eval()

# Calculate the trained violations and add to report
negative_pixels_total = 0
total_pixels = 0

mse_sum = 0.0
mse_count = 0

model.eval()
with torch.no_grad():
    for x, _ in test_loader:
        x = x.to(device, non_blocking=True)

        pred = model(x)

        pred_phys = pred * (vmax - vmin) + vmin
        x_phys    = x    * (vmax - vmin) + vmin

        if need_nonneg:
            negative_pixels_total += (pred_phys < 0).sum().item()
            total_pixels += pred_phys.numel()

        pred_lr = F.avg_pool2d(pred_phys, kernel_size=2, stride=2)
        x_lr    = F.avg_pool2d(x_phys,    kernel_size=2, stride=2)

        diff = pred_lr - x_lr
        mse_sum += (diff * diff).sum().item()
        mse_count += diff.numel()

trained_mean_consistency_rmse = float(math.sqrt(mse_sum / max(mse_count, 1)))
trained_negatives_per_million = 1000000*(float(negative_pixels_total/max(total_pixels, 1))) if need_nonneg else None
    
tpm_str = f"{trained_negatives_per_million:.6f}" if trained_negatives_per_million is not None else "NA"
print(f"[TRAINED VIOLATION | SRCNN_SmCL] mean_consistency_rmse = {trained_mean_consistency_rmse:.6f} | negatives_per_million = {tpm_str}")
    
report["trained_violation"] = {
    "mean_consistency_rmse": trained_mean_consistency_rmse,
    "negatives_per_million": trained_negatives_per_million,
}

#=========================================================================================#

# Calculate how much the model performace improved
report["violation_reduction"] = {
    "mean_consistency_rmse": calculate_reduction(baseline_mean_consistency_rmse, trained_mean_consistency_rmse),
    "negatives_per_million": calculate_reduction(baseline_negatives_per_million, trained_negatives_per_million),
}

# Save the whole report in the working directory
def scrub_json(x):
    if isinstance(x, float):
        return None if (math.isnan(x) or math.isinf(x)) else x
    if isinstance(x, dict):
        return {k: scrub_json(v) for k, v in x.items()}
    if isinstance(x, list):
        return [scrub_json(v) for v in x]
    return x

with open(report_path, "w", encoding="utf-8") as f:
    json.dump(scrub_json(report), f, indent=2, sort_keys=False, allow_nan=False)

[BASELINE VIOLATION | IDENTITY MODEL] mean_consistency_rmse = 0.000000 | negatives_per_million = 0.000000
[TRAINED VIOLATION | SRCNN_SmCL] mean_consistency_rmse = 0.000014 | negatives_per_million = 0.000000
