# Predicting Archetypes with Earthformer
Loading the data

In [None]:
import numpy as np
import h5py
import xarray as xr

# import both nc's
stream_path = "../../data/lentis_stream250_JJA_2deg_101_deseason_smsubd_sqrtcosw.nc"
dataset_stream = xr.open_dataset(stream_path)

tas_path = "../../data/lentis_tas_JJA_2deg_101_deseason.nc"
dataset_tas = xr.open_dataset(tas_path)

# get S_PCHA from archetypes file
with h5py.File('../../data/pcha_results_8a.hdf5', 'r') as f: # run from mmi393 directory or gives error
        S_PCHA = f['/S_PCHA'][:]

FileNotFoundError: [Errno 2] No such file or directory: '/net/sys/pscst001/export/BETA-IVM-BAZIS/mmi393/concurrent-heatwave-prediction/data/lentis_stream250_JJA_2deg_101_deseason_spatialsub.nc'

Join TAS and stream function data

In [2]:
# group indices based on whichever archetype is maximum there
arch_indices = np.argmax(S_PCHA, axis=0)

# sanity check part 1: these results should be the same in part 2
print(dataset_tas.isel(time=123)['tas'].isel(lon=0, lat=0).values)
print(dataset_tas.isel(time=74)['tas'].isel(lon=4, lat=8).values)

# join the nc's together
dataset_comb = dataset_stream.assign(tas=dataset_tas['tas'])

# sanity check part 2
print(dataset_comb.isel(time=123)['tas'].isel(lon=0, lat=0).values)
print(dataset_comb.isel(time=74)['tas'].isel(lon=4, lat=8).values)

-0.3480835
0.13922119
-0.3480835
0.13922119


Add labels from AA results

In [3]:
# sanity check 1
print(arch_indices[0], arch_indices[5], arch_indices[6], arch_indices[9119])

arch_da = xr.DataArray(arch_indices, dims="time", coords={"time": dataset_comb.time})
# sanity check 2 
print(arch_da.isel(time=0).values, arch_da.isel(time=5).values, arch_da.isel(time=6).values, arch_da.isel(time=9119).values)

# calculate the mean for each archetype's group
dataset_comb_labeled = dataset_comb.assign(archetype=arch_da)

# sanity check 3
print(dataset_comb_labeled.isel(time=0)['archetype'].values,
      dataset_comb_labeled.isel(time=5)['archetype'].values,
      dataset_comb_labeled.isel(time=6)['archetype'].values,
      dataset_comb_labeled.isel(time=9119)['archetype'].values)

4 2 2 3
4 2 2 3
4 2 2 3


## Dataset construction

In [None]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


From xarray dataset to pytorch tensors

In [6]:
stream = dataset_comb['stream'].squeeze('plev').values  # (T, lat, lon)
tas = dataset_comb['tas'].values                        # (T, lat, lon)

# Extract and squeeze stream function
stream = dataset_comb['stream'].squeeze('plev').values  # (T, H, W)
tas = dataset_comb['tas'].values                        # (T, H, W)

# Stack the variables along the channel axis
x_np = np.stack([stream, tas], axis=-1)  # shape: (T, H, W, C) where C = 2

# Convert to PyTorch tensor
x_tensor = torch.from_numpy(x_np).float()

print(x_tensor.shape)  # (T=9200, H=29, W=170, C=2)

torch.Size([9200, 29, 170, 2])


Target construction

If t+7 is belongs to another year, exclude example from labeling

In [None]:
from datetime import timedelta

l = 7  # lead time
time = dataset_comb['time'].values  # format: datetime64
arch_labels = arch_da.values        # (9200,)

x_all = x_tensor  # shape: (T, H, W, C)
x_list = []
y_list = []
kept_time_indices = []

# Makes it so that examples from different years do not get combined
# TODO Add data from September to include last week of August?
for t in range(len(time) - l):
    target_time = time[t] + np.timedelta64(l, 'D')
    if time[t + l] == target_time:
        x_list.append(x_all[t])
        y_list.append(arch_labels[t + l])
        kept_time_indices.append(t)

# Stack into tensors
x_final = torch.stack(x_list)              # shape: (N, H, W, C)
y_final = torch.tensor(y_list, dtype=torch.long)  # shape: (N,)

print(f"x_final shape: {x_final.shape}") # approx. 8% of the dataset is cut
print(f"y_final shape: {y_final.shape}")
print(kept_time_indices[:100])

x_final shape: torch.Size([8500, 29, 170, 2])
y_final shape: torch.Size([8500])
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106]


Train/Test Split

In [None]:
# TODO split x&y into train/test sets

## Using Earthformer
Import Earthformer

In [None]:
import os
from earthformer.cuboid_transformer.cuboid_transformer import CuboidTransformerModel
#from earthformer.train_cuboid_earthnet import CuboidEarthNet2021PLModule
from earthformer.utils.utils import download

save_dir = "./experiments"

pretrained_checkpoint_url = "https://earthformer.s3.amazonaws.com/pretrained_checkpoints/earthformer_earthnet2021.pt"
local_checkpoint_path = os.path.join(save_dir, "earthformer_earthnet2021.pt")
download(url=pretrained_checkpoint_url, path=local_checkpoint_path)

state_dict = torch.load(local_checkpoint_path, map_location=torch.device("cpu"))
pl_module.torch_nn_module.load_state_dict(state_dict=state_dict)

ModuleNotFoundError: No module named 'earthformer.train_cuboid_earthnet'

Initialize Earthformer model with the correct shapes, load pretrained weights where applicable

In [None]:
# TODO figure out the proper initialization
model = CuboidTransformerModel(input_shape=(x_final.shape[0], x_final.shape[1], x_final.shape[2], x_final.shape[3]),
                               target_shape=(x_final.shape[0], x_final.shape[1], x_final.shape[2], x_final.shape[3]),
                               enc_depth=[4, 4],
                               dec_depth=[2, 2])


Adapt Earthformer to classification task

In [17]:
n_classes = S_PCHA.shape[0]

In [None]:
class EarthformerClassifier(nn.Module):
    def __init__(self, earthformer_model, num_classes=n_classes):
        super().__init__()
        self.model = earthformer_model
        self.pool = nn.AdaptiveAvgPool3d((1, 1, 1))  # Pool over T, H, W
        self.classifier = nn.Linear(self.model.target_shape[-1], num_classes)

    def forward(self, x):
        x = self.model(x)  # (B, T_out, H, W, C_out)
        x = x.permute(0, 4, 1, 2, 3)  # → [B, C_out, T_out, H, W]
        x = self.pool(x).squeeze()    # → [B, C_out]
        logits = self.classifier(x)   # → [B, num_classes]
        probs = torch.sigmoid(logits) if logits.shape[1] == 1 else torch.softmax(logits, dim=1)
        return probs