In [14]:
import xgboost as xgb
import shap
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import os
import re
from matplotlib.colors import TwoSlopeNorm
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from skimage.transform import resize
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
import itertools 
from matplotlib.patches import Ellipse
from scipy import stats 

# --- Configuration ---
# Base directory where your experiment folders are located
base_data_dir = r"C:\Users\benra"

# Path to your basin mask NetCDF file
basin_mask_path = r"C:\Users\benra\output_combined_masked_basins.nc"

# Path to your model characteristics CSV file
model_characteristics_path = r"C:\Users\benra\ISMIP6_EXP_ALL - GIS_Model_Characteristics.csv"

# The specific experiment ID for which you want to calculate SHAP values across all basins
shap_experiment_id = 'expa03' 

# Variable name for ice thickness in your NetCDF files
variable_name = 'lithk'

# PCA and K-means parameters for outlier score calculation
n_components_pca = 4  # Number of PCA components to retain for clustering and outlier scoring
n_clusters_kmeans = 5 # Number of clusters for K-means

# Time window for delta calculation (using full time series for SHAP)
time_window_start_idx = 0   # Index for 2015
time_window_end_idx = 85    # Index for 2100
time_window_label = '2015-2100' # Label for plots and output folders

# Output directory for SHAP plots
output_plots_dir = "Research_Plots"
shap_output_dir = os.path.join(output_plots_dir, "SHAP_Plots_Per_Basin_Cluster")
os.makedirs(shap_output_dir, exist_ok=True)
exp_output_dir = os.path.join(shap_output_dir, f"{shap_experiment_id}")
os.makedirs(exp_output_dir, exist_ok=True)
pca_plot_output_dir = os.path.join(exp_output_dir, "PCA_KMeans_Plots")
os.makedirs(pca_plot_output_dir, exist_ok=True)

# --- Define Features for One-Hot Encoding ---
# This list specifies which columns from your model_characteristics_df
# should be treated as categorical features and one-hot encoded for SHAP analysis.
CATEGORICAL_FEATURES_TO_ENCODE = [
    'Initialisation',
    'Numerics',
    'Ice flow',
    'Initial SMB',
    'Bed',
    'GH F'
]

# --- Basin Geographic Names Mapping ---
basin_names = {
    1: 'SW',
    2: 'SE',
    3: 'CE',
    4: 'NE',
    5: 'NO',
    6: 'NW',
    7: 'CW'
}

# --- Define the standardization function for model names ---
def standardize_model_name(model_name_from_netcdf):
    """
    Transforms a model name from the NetCDF derived format
    to match the 'Model ID' format in the characteristics CSV.
    """
    # Specific rules for ILTS_PIK_SICOPOLIS models (example from your notebook)
    if model_name_from_netcdf == 'ILTS_PIK_SICOPOLIS1':
        return 'ILTSPIK- SICOPOLIS1'
    elif model_name_from_netcdf == 'ILTS_PIK_SICOPOLIS2':
        return 'ILTSPIK- SICOPOLIS2'
    
    # General rule: for all other models, replace underscores with dashes
    return model_name_from_netcdf.replace('_', '-')

# --- Load Basin Mask ---
try:
    basin_mask_data = xr.open_dataset(basin_mask_path)
    basin_id_var_name = None
    for var_key in basin_mask_data.data_vars.keys():
        if var_key.strip() == 'IDs':
            basin_id_var_name = var_key
            break
    
    if basin_id_var_name is None:
        raise KeyError("Could not find 'IDs' variable (even after stripping whitespace) in basin mask NetCDF.")

    basin_ids_grid = basin_mask_data[basin_id_var_name]
    unique_basin_ids = np.unique(basin_ids_grid.values[~np.isnan(basin_ids_grid.values)]).astype(int)
    print(f"Found {len(unique_basin_ids)} unique basins: {unique_basin_ids}")

    # Get grid coordinates and shape from the basin mask
    x_coords_grid = basin_ids_grid['x'].values if 'x' in basin_ids_grid.coords else basin_ids_grid.coords[basin_ids_grid.dims[1]].values
    y_coords_grid = basin_ids_grid['y'].values if 'y' in basin_ids_grid.coords else basin_ids_grid.coords[basin_ids_grid.dims[0]].values
    rows_grid, cols_grid = basin_ids_grid.shape
except FileNotFoundError:
    print(f"Error: Basin mask file not found at {basin_mask_path}")
    exit()
except KeyError as e:
    print(f"Error: {e}. Please check the variable name or file path for the basin mask.")
    exit()
except Exception as e:
    print(f"An unexpected error occurred while loading basin mask: {e}")
    exit()

# --- Load Model Characteristics CSV ---
try:
    model_characteristics_df = pd.read_csv(model_characteristics_path)
    # Ensure 'Model ID' column exists and is suitable for merging
    if 'Model ID' not in model_characteristics_df.columns:
        raise ValueError(f"'{model_characteristics_path}' must contain a 'Model ID' column.")
    print(f"Loaded model characteristics from {model_characteristics_path}. Shape: {model_characteristics_df.shape}")
except FileNotFoundError:
    print(f"Error: Model characteristics CSV file not found at {model_characteristics_path}")
    exit()
except Exception as e:
    print(f"An unexpected error occurred while loading model characteristics: {e}")
    exit()

# --- Main Analysis Loop: Per Basin ---
all_shap_results = [] # To store SHAP values for all basins if needed later

for basin_id in unique_basin_ids:
    basin_geo_name = basin_names.get(basin_id, f"Basin {basin_id}")
    print(f"\n--- Processing Basin: {basin_geo_name} (ID: {basin_id}) for SHAP Analysis ---")

    # Construct the actual folder name for the current experiment
    experiment_folder_name = f"{variable_name}_{shap_experiment_id}"
    netcdf_dir = os.path.join(base_data_dir, experiment_folder_name)

    if not os.path.exists(netcdf_dir):
        print(f"  Warning: Experiment directory not found for {shap_experiment_id} at {netcdf_dir}. Skipping this basin.")
        continue

    netcdf_files = [os.path.join(netcdf_dir, f) for f in os.listdir(netcdf_dir) if f.endswith('.nc')]

    current_basin_deltas = []
    current_basin_models_present = []

    # Loop through each model in the selected experiment's directory
    for i, file_path in enumerate(netcdf_files):
        base_filename = os.path.basename(file_path)
        
        # Extract raw model name and standardize
        model_id_raw = base_filename.replace('.nc', '')
        model_id_raw = model_id_raw.replace(f'{variable_name}_GIS_', '')
        if model_id_raw.endswith(f'_{shap_experiment_id}'):
            model_id_raw = model_id_raw.replace(f'_{shap_experiment_id}', '')
        model_id_raw = re.sub(r'_\d+$', '', model_id_raw) # Remove _<digits> if present
        standardized_model_id = standardize_model_name(model_id_raw)

        try:
            model_ds = xr.open_dataset(file_path, decode_times=False)
            if 'time' not in model_ds[variable_name].dims or len(model_ds['time']) < time_window_end_idx + 1:
                print(f"    Warning: '{variable_name}' in {standardized_model_id}/{shap_experiment_id} has insufficient time dimension for {time_window_label}. Skipping.")
                model_ds.close()
                continue

            delta_lithk = model_ds[variable_name].isel(time=time_window_end_idx) - model_ds[variable_name].isel(time=time_window_start_idx)
            model_ds.close()

            # Align delta_lithk to basin_ids_grid (important for correct masking)
            target_dims = basin_ids_grid.dims
            if len(delta_lithk.dims) == len(target_dims):
                dim_rename_map = {old_dim: new_dim for old_dim, new_dim in zip(delta_lithk.dims, target_dims)}
                delta_lithk_aligned = delta_lithk.rename(dim_rename_map)
            else:
                print(f"    Warning: Delta_lithk for {standardized_model_id} in {shap_experiment_id} has unexpected dimensions: {delta_lithk.dims}. Skipping alignment.")
                delta_lithk_aligned = delta_lithk

            delta_lithk_reindexed = delta_lithk_aligned.reindex_like(basin_ids_grid, method='nearest', tolerance=1e-6)

            basin_specific_delta = delta_lithk_reindexed.where(basin_ids_grid == basin_id)

            # Flatten and append
            flattened_delta = basin_specific_delta.values.flatten()
            current_basin_deltas.append(flattened_delta)
            current_basin_models_present.append(standardized_model_id)

        except Exception as e:
            print(f"    Error processing {standardized_model_id} in {shap_experiment_id}: {e}. Skipping.")
            if 'model_ds' in locals() and model_ds is not None:
                model_ds.close()
            continue

    if not current_basin_models_present:
        print(f"  No valid model data found for Experiment {shap_experiment_id} in Basin {basin_geo_name}. Skipping SHAP analysis for this basin.")
        continue

    # Stack all models' basin-specific deltas into a matrix
    all_models_basin_data = np.stack(current_basin_deltas)

    # Handle NaNs: Remove grid points that are NaN for any model in this basin/experiment
    mask_valid_points = ~np.any(np.isnan(all_models_basin_data), axis=0)
    data_matrix_for_clustering = all_models_basin_data[:, mask_valid_points]

    if data_matrix_for_clustering.shape[1] == 0:
        print(f"  No valid (non-NaN) data points for clustering in Experiment {shap_experiment_id}, Basin {basin_geo_name}. Skipping SHAP analysis.")
        continue
    # --- Perform PCA ---
    scaler = StandardScaler()
    data_matrix_scaled = scaler.fit_transform(data_matrix_for_clustering)

    pca = PCA(n_components=n_components_pca)
    pca_scores = pca.fit_transform(data_matrix_scaled)

    # --- Perform K-Means Clustering ---
    kmeans = KMeans(n_clusters=n_clusters_kmeans, random_state=0, n_init='auto')
    labels = kmeans.fit_predict(pca_scores)

    # Calculate outlier scores based on distance to the 'main' cluster centroid
    unique_labels, counts = np.unique(labels, return_counts=True)
    main_cluster_label = unique_labels[np.argmax(counts)]
    main_cluster_centroid = kmeans.cluster_centers_[main_cluster_label]
    outlier_scores = np.linalg.norm(pca_scores - main_cluster_centroid, axis=1)

    # --- Create 2D PCA Scatter Plot ---
    print(f"  Creating PCA scatter plot for Basin {basin_geo_name}")

    # Improved Cluster Color Selection
    if n_clusters_kmeans <= 9:
        cluster_colors = [plt.cm.get_cmap('Set1')(i) for i in range(n_clusters_kmeans)]
    elif n_clusters_kmeans <= 12:
        cluster_colors = [plt.cm.get_cmap('Set3')(i) for i in range(n_clusters_kmeans)]
    else:
        cluster_colors = [plt.cm.get_cmap('hsv')(i/n_clusters_kmeans) for i in range(n_clusters_kmeans)]

    # Define markers for different models
    import itertools
    markers = itertools.cycle(('o', '^', 's', 'D', 'v', 'P', 'X', '*', 'H', '<', '>'))
    model_to_marker = {model: next(markers) for model in current_basin_models_present}

    # Create 2D scatter plot
    fig_2d, ax_2d = plt.subplots(figsize=(12, 10))

    # Plot each model with its cluster color and unique marker
    for i, model_name in enumerate(current_basin_models_present):
        cluster_id = labels[i]
        color = cluster_colors[cluster_id]
        marker = model_to_marker[model_name]

        ax_2d.scatter(
            pca_scores[i, 0],
            pca_scores[i, 1],
            c=[color],
            marker=marker,
            s=250,  # Large marker size
            edgecolor='black',
            alpha=0.8,
            linewidth=1.5
        )

    # Add confidence ellipses for each cluster
    from scipy import stats
    from matplotlib.patches import Ellipse
    confidence_level = 0.95
    chi2_val = stats.chi2.ppf(confidence_level, df=2)

    for cluster_id in np.unique(labels):
        cluster_points = pca_scores[labels == cluster_id, :2]

        if len(cluster_points) < 2:
            continue

        centroid = np.mean(cluster_points, axis=0)
        covariance = np.cov(cluster_points, rowvar=False)

        eigenvalues, eigenvectors = np.linalg.eigh(covariance)
        order = eigenvalues.argsort()[::-1]
        eigenvalues = eigenvalues[order]
        eigenvectors = eigenvectors[:, order]

        angle = np.degrees(np.arctan2(*eigenvectors[:, 0][::-1]))
        width = 2 * np.sqrt(eigenvalues[0] * chi2_val)
        height = 2 * np.sqrt(eigenvalues[1] * chi2_val)

        ellipse_color = cluster_colors[cluster_id]
        ellipse = Ellipse(xy=centroid, width=width, height=height,
                         angle=angle, edgecolor=ellipse_color, fc='None',
                         lw=2, alpha=0.7, zorder=0)
        ax_2d.add_patch(ellipse)

    # Create cluster legend
    cluster_legend_handles = []
    for i in range(n_clusters_kmeans):
        cluster_legend_handles.append(plt.Line2D([0], [0], marker='o', color='w',
                                               markerfacecolor=cluster_colors[i], markersize=10,
                                               label=f'Cluster {i}', markeredgecolor='black'))

    # Create model legend
    model_legend_handles = []
    for i, model_name in enumerate(current_basin_models_present):
        cluster_id = labels[i]
        model_color = cluster_colors[cluster_id]
        model_legend_handles.append(plt.Line2D([0], [0], marker=model_to_marker[model_name], 
                                              color='w', markerfacecolor=model_color,
                                              markersize=10, label=model_name, 
                                              markeredgecolor='black'))

    # Add legends
    first_legend = ax_2d.legend(handles=cluster_legend_handles, loc='upper left', 
                               title='Clusters', bbox_to_anchor=(1.05, 1))
    ax_2d.add_artist(first_legend)
    ax_2d.legend(handles=model_legend_handles, loc='lower left', 
                title='Models', bbox_to_anchor=(1.05, 0))

    # Set labels and title
    ax_2d.set_xlabel(f'Principal Component 1 ({pca.explained_variance_ratio_[0]*100:.1f}% explained variance)')
    ax_2d.set_ylabel(f'Principal Component 2 ({pca.explained_variance_ratio_[1]*100:.1f}% explained variance)')
    plt.title(f'PCA: PC1 vs PC2 - K-Means Clusters\nExperiment: {shap_experiment_id}, Basin: {basin_geo_name}')
    plt.tight_layout(rect=[0, 0, 0.85, 1])

    # Save the scatter plot
    scatter_filename = f'PCA_2D_Scatter_Exp_{shap_experiment_id}_Basin_{basin_geo_name}.png'
    plt.savefig(os.path.join(pca_plot_output_dir, scatter_filename), bbox_inches='tight', dpi=300)
    plt.close()

    print(f"  PCA scatter plot saved: {scatter_filename}")

    # Create a DataFrame for this basin's results
    basin_results_df = pd.DataFrame({
        'model_name': current_basin_models_present,
        'experiment_name': shap_experiment_id,
        'basin_id': basin_id,
        'basin_geo_name': basin_geo_name,
        'time_window_label': time_window_label,
        'cluster_label': labels,
        'outlier_score': outlier_scores
    })

    # --- Merge with Model Characteristics ---
    shap_input_df = pd.merge(basin_results_df, model_characteristics_df, 
                             left_on='model_name', right_on='Model ID', how='left')

    # Drop columns that are not features for SHAP or are redundant
    columns_to_drop = [
        'model_name', 'experiment_name', 'basin_id', 'basin_geo_name',
        'time_window_label', 'cluster_label', 'Model ID',
        '#', 'exp_id', 'Unnamed: 5', 'Unnamed: 6', 'Velocity', 'Surface/ Thickness'
    ]
    columns_to_drop_existing = [col for col in columns_to_drop if col in shap_input_df.columns]
    X_shap_raw = shap_input_df.drop(columns=columns_to_drop_existing + ['outlier_score'])

    # Get cluster labels for SHAP analysis
    y_clusters = shap_input_df['cluster_label']

    # Handle NaNs in features
    for col in X_shap_raw.columns:
        if col in CATEGORICAL_FEATURES_TO_ENCODE:
            X_shap_raw[col] = X_shap_raw[col].fillna('Missing').astype('category')
        else:
            X_shap_raw[col] = X_shap_raw[col].fillna(0) 

    # One-hot encode specified categorical features
    X_shap_encoded = pd.get_dummies(X_shap_raw, columns=CATEGORICAL_FEATURES_TO_ENCODE, drop_first=True)

    # Drop rows with any remaining NaNs
    initial_samples = X_shap_encoded.shape[0]
    X_shap_cleaned = X_shap_encoded.dropna()
    y_clusters_cleaned = y_clusters[X_shap_cleaned.index]

    if X_shap_cleaned.shape[0] == 0:
        print(f"  No valid samples for SHAP analysis after cleaning for Basin {basin_geo_name}. Skipping.")
        continue
    if X_shap_cleaned.shape[0] < 2:
        print(f"  Too few samples ({X_shap_cleaned.shape[0]}) for SHAP analysis for Basin {basin_geo_name}. Skipping.")
        continue

    print(f"  Shape of X_shap_cleaned (features for SHAP): {X_shap_cleaned.shape}")
    print(f"  Cluster distribution: {np.bincount(y_clusters_cleaned)}")

    # Check clusters present in this basin
    unique_clusters = np.unique(y_clusters_cleaned)
    print(f"  Clusters present in Basin {basin_geo_name}: {unique_clusters}")

    # --- Create SHAP Summary Plot for Each Cluster ---
    for target_cluster in range(n_clusters_kmeans):  # Loop through all possible clusters (0 to 4)

        # Check if this cluster exists in this basin
        if target_cluster not in unique_clusters:
            print(f"    Cluster {target_cluster} not present in Basin {basin_geo_name}. Creating empty plot.")

            # Create empty plot for missing cluster
            plt.figure(figsize=(12, 8))
            plt.text(0.5, 0.5, f'Cluster {target_cluster}\nNo models in this cluster\nfor Basin {basin_geo_name}', 
                     ha='center', va='center', fontsize=16, transform=plt.gca().transAxes)
            plt.xlim(0, 1)
            plt.ylim(0, 1)
            plt.axis('off')
            plt.title(f'SHAP Summary Plot - Cluster {target_cluster}\nExperiment: {shap_experiment_id}, Basin: {basin_geo_name}')

            # Save empty plot
            plot_filename = f'SHAP_Summary_Cluster_{target_cluster}_Exp_{shap_experiment_id}_Basin_{basin_geo_name}.png'
            plt.savefig(os.path.join(exp_output_dir, plot_filename), dpi=300, bbox_inches='tight')
            plt.close()

            continue

        # Count models in this cluster
        cluster_count = np.sum(y_clusters_cleaned == target_cluster)
        print(f"    Processing Cluster {target_cluster} ({cluster_count} models)")

        # Create binary target: 1 if in this cluster, 0 otherwise
        y_binary = (y_clusters_cleaned == target_cluster).astype(int)

        # Check if we have enough samples and class balance
        positive_samples = np.sum(y_binary)
        negative_samples = len(y_binary) - positive_samples

        print(f"      Cluster {target_cluster}: {positive_samples} positive, {negative_samples} negative samples")

        # Skip if too few samples or perfect separation
        if positive_samples < 2 or negative_samples < 2:
            print(f"      Insufficient samples for meaningful classification. Skipping Cluster {target_cluster}.")

            # Create diagnostic plot
            plt.figure(figsize=(12, 8))
            plt.text(0.5, 0.5, f'Cluster {target_cluster}\nInsufficient samples for SHAP analysis\n'
                               f'{positive_samples} models in cluster, {negative_samples} models outside cluster\n'
                               f'Need at least 2 models in each group', 
                     ha='center', va='center', fontsize=14, transform=plt.gca().transAxes)
            plt.xlim(0, 1)
            plt.ylim(0, 1)
            plt.axis('off')
            plt.title(f'SHAP Summary Plot - Cluster {target_cluster}\nExperiment: {shap_experiment_id}, Basin: {basin_geo_name}')

            plot_filename = f'SHAP_Summary_Cluster_{target_cluster}_Exp_{shap_experiment_id}_Basin_{basin_geo_name}.png'
            plt.savefig(os.path.join(exp_output_dir, plot_filename), dpi=300, bbox_inches='tight')
            plt.close()
            continue

        # Train binary classifier for this cluster with adjusted parameters for small datasets
        model = xgb.XGBClassifier(
            objective='binary:logistic',
            n_estimators=50,  # Fewer trees for small datasets
            random_state=42,
            max_depth=2,      # Shallower trees to prevent overfitting
            min_child_weight=1,  # Lower threshold for small datasets
            learning_rate=0.3,   # Higher learning rate for faster convergence
            subsample=1.0,       # Use all samples
            colsample_bytree=1.0, # Use all features
            reg_alpha=0.1,       # Add some regularization
            reg_lambda=0.1,
            scale_pos_weight=negative_samples / positive_samples if positive_samples > 0 else 1
        )

        try:
            model.fit(X_shap_cleaned, y_binary)
            accuracy = model.score(X_shap_cleaned, y_binary)

            # Get prediction probabilities to check if model is actually discriminating
            y_pred_proba = model.predict_proba(X_shap_cleaned)[:, 1]
            prob_range = np.max(y_pred_proba) - np.min(y_pred_proba)

            print(f"      Binary classifier accuracy for Cluster {target_cluster}: {accuracy:.2f}")
            print(f"      Prediction probability range: {prob_range:.3f}")

            # If model isn't discriminating well, skip SHAP
            if prob_range < 0.1:
                print(f"      Model predictions too uniform (range < 0.1). Skipping SHAP for Cluster {target_cluster}.")

                plt.figure(figsize=(12, 8))
                plt.text(0.5, 0.5, f'Cluster {target_cluster}\nModel cannot distinguish cluster membership\n'
                                   f'Prediction probability range: {prob_range:.3f}\n'
                                   f'Features may not be informative for this cluster', 
                         ha='center', va='center', fontsize=14, transform=plt.gca().transAxes)
                plt.xlim(0, 1)
                plt.ylim(0, 1)
                plt.axis('off')
                plt.title(f'SHAP Summary Plot - Cluster {target_cluster}\nExperiment: {shap_experiment_id}, Basin: {basin_geo_name}')

                plot_filename = f'SHAP_Summary_Cluster_{target_cluster}_Exp_{shap_experiment_id}_Basin_{basin_geo_name}.png'
                plt.savefig(os.path.join(exp_output_dir, plot_filename), dpi=300, bbox_inches='tight')
                plt.close()
                continue

            # Calculate SHAP values
            explainer = shap.TreeExplainer(model)
            shap_values = explainer.shap_values(X_shap_cleaned)

            # Check if SHAP values are meaningful
            shap_range = np.max(np.abs(shap_values)) - np.min(np.abs(shap_values))
            mean_abs_shap = np.mean(np.abs(shap_values))

            print(f"      SHAP value range: {shap_range:.4f}, Mean absolute SHAP: {mean_abs_shap:.4f}")

            if mean_abs_shap < 0.001:
                print(f"      SHAP values too small to be meaningful. Creating diagnostic plot.")

                plt.figure(figsize=(12, 8))
                plt.text(0.5, 0.5, f'Cluster {target_cluster}\nSHAP values too small to interpret\n'
                                   f'Mean absolute SHAP value: {mean_abs_shap:.6f}\n'
                                   f'This suggests features don\'t strongly distinguish this cluster', 
                         ha='center', va='center', fontsize=14, transform=plt.gca().transAxes)
                plt.xlim(0, 1)
                plt.ylim(0, 1)
                plt.axis('off')
                plt.title(f'SHAP Summary Plot - Cluster {target_cluster}\nExperiment: {shap_experiment_id}, Basin: {basin_geo_name}')

                plot_filename = f'SHAP_Summary_Cluster_{target_cluster}_Exp_{shap_experiment_id}_Basin_{basin_geo_name}.png'
                plt.savefig(os.path.join(exp_output_dir, plot_filename), dpi=300, bbox_inches='tight')
                plt.close()
                continue

            # Create SHAP Summary Plot
            plt.figure(figsize=(12, 8))
            shap.summary_plot(shap_values, X_shap_cleaned, 
                              feature_names=X_shap_cleaned.columns, 
                              show=False, max_display=12)
            plt.title(f'SHAP Summary Plot - Cluster {target_cluster}\nExperiment: {shap_experiment_id}, Basin: {basin_geo_name}\n'
                      f'({cluster_count} models in cluster, Accuracy: {accuracy:.2f}, Mean |SHAP|: {mean_abs_shap:.4f})')
            plt.tight_layout()

            # Save the plot
            plot_filename = f'SHAP_Summary_Cluster_{target_cluster}_Exp_{shap_experiment_id}_Basin_{basin_geo_name}.png'
            plt.savefig(os.path.join(exp_output_dir, plot_filename), dpi=300, bbox_inches='tight')
            plt.close()

            print(f"      SHAP Summary Plot saved: {plot_filename}")

        except Exception as e:
            print(f"      Error creating SHAP plot for Cluster {target_cluster}: {str(e)}")

            # Create error plot
            plt.figure(figsize=(12, 8))
            plt.text(0.5, 0.5, f'Cluster {target_cluster}\nError in SHAP analysis\n{str(e)[:100]}...', 
                     ha='center', va='center', fontsize=12, transform=plt.gca().transAxes)
            plt.xlim(0, 1)
            plt.ylim(0, 1)
            plt.axis('off')
            plt.title(f'SHAP Summary Plot - Cluster {target_cluster} (Error)\nExperiment: {shap_experiment_id}, Basin: {basin_geo_name}')

            plot_filename = f'SHAP_Summary_Cluster_{target_cluster}_Exp_{shap_experiment_id}_Basin_{basin_geo_name}.png'
            plt.savefig(os.path.join(exp_output_dir, plot_filename), dpi=300, bbox_inches='tight')
            plt.close()

    print(f"\n--- SHAP Analysis Complete for Basin {basin_geo_name} ---")
    print(f"Created SHAP summary plots for all {n_clusters_kmeans} clusters")

    print("\n--- All Basin SHAP Analysis Complete ---")

Found 7 unique basins: [1 2 3 4 5 6 7]
Loaded model characteristics from C:\Users\benra\ISMIP6_EXP_ALL - GIS_Model_Characteristics.csv. Shape: (21, 12)

--- Processing Basin: SW (ID: 1) for SHAP Analysis ---


  cluster_colors = [plt.cm.get_cmap('Set1')(i) for i in range(n_clusters_kmeans)]
  height = 2 * np.sqrt(eigenvalues[1] * chi2_val)


  Creating PCA scatter plot for Basin SW
  PCA scatter plot saved: PCA_2D_Scatter_Exp_expa03_Basin_SW.png
  Shape of X_shap_cleaned (features for SHAP): (15, 21)
  Cluster distribution: [2 4 5 3 1]
  Clusters present in Basin SW: [0 1 2 3 4]
    Processing Cluster 0 (2 models)
      Cluster 0: 2 positive, 13 negative samples
      Binary classifier accuracy for Cluster 0: 1.00
      Prediction probability range: 0.865
      SHAP value range: 2.6232, Mean absolute SHAP: 0.1249
      SHAP Summary Plot saved: SHAP_Summary_Cluster_0_Exp_expa03_Basin_SW.png
    Processing Cluster 1 (4 models)
      Cluster 1: 4 positive, 11 negative samples
      Binary classifier accuracy for Cluster 1: 0.87
      Prediction probability range: 0.790
      SHAP value range: 2.3239, Mean absolute SHAP: 0.1528
      SHAP Summary Plot saved: SHAP_Summary_Cluster_1_Exp_expa03_Basin_SW.png
    Processing Cluster 2 (5 models)
      Cluster 2: 5 positive, 10 negative samples
      Binary classifier accuracy for Cl

  cluster_colors = [plt.cm.get_cmap('Set1')(i) for i in range(n_clusters_kmeans)]
  height = 2 * np.sqrt(eigenvalues[1] * chi2_val)


  Creating PCA scatter plot for Basin SE
  PCA scatter plot saved: PCA_2D_Scatter_Exp_expa03_Basin_SE.png
  Shape of X_shap_cleaned (features for SHAP): (15, 21)
  Cluster distribution: [2 5 4 3 1]
  Clusters present in Basin SE: [0 1 2 3 4]
    Processing Cluster 0 (2 models)
      Cluster 0: 2 positive, 13 negative samples
      Binary classifier accuracy for Cluster 0: 1.00
      Prediction probability range: 0.865
      SHAP value range: 2.6232, Mean absolute SHAP: 0.1249
      SHAP Summary Plot saved: SHAP_Summary_Cluster_0_Exp_expa03_Basin_SE.png
    Processing Cluster 1 (5 models)
      Cluster 1: 5 positive, 10 negative samples
      Binary classifier accuracy for Cluster 1: 1.00
      Prediction probability range: 0.760
      SHAP value range: 2.4044, Mean absolute SHAP: 0.1106
      SHAP Summary Plot saved: SHAP_Summary_Cluster_1_Exp_expa03_Basin_SE.png
    Processing Cluster 2 (4 models)
      Cluster 2: 4 positive, 11 negative samples
      Binary classifier accuracy for Cl

  cluster_colors = [plt.cm.get_cmap('Set1')(i) for i in range(n_clusters_kmeans)]


  Creating PCA scatter plot for Basin CE
  PCA scatter plot saved: PCA_2D_Scatter_Exp_expa03_Basin_CE.png
  Shape of X_shap_cleaned (features for SHAP): (15, 21)
  Cluster distribution: [2 6 4 2 1]
  Clusters present in Basin CE: [0 1 2 3 4]
    Processing Cluster 0 (2 models)
      Cluster 0: 2 positive, 13 negative samples
      Binary classifier accuracy for Cluster 0: 1.00
      Prediction probability range: 0.865
      SHAP value range: 2.6232, Mean absolute SHAP: 0.1249
      SHAP Summary Plot saved: SHAP_Summary_Cluster_0_Exp_expa03_Basin_CE.png
    Processing Cluster 1 (6 models)
      Cluster 1: 6 positive, 9 negative samples
      Binary classifier accuracy for Cluster 1: 0.87
      Prediction probability range: 0.866
      SHAP value range: 1.6722, Mean absolute SHAP: 0.1341
      SHAP Summary Plot saved: SHAP_Summary_Cluster_1_Exp_expa03_Basin_CE.png
    Processing Cluster 2 (4 models)
      Cluster 2: 4 positive, 11 negative samples
      Binary classifier accuracy for Clu

  cluster_colors = [plt.cm.get_cmap('Set1')(i) for i in range(n_clusters_kmeans)]
  height = 2 * np.sqrt(eigenvalues[1] * chi2_val)


  Creating PCA scatter plot for Basin NE
  PCA scatter plot saved: PCA_2D_Scatter_Exp_expa03_Basin_NE.png
  Shape of X_shap_cleaned (features for SHAP): (15, 21)
  Cluster distribution: [2 7 4 1 1]
  Clusters present in Basin NE: [0 1 2 3 4]
    Processing Cluster 0 (2 models)
      Cluster 0: 2 positive, 13 negative samples
      Binary classifier accuracy for Cluster 0: 1.00
      Prediction probability range: 0.865
      SHAP value range: 2.6232, Mean absolute SHAP: 0.1249
      SHAP Summary Plot saved: SHAP_Summary_Cluster_0_Exp_expa03_Basin_NE.png
    Processing Cluster 1 (7 models)
      Cluster 1: 7 positive, 8 negative samples
      Binary classifier accuracy for Cluster 1: 0.87
      Prediction probability range: 0.794
      SHAP value range: 1.7215, Mean absolute SHAP: 0.1049
      SHAP Summary Plot saved: SHAP_Summary_Cluster_1_Exp_expa03_Basin_NE.png
    Processing Cluster 2 (4 models)
      Cluster 2: 4 positive, 11 negative samples
      Binary classifier accuracy for Clu

  cluster_colors = [plt.cm.get_cmap('Set1')(i) for i in range(n_clusters_kmeans)]


  Creating PCA scatter plot for Basin NO
  PCA scatter plot saved: PCA_2D_Scatter_Exp_expa03_Basin_NO.png
  Shape of X_shap_cleaned (features for SHAP): (15, 21)
  Cluster distribution: [2 4 5 1 3]
  Clusters present in Basin NO: [0 1 2 3 4]
    Processing Cluster 0 (2 models)
      Cluster 0: 2 positive, 13 negative samples
      Binary classifier accuracy for Cluster 0: 1.00
      Prediction probability range: 0.865
      SHAP value range: 2.6232, Mean absolute SHAP: 0.1249
      SHAP Summary Plot saved: SHAP_Summary_Cluster_0_Exp_expa03_Basin_NO.png
    Processing Cluster 1 (4 models)
      Cluster 1: 4 positive, 11 negative samples
      Binary classifier accuracy for Cluster 1: 0.93
      Prediction probability range: 0.874
      SHAP value range: 2.1046, Mean absolute SHAP: 0.1253
      SHAP Summary Plot saved: SHAP_Summary_Cluster_1_Exp_expa03_Basin_NO.png
    Processing Cluster 2 (5 models)
      Cluster 2: 5 positive, 10 negative samples
      Binary classifier accuracy for Cl

  cluster_colors = [plt.cm.get_cmap('Set1')(i) for i in range(n_clusters_kmeans)]
  height = 2 * np.sqrt(eigenvalues[1] * chi2_val)


  Creating PCA scatter plot for Basin NW
  PCA scatter plot saved: PCA_2D_Scatter_Exp_expa03_Basin_NW.png
  Shape of X_shap_cleaned (features for SHAP): (15, 21)
  Cluster distribution: [2 7 2 3 1]
  Clusters present in Basin NW: [0 1 2 3 4]
    Processing Cluster 0 (2 models)
      Cluster 0: 2 positive, 13 negative samples
      Binary classifier accuracy for Cluster 0: 1.00
      Prediction probability range: 0.865
      SHAP value range: 2.6232, Mean absolute SHAP: 0.1249
      SHAP Summary Plot saved: SHAP_Summary_Cluster_0_Exp_expa03_Basin_NW.png
    Processing Cluster 1 (7 models)
      Cluster 1: 7 positive, 8 negative samples
      Binary classifier accuracy for Cluster 1: 0.93
      Prediction probability range: 0.773
      SHAP value range: 1.9363, Mean absolute SHAP: 0.1046
      SHAP Summary Plot saved: SHAP_Summary_Cluster_1_Exp_expa03_Basin_NW.png
    Processing Cluster 2 (2 models)
      Cluster 2: 2 positive, 13 negative samples
      Binary classifier accuracy for Clu

  cluster_colors = [plt.cm.get_cmap('Set1')(i) for i in range(n_clusters_kmeans)]


  Creating PCA scatter plot for Basin CW
  PCA scatter plot saved: PCA_2D_Scatter_Exp_expa03_Basin_CW.png
  Shape of X_shap_cleaned (features for SHAP): (15, 21)
  Cluster distribution: [2 6 4 2 1]
  Clusters present in Basin CW: [0 1 2 3 4]
    Processing Cluster 0 (2 models)
      Cluster 0: 2 positive, 13 negative samples
      Binary classifier accuracy for Cluster 0: 1.00
      Prediction probability range: 0.865
      SHAP value range: 2.6232, Mean absolute SHAP: 0.1249
      SHAP Summary Plot saved: SHAP_Summary_Cluster_0_Exp_expa03_Basin_CW.png
    Processing Cluster 1 (6 models)
      Cluster 1: 6 positive, 9 negative samples
      Binary classifier accuracy for Cluster 1: 0.93
      Prediction probability range: 0.757
      SHAP value range: 2.2423, Mean absolute SHAP: 0.0954
      SHAP Summary Plot saved: SHAP_Summary_Cluster_1_Exp_expa03_Basin_CW.png
    Processing Cluster 2 (4 models)
      Cluster 2: 4 positive, 11 negative samples
      Binary classifier accuracy for Clu