## Setting Up:

In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import pandas as pd
import warnings
from tqdm.notebook import tqdm
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle

from scripts.helpers import *
from scripts.iceland_preprocess import *
from scripts.plots import *
from scripts.config_ICE import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.NN_networks import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

cfg = mbm.IcelandConfig(dataPath='/home/mburlet/scratch/data/DATA_MB/WGMS/Iceland/')

In [None]:
seed_all(cfg.seed)
free_up_cuda()

# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)

vois_climate = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'u10', 'v10'
]

vois_topographical = [
    "aspect", # OGGM
    "slope", # OGGM
    "hugonnet_dhdt",  # OGGM
    "consensus_ice_thickness",  # OGGM
    "millan_v",
]

In [None]:
seed_all(cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()

    # # Try to limit CPU usage of random search
    # torch.set_num_threads(2)  # or 1
    # os.environ["OMP_NUM_THREADS"] = "1"
    # os.environ["MKL_NUM_THREADS"] = "1"
else:
    print("CUDA is NOT available")


## A. Read GL data:

In [None]:
data_wgms = pd.read_csv(cfg.dataPath + path_PMB_WGMS_csv + 'ICE_dataset_all_oggm_with_hugonnetdhdt.csv')

# Drop Nan entries in hugonnetdhdt of Iceland dataset
data_wgms = data_wgms.dropna(subset=data_wgms.columns.drop('DATA_MODIFICATION'))

print('Number of glaciers:', len(data_wgms['GLACIER'].unique()))
print('Number of winter, summer and annual samples:', len(data_wgms[data_wgms.PERIOD == 'annual']) + len(data_wgms[data_wgms.PERIOD == 'winter']) + len(data_wgms[data_wgms.PERIOD == 'summer']))
print('Number of annual samples:',
      len(data_wgms[data_wgms.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_wgms[data_wgms.PERIOD == 'winter']))
print('Number of summer samples:',
      len(data_wgms[data_wgms.PERIOD == 'summer']))

data_wgms.columns


## B. Progressive Transfer

In [None]:
data_ICE_test = data_wgms.copy()

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_WGMS_csv,
    'era5_climate_data': cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data_ICECH.nc',
    'geopotential_data': cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure_ICECH.nc'
}

RUN = False
dataloader_gl = process_or_load_data(run_flag=RUN,
                                     df=data_ICE_test,
                                     paths=paths,
                                     cfg=cfg,
                                     vois_climate=vois_climate,
                                     vois_topographical=vois_topographical,
                                     output_file= 'ICE_dataset_monthly_full_with_hugonnetdhdt.csv')
data_monthly = dataloader_gl.data

display(data_monthly.head(2))

In [None]:
display(data_monthly['GLACIER'].value_counts())
display(data_monthly.shape)

display()

In [None]:
display(data_monthly[data_monthly['GLACIER']=='RGI60-06.00478'])

###### 4 glaciers 50% train set

In [None]:
# TRANSFER LEARNING SETUP 50%
# Fine-tuning glaciers (4 Iceland glaciers to adapt Swiss model)
train_glaciers = ['Bruarjoekull', 'Skeidararjoekull', 'Koeldukvislarjoekull', 'Slettjoekull West', 'RGI60-06.00238', 'Hagafellsjoekull West']

# Test glaciers (all remaining Iceland glaciers)
all_iceland_glaciers = list(data_wgms['GLACIER'].unique())
test_glaciers = [g for g in all_iceland_glaciers if g not in train_glaciers]

print(f"Fine-tuning glaciers ({len(train_glaciers)}): {train_glaciers}")
print(f"Test glaciers ({len(test_glaciers)}): {test_glaciers}")

# Ensure all glaciers exist in the dataset
existing_glaciers = set(dataloader_gl.data.GLACIER.unique())
missing_fine_tune = [g for g in train_glaciers if g not in existing_glaciers]
missing_test = [g for g in test_glaciers if g not in existing_glaciers]

if missing_fine_tune:
    print(f"Warning: Fine-tuning glaciers not in dataset: {missing_fine_tune}")
if missing_test:
    print(f"Warning: Test glaciers not in dataset: {missing_test}")


## CV Splits
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=test_glaciers,
                                            random_state=cfg.seed)
    
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                      train_set['splits_vals']))
print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
display('length train set', len(train_set['df_X']))
display('length test set', len(test_set['df_X']))


###### Split east/west

In [None]:
# Get glacier latitudes
glacier_lat = data_wgms.groupby('GLACIER')['POINT_LON'].first()

# Use the median latitude as the split
lon_threshold = -18.25

east_glaciers = glacier_lat[glacier_lat >= lon_threshold].index.tolist()
west_glaciers = glacier_lat[glacier_lat < lon_threshold].index.tolist()

print(f"East glaciers ({len(east_glaciers)}): {east_glaciers}")
print(f"West glaciers ({len(west_glaciers)}): {west_glaciers}")

In [None]:
# North/south at median, train is south set 50%
train_glaciers = east_glaciers

# Test glaciers (all remaining Iceland glaciers)
all_iceland_glaciers = list(data_wgms['GLACIER'].unique())
test_glaciers = [g for g in all_iceland_glaciers if g not in train_glaciers]

print(f"Fine-tuning glaciers ({len(train_glaciers)}): {train_glaciers}")
print(f"Test glaciers ({len(test_glaciers)}): {test_glaciers}")

# Ensure all glaciers exist in the dataset
existing_glaciers = set(dataloader_gl.data.GLACIER.unique())
missing_fine_tune = [g for g in train_glaciers if g not in existing_glaciers]
missing_test = [g for g in test_glaciers if g not in existing_glaciers]

if missing_fine_tune:
    print(f"Warning: Fine-tuning glaciers not in dataset: {missing_fine_tune}")
if missing_test:
    print(f"Warning: Test glaciers not in dataset: {missing_test}")


## CV Splits
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=test_glaciers,
                                            random_state=cfg.seed)
    
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                      train_set['splits_vals']))
print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))

display('length train set', len(train_set['df_X']))
display('length test set', len(test_set['df_X']))

###### 5-10% train set

In [None]:
# TRANSFER LEARNING SETUP 5-10%
# Fine-tuning glaciers
train_glaciers = ['Mulajoekull' ,'Slettjoekull West', 'Hagafellsjoekull East (Langjoekull S Dome)', 'Tungnaarjoekull', 'RGI60-06.00478']

# Test glaciers (all remaining Iceland glaciers)
all_iceland_glaciers = list(data_wgms['GLACIER'].unique())
test_glaciers = [g for g in all_iceland_glaciers if g not in train_glaciers]

print(f"Fine-tuning glaciers ({len(train_glaciers)}): {train_glaciers}")
print(f"Test glaciers ({len(test_glaciers)}): {test_glaciers}")

# Ensure all glaciers exist in the dataset
existing_glaciers = set(dataloader_gl.data.GLACIER.unique())
missing_fine_tune = [g for g in train_glaciers if g not in existing_glaciers]
missing_test = [g for g in test_glaciers if g not in existing_glaciers]

if missing_fine_tune:
    print(f"Warning: Fine-tuning glaciers not in dataset: {missing_fine_tune}")
if missing_test:
    print(f"Warning: Test glaciers not in dataset: {missing_test}")


## CV Splits
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=test_glaciers,
                                            random_state=cfg.seed)
    
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                      train_set['splits_vals']))
print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))

display('length train set', len(train_set['df_X']))
display('length test set', len(test_set['df_X']))

In [None]:
display(test_set)
display(train_set)

###### Train/val split on random 80/20

In [None]:
# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']
dataloader = mbm.dataloader.DataLoader(cfg, data=data_train)

train_itr, val_itr = dataloader.set_train_test_split(test_size=0.2)

# Get all indices of the training and valing dataset at once from the iterators. Once called, the iterators are empty.
train_indices, val_indices = list(train_itr), list(val_itr)

df_X_train = data_train.iloc[train_indices]
y_train = df_X_train['POINT_BALANCE'].values

# Get val set
df_X_val = data_train.iloc[val_indices]
y_val = df_X_val['POINT_BALANCE'].values


print("Train data glacier distribution:", df_X_train['GLACIER'].value_counts().head())
print("Val data glacier distribution:", df_X_val['GLACIER'].value_counts().head())
print("Train data shape:", df_X_train.shape)
print("Val data shape:", df_X_val.shape)

###### Train/val split on specific glacier

In [None]:
# Glacier-wise train/val split: validate on Engabreen, train on the other 3
data_train = train_set['df_X']
data_train['y'] = train_set['y']

val_glacier = ['Engabreen']
train_glaciers = [g for g in train_glaciers if g not in val_glacier]

df_X_train = data_train[data_train['GLACIER'].isin(train_glaciers)].copy()
y_train = df_X_train['POINT_BALANCE'].values

df_X_val = data_train[data_train['GLACIER'].isin(val_glacier)].copy()
y_val = df_X_val['POINT_BALANCE'].values

print("Train data glacier distribution:", df_X_train['GLACIER'].value_counts().head())
print("Val data glacier distribution:", df_X_val['GLACIER'].value_counts().head())
print("Train data shape:", df_X_train.shape)
print("Val data shape:", df_X_val.shape)

## Neural Network:

In [None]:
def create_period_indicator(df):
    """Create numerical PERIOD_INDICATOR feature"""
    df = df.copy()
    df['PERIOD_INDICATOR'] = df['PERIOD'].map({'annual': 0, 'winter': 1})
    return df

# Apply to all datasets
df_X_train = create_period_indicator(df_X_train)
df_X_val = create_period_indicator(df_X_val)
test_set['df_X'] = create_period_indicator(test_set['df_X'])

print("PERIOD_INDICATOR created:")
print("Annual (0):", (df_X_train['PERIOD_INDICATOR'] == 0).sum())
print("Winter (1):", (df_X_train['PERIOD_INDICATOR'] == 1).sum())
print("Original PERIOD column preserved:", df_X_train['PERIOD'].unique())

In [None]:
features_topo = [
    'ELEVATION_DIFFERENCE',
] + list(vois_topographical)

feature_columns = features_topo + list(vois_climate)# + ['PERIOD_INDICATOR']

cfg.setFeatures(feature_columns)

all_columns = feature_columns + cfg.fieldsNotFeatures

df_X_train_subset = df_X_train[all_columns]
df_X_val_subset = df_X_val[all_columns]
df_X_test_subset = test_set['df_X'][all_columns]

print('Shape of training dataset:', df_X_train_subset.shape)
print('Shape of validation dataset:', df_X_val_subset.shape)
print('Shape of testing dataset:', df_X_test_subset.shape)
print('Running with features:', feature_columns)

assert all(train_set['df_X'].POINT_BALANCE == train_set['y'])

### Initialise network:

In [None]:
early_stop = EarlyStopping(
    monitor='valid_loss',
    patience=15,
    threshold=1e-4,
)

lr_scheduler_cb = LRScheduler(policy=ReduceLROnPlateau,
                              monitor='valid_loss',
                              mode='min',
                              factor=0.5,
                              patience=5,
                              threshold=0.01,
                              threshold_mode='rel',
                              verbose=True)

dataset = dataset_val = None

def my_train_split(ds, y=None, **fit_params):
    return dataset, dataset_val

param_init = {'device': 'cpu'}
nInp = len(feature_columns)

checkpoint_cb = Checkpoint(
    monitor='valid_loss_best',
    f_params='best_model.pt',
    f_optimizer=None,     # do not save optimizer state
    f_history=None,       # do not save training history
    f_criterion=None,     # do not save criterion state
    load_best=True,
)


"""
params = {
    'lr': 0.001,
    'batch_size': 128,
    'optimizer': torch.optim.Adam,
    'optimizer__weight_decay': 1e-05,
    'module__hidden_layers': [128, 128, 64, 32],
    'module__dropout': 0.2,
    'module__use_batchnorm': True,
}

args = {
    'module': FlexibleNetwork,
    'nbFeatures': nInp,
    'module__input_dim': nInp,
    'module__dropout': params['module__dropout'],
    'module__hidden_layers': params['module__hidden_layers'],
    'train_split': my_train_split,
    'batch_size': params['batch_size'],
    'verbose': 1,
    'iterator_train__shuffle': True,
    'lr': params['lr'],
    'max_epochs': 200,
    'optimizer': params['optimizer'],
    'optimizer__weight_decay': params['optimizer__weight_decay'],
    'module__use_batchnorm': params['module__use_batchnorm'],
    'callbacks': [
        ('early_stop', early_stop),
        ('lr_scheduler', lr_scheduler_cb),
    ]
}

custom_nn = mbm.models.CustomNeuralNetRegressor(cfg, **args, **param_init)
"""

### Create datasets:

In [None]:
"""
features, metadata = custom_nn._create_features_metadata(df_X_train_subset)

features_val, metadata_val = custom_nn._create_features_metadata(
    df_X_val_subset)

# Define the dataset for the NN
dataset = mbm.data_processing.AggregatedDataset(cfg,
                                                features=features,
                                                metadata=metadata,
                                                targets=y_train)
dataset = mbm.data_processing.SliceDatasetBinding(SliceDataset(dataset, idx=0),
                                                  SliceDataset(dataset, idx=1))
print("train:", dataset.X.shape, dataset.y.shape)

dataset_val = mbm.data_processing.AggregatedDataset(cfg,
                                                    features=features_val,
                                                    metadata=metadata_val,
                                                    targets=y_val)
dataset_val = mbm.data_processing.SliceDatasetBinding(
    SliceDataset(dataset_val, idx=0), SliceDataset(dataset_val, idx=1))
print("validation:", dataset_val.X.shape, dataset_val.y.shape)
"""
# Don't create datasets here, create them after loading the Swiss model
features = features_val = None
metadata = metadata_val = None
dataset = dataset_val = None
print("Datasets will be created after loading Swiss model...")

### Train custom model:

###### Fine tuning and freezing

In [None]:
from skorch.callbacks import Callback

class SaveBestAtEpochs(Callback):
    def __init__(self, epochs, prefix="nn_model_best_epoch"):
        self.epochs = set(epochs)
        self.prefix = prefix
        self.best_score = float('inf')
        self.best_state = None

    def on_epoch_end(self, net, **kwargs):
        epoch = net.history[-1]['epoch']
        valid_loss = net.history[-1]['valid_loss']
        if valid_loss < self.best_score:
            self.best_score = valid_loss
            self.best_state = {k: v.cpu().clone() for k, v in net.module_.state_dict().items()}
        if epoch in self.epochs and self.best_state is not None:
            filename = f"{self.prefix}_{epoch}.pt"
            torch.save(self.best_state, filename)
            print(f"Best model up to epoch {epoch} saved as {filename}")

save_best_epochs_cb = SaveBestAtEpochs([10, 15, 20, 30, 50, 100])

TRAIN = True  # Set to True to actually train
if TRAIN:
    # STEP 1: Load the pre-trained Swiss model FIRST
    print("Loading pre-trained Swiss model...")
    model_filename = "nn_model_2025-07-14_CH_flexible.pt"
    
    swiss_args = {
        'module': FlexibleNetwork,
        'nbFeatures': nInp,
        'module__input_dim': nInp,
        'module__dropout': 0.2,
        'module__hidden_layers': [128, 128, 64, 32],
        'module__use_batchnorm': True,
        'warm_start': True, # Important!!! this tells skorch to not re-initialize the weights etc.
        'train_split': my_train_split,
        'batch_size': 128,
        'verbose': 1,
        'iterator_train__shuffle': True,
        'lr': 0.001,
        'max_epochs': 200,
        'optimizer': torch.optim.Adam,
        'optimizer__weight_decay': 1e-05,
        'callbacks': [
            ('early_stop', early_stop),
            ('lr_scheduler', lr_scheduler_cb),
            ('checkpoint', checkpoint_cb),
            ('save_best_at_epochs', save_best_epochs_cb)
        ]
    }
    
    loaded_model = mbm.models.CustomNeuralNetRegressor.load_model(
        cfg, model_filename, **{**swiss_args, **param_init}
    )

    print("✓ Swiss model loaded successfully!")
    
    # STEP 2: Create datasets using the loaded Swiss model
    print("Creating datasets with Swiss model...")
    features, metadata = loaded_model._create_features_metadata(df_X_train_subset)
    features_val, metadata_val = loaded_model._create_features_metadata(df_X_val_subset)
    
    # Create global datasets
    dataset = mbm.data_processing.AggregatedDataset(cfg,
                                                    features=features,
                                                    metadata=metadata,
                                                    targets=y_train)
    dataset = mbm.data_processing.SliceDatasetBinding(SliceDataset(dataset, idx=0),
                                                      SliceDataset(dataset, idx=1))
    
    dataset_val = mbm.data_processing.AggregatedDataset(cfg,
                                                        features=features_val,
                                                        metadata=metadata_val,
                                                        targets=y_val)
    dataset_val = mbm.data_processing.SliceDatasetBinding(
        SliceDataset(dataset_val, idx=0), SliceDataset(dataset_val, idx=1))
    
    print("train:", dataset.X.shape, dataset.y.shape)
    print("validation:", dataset_val.X.shape, dataset_val.y.shape)


    # STEP 2.5: Freeze layers
    for name, param in loaded_model.module_.named_parameters():
        # Freeze layers
        if name not in [#'model.0.weight', 'model.0.bias',
                        'model.1.weight', 'model.1.bias',
                        #'model.4.weight', 'model.4.bias',
                        'model.5.weight', 'model.5.bias',
                        #'model.8.weight', 'model.8.bias',
                        'model.9.weight', 'model.9.bias',
                        #'model.12.weight', 'model.12.bias',
                        'model.13.weight', 'model.13.bias',
                        #'model.16.weight', 'model.16.bias'
                        ]:
            param.requires_grad = False
    
    # STEP 3: Update for fine-tuning
    print("Updating model for fine-tuning...")
    loaded_model = loaded_model.set_params(
        lr=0.0005,
        max_epochs=100,
    )
    
    # STEP 4: Fine-tune
    print("Starting fine-tuning...")
    loaded_model.fit(features, y_train)
    
    # STEP 5: Save
    current_date = datetime.now().strftime("%Y-%m-%d")
    finetuned_model_filename = f"nn_model_finetuned_{current_date}"
    loaded_model.save_model(finetuned_model_filename)
    print(f"✓ Fine-tuned model saved as: {finetuned_model_filename}")



###### Progressively unfreezing

In [None]:
TRAIN = True  # Set to True to actually train
if TRAIN:
    # STEP 1: Load the pre-trained Swiss model FIRST
    print("Loading pre-trained Swiss model...")
    model_filename = "nn_model_2025-07-14_CH_flexible.pt"
    
    swiss_args = {
        'module': FlexibleNetwork,
        'nbFeatures': nInp,
        'module__input_dim': nInp,
        'module__dropout': 0.2,
        'module__hidden_layers': [128, 128, 64, 32],
        'module__use_batchnorm': True,
        'warm_start': True, # Important!!! this tells skorch not re-initialize the weights etc.
        'train_split': my_train_split,
        'batch_size': 128,
        'verbose': 1,
        'iterator_train__shuffle': True,
        'lr': 0.001,
        'max_epochs': 200,
        'optimizer': torch.optim.Adam,
        'optimizer__weight_decay': 1e-05,
        'callbacks': [
            ('early_stop', early_stop),
            ('lr_scheduler', lr_scheduler_cb),
            ('checkpoint', checkpoint_cb),
        ]
    }
    
    loaded_model = mbm.models.CustomNeuralNetRegressor.load_model(
        cfg, model_filename, **{**swiss_args, **param_init}
    )

    print("✓ Swiss model loaded successfully!")
    
    # STEP 2: Create datasets using the loaded Swiss model
    print("Creating datasets with Swiss model...")
    features, metadata = loaded_model._create_features_metadata(df_X_train_subset)
    features_val, metadata_val = loaded_model._create_features_metadata(df_X_val_subset)
    
    # Create global datasets
    dataset = mbm.data_processing.AggregatedDataset(cfg,
                                                    features=features,
                                                    metadata=metadata,
                                                    targets=y_train)
    dataset = mbm.data_processing.SliceDatasetBinding(SliceDataset(dataset, idx=0),
                                                      SliceDataset(dataset, idx=1))
    
    dataset_val = mbm.data_processing.AggregatedDataset(cfg,
                                                        features=features_val,
                                                        metadata=metadata_val,
                                                        targets=y_val)
    dataset_val = mbm.data_processing.SliceDatasetBinding(
        SliceDataset(dataset_val, idx=0), SliceDataset(dataset_val, idx=1))
    
    print("train:", dataset.X.shape, dataset.y.shape)
    print("validation:", dataset_val.X.shape, dataset_val.y.shape)


    # STEP 2.5: Freeze layers
    
    # Helper to freeze/unfreeze layers
    def set_requires_grad(layer_names, requires_grad=True):
        for name, param in loaded_model.module_.named_parameters():
            if name in layer_names:
                param.requires_grad = requires_grad

    # List of layer groups to progressively unfreeze
    layer_groups = [
        #(
            #[
                #'model.1.weight', 'model.1.bias',
                #'model.5.weight', 'model.5.bias',
                #'model.9.weight', 'model.9.bias',
                #'model.13.weight', 'model.13.bias'
            #],200,  0.1
        #),
        
        (['model.16.weight', 'model.16.bias'], 30, 0.01),
        (['model.12.weight', 'model.12.bias'], 20, 0.005),
        (['model.8.weight', 'model.8.bias'], 10, 0.001)
    ]

    # Freeze all layers first
    for name, param in loaded_model.module_.named_parameters():
        param.requires_grad = False

    # Progressive unfreezing loop with custom learning rates
    for layers, epochs, lr in layer_groups:
        set_requires_grad(layers, True)
        print(f"Fine-tuning layers: {layers} for {epochs} epochs with lr={lr}...")
        loaded_model = loaded_model.set_params(
            lr=lr,
            max_epochs=epochs,
            callbacks=[
                ('early_stop', early_stop),
                ('lr_scheduler', lr_scheduler_cb),
                ('checkpoint', checkpoint_cb),
                ]
            )
        loaded_model.fit(features, y_train)

        val_score = loaded_model.score(dataset_val.X, dataset_val.y)
        print("Validation score:", val_score)
    
    # STEP 3: Save
    current_date = datetime.now().strftime("%Y-%m-%d")
    finetuned_model_filename = f"nn_model_finetuned_{current_date}"
    loaded_model.save_model(finetuned_model_filename)
    print(f"✓ Fine-tuned model saved as: {finetuned_model_filename}")



In [None]:
grouped_ids, scores_NN, ids_NN, y_pred_NN = evaluate_model_and_group_predictions(
        loaded_model, df_X_test_subset, test_set['y'], cfg, mbm)
display(scores_NN)

In [None]:
epochs = [10, 15, 20, 30, 50, 100]
model_prefix = "nn_model_best_epoch"

for epoch in epochs:
    model_name = f"{model_prefix}_{epoch}.pt"
    if not os.path.exists(model_name):
        continue

    loaded_model = mbm.models.CustomNeuralNetRegressor(
        cfg,
        **swiss_args,
        **param_init
    )
    loaded_model = loaded_model.set_params(device='cpu')
    loaded_model = loaded_model.to('cpu')

    loaded_model.initialize()
    state_dict = torch.load(model_name, map_location='cpu')
    loaded_model.module_.load_state_dict(state_dict)

    # 4. Evaluate
    grouped_ids, scores_NN, ids_NN, y_pred_NN = evaluate_model_and_group_predictions(
        loaded_model, df_X_test_subset, test_set['y'], cfg, mbm)
    print(f"Scores for epoch {epoch}:")
    display(scores_NN)

In [None]:
val_score = loaded_model.score(dataset_val.X, dataset_val.y)
print("Validation score (higher is better):", val_score)

In [None]:
for name, param in loaded_model.module_.named_parameters():
    print(name)
print(loaded_model.module_)

In [None]:
for name, param in loaded_model.module_.named_parameters():
    print(name, param.data.cpu().numpy().sum())

### Load model and make predictions:

In [None]:
# Create features and metadata
features_test, metadata_test = loaded_model._create_features_metadata(
    df_X_test_subset)

# Ensure all tensors are on CPU if they are torch tensors
if hasattr(features_test, 'cpu'):
    features_test = features_test.cpu()

# Ensure targets are also on CPU
targets_test = test_set['y']
if hasattr(targets_test, 'cpu'):
    targets_test = targets_test.cpu()

# Create the dataset
dataset_test = mbm.data_processing.AggregatedDataset(cfg,
                                                     features=features_test,
                                                     metadata=metadata_test,
                                                     targets=targets_test)

dataset_test = [
    SliceDataset(dataset_test, idx=0),
    SliceDataset(dataset_test, idx=1)
]

# Make predictions aggr to meas ID
y_pred = loaded_model.predict(dataset_test[0])
y_pred_agg = loaded_model.aggrPredict(dataset_test[0])

batchIndex = np.arange(len(y_pred_agg))
y_true = np.array([e for e in dataset_test[1][batchIndex]])

# Calculate scores
score = loaded_model.score(dataset_test[0], dataset_test[1])
mse, rmse, mae, pearson = loaded_model.evalMetrics(y_pred, y_true)

# Aggregate predictions
id = dataset_test[0].dataset.indexToId(batchIndex)
data = {
    'target': [e[0] for e in dataset_test[1]],
    'ID': id,
    'pred': y_pred_agg
}
grouped_ids = pd.DataFrame(data)

# Add period
periods_per_ids = df_X_test_subset.groupby('ID')['PERIOD'].first()
grouped_ids = grouped_ids.merge(periods_per_ids, on='ID')

# Add glacier name
glacier_per_ids = df_X_test_subset.groupby('ID')['GLACIER'].first()
grouped_ids = grouped_ids.merge(glacier_per_ids, on='ID')

# Add YEAR
years_per_ids = df_X_test_subset.groupby('ID')['YEAR'].first()
grouped_ids = grouped_ids.merge(years_per_ids, on='ID')

In [None]:
PlotPredictions_NN(grouped_ids)
predVSTruth_all(grouped_ids, mae, rmse, title='NN on test')
PlotIndividualGlacierPredVsTruth(grouped_ids, base_figsize=(20, 15))