## Setting Up:

In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import pandas as pd
import warnings
from tqdm.notebook import tqdm
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle 

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig(dataPath='/home/mburlet/scratch/data/DATA_MB/')

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

"""
# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)
"""

vois_climate = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'u10', 'v10'
]

vois_topographical = [
    "aspect",
    "slope",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

In [None]:
seed_all(cfg.seed)

print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()

    # # Try to limit CPU usage of random search
    # torch.set_num_threads(2)  # or 1
    # os.environ["OMP_NUM_THREADS"] = "1"
    # os.environ["MKL_NUM_THREADS"] = "1"
else:
    print("CUDA is NOT available")


## Read GL data:

In [None]:
data_glamos = pd.read_csv('/home/mburlet/scratch/data/DATA_MB/CH_wgms_dataset_all_04_06_oggm.csv')

# drop taelliboden if in there
if 'taelliboden' in data_glamos['GLACIER'].unique():
    data_glamos = data_glamos[data_glamos['GLACIER'] != 'taelliboden']

print('-------------------')
print('Number of glaciers:', len(data_glamos['GLACIER'].unique()))
print('Number of winter and annual samples:', len(data_glamos))
print('Number of annual samples:',
      len(data_glamos[data_glamos.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_glamos[data_glamos.PERIOD == 'winter']))

## Input data:
### Input dataset:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': '/home/mburlet/scratch/data/DATA_MB/',
    'era5_climate_data':
    '/home/mburlet/scratch/data/ERA5Land/raw/era5_monthly_averaged_data.nc',
    'geopotential_data':
    '/home/mburlet/scratch/data/ERA5Land/raw/era5_geopotential_pressure.nc',
    'radiation_save_path': '/home/mburlet/scratch/data/GLAMOS/topo/pcsr/'
}
RUN = False
dataloader_gl = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_full_no_sgi.csv')
data_monthly = dataloader_gl.data

data_monthly['GLWD_ID'] = data_monthly.apply(
    lambda x: mbm.data_processing.utils.get_hash(f"{x.GLACIER}_{x.YEAR}"),
    axis=1)
data_monthly['GLWD_ID'] = data_monthly['GLWD_ID'].astype(str)

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)


In [None]:
def balance_period_distribution(data, random_seed=None):
    """
    Balance the dataset by randomly dropping winter measurements 
    to match the number of annual measurements.
    
    Args:
        data: DataFrame with PERIOD column
        random_seed: Random seed for reproducibility
    
    Returns:
        Balanced DataFrame
    """
    if random_seed is not None:
        np.random.seed(random_seed)
    
    # Count periods
    period_counts = data['PERIOD'].value_counts()
    print(f"Original distribution:")
    print(f"Annual: {period_counts.get('annual', 0)} ({period_counts.get('annual', 0)/len(data)*100:.1f}%)")
    print(f"Winter: {period_counts.get('winter', 0)} ({period_counts.get('winter', 0)/len(data)*100:.1f}%)")
    
    # Find the minimum count (should be annual since winter dominates)
    min_count = min(period_counts['annual'], period_counts['winter'])
    
    # Separate the data by period
    annual_data = data[data['PERIOD'] == 'annual']
    winter_data = data[data['PERIOD'] == 'winter']
    
    # Randomly sample from winter data to match annual count
    winter_balanced = winter_data.sample(n=min_count, random_state=random_seed)
    
    # Combine balanced datasets
    balanced_data = pd.concat([annual_data, winter_balanced], ignore_index=True)
    
    # Shuffle the combined dataset
    balanced_data = balanced_data.sample(frac=1, random_state=random_seed).reset_index(drop=True)
    
    # Print new distribution
    new_period_counts = balanced_data['PERIOD'].value_counts()
    print(f"\nBalanced distribution:")
    print(f"Annual: {new_period_counts.get('annual', 0)} ({new_period_counts.get('annual', 0)/len(balanced_data)*100:.1f}%)")
    print(f"Winter: {new_period_counts.get('winter', 0)} ({new_period_counts.get('winter', 0)/len(balanced_data)*100:.1f}%)")
    print(f"Total samples: {len(data)} -> {len(balanced_data)} (dropped {len(data) - len(balanced_data)} winter samples)")
    
    return balanced_data

# Balance the dataset
data_monthly_balanced = balance_period_distribution(data_monthly, random_seed=cfg.seed)

# Update the data in dataloader_gl
data_monthly_balanced['GLWD_ID'] = data_monthly_balanced.apply(
    lambda x: mbm.data_processing.utils.get_hash(f"{x.GLACIER}_{x.YEAR}"),
    axis=1)
data_monthly_balanced['GLWD_ID'] = data_monthly_balanced['GLWD_ID'].astype(str)

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly_balanced,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

In [None]:
display(data_monthly.columns)


## Blocking on glaciers:

In [None]:
"""
test_glaciers = [
    'tortin', 'plattalva', 'sanktanna', 'schwarzberg', 'hohlaub', 'pizol',
    'corvatsch', 'tsanfleuron', 'forno'
]
"""
test_glaciers=[]

# Ensure all test glaciers exist in the dataset
existing_glaciers = set(dataloader_gl.data.GLACIER.unique())
missing_glaciers = [g for g in test_glaciers if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in test_glaciers]

data_test = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(test_glaciers)]
print('Size of monthly test data:', len(data_test))

data_train = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

# Number of annual versus winter measurements:
print('-------------\nTrain:')
print('Number of monthly winter and annual samples:', len(data_train))
print('Number of monthly annual samples:',
      len(data_train[data_train.PERIOD == 'annual']))
print('Number of monthly winter samples:',
      len(data_train[data_train.PERIOD == 'winter']))

# Same for test
data_test_annual = data_test[data_test.PERIOD == 'annual']
data_test_winter = data_test[data_test.PERIOD == 'winter']

print('Test:')
print('Number of monthly winter and annual samples:', len(data_test))
print('Number of monthly annual samples:', len(data_test_annual))
print('Number of monthly winter samples:', len(data_test_winter))

print('Total:')
print('Number of monthly rows:', len(dataloader_gl.data))
print('Number of annual rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']))
print('Number of winter rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']))

# same for original data:
print('-------------\nIn annual format:')
print('Number of annual train rows:',
      len(data_glamos[data_glamos.GLACIER.isin(train_glaciers)]))
print('Number of annual test rows:',
      len(data_glamos[data_glamos.GLACIER.isin(test_glaciers)]))


In [None]:
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=test_glaciers,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

###### 80/20 val split

In [None]:
# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']
dataloader = mbm.dataloader.DataLoader(cfg, data=data_train)

train_itr, val_itr = dataloader.set_train_test_split(test_size=0.2)

# Get all indices of the training and valing dataset at once from the iterators. Once called, the iterators are empty.
train_indices, val_indices = list(train_itr), list(val_itr)

df_X_train = data_train.iloc[train_indices]
y_train = df_X_train['POINT_BALANCE'].values

# Get val set
df_X_val = data_train.iloc[val_indices]
y_val = df_X_val['POINT_BALANCE'].values

###### glacier based val split

In [None]:
# Selecting glaciers with kmeans clustering for validation set

from sklearn.cluster import KMeans


glacier_coords = train_set['df_X'].groupby('GLACIER')[['POINT_LON', 'POINT_LAT']].mean()
n_clusters = min(len(glacier_coords), 20)
kmeans = KMeans(n_clusters=n_clusters, random_state=1)
labels = kmeans.fit_predict(glacier_coords)

glacier_sizes = train_set['df_X'].groupby('GLACIER').size()
clusters = {i: [] for i in range(n_clusters)}
for glacier, label in zip(glacier_coords.index, labels):
    clusters[label].append(glacier)

val_glaciers = []
rows_selected = 0
target_rows = int(0.2 * len(train_set['df_X']))

while rows_selected < target_rows:
    for i in range(n_clusters):
        glaciers = clusters[i]
        if glaciers:
            glacier = glaciers.pop(0)
            val_glaciers.append(glacier)
            rows_selected += glacier_sizes[glacier]
        if rows_selected >= target_rows:
            break

print(f"Selected {len(val_glaciers)} glaciers for validation, covering {rows_selected} rows ({rows_selected/len(train_set['df_X']):.2%} of data)")

display(val_glaciers)

val_coords = glacier_coords.loc[val_glaciers]

plt.figure(figsize=(8, 8))
plt.scatter(glacier_coords['POINT_LON'], glacier_coords['POINT_LAT'], c='lightgray', label='All glaciers')
plt.scatter(val_coords['POINT_LON'], val_coords['POINT_LAT'], c='red', label='Validation glaciers')
for name, row in val_coords.iterrows():
    plt.text(row['POINT_LON'], row['POINT_LAT'], name, fontsize=8)

plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('Validation Glaciers on Map')
plt.legend()
plt.show()

In [None]:
# Glacier-wise train/val split
data_train = train_set['df_X']
data_train['y'] = train_set['y']

val_glacier = ['gries',
 'albigna',
 'corbassiere',
 'sardona',
 'adler',
 'joeri',
 'sexrouge',
 'clariden',
 'aletsch']
train_glaciers = [g for g in train_glaciers if g not in val_glacier]

df_X_train = data_train[data_train['GLACIER'].isin(train_glaciers)].copy()
y_train = df_X_train['POINT_BALANCE'].values

df_X_val = data_train[data_train['GLACIER'].isin(val_glacier)].copy()
y_val = df_X_val['POINT_BALANCE'].values

print("Train data glacier distribution:", df_X_train['GLACIER'].value_counts())
print("Val data glacier distribution:", df_X_val['GLACIER'].value_counts())
print("Train data shape:", df_X_train.shape)
print("Val data shape:", df_X_val.shape)

## Neural Network:

In [None]:
def create_period_indicator(df):
    """Create numerical PERIOD_INDICATOR feature"""
    df = df.copy()
    df['PERIOD_INDICATOR'] = df['PERIOD'].map({'annual': 0, 'winter': 1})
    return df

# Apply to all datasets
df_X_train = create_period_indicator(df_X_train)
df_X_val = create_period_indicator(df_X_val)
test_set['df_X'] = create_period_indicator(test_set['df_X'])

print("PERIOD_INDICATOR created:")
print("Annual (0):", (df_X_train['PERIOD_INDICATOR'] == 0).sum())
print("Winter (1):", (df_X_train['PERIOD_INDICATOR'] == 1).sum())
print("Original PERIOD column preserved:", df_X_train['PERIOD'].unique())

In [None]:
features_topo = [
    'ELEVATION_DIFFERENCE',
] + list(vois_topographical)

feature_columns = features_topo + list(vois_climate) # + ['PERIOD_INDICATOR']

cfg.setFeatures(feature_columns)

all_columns = feature_columns + cfg.fieldsNotFeatures

# Because CH has some extra columns, we need to cut those
df_X_train_subset = df_X_train[all_columns]
df_X_val_subset = df_X_val[all_columns]
df_X_test_subset = test_set['df_X'][all_columns]

print('Shape of training dataset:', df_X_train_subset.shape)
print('Shape of validation dataset:', df_X_val_subset.shape)
print('Shape of testing dataset:', df_X_test_subset.shape)
print('Running with features:', feature_columns)

assert all(train_set['df_X'].POINT_BALANCE == train_set['y'])

### Initialise network:

In [None]:
early_stop = EarlyStopping(
    monitor='valid_loss',
    patience=15,
    threshold=1e-4,  # Optional: stop only when improvement is very small
)

lr_scheduler_cb = LRScheduler(policy=ReduceLROnPlateau,
                              monitor='valid_loss',
                              mode='min',
                              factor=0.5,
                              patience=5,
                              threshold=0.01,
                              threshold_mode='rel',
                              verbose=True)

dataset = dataset_val = None  # Initialized hereafter


def my_train_split(ds, y=None, **fit_params):
    return dataset, dataset_val


# param_init = {'device': 'cuda:0'}
param_init = {'device': 'cpu'}  # Use CPU for training
nInp = len(feature_columns)

In [None]:
early_stop = EarlyStopping(
    monitor='valid_loss',
    patience=15,
    threshold=1e-4,  # Optional: stop only when improvement is very small
)

lr_scheduler_cb = LRScheduler(policy=ReduceLROnPlateau,
                              monitor='valid_loss',
                              mode='min',
                              factor=0.5,
                              patience=5,
                              threshold=0.01,
                              threshold_mode='rel',
                              verbose=True)

dataset = dataset_val = None  # Initialized hereafter


def my_train_split(ds, y=None, **fit_params):
    return dataset, dataset_val


# param_init = {'device': 'cuda:0'}
param_init = {'device': 'cpu'}  # Use CPU for training
nInp = len(feature_columns)

params = {
    'lr': 0.001,
    'batch_size': 128,
    'optimizer': torch.optim.Adam,
    'optimizer__weight_decay': 1e-05,
    'module__hidden_layers': [128, 128, 64, 32],
    'module__dropout': 0.2,
    'module__use_batchnorm': True,
}

args = {
    'module': FlexibleNetwork,
    'nbFeatures': nInp,
    'module__input_dim': nInp,
    'module__dropout': params['module__dropout'],
    'module__hidden_layers': params['module__hidden_layers'],
    'train_split': my_train_split,
    'batch_size': params['batch_size'],
    'verbose': 1,
    'iterator_train__shuffle': True,
    'lr': params['lr'],
    'max_epochs': 200,
    'optimizer': params['optimizer'],
    'optimizer__weight_decay': params['optimizer__weight_decay'],
    'module__use_batchnorm': params['module__use_batchnorm'],
    'callbacks': [
        ('early_stop', early_stop),
        ('lr_scheduler', lr_scheduler_cb),
    ]
}
custom_nn = mbm.models.CustomNeuralNetRegressor(cfg, **args, **param_init)

### Create datasets:

In [None]:
features, metadata = custom_nn._create_features_metadata(df_X_train_subset)

features_val, metadata_val = custom_nn._create_features_metadata(
    df_X_val_subset)

# Define the dataset for the NN
dataset = mbm.data_processing.AggregatedDataset(cfg,
                                                features=features,
                                                metadata=metadata,
                                                targets=y_train)
dataset = mbm.data_processing.SliceDatasetBinding(SliceDataset(dataset, idx=0),
                                                  SliceDataset(dataset, idx=1))
print("train:", dataset.X.shape, dataset.y.shape)

dataset_val = mbm.data_processing.AggregatedDataset(cfg,
                                                    features=features_val,
                                                    metadata=metadata_val,
                                                    targets=y_val)
dataset_val = mbm.data_processing.SliceDatasetBinding(
    SliceDataset(dataset_val, idx=0), SliceDataset(dataset_val, idx=1))
print("validation:", dataset_val.X.shape, dataset_val.y.shape)

### Train custom model:

In [None]:
TRAIN = True
if TRAIN:
    custom_nn.seed_all()

    print("Training the model...")
    print('Model parameters:')
    for key, value in args.items():
        print(f"{key}: {value}")
    custom_nn.fit(dataset.X, dataset.y)
    # The dataset provided in fit is not used as the datasets are overwritten in the provided train_split function

    # Generate filename with current date
    current_date = datetime.now().strftime("%Y-%m-%d")
    model_filename = f"nn_model_{current_date}"

    plot_training_history(custom_nn, skip_first_n=5)

    # After Training: Best weights are already loaded
    # Save the model
    custom_nn.save_model(model_filename)

    # save params dic
    params_filename = f"nn_params_{current_date}_test_test.pkl"

    with open(f"models/{params_filename}", "wb") as f:
        pickle.dump(args, f)

### Load model and make predictions:

In [None]:
# Load model and set to CPU
model_filename = "nn_model_2025-07-10.pt"  # Replace with actual date if needed
# model_filename = "nn_model_finetuned_winter_2025-07-08.pt"
# read pickle with params
params_filename = "nn_params_2025-07-10.pkl"  # Replace with actual date if needed
with open(f"models/{params_filename}", "rb") as f:
    custom_params = pickle.load(f)

"""
params = custom_params

args = {
    'module': FlexibleNetwork,
    'nbFeatures': nInp,
    'module__input_dim': nInp,
    'module__dropout': params['module__dropout'],
    'module__hidden_layers': params['module__hidden_layers'],
    'train_split': my_train_split,
    'batch_size': params['batch_size'],
    'verbose': 1,
    'iterator_train__shuffle': True,
    'lr': params['lr'],
    'max_epochs': 300,
    'optimizer': params['optimizer'],
    'optimizer__weight_decay': params['optimizer__weight_decay'],
    'module__use_batchnorm': params['module__use_batchnorm'],
    'callbacks': [
        ('early_stop', early_stop),
        ('lr_scheduler', lr_scheduler_cb),
    ]
}
"""

loaded_model = mbm.models.CustomNeuralNetRegressor.load_model(
    cfg,
    model_filename,
    **{
        **args,
        **param_init
    },
)
loaded_model = loaded_model.set_params(device='cpu')
loaded_model = loaded_model.to('cpu')

In [None]:
params

#### On test:

In [None]:
grouped_ids, scores_NN, ids_NN, y_pred_NN = evaluate_model_and_group_predictions(
    loaded_model, df_X_test_subset, test_set['y'], cfg, mbm)
scores_annual, scores_winter = compute_seasonal_scores(grouped_ids,
                                                       target_col='target',
                                                       pred_col='pred')
fig = plot_predictions_summary(grouped_ids=grouped_ids,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               predVSTruth=predVSTruth,
                               plotMeanPred=plotMeanPred,
                               color_annual=color_dark_blue,
                               color_winter=color_pink,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

PlotIndividualGlacierPredVsTruth(grouped_ids, base_figsize=(20, 15))