# Plotting input datasets

Plot input data, colored by inclusion (green shades) and exclusion (red shades) of datasets used to produce the results of the paper.

In [None]:
import os
import pandas as pd
import numpy as np
from pathlib import Path
import glob
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import statsmodels.api as sm
from scipy import stats
from scipy.stats import pearsonr, spearmanr
import pymannkendall as mk
import warnings
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

warnings.filterwarnings("ignore", category=FutureWarning)

# Paths
input_dir = Path("input")
output_dir = Path("output")
input_rel = input_dir / "glambie_reference"
input_maps = input_dir / "maps"
input_datasets = input_dir / "input_datasets"
glambie_runs = input_dir / "glambie_runs"
output_sensitivity = output_dir / "sensitivity"
output_rel = output_dir / "relative_change"
output_dir_datasets = output_dir / "datasets"

In [None]:
columns = {'start_dates', 'end_dates', 'changes', 'errors'}
included_colors = ['#2ecc71', '#27ae60', '#229954', '#1e8449', '#196f3d']
excluded_colors = ['#e74c3c', '#c0392b', '#a93226', '#922b21', '#7b241c']

# included_colors = ['#2ecc71', '#00acc1', '#558b2f', '#26a69a', '#0277bd', '#7cb342', '#00695c', '#43a047', '#00838f', '#1b5e20']
# excluded_colors = ['#e53935', '#fb8c00', '#d81b60', '#8d6e63', '#f4511e', '#ad1457', '#c62828', '#ef6c00', '#e91e63', '#6d4c41']

# Load excluded list
excluded_df = pd.read_csv(str(input_datasets / 'excluded_datasets_list.csv'))
# excluded_df = excluded_df[excluded_df['inclusion_possible'].str.strip().str.lower() == 'no'] # plot only problematic datasets as "excluded"

excluded_set = set()
for _, row in excluded_df.iterrows():
    region = row['region']
    data_group = row['data_group']
    dataset = str(row['dataset']).lower()

    if data_group == 'demdiff_and_glaciological':
        excluded_set.add((region, 'demdiff', dataset))
        excluded_set.add((region, 'glaciological', dataset))
    else:
        excluded_set.add((region, data_group, dataset))

# Datasets to completely remove from plots
skip_datasets = set()

# Load all datasets
datasets = []
base_dir = input_datasets
skip_files = {'excluded_datasets_list.csv'}
all_csv_files = [f for f in base_dir.rglob('*.csv') if f.name not in skip_files]

for csv_file in all_csv_files:
    try:
        df = pd.read_csv(csv_file)
        if not columns.issubset(df.columns):
            continue

        parts = csv_file.parts
        if len(parts) < 3:
            continue

        region = parts[-3]
        data_group = parts[-2]
        dataset_name = csv_file.stem

        if dataset_name in skip_datasets:
            continue

        unit = 'Gt' if data_group == 'gravimetry' else 'm'
        is_excluded = (region, data_group, dataset_name.lower()) in excluded_set

        datasets.append({
            'region': region,
            'data_group': data_group,
            'dataset': dataset_name,
            'unit': unit,
            'is_excluded': is_excluded,
            'data': df,
            'filepath': csv_file
        })
    except Exception:
        continue

# Group by (region, unit)
grouped = defaultdict(list)
for ds in datasets:
    grouped[(ds['region'], ds['unit'])].append(ds)

# Plot per group
for (region, unit), group_datasets in sorted(grouped.items()):
    print(f"Plotting {region, unit}")
    if not group_datasets:
        continue

    fig, ax = plt.subplots(figsize=(8, 5))

    included = [ds for ds in group_datasets if not ds['is_excluded']]
    excluded = [ds for ds in group_datasets if ds['is_excluded']]

    # Included
    for idx, ds in enumerate(included):
        color = included_colors[idx % len(included_colors)]
        dfp = ds['data'].copy()

        dfp['time'] = (dfp['start_dates'] + dfp['end_dates']) / 2
        dfp['errors_abs'] = dfp['errors'].abs()
        dfp = dfp.sort_values('time').reset_index(drop=True)
        if len(dfp) == 0:
            continue

        x_lines, y_lines = [], []
        for _, row in dfp.iterrows():
            x_lines.extend([row['start_dates'], row['end_dates'], np.nan])
            y_lines.extend([row['changes'], row['changes'], np.nan])
        ax.plot(x_lines, y_lines, '-', linewidth=2, alpha=0.7, color=color)

        for _, row in dfp.iterrows():
            x_ribbon = [row['start_dates'], row['end_dates'], row['end_dates'], row['start_dates']]
            y_ribbon = [
                row['changes'] - row['errors_abs'],
                row['changes'] - row['errors_abs'],
                row['changes'] + row['errors_abs'],
                row['changes'] + row['errors_abs'],
            ]
            ax.fill(x_ribbon, y_ribbon, color=color, alpha=0.2, edgecolor='none')

    # Excluded
    for idx, ds in enumerate(excluded):
        color = excluded_colors[idx % len(excluded_colors)]
        dfp = ds['data'].copy()

        dfp['time'] = (dfp['start_dates'] + dfp['end_dates']) / 2
        dfp['errors_abs'] = dfp['errors'].abs()
        dfp = dfp.sort_values('time').reset_index(drop=True)
        if len(dfp) == 0:
            continue

        x_lines, y_lines = [], []
        for _, row in dfp.iterrows():
            x_lines.extend([row['start_dates'], row['end_dates'], np.nan])
            y_lines.extend([row['changes'], row['changes'], np.nan])
        ax.plot(x_lines, y_lines, '-', linewidth=2, alpha=0.7, color=color)

        for _, row in dfp.iterrows():
            x_ribbon = [row['start_dates'], row['end_dates'], row['end_dates'], row['start_dates']]
            y_ribbon = [
                row['changes'] - row['errors_abs'],
                row['changes'] - row['errors_abs'],
                row['changes'] + row['errors_abs'],
                row['changes'] + row['errors_abs'],
            ]
            ax.fill(x_ribbon, y_ribbon, color=color, alpha=0.2, edgecolor='none')

    # Legend
    legend_elements = []
    for idx, ds in enumerate(included):
        color = included_colors[idx % len(included_colors)]
        legend_elements.append(plt.Line2D([0], [0], color=color, lw=2,
                                         label=f"{ds['dataset']} (Included)"))
    for idx, ds in enumerate(excluded):
        color = excluded_colors[idx % len(excluded_colors)]
        legend_elements.append(plt.Line2D([0], [0], color=color, lw=2,
                                         label=f"{ds['dataset']} (Excluded)"))

    unit_label = 'Gravimetry (Gt)' if unit == 'Gt' else 'All Other Methods (m)'
    ax.set_xlabel('Time (year)')
    ax.set_ylabel(f'Change ({unit})')
    ax.set_title(
        f'{region.replace("_", " ").title()} ({unit_label})')
    ax.grid(True, alpha=0.3, linestyle='--')

    if legend_elements:
        ax.legend(handles=legend_elements, loc='best', fontsize=9, ncol=2)

    plt.tight_layout()

    safe_region = region.replace('/', '_').replace('\\', '_')
    output_path = output_dir_datasets / f"{safe_region}_{unit}.png"
    plt.savefig(output_path, dpi=200, bbox_inches='tight')
    plt.close()