# 1. Project Overview

This notebook trains a classifier to recognize spatiotemporal weather archetypes using Earthformer, a pretrained geospatial transformer. The pipeline includes data preprocessing, model adaptation, transfer learning, and evaluation.

# 2. Environment & Imports

In [None]:
# ordered based on first usage in notebook
import numpy as np
import h5py
import xarray as xr
import torch.nn.functional as F
from datetime import timedelta
import torch
import os
from earthformer.cuboid_transformer.cuboid_transformer import CuboidTransformerModel
from earthformer.utils.utils import download
import torch.nn as nn
import torch.optim as optim
import json
from datetime import datetime
import matplotlib.pyplot as plt

# 3. Data Loading & Preprocessing

## 3.1 Load raw input data & interpolate

In [None]:
# import both nc's
stream_path = "../../data/deseason_smsub/lentis_stream250_JJA_2deg_101_deseason_spatialsub.nc"
dataset_stream = xr.open_dataset(stream_path)

tas_path = "../../data/deseason_smsub/lentis_tas_JJA_2deg_101_deseason.nc"
dataset_tas = xr.open_dataset(tas_path)

# get S_PCHA from archetypes file
with h5py.File('../../data/deseason_smsub/pcha_results_8a.hdf5', 'r') as f: # run from mmi393 directory or gives error
        S_PCHA = f['/S_PCHA'][:]

# group indices based on whichever archetype is maximum there
arch_indices = np.argmax(S_PCHA, axis=0)

# join the nc's together
dataset_comb = dataset_stream.assign(tas=dataset_tas['tas'])

# add labels from archetypes into the dataset
arch_da = xr.DataArray(arch_indices, dims="time", coords={"time": dataset_comb.time})
dataset_comb_labeled = dataset_comb.assign(archetype=arch_da)

Load xarray data into tensors

In [None]:
stream = dataset_comb['stream'].squeeze('plev').values  # (T, lat, lon)
tas = dataset_comb['tas'].values                        # (T, lat, lon)

# Extract and squeeze stream function
stream = dataset_comb['stream'].squeeze('plev').values  # (T, H, W)
tas = dataset_comb['tas'].values                        # (T, H, W)

# Stack the variables along the channel axis
x_np = np.stack([stream, tas], axis=-1)  # shape: (T, H, W, C) where C = 2

# Convert to PyTorch tensor
x_tensor = torch.from_numpy(x_np).float()

Interpolate from original (H, W) shape to (128, 128)

In [None]:
# Change shape from (T, H, W, C) → (T, C, H, W) for interp
x_tensor_perm = x_tensor.permute(0, 3, 1, 2)

# Resize spatial dimensions to 128x128
x_tensor_resized = F.interpolate(x_tensor_perm, size=(128, 128), mode='bilinear', align_corners=False)

# Restore dimension ordering
x_tensor_resized = x_tensor_resized.permute(0, 2, 3, 1)

# Check final shape
print("Final resized shape:", x_tensor_resized.shape) # (T, 128, 128, C)

x_tensor = x_tensor_resized

## 3.2 Construct target labels

NB: Some of the data will be cut during target construction due to the lead time.

In [None]:
l = 7  # lead time
time = dataset_comb['time'].values  # format: datetime64
arch_labels = arch_da.values        # (9200,)

x_all = x_tensor  # shape: (T, H, W, C)
x_list = []
y_list = []
kept_time_indices = []

# Makes it so that examples from different years do not get combined
# TODO Add data from September to include last week of August?
for t in range(len(time) - l):
    target_time = time[t] + np.timedelta64(l, 'D')
    if time[t + l] == target_time:
        x_list.append(x_all[t])
        y_list.append(arch_labels[t + l])
        kept_time_indices.append(t)

# Stack into tensors
x_final = torch.stack(x_list)              # shape: (N, H, W, C)
# TODO change y into one-hot vector encoding?
y_final = torch.tensor(y_list, dtype=torch.long)  # shape: (N,)

print(f"x_final shape: {x_final.shape}") # approx. 8% of the dataset is cut
print(f"y_final shape: {y_final.shape}")
print(kept_time_indices[:100])

## 3.3 Train/Test split

In [None]:
split = 0.8
data_length = x_final.shape[0]
x_train, x_test = x_final[:floor(data_length*split)], x_final[floor(data_length*split):]
y_train, y_test = y_final[:floor(data_length*split)], y_final[floor(data_length*split):]

# 4. Earthformer Model Setup

## 4.1 Load Earthformer model config

Load state dict

In [None]:
save_dir = "./experiments"

pretrained_checkpoint_url = "https://earthformer.s3.amazonaws.com/pretrained_checkpoints/earthformer_earthnet2021.pt"
local_checkpoint_path = os.path.join(save_dir, "earthformer_earthnet2021.pt")
download(url=pretrained_checkpoint_url, path=local_checkpoint_path)

state_dict = torch.load(local_checkpoint_path, map_location=torch.device("cpu"))

Load EarthNet2021 config, sourced from [Earthnet repository](https://github.com/amazon-science/earth-forecasting-transformer/blob/main/scripts/cuboid_transformer/earthnet_w_meso/earthformer_earthnet_v1.yaml)

In [4]:
earthformer_config = {
    "base_units": 256,
    "block_units": None,
    "scale_alpha": 1.0,

    "enc_depth": [1, 1],
    "dec_depth": [1, 1],
    "enc_use_inter_ffn": True,
    "dec_use_inter_ffn": True,
    "dec_hierarchical_pos_embed": False,

    "downsample": 2,
    "downsample_type": "patch_merge",
    "upsample_type": "upsample",

    "num_global_vectors": 2,
    "use_dec_self_global": False,
    "dec_self_update_global": True,
    "use_dec_cross_global": False,
    "use_global_vector_ffn": False,
    "use_global_self_attn": True,
    "separate_global_qkv": True,
    "global_dim_ratio": 1,

    "attn_drop": 0.1,
    "proj_drop": 0.1,
    "ffn_drop": 0.1,
    "num_heads": 4,

    "ffn_activation": "gelu",
    "gated_ffn": False,
    "norm_layer": "layer_norm",
    "padding_type": "zeros",
    "pos_embed_type": "t+hw",
    "use_relative_pos": True,
    "self_attn_use_final_proj": True,

    "checkpoint_level": 0,

    "initial_downsample_type": "stack_conv",
    "initial_downsample_activation": "leaky",
    "initial_downsample_stack_conv_num_layers": 2,
    "initial_downsample_stack_conv_dim_list": [64, 256],
    "initial_downsample_stack_conv_downscale_list": [2, 2],
    "initial_downsample_stack_conv_num_conv_list": [2, 2],

    "attn_linear_init_mode": "0",
    "ffn_linear_init_mode": "0",
    "conv_init_mode": "0",
    "down_up_linear_init_mode": "0",
    "norm_init_mode": "0",

    "dec_cross_last_n_frames": None
}

## 4.2 Initialize model & load pretrained weights

Initialize Earthformer model

In [None]:
EFmodel = CuboidTransformerModel(input_shape=[1, 128, 128, 2],
                               target_shape=[1, 128, 128, 2],
                               **earthformer_config)

Filter and log matching pretrained weights from state_dict

In [None]:
model_state_dict = EFmodel.state_dict()
# Filter the keys that match in name AND shape
compatible_state_dict = {}
for k, v in state_dict.items():
    if k in model_state_dict and model_state_dict[k].shape == v.shape:
        compatible_state_dict[k] = v
        print(f"Loading: {k} | with shape: {v.shape}")
    else:
        val = model_state_dict.get(k, 'MISSING')
        if isinstance(val, torch.Tensor):
            val = val.shape
        print(f"Skipping: {k} | pretrained shape: {v.shape} vs model shape: {val}")

Load compatible keys

In [None]:
load_result = EFmodel.load_state_dict(state_dict, strict=False)
print("Missing keys:")
print(load_result.missing_keys)
print("Unexpected keys:")
print(load_result.unexpected_keys)

# 5. Classifier Head Construction

## 5.1 Wrap Earthformer into classification model

Define classifier

In [None]:
class EarthformerClassifier(nn.Module):
    def __init__(self, earthformer_model, num_classes=8):
        super().__init__()
        self.model = earthformer_model
        self.pool = nn.AdaptiveAvgPool3d((1, 1, 1))  # Pool over T, H, W
        self.classifier = nn.Linear(self.model.target_shape[-1], num_classes)

    def forward(self, x):
        x = self.model(x)  # (B, T_out, H, W, C_out)
        x = x.permute(0, 4, 1, 2, 3)  # → [B, C_out, T_out, H, W]
        x = self.pool(x).squeeze()    # → [B, C_out]
        logits = self.classifier(x)   # → [B, num_classes]
        probs = torch.sigmoid(logits) if logits.shape[1] == 1 else torch.softmax(logits, dim=1)
        return probs

Instantiate classifier with EF model from previous section

In [None]:
n_classes = 8
EFClassifier = EarthformerClassifier(EFmodel, n_classes)

## 5.2 Freeze pretrained layers

Freeze everything except classifier head:

In [None]:
for param in EFmodel.parameters():
    param.requires_grad = False

for param in EFClassifier.parameters():
    param.requires_grad = True

Freeze encoder/decoder blocks:

In [None]:
for param in EFModel.encoder.parameters():
    param.requires_grad = False
for param in EFModel.decoder.parameters():
    param.requires_grad = False

# 6. Training Setup

## 6.1 Define loss & optimizer

In [None]:
# Cross Entropy Loss for classification
loss_func = nn.CrossEntropyLoss()

# Optimizer
lr = 5e-4  # needs to be adjusted if finetuning
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

## 6.2 Helper functions

In [None]:
def training_step(model, batch_x, batch_y):
    model.train()
    logits = model(batch_x)
    loss = loss_func(logits, batch_y)
    return loss

def validation_step(model, batch_x, batch_y):
    model.eval()
    with torch.no_grad():
        logits = model(batch_x)
        loss = loss_func(logits, batch_y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == batch_y).float().mean()
    return loss.item(), acc.item()

def compute_accuracy(model, dataloader):
    model.eval()
    total_correct = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            logits = model(x)
            preds = torch.argmax(logits, dim=1)
            total_correct += (preds == y).sum().item()
            total += y.size(0)
    return total_correct / total

# 7. Training Loop

Train the classifier over an appropriate number of epochs, log training and validation loss (and accuracy)

In [None]:
num_epochs = 20
train_losses, val_losses, val_accuracies = [], [], []

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for x_batch, y_batch in train_loader:
        loss = training_step(EFClassifier, x_batch, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    train_losses.append(epoch_loss / len(train_loader))

    val_loss, val_acc = 0.0, 0.0
    for x_batch, y_batch in val_loader:
        batch_loss, batch_acc = validation_step(EFClassifier, x_batch, y_batch)
        val_loss += batch_loss
        val_acc += batch_acc
    
    val_losses.append(val_loss / len(val_loader))
    val_accuracies.append(val_acc / len(val_loader))

    print(f"Epoch {epoch+1} | Train Loss: {train_losses[-1]:.4f} | Val Loss: {val_losses[-1]:.4f} | Val Acc: {val_accuracies[-1]:.4f}")


# 8. Results & Visualization

In [12]:
# TODO: Plot loss/accuracy vs. epochs

# 9. Save Model & Export Artifacts

In [None]:
# Create output directory
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = f"outputs/{timestamp}"
os.makedirs(output_dir, exist_ok=True)

# Save model and optimizer states
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
}, os.path.join(output_dir, "checkpoint.pt"))

# Save loss history
with open(os.path.join(output_dir, "train_loss.json"), "w") as f:
    json.dump(train_loss_history, f)
with open(os.path.join(output_dir, "val_loss.json"), "w") as f:
    json.dump(val_loss_history, f)

# Plot and save loss curves
plt.figure(figsize=(8, 5))
plt.plot(train_loss_history, label='Train Loss')
plt.plot(val_loss_history, label='Val Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training & Validation Loss")
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(output_dir, "loss_curve.png"))
plt.close()