# Meta-Learner Weights Visualization

This notebook visualizes the weights generated by the Historical Performance-Weighted Meta-Learning Framework for the Snow_HistMeta model.

## Overview
The meta-learner assigns weights to different base models based on their historical performance for each basin-period combination. This notebook explores these weights to understand:
- Which models perform best in different periods
- How weights vary across basins
- Seasonal patterns in model performance
- Overall ensemble composition

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings

warnings.filterwarnings("ignore")

# Set up plotting style
plt.style.use("seaborn-v0_8")
sns.set_palette("husl")
plt.rcParams["figure.figsize"] = (12, 8)
plt.rcParams["font.size"] = 12

print("Libraries imported successfully!")

## 1. Load Meta-Learner Weights and Performance Data

In [None]:
# Define paths to the meta-learner data
weights_path = "../../monthly_forecasting_models/SnowMapper_Based/Snow_HistMeta/Snow_HistMeta_weights.parquet"
performance_path = "../../monthly_forecasting_models/SnowMapper_Based/Snow_HistMeta/Snow_HistMeta_performance.parquet"

# Load the weights and performance data
try:
    weights_df = pd.read_parquet(weights_path)
    performance_df = pd.read_parquet(performance_path)

    print(f"‚úì Loaded weights data: {weights_df.shape}")
    print(f"‚úì Loaded performance data: {performance_df.shape}")

except FileNotFoundError as e:
    print(f"‚ùå Error loading data: {e}")
    print(
        "Please ensure the Snow_HistMeta model has been trained and weights are available."
    )

In [None]:
# Explore the structure of the data
print("=== WEIGHTS DATA STRUCTURE ===")
print(f"Columns: {list(weights_df.columns)}")
print(f"Shape: {weights_df.shape}")
print("\nFirst few rows:")
display(weights_df.head())

print("\n=== PERFORMANCE DATA STRUCTURE ===")
print(f"Columns: {list(performance_df.columns)}")
print(f"Shape: {performance_df.shape}")
print("\nFirst few rows:")
display(performance_df.head())

In [None]:
# Identify model columns (exclude 'code' and 'period')
model_columns = [col for col in weights_df.columns if col not in ["code", "period"]]
print(f"Base models identified: {model_columns}")
print(f"Number of base models: {len(model_columns)}")

# Basic statistics
print(f"\nNumber of basins: {weights_df['code'].nunique()}")
print(f"Number of periods: {weights_df['period'].nunique()}")
print(f"Period range: {weights_df['period'].min()} to {weights_df['period'].max()}")
print(f"Basin codes: {sorted(weights_df['code'].unique())}")

## 2. Overall Weight Distribution Analysis

In [None]:
# Calculate overall weight statistics across all basins and periods
weight_stats = weights_df[model_columns].describe()
print("=== OVERALL WEIGHT STATISTICS ===")
display(weight_stats)

# Calculate mean weights across all combinations
mean_weights = weights_df[model_columns].mean()
print(f"\n=== MEAN WEIGHTS ACROSS ALL BASINS AND PERIODS ===")
for model, weight in mean_weights.sort_values(ascending=False).items():
    print(f"{model}: {weight:.4f} ({weight * 100:.2f}%)")

In [None]:
# Plot overall weight distribution
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# 1. Box plot of weights by model
weights_melted = weights_df.melt(
    id_vars=["code", "period"],
    value_vars=model_columns,
    var_name="Model",
    value_name="Weight",
)

sns.boxplot(data=weights_melted, x="Model", y="Weight", ax=axes[0, 0])
axes[0, 0].set_title("Weight Distribution by Model")
axes[0, 0].tick_params(axis="x", rotation=45)

# 2. Mean weights bar plot
mean_weights.plot(kind="bar", ax=axes[0, 1], color="skyblue")
axes[0, 1].set_title("Mean Weights by Model")
axes[0, 1].set_ylabel("Mean Weight")
axes[0, 1].tick_params(axis="x", rotation=45)

# 3. Weight correlation heatmap
correlation_matrix = weights_df[model_columns].corr()
sns.heatmap(correlation_matrix, annot=True, cmap="coolwarm", center=0, ax=axes[1, 0])
axes[1, 0].set_title("Model Weight Correlations")

# 4. Histogram of weight values
weights_df[model_columns].hist(bins=30, ax=axes[1, 1], alpha=0.7)
axes[1, 1].set_title("Distribution of All Weights")
axes[1, 1].set_xlabel("Weight Value")
axes[1, 1].set_ylabel("Frequency")

plt.tight_layout()
plt.show()

## 3. Seasonal Patterns in Model Weights

In [None]:
# Calculate mean weights by period (seasonal analysis)
weights_by_period = weights_df.groupby("period")[model_columns].mean()

print("=== SEASONAL WEIGHT PATTERNS ===")
display(weights_by_period.round(4))

# Create a mapping from period to month names
month_names = [
    "Jan", "Feb", "Mar", "Apr", "May", "Jun",
    "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"
]

def parse_period_to_month(period_str: str) -> str:
    """Parse period string to get month name and period description."""
    try:
        # Handle different period formats
        if isinstance(period_str, (int, float)):
            # Simple numeric periods (1-12)
            month_idx = int(period_str) - 1
            if 0 <= month_idx < 12:
                return f"{period_str} ({month_names[month_idx]})"
            return str(period_str)
        
        # Handle string periods like "1-10", "1-20", "1-end"
        period_str = str(period_str)
        if "-" in period_str:
            month_part, day_part = period_str.split("-", 1)
            try:
                month_num = int(month_part)
                if 1 <= month_num <= 12:
                    month_name = month_names[month_num - 1]
                    if day_part == "end":
                        return f"{period_str} ({month_name} end)"
                    else:
                        return f"{period_str} ({month_name} {day_part})"
                else:
                    return period_str
            except ValueError:
                return period_str
        else:
            return period_str
    except (ValueError, IndexError, AttributeError):
        return str(period_str)

# Check period data types and analyze format
print(f"\nPeriod index type: {type(weights_by_period.index[0])}")
print(f"Sample periods: {list(weights_by_period.index[:5])}")
print(f"Total periods: {len(weights_by_period.index)}")

# Create enhanced period labels
try:
    weights_by_period_display = weights_by_period.copy()
    weights_by_period_display.index = [
        parse_period_to_month(period) for period in weights_by_period_display.index
    ]
    
    print("\n=== SEASONAL PATTERNS WITH ENHANCED LABELS ===")
    display(weights_by_period_display.round(4))
    
    # Analyze period structure
    periods = list(weights_by_period.index)
    unique_months = set()
    unique_day_parts = set()
    
    for period in periods:
        if "-" in str(period):
            month_part, day_part = str(period).split("-", 1)
            try:
                unique_months.add(int(month_part))
                unique_day_parts.add(day_part)
            except ValueError:
                pass
    
    print(f"\n=== PERIOD STRUCTURE ANALYSIS ===")
    print(f"Unique months found: {sorted(unique_months)}")
    print(f"Unique day parts: {sorted(unique_day_parts)}")
    print(f"Period format appears to be: month-daypart (e.g., '1-10' = January 10th)")
    
except Exception as e:
    print(f"\n‚ùå Error processing periods: {e}")
    print("Using original period labels without enhancement")

In [None]:
# Plot seasonal patterns
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Line plot of weights by period
for model in model_columns:
    axes[0, 0].plot(
        weights_by_period.index,
        weights_by_period[model],
        marker="o",
        label=model,
        linewidth=2,
    )
axes[0, 0].set_title("Seasonal Weight Patterns")
axes[0, 0].set_xlabel("Period")
axes[0, 0].set_ylabel("Mean Weight")
axes[0, 0].legend(bbox_to_anchor=(1.05, 1), loc="upper left")
axes[0, 0].grid(True, alpha=0.3)

# 2. Heatmap of weights by period
sns.heatmap(
    weights_by_period[model_columns].T, annot=True, cmap="viridis", ax=axes[0, 1]
)
axes[0, 1].set_title("Weight Heatmap by Period")
axes[0, 1].set_xlabel("Period")
axes[0, 1].set_ylabel("Model")

# 3. Stacked bar plot showing relative importance by period
weights_by_period[model_columns].plot(kind="bar", stacked=True, ax=axes[1, 0])
axes[1, 0].set_title("Stacked Weights by Period")
axes[1, 0].set_xlabel("Period")
axes[1, 0].set_ylabel("Cumulative Weight")
axes[1, 0].legend(bbox_to_anchor=(1.05, 1), loc="upper left")
axes[1, 0].tick_params(axis="x", rotation=45)

# 4. Dominant model by period
dominant_model = weights_by_period[model_columns].idxmax(axis=1)
dominant_counts = dominant_model.value_counts()
dominant_counts.plot(kind="bar", ax=axes[1, 1], color="lightcoral")
axes[1, 1].set_title("Dominant Model Frequency by Period")
axes[1, 1].set_xlabel("Model")
axes[1, 1].set_ylabel("Number of Periods Dominated")
axes[1, 1].tick_params(axis="x", rotation=45)

plt.tight_layout()
plt.show()

## 4. Basin-Specific Weight Analysis

In [None]:
# Calculate mean weights by basin
weights_by_basin = weights_df.groupby("code")[model_columns].mean()

print("=== BASIN-SPECIFIC WEIGHT PATTERNS ===")
display(weights_by_basin.round(4))

# Find the dominant model for each basin
dominant_by_basin = weights_by_basin.idxmax(axis=1)
print("\n=== DOMINANT MODEL BY BASIN ===")
for basin, model in dominant_by_basin.items():
    weight = weights_by_basin.loc[basin, model]
    print(f"Basin {basin}: {model} (weight: {weight:.4f})")

In [None]:
# Plot basin-specific patterns
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Heatmap of weights by basin
sns.heatmap(weights_by_basin.T, annot=True, cmap="plasma", ax=axes[0, 0])
axes[0, 0].set_title("Weight Heatmap by Basin")
axes[0, 0].set_xlabel("Basin Code")
axes[0, 0].set_ylabel("Model")

# 2. Stacked bar plot by basin
weights_by_basin.plot(kind="bar", stacked=True, ax=axes[0, 1])
axes[0, 1].set_title("Stacked Weights by Basin")
axes[0, 1].set_xlabel("Basin Code")
axes[0, 1].set_ylabel("Cumulative Weight")
axes[0, 1].legend(bbox_to_anchor=(1.05, 1), loc="upper left")
axes[0, 1].tick_params(axis="x", rotation=45)

# 3. Box plot of weights by basin
weights_melted_basin = weights_df.melt(
    id_vars=["code", "period"],
    value_vars=model_columns,
    var_name="Model",
    value_name="Weight",
)

# Select a subset of basins if too many
unique_basins = sorted(weights_df["code"].unique())
if len(unique_basins) > 10:
    selected_basins = unique_basins[:10]
    weights_subset = weights_melted_basin[
        weights_melted_basin["code"].isin(selected_basins)
    ]
    title_suffix = f" (First 10 of {len(unique_basins)} basins)"
else:
    weights_subset = weights_melted_basin
    title_suffix = ""

sns.boxplot(data=weights_subset, x="code", y="Weight", ax=axes[1, 0])
axes[1, 0].set_title(f"Weight Distribution by Basin{title_suffix}")
axes[1, 0].set_xlabel("Basin Code")
axes[1, 0].tick_params(axis="x", rotation=45)

# 4. Dominant model distribution
dominant_by_basin.value_counts().plot(kind="bar", ax=axes[1, 1], color="lightgreen")
axes[1, 1].set_title("Dominant Model Distribution Across Basins")
axes[1, 1].set_xlabel("Model")
axes[1, 1].set_ylabel("Number of Basins Dominated")
axes[1, 1].tick_params(axis="x", rotation=45)

plt.tight_layout()
plt.show()

## 5. Performance vs. Weight Analysis

In [None]:
# Merge weights and performance data for analysis
if "performance_df" in locals() and len(performance_df) > 0:
    # Merge on code and period
    merged_df = pd.merge(
        weights_df, performance_df, on=["code", "period"], suffixes=("_weight", "_perf")
    )

    print(f"‚úì Merged weights and performance data: {merged_df.shape}")
    print(f"Columns: {list(merged_df.columns)}")

    # Get model columns for both weights and performance
    weight_cols = [col for col in merged_df.columns if col.endswith("_weight")]
    perf_cols = [col for col in merged_df.columns if col.endswith("_perf")]

    print(f"\nWeight columns: {weight_cols}")
    print(f"Performance columns: {perf_cols}")

else:
    print("‚ùå Performance data not available for comparison")
    merged_df = None

In [None]:
# Performance vs Weight scatter plots
if merged_df is not None:
    # Create a mapping between weight and performance columns
    model_pairs = []
    for model in model_columns:
        weight_col = (
            f"{model}_weight" if f"{model}_weight" in merged_df.columns else model
        )
        perf_col = f"{model}_perf" if f"{model}_perf" in merged_df.columns else model

        if weight_col in merged_df.columns and perf_col in merged_df.columns:
            model_pairs.append((model, weight_col, perf_col))

    if model_pairs:
        n_models = len(model_pairs)
        n_cols = min(3, n_models)
        n_rows = (n_models + n_cols - 1) // n_cols

        fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
        if n_rows == 1:
            axes = axes.reshape(1, -1)

        for i, (model, weight_col, perf_col) in enumerate(model_pairs):
            row = i // n_cols
            col = i % n_cols

            # Clean data (remove NaN values)
            clean_data = merged_df[[weight_col, perf_col]].dropna()

            if len(clean_data) > 0:
                axes[row, col].scatter(
                    clean_data[perf_col], clean_data[weight_col], alpha=0.6, s=50
                )
                axes[row, col].set_xlabel(f"{model} Performance")
                axes[row, col].set_ylabel(f"{model} Weight")
                axes[row, col].set_title(f"{model}: Performance vs Weight")
                axes[row, col].grid(True, alpha=0.3)

                # Add correlation coefficient
                corr = clean_data[perf_col].corr(clean_data[weight_col])
                axes[row, col].text(
                    0.05,
                    0.95,
                    f"r = {corr:.3f}",
                    transform=axes[row, col].transAxes,
                    bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
                )
            else:
                axes[row, col].text(
                    0.5,
                    0.5,
                    "No data available",
                    transform=axes[row, col].transAxes,
                    ha="center",
                    va="center",
                )

        # Hide unused subplots
        for i in range(len(model_pairs), n_rows * n_cols):
            row = i // n_cols
            col = i % n_cols
            axes[row, col].set_visible(False)

        plt.tight_layout()
        plt.show()
    else:
        print("‚ùå No matching performance and weight columns found")
else:
    print("‚ùå Cannot create performance vs weight plots - merged data not available")

## 6. Weight Stability Analysis

In [None]:
# Calculate weight variability across periods for each basin
weight_variability = weights_df.groupby("code")[model_columns].std()

print("=== WEIGHT VARIABILITY BY BASIN ===")
print("(Standard deviation of weights across periods)")
display(weight_variability.round(4))

# Calculate overall weight stability
mean_variability = weight_variability.mean()
print("\n=== AVERAGE WEIGHT VARIABILITY BY MODEL ===")
for model, var in mean_variability.sort_values().items():
    print(f"{model}: {var:.4f} (lower = more stable)")

In [None]:
# Plot weight stability
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Heatmap of weight variability by basin
sns.heatmap(weight_variability.T, annot=True, cmap="Reds", ax=axes[0, 0])
axes[0, 0].set_title("Weight Variability by Basin\n(Standard Deviation)")
axes[0, 0].set_xlabel("Basin Code")
axes[0, 0].set_ylabel("Model")

# 2. Mean variability by model
mean_variability.plot(kind="bar", ax=axes[0, 1], color="orange")
axes[0, 1].set_title("Average Weight Variability by Model")
axes[0, 1].set_ylabel("Standard Deviation")
axes[0, 1].tick_params(axis="x", rotation=45)

# 3. Distribution of weight variability
weight_variability.hist(bins=20, ax=axes[1, 0], alpha=0.7)
axes[1, 0].set_title("Distribution of Weight Variability")
axes[1, 0].set_xlabel("Standard Deviation")
axes[1, 0].set_ylabel("Frequency")

# 4. Coefficient of variation (CV) for each model
mean_weights_overall = weights_df[model_columns].mean()
cv_by_model = (mean_variability / mean_weights_overall) * 100
cv_by_model.plot(kind="bar", ax=axes[1, 1], color="purple")
axes[1, 1].set_title("Coefficient of Variation by Model\n(CV = std/mean * 100)")
axes[1, 1].set_ylabel("Coefficient of Variation (%)")
axes[1, 1].tick_params(axis="x", rotation=45)

plt.tight_layout()
plt.show()

## 7. Summary and Insights

In [None]:
# Generate comprehensive summary
print("=" * 60)
print("           SNOW_HISTMETA WEIGHTS ANALYSIS SUMMARY")
print("=" * 60)

print(f"\nüìä Dataset Overview:")
print(f"   ‚Ä¢ {len(model_columns)} base models analyzed")
print(f"   ‚Ä¢ {weights_df['code'].nunique()} basins")
print(f"   ‚Ä¢ {weights_df['period'].nunique()} periods")
print(f"   ‚Ä¢ {len(weights_df)} total weight combinations")

print(f"\nüèÜ Model Performance Ranking (by mean weight):")
for i, (model, weight) in enumerate(
    mean_weights.sort_values(ascending=False).items(), 1
):
    print(f"   {i}. {model}: {weight:.4f} ({weight * 100:.1f}%)")

print(f"\nüìà Most Stable Models (lowest variability):")
for i, (model, var) in enumerate(mean_variability.sort_values().items(), 1):
    print(f"   {i}. {model}: œÉ = {var:.4f}")

print(f"\nüóìÔ∏è Seasonal Insights:")
# Find periods where each model dominates
dominant_periods = weights_by_period[model_columns].idxmax(axis=1)
for model in model_columns:
    periods = dominant_periods[dominant_periods == model].index.tolist()
    if periods:
        print(f"   ‚Ä¢ {model} dominates in periods: {periods}")

print(f"\nüèûÔ∏è Basin-Specific Insights:")
dominant_model_counts = dominant_by_basin.value_counts()
for model, count in dominant_model_counts.items():
    percentage = (count / len(dominant_by_basin)) * 100
    print(f"   ‚Ä¢ {model} dominates {count} basins ({percentage:.1f}%)")

print(f"\nüìä Weight Distribution:")
total_weight_check = weights_df[model_columns].sum(axis=1)
print(
    f"   ‚Ä¢ Weight sum check: {total_weight_check.mean():.4f} ¬± {total_weight_check.std():.4f}"
)
print(f"   ‚Ä¢ Min weight across all models: {weights_df[model_columns].min().min():.4f}")
print(f"   ‚Ä¢ Max weight across all models: {weights_df[model_columns].max().max():.4f}")

print(f"\nüí° Key Findings:")
best_model = mean_weights.idxmax()
best_weight = mean_weights.max()
most_stable = mean_variability.idxmin()
most_stable_var = mean_variability.min()

print(f"   ‚Ä¢ Best overall model: {best_model} (weight: {best_weight:.4f})")
print(f"   ‚Ä¢ Most stable model: {most_stable} (variability: {most_stable_var:.4f})")
print(
    f"   ‚Ä¢ Weight balance: {'Well-balanced' if best_weight < 0.5 else 'Dominated by one model'}"
)

if merged_df is not None:
    print(f"\nüîó Performance-Weight Correlation:")
    # Calculate correlations if performance data is available
    for model in model_columns:
        weight_col = (
            f"{model}_weight" if f"{model}_weight" in merged_df.columns else model
        )
        perf_col = f"{model}_perf" if f"{model}_perf" in merged_df.columns else model

        if weight_col in merged_df.columns and perf_col in merged_df.columns:
            clean_data = merged_df[[weight_col, perf_col]].dropna()
            if len(clean_data) > 0:
                corr = clean_data[perf_col].corr(clean_data[weight_col])
                print(f"   ‚Ä¢ {model}: r = {corr:.3f}")

print("\n" + "=" * 60)

## 8. Export Results

In [None]:
# Create output directory for results
output_dir = Path("../analysis_output")
output_dir.mkdir(exist_ok=True)

# Export summary statistics
summary_stats = pd.DataFrame(
    {
        "mean_weight": mean_weights,
        "weight_variability": mean_variability,
        "coefficient_of_variation": cv_by_model,
        "dominant_basins": dominant_by_basin.value_counts().reindex(
            model_columns, fill_value=0
        ),
    }
)

summary_stats.to_csv(output_dir / "snow_histmeta_weight_summary.csv")

# Export detailed weight analysis
weights_by_period.to_csv(output_dir / "snow_histmeta_weights_by_period.csv")
weights_by_basin.to_csv(output_dir / "snow_histmeta_weights_by_basin.csv")

print(f"‚úì Results exported to {output_dir}")
print(f"Files created:")
print(f"  ‚Ä¢ snow_histmeta_weight_summary.csv")
print(f"  ‚Ä¢ snow_histmeta_weights_by_period.csv")
print(f"  ‚Ä¢ snow_histmeta_weights_by_basin.csv")

## Conclusions

This analysis provides comprehensive insights into how the Historical Performance-Weighted Meta-Learning Framework assigns weights to different base models in the Snow_HistMeta ensemble. Key takeaways include:

1. **Model Importance**: The ranking of models by average weight across all basins and periods
2. **Seasonal Patterns**: How model preferences change throughout the year
3. **Basin Specificity**: Which models work best for different basins
4. **Weight Stability**: How consistent the weights are across different conditions
5. **Performance Correlation**: The relationship between historical performance and assigned weights

These insights can be used to:
- Understand the meta-learner's decision-making process
- Identify potential improvements to base models
- Validate the meta-learning approach
- Guide future model development