In [None]:
import pandas as pd
import os
import warnings
from tqdm.notebook import tqdm
import re
import geopandas as gpd
from shapely.geometry import Point
import pyproj
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns
import xarray as xr
import glob
from geopy.distance import geodesic
import numpy as np

from scripts.plots import *

from pathlib import Path
#from cmcrameri import cm
from calendar import monthrange


warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2


In [None]:
seed_all(cfg.seed)
free_up_cuda()

# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)

cmap = cm.devon

# For bars and lines:
color_diff_xgb = '#4d4d4d'

colors = get_cmap_hex(cm.batlow, 10)
color_1 = colors[0]
color_2 = '#c51b7d'

## 1. Load in available datasets

In [None]:
data_CH = pd.read_csv('/home/mburlet/scratch/data/DATA_MB/CH_wgms_dataset_all_04_06_oggm.csv')
data_FR = pd.read_csv('/home/mburlet/scratch/data/DATA_MB/GLACIOCLIM/csv/FR_wgms_dataset_all_oggm.csv')
data_IT_AT = pd.read_csv('/home/mburlet/scratch/data/DATA_MB/WGMS/IT_AT/csv/IT_AT_wgms_dataset_all_oggm.csv')
data_IT = data_IT_AT[data_IT_AT['POINT_ID'].str.endswith('_IT')]
data_AT = data_IT_AT[data_IT_AT['POINT_ID'].str.endswith('_AT')]
data_NOR = pd.read_csv('/home/mburlet/scratch/data/DATA_MB/WGMS/Norway/csv/Nor_dataset_all_oggm.csv')
data_ICE = pd.read_csv('/home/mburlet/scratch/data/DATA_MB/WGMS/Iceland/csv/ICE_dataset_all_oggm.csv')

display(data_NOR.head(2))

###### Some Plots to try

In [None]:
# Set global matplotlib parameters for larger fonts
plt.rcParams.update({
    'font.size': 16,
    'axes.labelsize': 16,
    'axes.titlesize': 18,
    'xtick.labelsize': 14,
    'ytick.labelsize': 14,
    'legend.fontsize': 14
})

fontsize = 16  # Increased from 12

### Basic dataset size comparison
dataset_sizes = {
    'Switzerland (CH)': len(data_CH),
    'France (FR)': len(data_FR),
    'Italy (IT)': len(data_IT),
    'Austria (AT)': len(data_AT),
    'Norway (NOR)': len(data_NOR),
    'Iceland (ICE)': len(data_ICE)
}

fig, ax = plt.subplots(figsize=(12, 7))  # Slightly larger figure
bars = ax.bar(dataset_sizes.keys(), dataset_sizes.values(), color=['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#66c2a5'])

for bar in bars:
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.1,
            f'{int(height):,}', ha='center', va='bottom', fontsize=14)

ax.set_title('Total Number of Stake Measurements by Region', fontsize=18)
ax.set_ylabel('Number of Measurements', fontsize=16)
plt.xticks(rotation=45, fontsize=14)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()

### Number of unique glaciers per region and their measurement counts
glacier_counts = pd.DataFrame()

for name, df in {'CH': data_CH, 'FR': data_FR, 'IT': data_IT, 
                'AT': data_AT, 'NOR': data_NOR, 'ICE': data_ICE}.items():
    # Count measurements per glacier
    gl_counts = df.groupby('GLACIER').size().reset_index()
    gl_counts.columns = ['Glacier', 'Count']
    gl_counts['Region'] = name
    glacier_counts = pd.concat([glacier_counts, gl_counts])

plt.figure(figsize=(16, 9))  # Larger figure
ax = sns.barplot(data=glacier_counts, x='Glacier', y='Count', hue='Region',
                palette=['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#66c2a5'])
plt.title('Measurement Count by Glacier and Region', fontsize=18)
plt.xlabel('Glacier', fontsize=16)
plt.ylabel('Count', fontsize=16)
plt.xticks(rotation=90, fontsize=12)  # Smaller font for many x-labels
plt.yticks(fontsize=14)
plt.legend(fontsize=14)
plt.grid(axis='y', linestyle='--', alpha=0.3)
plt.tight_layout()

### Period distribution (annual, winter, summer)
period_dist = pd.DataFrame()

for name, df in {'CH': data_CH, 'FR': data_FR, 'IT': data_IT, 
                'AT': data_AT, 'NOR': data_NOR, 'ICE': data_ICE}.items():
    pd_counts = df.groupby('PERIOD').size().reset_index()
    pd_counts.columns = ['Period', 'Count']
    pd_counts['Region'] = name
    period_dist = pd.concat([period_dist, pd_counts])

plt.figure(figsize=(14, 8))
ax = sns.barplot(data=period_dist, x='Region', y='Count', hue='Period',
                palette=['#0072B2', '#D55E00', '#009E73'])
plt.title('Measurement Periods by Region', fontsize=18)
plt.xlabel('Region', fontsize=16)
plt.ylabel('Count', fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=14)
plt.grid(axis='y', linestyle='--', alpha=0.3)

for i, p in enumerate(ax.patches):
    width = p.get_width()
    height = p.get_height()
    x, y = p.get_xy() 
    if height > 0:
        ax.annotate(f'{int(height)}', (x + width/2, y + height*0.5), ha='center', fontsize=14)

plt.tight_layout()

###### Summary data

In [None]:
summary_data = []

for name, df in {'CH': data_CH, 'FR': data_FR, 'IT': data_IT, 
                'AT': data_AT, 'NOR': data_NOR, 'ICE': data_ICE}.items():
    
    # Calculate statistics for this region
    summary = {
        'Region': name,
        'Total Measurements': len(df),
        'Unique Glaciers': df['GLACIER'].nunique(),
        'Year Range': f"{df['YEAR'].min()}-{df['YEAR'].max()}",
        'Elevation Range (m)': f"{int(df['POINT_ELEVATION'].min())}-{int(df['POINT_ELEVATION'].max())}",
        'Mean Annual Balance (m w.e.)': df[df['PERIOD'] == 'annual']['POINT_BALANCE'].mean(),
        'Mean Winter Balance (m w.e.)': df[df['PERIOD'] == 'winter']['POINT_BALANCE'].mean(),
        'Mean Summer Balance (m w.e.)': df[df['PERIOD'] == 'summer']['POINT_BALANCE'].mean() if 'summer' in df['PERIOD'].values else np.nan,
    }
    
    summary_data.append(summary)

# Create DataFrame from summary data
summary_df = pd.DataFrame(summary_data)

# Display nicely formatted table
display(summary_df.style.format({
    'Mean Annual Balance (m w.e.)': '{:.2f}',
    'Mean Winter Balance (m w.e.)': '{:.2f}',
    'Mean Summer Balance (m w.e.)': '{:.2f}'
}))

In [None]:
import folium
from folium import plugins

def add_dataset_to_map(data, dataset_name, color, m):
    
    unique_glaciers = data['GLACIER'].unique()
    print(f"Found {len(unique_glaciers)} unique glaciers in {dataset_name}")
    
    for glacier_name in unique_glaciers:
        glacier_data = data[data['GLACIER'] == glacier_name]
        
        max_points = 30  # Maximum number of points to show per glacier
        if len(glacier_data) > max_points:
            glacier_data = glacier_data.sample(max_points, random_state=42)
        
        # Add points to map
        for idx, row in glacier_data.iterrows():
            if 'POINT_ELEVATION' in row:
                altitude = row['POINT_ELEVATION']
            else:
                altitude = 'N/A'
                
            if 'POINT_ID' in row:
                stake_id = row['POINT_ID'].split('_')[-1][:5]
            else:
                stake_id = str(idx)
                
            # Get coordinates
            lat = row['POINT_LAT'] #if 'POINT_LAT' in row else row['lat']
            lon = row['POINT_LON'] #if 'POINT_LON' in row else row['lon']
                
            popup_text = f"""
            Dataset: {dataset_name}<br>
            Glacier: {glacier_name}<br>
            ID: {stake_id}<br>
            Altitude: {altitude}m<br>
            """
            
            # Add marker
            folium.CircleMarker(
                location=[lat, lon],
                radius=5,
                popup=popup_text,
                tooltip=f"{dataset_name} - {glacier_name}",
                color=color,
                fill=True,
                fill_color=color
            ).add_to(m)

colors = {
    'CH': '#e41a1c',   # red
    'FR': '#377eb8',   # blue
    'IT': '#4daf4a',   # green
    'AT': '#984ea3',   # purple
    'NOR': '#ff7f00',  # orange
    'ICE': '#66c2a5'   # teal
}

datasets = {
    'CH': data_CH,
    'FR': data_FR, 
    'IT': data_IT,
    'AT': data_AT,
    'NOR': data_NOR,
    'ICE': data_ICE
}

all_lats = []
all_lons = []

for name, dataset in datasets.items():
    lat_col = 'POINT_LAT' if 'POINT_LAT' in dataset.columns else 'lat'
    lon_col = 'POINT_LON' if 'POINT_LON' in dataset.columns else 'lon'
    
    if not dataset.empty:
        all_lats.extend(dataset[lat_col].dropna().tolist())
        all_lons.extend(dataset[lon_col].dropna().tolist())


center_lat = 46.0
center_lon = 8.0

# Create map
m = folium.Map(location=[center_lat, center_lon], zoom_start=7)

# Add datasets to map
for name, dataset in datasets.items():
    add_dataset_to_map(dataset, name, colors[name], m)

# Add legend
legend_html = '''
<div style="position: fixed; bottom: 50px; left: 50px; z-index: 1000; background-color: white; padding: 10px; border-radius: 5px;">
    <p><strong>Datasets</strong></p>
'''

for name, color in colors.items():
    legend_html += f'<p><span style="color: {color};">●</span> {name}</p>'

legend_html += '</div>'
m.get_root().html.add_child(folium.Element(legend_html))

# Add fullscreen option and layer control
plugins.Fullscreen().add_to(m)
folium.LayerControl().add_to(m)

# Display the map
m

# Save the map as an HTML file
#m.save('/home/mburlet/MBM/MassBalanceMachine/regions/Overview/glacier_overview_map.html')

In [None]:
import folium
from folium import plugins
import matplotlib.colors as mcolors

data_to_plot = data_ICE

year_to_plot = 2015

data_year = data_to_plot[data_to_plot['TO_DATE'].astype(str).str[:4] == str(year_to_plot)]

unique_glaciers = data_year['GLACIER'].unique()

# Dynamic Color assignment
color_options = list(mcolors.TABLEAU_COLORS.values())
# Add more colors if needed
color_options.extend(['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', 
                     '#a65628', '#f781bf', '#999999', '#66c2a5', '#fc8d62'])

glacier_colors = {}
for i, glacier in enumerate(unique_glaciers):
    glacier_colors[glacier] = color_options[i % len(color_options)]  # Use modulo to handle more glaciers than colors

# Map center
center_lat = 46.0
center_lon = 8.0

# Create the map
m = folium.Map(location=[center_lat, center_lon], zoom_start=10)

# Function to add markers for a glacier's stakes
def add_glacier_markers(glacier_data, glacier_name):
    # Determine coordinate column names
    lat_col = 'POINT_LAT' #if 'POINT_LAT' in glacier_data.columns else 'lat'
    lon_col = 'POINT_LON' #if 'POINT_LON' in glacier_data.columns else 'lon'
    
    # Get color for the glacier
    color = glacier_colors[glacier_name]
    
    # Add point for each stake
    for idx, row in glacier_data.iterrows():
        # Determine elevation and stake ID
        if 'POINT_ELEVATION' in row:
            altitude = row['POINT_ELEVATION']
        elif 'altitude' in row:
            altitude = row['altitude']
        else:
            altitude = 'N/A'
            
        if 'stake_number' in row:
            stake_id = row['stake_number']
        elif 'POINT_ID' in row:
            stake_id = row['POINT_ID'].split('_')[-1]  # Extract stake number from POINT_ID
        else:
            stake_id = f"Stake {idx}"
            
        # Create popup content
        popup_text = f"""
        Glacier: {glacier_name}<br>
        Year: {year_to_plot}<br>
        Stake: {stake_id}<br>
        Altitude: {altitude}m<br>
        """
        
        if 'POINT_BALANCE' in row:
            popup_text += f"Balance: {row['POINT_BALANCE']} m w.e.<br>"
        
        # Add marker
        folium.CircleMarker(
            location=[row[lat_col], row[lon_col]],
            radius=5,
            popup=popup_text,
            tooltip=f"{glacier_name} - {stake_id}",
            color=color,
            fill=True,
            fill_color=color,
            fill_opacity=0.7,
            weight=1
        ).add_to(m)

# Add markers for each glacier's stakes in the selected year
for glacier in unique_glaciers:
    glacier_data = data_year[data_year['GLACIER'] == glacier]
    add_glacier_markers(glacier_data, glacier)

# Add title to the map
title_html = f'''
<h3 align="center" style="font-size:16px"><b>Glacier Stakes in {year_to_plot}</b></h3>
'''
m.get_root().html.add_child(folium.Element(title_html))

# Add legend
legend_html = '''
<div style="position: fixed; bottom: 50px; left: 50px; z-index: 1000; background-color: white; padding: 10px; border-radius: 5px;">
    <p><strong>Glaciers</strong></p>
'''

for glacier, color in glacier_colors.items():
    # Count stakes for this glacier in the selected year
    count = len(data_year[data_year['GLACIER'] == glacier])
    legend_html += f'<p><span style="color: {color};">●</span> {glacier} ({count} stakes)</p>'

legend_html += '</div>'
m.get_root().html.add_child(folium.Element(legend_html))

# Add fullscreen option and layer control
plugins.Fullscreen().add_to(m)
folium.LayerControl().add_to(m)

# Display the map
m

### 2. Individual dataset stats

##### France

In [None]:
display(data_FR.head(2))

fig, ax = plt.subplots(figsize=(20, 3.2))
data_FR.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar',
    stacked=True,
    color=['#1f78b4', '#33a02c', '#e31a1c'],
    ax=ax)
plt.title('France', fontsize=24)
ax.set_xlabel('', fontsize=18)
ax.set_ylabel('Number of measurements', fontsize=18)
ax.tick_params(axis='both', labelsize=16)
ax.legend(fontsize=16)



plotHeatmap(list(data_FR['GLACIER'].unique()), data_FR, period='annual', figsize=(20, 5))

plotHeatmap(list(data_FR['GLACIER'].unique()), data_FR, period='winter', figsize=(20, 5))

plotHeatmap(list(data_FR['GLACIER'].unique()), data_FR, period='summer', figsize=(20, 5))

gl_per_el = data_FR.groupby(['GLACIER'])['POINT_ELEVATION'].mean()
plot_glacier_elevations(gl_per_el)



In [None]:
fig, ax = plt.subplots(figsize=(20, 3.2))
data_FR.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar',
    stacked=True,
    color=['#1f78b4', '#33a02c', '#e31a1c'],
    ax=ax)
plt.title('France', fontsize=24)
ax.set_xlabel('', fontsize=18)
ax.set_ylabel('Number of measurements', fontsize=18)
ax.tick_params(axis='both', labelsize=16)
ax.legend(fontsize=16)

fig, ax = plt.subplots(figsize=(20, 3.2))
data_IT.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar',
    stacked=True,
    color=['#1f78b4', '#33a02c', '#e31a1c'],
    ax=ax)
plt.title('Italy', fontsize=24)
ax.set_xlabel('', fontsize=18)
ax.set_ylabel('Number of measurements', fontsize=18)
ax.tick_params(axis='both', labelsize=16)
ax.legend(fontsize=16)

fig, ax = plt.subplots(figsize=(20, 3.2))
data_AT.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar',
    stacked=True,
    color=['#1f78b4', '#33a02c', '#e31a1c'],
    ax=ax)
plt.title('Austria', fontsize=24)
ax.set_xlabel('', fontsize=18)
ax.set_ylabel('Number of measurements', fontsize=18)
ax.tick_params(axis='both', labelsize=16)
ax.legend(fontsize=16)

fig, ax = plt.subplots(figsize=(20, 3.2))
data_NOR.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar',
    stacked=True,
    color=['#1f78b4', '#e31a1c'],
    ax=ax)
plt.title('Norway', fontsize=24)
ax.set_xlabel('', fontsize=18)
ax.set_ylabel('Number of measurements', fontsize=18)
ax.tick_params(axis='both', labelsize=16)
ax.legend(fontsize=16)

fig, ax = plt.subplots(figsize=(20, 3.2))
data_ICE.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar',
    stacked=True,
    color=['#1f78b4', '#e31a1c'],
    ax=ax)
plt.title('Iceland', fontsize=24)
ax.set_xlabel('', fontsize=18)
ax.set_ylabel('Number of measurements', fontsize=18)
ax.tick_params(axis='both', labelsize=16)
ax.legend(fontsize=16)

#### Italy

In [None]:
fig, ax = plt.subplots(figsize=(20, 3.2))
data_IT.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar',
    stacked=True,
    color=['#1f78b4', '#33a02c', '#e31a1c'],
    ax=ax)
plt.title('Italy', fontsize=24)
ax.set_xlabel('', fontsize=18)
ax.set_ylabel('Number of measurements', fontsize=18)
ax.tick_params(axis='both', labelsize=16)
ax.legend(fontsize=16)

plotHeatmap(list(data_IT['GLACIER'].unique()), data_IT, period='annual', figsize=(20, 5))

plotHeatmap(list(data_IT['GLACIER'].unique()), data_IT, period='winter', figsize=(20, 5))

plotHeatmap(list(data_IT['GLACIER'].unique()), data_IT, period='summer', figsize=(20, 5))

gl_per_el = data_IT.groupby(['GLACIER'])['POINT_ELEVATION'].mean()
plot_glacier_elevations(gl_per_el, figsize =(10,5))

In [None]:
(data_IT[data_IT['GLACIER'] == 'GRAND ETRET'])
#2700	-1.279	2008	GRAND ETRET_2008_32398_IT	-18.242163	annual	GRAND ETRET
#2701	-2.427	2008	GRAND ETRET_2008_32396_IT	-23.507818	annual	GRAND ETRET

#### Austria

In [None]:
fig, ax = plt.subplots(figsize=(20, 3.2))
data_AT.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar',
    stacked=True,
    color=['#0072B2', '#D55E00', '#009E73'],
    ax=ax)
plt.title('Austria', fontsize=24)
ax.set_xlabel('', fontsize=18)
ax.set_ylabel('Number of measurements', fontsize=18)
ax.tick_params(axis='both', labelsize=16)
ax.legend(fontsize=16)

plotHeatmap(list(data_AT['GLACIER'].unique()), data_AT, period='annual', figsize=(20, 5))

plotHeatmap(list(data_AT['GLACIER'].unique()), data_AT, period='winter', figsize=(20, 5))

gl_per_el = data_AT.groupby(['GLACIER'])['POINT_ELEVATION'].mean()
plot_glacier_elevations(gl_per_el, figsize =(10,5))

display(data_AT[data_AT['GLACIER'] == 'VERNAGT F.'])

#### Norway

###### Summer stakes have been preemptively removed, as this dataset is only used for transfer learning CH to NOR

In [None]:
fig, ax = plt.subplots(figsize=(20, 3.2))
data_NOR.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar',
    stacked=True,
    color=['#0072B2', '#D55E00', '#009E73'],
    ax=ax)
plt.title('Norway', fontsize=24)
ax.set_xlabel('', fontsize=18)
ax.set_ylabel('Number of measurements', fontsize=18)
ax.tick_params(axis='both', labelsize=16)
ax.legend(fontsize=16)

plotHeatmap(list(data_NOR['GLACIER'].unique()), data_NOR, period='annual', figsize=(20, 8))

plotHeatmap(list(data_NOR['GLACIER'].unique()), data_NOR, period='winter', figsize=(20, 8))

gl_per_el = data_NOR.groupby(['GLACIER'])['POINT_ELEVATION'].mean()
plot_glacier_elevations(gl_per_el, figsize =(10,5))

#### Iceland

###### Summer stakes have been preemptively removed, as this dataset is only used for transfer learning CH to NOR

In [None]:
fig, ax = plt.subplots(figsize=(20, 3.2))
data_ICE.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar',
    stacked=True,
    color=['#0072B2', '#D55E00', '#009E73'],
    ax=ax)
plt.title('Iceland', fontsize=24)
ax.set_xlabel('', fontsize=18)
ax.set_ylabel('Number of measurements', fontsize=18)
ax.tick_params(axis='both', labelsize=16)
ax.legend(fontsize=16)

plotHeatmap(list(data_ICE['GLACIER'].unique()), data_ICE, period='annual', figsize=(20, 8))

plotHeatmap(list(data_ICE['GLACIER'].unique()), data_ICE, period='winter', figsize=(20, 8))

gl_per_el = data_ICE.groupby(['GLACIER'])['POINT_ELEVATION'].mean()
plot_glacier_elevations(gl_per_el, figsize =(10,5))

### 3. Results XGBoost split

##### FR split tested on Gebroulaz

In [None]:
img = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/Result_Gebroulaz.png')

plt.figure(figsize=(10, 8))
plt.imshow(img)
plt.axis('off')
plt.title('FR split tested on Gebroulaz')
plt.tight_layout()
plt.show()

img = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/Result_Leschaux.png')

plt.figure(figsize=(10, 8))
plt.imshow(img)
plt.axis('off')
plt.title('FR split tested on Gebroulaz')
plt.tight_layout()
plt.show()


### 4. Result XGBoost transfer

##### CH train FR Test

In [None]:
img = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/Result_CH_FR_no_summer.png')

plt.figure(figsize=(10, 8))
plt.imshow(img)
plt.axis('off')
plt.title('CH train FR Test')
plt.tight_layout()
plt.show()

img2 = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/individual_glaciers_CH_FR_no_summer.png')
plt.figure(figsize=(10, 8))
plt.imshow(img2)
plt.axis('off')
plt.title('Individual Glaciers: CH train FR Test')
plt.tight_layout()
plt.show()



##### IT train AT test

In [None]:
img = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/Result_train_IT_test_AT.png')

plt.figure(figsize=(10, 8))
plt.imshow(img)
plt.axis('off')
plt.title('IT train AT test')
plt.tight_layout()
plt.show()

img2 = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/individual_glaciers_train_IT_test_AT.png')
plt.figure(figsize=(10, 8))
plt.imshow(img2)
plt.axis('off')
plt.title('Individual Glaciers: IT train AT test')
plt.tight_layout()
plt.show()

CH train IT & AT test

In [None]:
img = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/Result_train_CH_test_IT_AT.png')

plt.figure(figsize=(10, 8))
plt.imshow(img)
plt.axis('off')
plt.title('CH train IT & AT test')
plt.tight_layout()
plt.show()

img2 = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/individual_glaciers_train_CH_test_IT_AT.png')
plt.figure(figsize=(10, 8))
plt.imshow(img2)
plt.axis('off')
plt.title('Individual Glaciers: CH train IT & AT test')
plt.tight_layout()
plt.show()

##### CH train NOR test

In [None]:
img = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/Result_CH_NOR.png')

plt.figure(figsize=(10, 8))
plt.imshow(img)
plt.axis('off')
plt.title('CH train NOR test')
plt.tight_layout()
plt.show()

img2 = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/individual_galciers_CH_NOR.png')
plt.figure(figsize=(20, 15))
plt.imshow(img2)
plt.axis('off')
plt.title('Individual Glaciers: CH train NOR test')
plt.tight_layout()
plt.show()

Possible next steps: 
- clean up regions code, rerun all results
- iceland data and transfer learning test
- NN transfer learning comparison and potential improvement
- WGMS dataset has many more potential regions (worth to look at?)


### Iceland

##### Iceland west train east test

In [None]:
img = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/Result_ICE.png')

plt.figure(figsize=(10, 8))
plt.imshow(img)
plt.axis('off')
plt.title('Iceland west train Iceland east test')
plt.tight_layout()
plt.show()

img2 = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/individual_galciers_ICE.png')
plt.figure(figsize=(20, 15))
plt.imshow(img2)
plt.axis('off')
plt.title('Individual Glaciers: Iceland west train Iceland east test')
plt.tight_layout()
plt.show()

##### CH train ICE test

In [None]:
img = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/Result_train_CH_test_ICE.png')

plt.figure(figsize=(10, 8))
plt.imshow(img)
plt.axis('off')
plt.title('CH train ICE test')
plt.tight_layout()
plt.show()

img2 = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/Result_summary_train_CH_test_ICE.png')
plt.figure(figsize=(10, 8))
plt.imshow(img2)
plt.axis('off')
plt.title('CH train ICE test')
plt.tight_layout()
plt.show()

img3 = mpimg.imread('/home/mburlet/scratch/data/DATA_MB/Outputs/individual_glaciers_train_CH_test_ICE.png')
plt.figure(figsize=(20, 25))
plt.imshow(img3)
plt.axis('off')
plt.title('Individual Glaciers: CH train ICE test')
plt.tight_layout()
plt.show()

### Climate variables comparison

In [None]:
data_CH_monthly = pd.read_csv('/home/mburlet/scratch/data/DATA_MB/CH_wgms_dataset_monthly_full_04_06.csv')
data_FR_monthly = pd.read_csv('/home/mburlet/scratch/data/DATA_MB/GLACIOCLIM/csv/FR_wgms_dataset_monthly_full.csv')
data_IT_AT_monthly = pd.read_csv('/home/mburlet/scratch/data/DATA_MB/WGMS/IT_AT/csv/IT_AT_wgms_dataset_monthly_full.csv')
data_IT_monthly = data_IT_AT_monthly[data_IT_AT_monthly['POINT_ID'].str.endswith('_IT')]
data_AT_monthly = data_IT_AT_monthly[data_IT_AT_monthly['POINT_ID'].str.endswith('_AT')]
data_NOR_monthly = pd.read_csv('/home/mburlet/scratch/data/DATA_MB/WGMS/Norway/csv/NOR_dataset_monthly_full.csv')
data_ICE_monthly = pd.read_csv('/home/mburlet/scratch/data/DATA_MB/WGMS/Iceland/csv/ICE_dataset_monthly_full.csv')

In [None]:
data_CH_monthly.columns

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib.gridspec import GridSpec

# Define the climate features to plot
climate_features = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'u10', 'v10'
]

# Define the topographical features to plot
topo_features = [
    "aspect",
    "slope",
    "hugonnet_dhdt", 
    "consensus_ice_thickness",
    "millan_v",
    "ELEVATION_DIFFERENCE",
    "ALTITUDE_CLIMATE" 
]

# Names and units for better labels
feature_labels = {
    # Climate
    't2m': 'Temperature [°C]',
    'tp': 'Precipitation [m]',
    'slhf': 'Surface Latent Heat Flux [W/m²]',
    'sshf': 'Surface Sensible Heat Flux [W/m²]',
    'ssrd': 'Surface Solar Radiation Downward [W/m²]',
    'fal': 'Forecast Albedo [-]',
    'str': 'Surface Thermal Radiation [W/m²]',
    'u10': 'Wind U-component [m/s]',
    'v10': 'Wind V-component [m/s]',
    # Topographical
    'aspect': 'Aspect [degrees]',
    'slope': 'Slope [degrees]',
    'hugonnet_dhdt': 'Ice Thickness Change Rate [m/yr]',
    'consensus_ice_thickness': 'Ice Thickness [m]',
    'millan_v': 'Ice Velocity [m/yr]',
    'ELEVATION_DIFFERENCE': 'Elevation Difference [m]',
    "ALTITUDE_CLIMATE": 'Climate Altitude [m]',
}

# Define region colors (using a colorblind-friendly palette)
region_colors = {
    'CH': '#0072B2',  # blue
    'FR': '#D55E00',  # orange-red
    'IT': '#009E73',  # green
    'AT': '#CC79A7',  # pink
    'NOR': '#F0E442', # yellow
    'ICE': '#880808'  # dark red
}

# Create a function to preprocess and sample the data for more balanced comparison
def prepare_data_for_plotting(datasets):
    processed_data = {}
    
    # Process each dataset
    for region, df in datasets.items():
        # Make a copy of the dataframe to avoid modifying the original
        df_copy = df.copy()
        
        # Convert temperature from Kelvin to Celsius for better interpretability
        if 't2m' in df_copy.columns:
            df_copy['t2m'] = df_copy['t2m'] + 273.15
        
        processed_data[region] = df_copy
            
    return processed_data

def plot_feature_distributions(datasets, features, labels, colors, title):
    n_features = len(features)
    n_cols = 3
    n_rows = (n_features + n_cols - 1) // n_cols
    
    fig = plt.figure(figsize=(16, 4*n_rows))
    gs = GridSpec(n_rows, n_cols, figure=fig)
    
    processed_data = prepare_data_for_plotting(datasets)
    
    # Create histograms for each feature
    for i, feature in enumerate(features):
        row, col = i // n_cols, i % n_cols
        ax = fig.add_subplot(gs[row, col])
        
        # Plot histogram for each region
        for region, df in processed_data.items():
            if feature in df.columns:
                # Use robust statistics to determine plot range (exclude outliers)
                data = df[feature].dropna()
                if len(data) > 0:
                    q1, q3 = np.percentile(data, [1, 99])
                    data_range = data[(data >= q1) & (data <= q3)]
                    
                    sns.histplot(
                        data_range, 
                        kde=True, 
                        alpha=0.4, 
                        label=region,
                        color=colors[region],
                        ax=ax
                    )
        
        ax.set_xlabel(labels[feature] if feature in labels else feature)
        ax.set_title(f"{feature} Distribution")
        ax.legend()
        
        ax.grid(alpha=0.3, linestyle='--')
    
    plt.tight_layout()
    plt.suptitle(title, fontsize=16, y=1.02)
    return fig

datasets = {
    'CH': data_CH_monthly,
    'FR': data_FR_monthly,
    'IT': data_IT_monthly,
    'AT': data_AT_monthly,
    'NOR': data_NOR_monthly,
    'ICE': data_ICE_monthly
}

climate_fig = plot_feature_distributions(
    datasets, 
    climate_features, 
    feature_labels, 
    region_colors,
    'Climate Feature Distributions Across Regions'
)

topo_fig = plot_feature_distributions(
    datasets, 
    topo_features, 
    feature_labels, 
    region_colors,
    'Topographical Feature Distributions Across Regions'
)