In [9]:
"""
ML ANALYSIS FOR ECONOMIC COMPLEXITY
- Panel data (all years, not just 2019)
- Temporal train/test split (train: 1995-2013, test: 2014-2019)
- Three model specifications: RC Baseline, RC + Interactions, Full Structural
- High Resource Country interactions
- SHAP for interpretability
"""

import pandas as pd
import numpy as np
from sklearn.model_selection import cross_val_score, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge, Lasso, ElasticNet
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import xgboost as xgb
import lightgbm as lgb
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
import os

warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION
# ============================================================================

input_file = "/Users/leoss/Desktop/GitHub/Capstone/MASTER/Master.csv"
production_file = "/Users/leoss/Desktop/GitHub/Capstone/MASTER/NaturalResource.csv"
output_dir = "/Users/leoss/Desktop/Portfolio/Website-/capstone_visualizations/individual_plots/ml"

os.makedirs(output_dir, exist_ok=True)

TRAIN_END_YEAR = 2013  # Train: <=2013, Test: 2014-2019

HIGH_RESOURCE_COUNTRIES = [
    'AGO', 'ARE', 'AZE', 'BFA', 'BHR', 'BOL', 'CHL', 'CIV', 'CMR',
    'COD', 'COG', 'DZA', 'ECU', 'EGY', 'ETH', 'GAB', 'GHA', 'GIN',
    'GNQ', 'IDN', 'IRN', 'IRQ', 'KAZ', 'KEN', 'KWT', 'LAO', 'LBR',
    'LBY', 'MDG', 'MLI', 'MMR', 'MNG', 'MOZ', 'MWI', 'MYS', 'NER',
    'NGA', 'OMN', 'PNG', 'QAT', 'RUS', 'RWA', 'SAU', 'TCD', 'TGO',
    'TTO', 'TZA', 'UGA', 'UZB', 'VEN', 'VNM', 'YEM', 'ZMB', 'ZWE'
]

# ============================================================================
# UNIFIED STYLE — same palette as clustering/diversity scripts & site CSS
# ============================================================================

STYLE = {
    'font_family': 'IBM Plex Sans, -apple-system, BlinkMacSystemFont, sans-serif',
    'title_size': 18,
    'subtitle_size': 13,
    'axis_title_size': 13,
    'tick_size': 11,
    'legend_size': 12,
    'title_color': '#1a2744',       # --navy

    'template': 'plotly_white',
    'bg_color': 'rgba(0,0,0,0)',
    'plot_bg': 'rgba(0,0,0,0)',
    'chart_height': 650,
    'chart_height_small': 500,
    'margin': dict(l=60, r=50, t=30, b=80),
    'margin_bar': dict(l=200, r=80, t=30, b=50),

    # Positive/negative — site success/accent
    'pos_color': '#2e7d4a',         # --success
    'neg_color': '#c23a3a',         # --accent

    # Model specification colors (3 specs → 3 distinct, site-compatible tones)
    'spec_colors': {
        'RC Baseline':      '#3d4f5f',   # slate
        'RC + Interactions': '#4a6fa5',   # steel blue
        'Full Structural':   '#c23a3a',   # accent red
    },

    # Single-series bar color
    'bar_color': '#1a2744',          # navy
    'bar_color_alt': '#4a6fa5',      # steel blue (for SHAP vs importance distinction)

    'gridline_color': '#dde1e7',     # --border
    'zero_line_color': '#c9cfd6',    # --grey-300
}

WRITE_CONFIG = {'displayModeBar': False}


# --- Helpers ----------------------------------------------------------------

def styled_title(main: str = None, sub: str = None) -> dict:
    """Titles are handled in the HTML page, not in charts."""
    return dict(text='', x=0.5)


def base_layout(**overrides) -> dict:
    layout = dict(
        template=STYLE['template'],
        font=dict(family=STYLE['font_family'], size=STYLE['tick_size'],
                  color='#4b5563'),
        paper_bgcolor=STYLE['bg_color'],
        plot_bgcolor=STYLE['plot_bg'],
        height=STYLE['chart_height'],
        margin=STYLE['margin'],
        hoverlabel=dict(
            bgcolor='white',
            bordercolor='#dde1e7',
            font=dict(
                family=STYLE['font_family'],
                size=13,
                color='#1a2744',
            ),
        ),
    )
    layout.update(overrides)
    return layout


def styled_axis(title_text: str) -> dict:
    return dict(
        title=dict(text=title_text,
                   font=dict(size=STYLE['axis_title_size'],
                             family=STYLE['font_family'])),
        tickfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
    )


def save_html(fig, filename: str):
    fig.write_html(
        os.path.join(output_dir, filename),
        config=WRITE_CONFIG,
        include_plotlyjs='cdn',
    )
    print(f"   ✓ {filename}")


print("=" * 70)
print("ML ANALYSIS: PREDICTING ECONOMIC COMPLEXITY")
print(f"Panel Data with Temporal Split (Train: ≤{TRAIN_END_YEAR}, "
      f"Test: {TRAIN_END_YEAR + 1}+)")
print("Including High Resource Country Interactions")
print("=" * 70)

# ============================================================================
# 1. HELPER FUNCTIONS AND MAPPINGS
# ============================================================================


def clean_name(s):
    return (s.replace('—', '-').replace('(', '').replace(')', '')
             .replace('%', 'pct').replace(',', ''))


feature_names_display = {
    'Oil_GDP_Pct': 'Oil (% GDP)',
    'Natural Gas_GDP_Pct': 'Natural Gas (% GDP)',
    'Coal_GDP_Pct': 'Coal (% GDP)',
    'Metals_GDP_Pct': 'Metals (% GDP)',
    'Human capital index': 'Human Capital',
    'Rule of law index': 'Rule of Law',
    'Property rights': 'Property Rights',
    'Political corruption index': 'Political Corruption',
    'Political stability - estimate': 'Political Stability',
    'Landlocked': 'Landlocked',
    'Manufacturing': 'Manufacturing (% GDP)',
    'Agriculture': 'Agriculture (% GDP)',
    'Trade pct of GDP': 'Trade Openness',
    'Gross fixed capital formation all Constant prices Percent of GDP': 'Investment (% GDP)',
    'Access to electricity pct of population': 'Electricity Access',
    'Urban population pct of total population': 'Urbanization',
    'Domestic credit to private sector pct of GDP': 'Private Credit',
    'Inflation consumer prices annual pct': 'Inflation',
    'High_Resource': 'High Resource Country',
    'Oil_GDP_Pct_x_HighRes': 'Oil × High Resource',
    'NatGas_GDP_Pct_x_HighRes': 'Nat Gas × High Resource',
    'Coal_GDP_Pct_x_HighRes': 'Coal × High Resource',
    'Metals_GDP_Pct_x_HighRes': 'Metals × High Resource',
    'Total_Resources_x_HighRes': 'Total Resources × High Resource',
    'HCI_x_TotalResources': 'Human Capital × Total Resources',
}

# ============================================================================
# 2. LOAD AND PREPARE DATA
# ============================================================================

print("\n1. Loading data...")

df_master = pd.read_csv(input_file)
df_master.columns = [clean_name(c) for c in df_master.columns]
df_prod = pd.read_csv(production_file)

print(f"   Master data: {len(df_master)} rows, "
      f"{df_master['Country Code'].nunique()} countries")
print(f"   Years: {df_master['Year'].min()} – {df_master['Year'].max()}")


def categorize_resource(resource):
    if resource == 'Oil':
        return 'Oil'
    elif resource == 'Natural Gas':
        return 'Natural Gas'
    elif resource == 'Coal':
        return 'Coal'
    else:
        return 'Metals'


df_prod['Resource_Category'] = df_prod['Resource'].apply(categorize_resource)

prod_agg = df_prod.groupby(
    ['Country Name', 'Year', 'Resource_Category']
)['Production_TotalValue'].sum().reset_index()

prod_wide = prod_agg.pivot_table(
    index=['Country Name', 'Year'],
    columns='Resource_Category',
    values='Production_TotalValue',
    fill_value=0,
).reset_index()

df = df_master.merge(prod_wide, on=['Country Name', 'Year'], how='left')

for col in ['Oil', 'Natural Gas', 'Coal', 'Metals']:
    if col in df.columns:
        df[col] = df[col].fillna(0)

df['GDP_total'] = df['GDP per capita constant prices PPP'] * df['Population']
for res in ['Oil', 'Natural Gas', 'Coal', 'Metals']:
    if res in df.columns:
        df[f'{res}_GDP_Pct'] = (df[res] / df['GDP_total']) * 100
        df[f'{res}_GDP_Pct'] = df[f'{res}_GDP_Pct'].replace([np.inf, -np.inf], np.nan)

df['Total_Resources_GDP_Pct'] = df[[
    'Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct', 'Metals_GDP_Pct'
]].sum(axis=1)

# ============================================================================
# 3. HIGH RESOURCE DUMMY AND INTERACTIONS
# ============================================================================

print("\n2. Creating high resource country dummy and interactions...")

df['High_Resource'] = df['Country Code'].isin(HIGH_RESOURCE_COUNTRIES).astype(int)
df['HCI_x_TotalResources'] = (
    df['Human capital index'] * df['Total_Resources_GDP_Pct']
)

n_high_res = df[df['High_Resource'] == 1]['Country Code'].nunique()
n_other = df[df['High_Resource'] == 0]['Country Code'].nunique()
print(f"   High resource countries: {n_high_res}")
print(f"   Other countries: {n_other}")
print(f"   Merged data: {len(df)} rows")

# ============================================================================
# 4. DEFINE FEATURE SETS
# ============================================================================

print("\n3. Defining feature sets...")

target = 'Economic Complexity Index'

features_resource_curse_raw = [
    'Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct', 'Metals_GDP_Pct',
    'Human capital index', 'Rule of law index', 'Property rights', 'Landlocked',
]

features_rc_interactions = [
    'Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct', 'Metals_GDP_Pct',
    'High_Resource', 'Human capital index', 'HCI_x_TotalResources',
    'Rule of law index', 'Property rights', 'Landlocked',
]

features_full_raw = [
    'Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct', 'Metals_GDP_Pct',
    'Manufacturing', 'Agriculture', 'Trade (% of GDP)',
    'Gross fixed capital formation, all, Constant prices, Percent of GDP',
    'Human capital index', 'Access to electricity (% of population)',
    'Urban population (% of total population)', 'Rule of law index',
    'Property rights', 'Political stability — estimate',
    'Domestic credit to private sector (% of GDP)',
    'Inflation, consumer prices (annual %)', 'Landlocked',
]

features_resource_curse = [clean_name(f) for f in features_resource_curse_raw]
features_full = [clean_name(f) for f in features_full_raw]

feature_names_clean = {clean_name(k): v for k, v in feature_names_display.items()}

print("\nChecking feature availability...")
all_features = set(features_resource_curse + features_rc_interactions + features_full)
missing = [f for f in all_features if f not in df.columns]

if missing:
    print(f"   ⚠ Missing: {missing}")
else:
    print("   ✓ All features found")

print(f"\n   RC Baseline: {len(features_resource_curse)} features")
print(f"   RC + Interactions: {len(features_rc_interactions)} features")
print(f"   Full Structural: {len(features_full)} features")

# ============================================================================
# 5. TEMPORAL SPLIT
# ============================================================================

print("\n4. Preparing train/test split...")


def prepare_split(df_src, features, label):
    d = df_src[['Country Code', 'Country Name', 'Year', target] + features].dropna()
    train = d[d['Year'] <= TRAIN_END_YEAR]
    test = d[d['Year'] > TRAIN_END_YEAR]
    X_tr, y_tr = train[features], train[target]
    X_te, y_te = test[features], test[target]

    scaler = StandardScaler()
    X_tr_s = scaler.fit_transform(X_tr)
    X_te_s = scaler.transform(X_te)

    print(f"\n   {label}: Train {len(train)}, Test {len(test)}")
    return X_tr, X_te, y_tr, y_te, X_tr_s, X_te_s, train, test, scaler


(X_train_rc, X_test_rc, y_train_rc, y_test_rc,
 X_train_rc_s, X_test_rc_s, train_rc, test_rc, scaler_rc) = \
    prepare_split(df, features_resource_curse, "RC BASELINE")

(X_train_ri, X_test_ri, y_train_ri, y_test_ri,
 X_train_ri_s, X_test_ri_s, train_ri, test_ri, scaler_ri) = \
    prepare_split(df, features_rc_interactions, "RC + INTERACTIONS")

(X_train_fu, X_test_fu, y_train_fu, y_test_fu,
 X_train_fu_s, X_test_fu_s, train_fu, test_fu, scaler_fu) = \
    prepare_split(df, features_full, "FULL STRUCTURAL")

# ============================================================================
# 6. TRAIN MODELS
# ============================================================================

print("\n5. Training models...")


def train_and_evaluate(X_train, X_test, y_train, y_test,
                       X_train_scaled, X_test_scaled):
    results = {}

    models_scaled = {
        'Ridge': Ridge(alpha=1.0, random_state=42),
        'Lasso': Lasso(alpha=0.1, random_state=42),
        'ElasticNet': ElasticNet(alpha=0.1, l1_ratio=0.5, random_state=42),
    }

    models_unscaled = {
        'Random Forest': RandomForestRegressor(
            n_estimators=200, max_depth=10, min_samples_split=10,
            min_samples_leaf=5, random_state=42, n_jobs=-1),
        'Gradient Boosting': GradientBoostingRegressor(
            n_estimators=200, learning_rate=0.05, max_depth=5,
            min_samples_split=10, min_samples_leaf=5, random_state=42),
        'XGBoost': xgb.XGBRegressor(
            n_estimators=200, learning_rate=0.05, max_depth=5,
            min_child_weight=5, subsample=0.8, colsample_bytree=0.8,
            random_state=42, objective='reg:squarederror', verbosity=0),
        'LightGBM': lgb.LGBMRegressor(
            n_estimators=200, learning_rate=0.05, max_depth=5,
            num_leaves=31, min_child_samples=10, subsample=0.8,
            colsample_bytree=0.8, random_state=42, verbose=-1),
    }

    for name, model in models_scaled.items():
        model.fit(X_train_scaled, y_train)
        yp_tr = model.predict(X_train_scaled)
        yp_te = model.predict(X_test_scaled)
        results[name] = {
            'model': model, 'scaled': True,
            'train_r2': r2_score(y_train, yp_tr),
            'test_r2': r2_score(y_test, yp_te),
            'test_rmse': np.sqrt(mean_squared_error(y_test, yp_te)),
            'test_mae': mean_absolute_error(y_test, yp_te),
            'predictions': yp_te,
        }
        if hasattr(model, 'coef_'):
            results[name]['coefficients'] = model.coef_

    for name, model in models_unscaled.items():
        model.fit(X_train, y_train)
        yp_tr = model.predict(X_train)
        yp_te = model.predict(X_test)
        results[name] = {
            'model': model, 'scaled': False,
            'train_r2': r2_score(y_train, yp_tr),
            'test_r2': r2_score(y_test, yp_te),
            'test_rmse': np.sqrt(mean_squared_error(y_test, yp_te)),
            'test_mae': mean_absolute_error(y_test, yp_te),
            'predictions': yp_te,
        }
        if hasattr(model, 'feature_importances_'):
            results[name]['feature_importance'] = model.feature_importances_

    return results


print("   Training RC Baseline...")
results_rc = train_and_evaluate(
    X_train_rc, X_test_rc, y_train_rc, y_test_rc,
    X_train_rc_s, X_test_rc_s)

print("   Training RC + Interactions...")
results_ri = train_and_evaluate(
    X_train_ri, X_test_ri, y_train_ri, y_test_ri,
    X_train_ri_s, X_test_ri_s)

print("   Training Full Structural...")
results_fu = train_and_evaluate(
    X_train_fu, X_test_fu, y_train_fu, y_test_fu,
    X_train_fu_s, X_test_fu_s)

# ============================================================================
# 7. CROSS-VALIDATION
# ============================================================================

print("\n6. Cross-validation (5-fold)...")

kf = KFold(n_splits=5, shuffle=True, random_state=42)


def run_cv(results, X_train, X_train_scaled, y_train):
    cv_out = {}
    for name, res in results.items():
        X_cv = X_train_scaled if res['scaled'] else X_train
        scores = cross_val_score(res['model'], X_cv, y_train, cv=kf, scoring='r2')
        cv_out[name] = {'mean': scores.mean(), 'std': scores.std()}
    return cv_out


cv_rc = run_cv(results_rc, X_train_rc, X_train_rc_s, y_train_rc)
cv_ri = run_cv(results_ri, X_train_ri, X_train_ri_s, y_train_ri)
cv_fu = run_cv(results_fu, X_train_fu, X_train_fu_s, y_train_fu)

# ============================================================================
# 8. RESULTS TABLES
# ============================================================================


def build_comparison(results, cv_results):
    rows = []
    for name, res in results.items():
        rows.append({
            'Model': name,
            'Train R²': res['train_r2'],
            'Test R²': res['test_r2'],
            'CV R²': cv_results[name]['mean'],
            'Test RMSE': res['test_rmse'],
            'Overfit': res['train_r2'] - res['test_r2'],
        })
    return pd.DataFrame(rows).sort_values('Test R²', ascending=False)


df_comp_rc = build_comparison(results_rc, cv_rc)
df_comp_ri = build_comparison(results_ri, cv_ri)
df_comp_fu = build_comparison(results_fu, cv_fu)

for label, df_c in [("RC BASELINE", df_comp_rc),
                     ("RC + INTERACTIONS", df_comp_ri),
                     ("FULL STRUCTURAL", df_comp_fu)]:
    print(f"\n{'=' * 70}\n{label}\n{'=' * 70}")
    print(df_c.to_string(index=False))

# ============================================================================
# 9. INTERACTION EFFECTS ANALYSIS
# ============================================================================

print("\n" + "=" * 70)
print("INTERACTION EFFECTS (RIDGE COEFFICIENTS)")
print("=" * 70)

ridge_coefs_int = pd.DataFrame({
    'Feature': features_rc_interactions,
    'Coefficient': results_ri['Ridge']['coefficients'],
}).sort_values('Coefficient', key=abs, ascending=False)

print(ridge_coefs_int.to_string(index=False))

coef_dict = dict(zip(features_rc_interactions,
                      results_ri['Ridge']['coefficients']))

print(f"\n   Base Oil effect:       {coef_dict.get('Oil_GDP_Pct', 0):.4f}")
print(f"   Base Nat Gas effect:   {coef_dict.get('Natural Gas_GDP_Pct', 0):.4f}")
print(f"   Base Metals effect:    {coef_dict.get('Metals_GDP_Pct', 0):.4f}")
print(f"   High Resource dummy:   {coef_dict.get('High_Resource', 0):.4f}")
hci_x = coef_dict.get('HCI_x_TotalResources', 0)
print(f"   HCI × Total Resources: {hci_x:.4f}")
if hci_x > 0:
    print("   → Human capital returns amplified in resource-rich contexts")
else:
    print("   → Human capital returns diminished in resource-rich contexts")

# ============================================================================
# 10. FEATURE IMPORTANCE / COEFFICIENTS
# ============================================================================

tree_models = ['Random Forest', 'Gradient Boosting', 'XGBoost', 'LightGBM']
best_tree_rc = df_comp_rc[df_comp_rc['Model'].isin(tree_models)].iloc[0]['Model']
best_tree_ri = df_comp_ri[df_comp_ri['Model'].isin(tree_models)].iloc[0]['Model']
best_tree_fu = df_comp_fu[df_comp_fu['Model'].isin(tree_models)].iloc[0]['Model']

print(f"\n   Best tree (RC baseline): {best_tree_rc}")
print(f"   Best tree (RC + int.):   {best_tree_ri}")
print(f"   Best tree (Full):        {best_tree_fu}")

ridge_coefs_rc = pd.DataFrame({
    'Feature': features_resource_curse,
    'Coefficient': results_rc['Ridge']['coefficients'],
}).sort_values('Coefficient', key=abs, ascending=False)

# ============================================================================
# 11. VISUALIZATIONS
# ============================================================================

print("\n7. Creating visualizations...")

# ---------- A. MODEL COMPARISON — grouped horizontal bar chart ---------------
# One row per algorithm, three bars per row (one per specification).
# This makes cross-spec comparison immediate: you see the same algorithm
# across all three specs on the same row.

algorithms = list(results_rc.keys())  # same 7 for each spec
# Sort by best average Test R² across specs
algo_avg = []
for alg in algorithms:
    avg = np.mean([results_rc[alg]['test_r2'],
                    results_ri[alg]['test_r2'],
                    results_fu[alg]['test_r2']])
    algo_avg.append((alg, avg))
algorithms_sorted = [a for a, _ in sorted(algo_avg, key=lambda x: x[1])]

spec_colors = STYLE['spec_colors']

fig_comp = go.Figure()

for spec_name, results_dict in [('RC Baseline', results_rc),
                                 ('RC + Interactions', results_ri),
                                 ('Full Structural', results_fu)]:
    r2_vals = [results_dict[alg]['test_r2'] for alg in algorithms_sorted]
    fig_comp.add_trace(go.Bar(
        y=algorithms_sorted,
        x=r2_vals,
        orientation='h',
        name=spec_name,
        marker_color=spec_colors[spec_name],
        text=[f"{v:.3f}" for v in r2_vals],
        textposition='outside',
        textfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
    ))

fig_comp.update_layout(
    **base_layout(
        height=STYLE['chart_height_small'],
        margin=dict(l=160, r=80, t=30, b=50),
    ),
    barmode='group',
    title=styled_title(
        'Model Performance Comparison',
        f'Test R² on held-out data ({TRAIN_END_YEAR + 1}–2019)',
    ),
    xaxis=dict(**styled_axis('Test R²'), range=[0, 1]),
    yaxis=dict(tickfont=dict(size=STYLE['tick_size'],
                             family=STYLE['font_family'])),
    legend=dict(
        orientation='h', yanchor='top', y=-0.12, xanchor='center', x=0.5,
        font=dict(size=STYLE['legend_size'], family=STYLE['font_family']),
    ),
)

save_html(fig_comp, 'ml_model_comparison.html')

# ---------- B. RIDGE COEFFICIENTS — RC + Interactions -----------------------

coef_int = pd.DataFrame({
    'Feature': [feature_names_clean.get(f, f) for f in features_rc_interactions],
    'Coefficient': results_ri['Ridge']['coefficients'],
}).sort_values('Coefficient')

fig_coef = go.Figure()
fig_coef.add_trace(go.Bar(
    y=coef_int['Feature'],
    x=coef_int['Coefficient'],
    orientation='h',
    marker_color=[STYLE['pos_color'] if c > 0 else STYLE['neg_color']
                  for c in coef_int['Coefficient']],
    text=[f"{x:.3f}" for x in coef_int['Coefficient']],
    textposition='outside',
    textfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
))
fig_coef.add_vline(x=0, line_dash='dash', line_color=STYLE['zero_line_color'])

fig_coef.update_layout(
    **base_layout(margin=STYLE['margin_bar']),
    title=styled_title(
        'Ridge Coefficients — RC + Interactions',
        'Standardized coefficients, positive = higher predicted ECI',
    ),
    xaxis=styled_axis('Coefficient (standardized)'),
    yaxis=dict(tickfont=dict(size=STYLE['tick_size'],
                             family=STYLE['font_family'])),
)

save_html(fig_coef, 'ml_coefficients_interactions.html')

# ---------- C. FEATURE IMPORTANCE — RC + Interactions (best tree model) ------

importance_int = pd.DataFrame({
    'Feature': [feature_names_clean.get(f, f) for f in features_rc_interactions],
    'Importance': results_ri[best_tree_ri]['feature_importance'],
}).sort_values('Importance', ascending=True)

fig_imp = go.Figure()
fig_imp.add_trace(go.Bar(
    y=importance_int['Feature'],
    x=importance_int['Importance'],
    orientation='h',
    marker_color=STYLE['bar_color'],
    text=[f"{x:.3f}" for x in importance_int['Importance']],
    textposition='outside',
    textfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
))

fig_imp.update_layout(
    **base_layout(margin=STYLE['margin_bar']),
    title=styled_title(
        f'Feature Importance — RC + Interactions',
        f'{best_tree_ri}, split-based importance',
    ),
    xaxis=styled_axis('Importance'),
    yaxis=dict(tickfont=dict(size=STYLE['tick_size'],
                             family=STYLE['font_family'])),
)

save_html(fig_imp, 'ml_feature_importance_interactions.html')

# ---------- C2. FEATURE IMPORTANCE — RC Baseline (best tree model) -----------

importance_rc = pd.DataFrame({
    'Feature': [feature_names_clean.get(f, f) for f in features_resource_curse],
    'Importance': results_rc[best_tree_rc]['feature_importance'],
}).sort_values('Importance', ascending=True)

fig_imp_rc = go.Figure()
fig_imp_rc.add_trace(go.Bar(
    y=importance_rc['Feature'],
    x=importance_rc['Importance'],
    orientation='h',
    marker_color=STYLE['spec_colors']['RC Baseline'],
    text=[f"{x:.3f}" for x in importance_rc['Importance']],
    textposition='outside',
    textfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
))

fig_imp_rc.update_layout(
    **base_layout(margin=STYLE['margin_bar']),
    title=styled_title(
        'Feature Importance — RC Baseline',
        f'{best_tree_rc}, split-based importance',
    ),
    xaxis=styled_axis('Importance'),
    yaxis=dict(tickfont=dict(size=STYLE['tick_size'],
                             family=STYLE['font_family'])),
)

save_html(fig_imp_rc, 'ml_feature_importance_resource_curse.html')

# ---------- C3. FEATURE IMPORTANCE — Full Structural (best tree model) -------

importance_fu = pd.DataFrame({
    'Feature': [feature_names_clean.get(f, f) for f in features_full],
    'Importance': results_fu[best_tree_fu]['feature_importance'],
}).sort_values('Importance', ascending=True)

fig_imp_fu = go.Figure()
fig_imp_fu.add_trace(go.Bar(
    y=importance_fu['Feature'],
    x=importance_fu['Importance'],
    orientation='h',
    marker_color=STYLE['spec_colors']['Full Structural'],
    text=[f"{x:.3f}" for x in importance_fu['Importance']],
    textposition='outside',
    textfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
))

fig_imp_fu.update_layout(
    **base_layout(height=STYLE['chart_height'], margin=STYLE['margin_bar']),
    title=styled_title(
        'Feature Importance — Full Structural',
        f'{best_tree_fu}, split-based importance',
    ),
    xaxis=styled_axis('Importance'),
    yaxis=dict(tickfont=dict(size=STYLE['tick_size'],
                             family=STYLE['font_family'])),
)

save_html(fig_imp_fu, 'ml_feature_importance_full.html')

# ---------- E. PREDICTED VS ACTUAL — best interaction model ------------------

print("\n   Creating predicted vs actual scatter...")

# Use the best tree model from RC + Interactions
best_ri_name = df_comp_ri.iloc[0]['Model']
best_ri_preds = results_ri[best_ri_name]['predictions']

fig_pva = go.Figure()

fig_pva.add_trace(go.Scatter(
    x=y_test_ri.values,
    y=best_ri_preds,
    mode='markers',
    marker=dict(
        size=6, color=STYLE['bar_color_alt'],
        line=dict(width=0.5, color='white'), opacity=0.7,
    ),
    text=test_ri['Country Name'].values,
    customdata=test_ri['Year'].values,
    hovertemplate=(
        '<b>%{text}</b> (%{customdata})<br>'
        'Actual: %{x:.2f}<br>Predicted: %{y:.2f}<extra></extra>'
    ),
))

# 45-degree reference line
eci_min = min(y_test_ri.min(), best_ri_preds.min()) - 0.2
eci_max = max(y_test_ri.max(), best_ri_preds.max()) + 0.2
fig_pva.add_trace(go.Scatter(
    x=[eci_min, eci_max], y=[eci_min, eci_max],
    mode='lines',
    line=dict(dash='dash', color=STYLE['zero_line_color'], width=1.5),
    showlegend=False,
))

fig_pva.update_layout(
    **base_layout(),
    title=styled_title(
        'Predicted vs Actual ECI',
        f'{best_ri_name} (RC + Interactions), test set {TRAIN_END_YEAR + 1}–2019',
    ),
    xaxis=styled_axis('Actual ECI'),
    yaxis=styled_axis('Predicted ECI'),
)

save_html(fig_pva, 'ml_predicted_vs_actual.html')

# ---------- F. RESIDUALS BY COUNTRY -----------------------------------------

print("   Creating residuals by country chart...")

residuals_df = pd.DataFrame({
    'Country Name': test_ri['Country Name'].values,
    'Country Code': test_ri['Country Code'].values,
    'Year': test_ri['Year'].values,
    'Actual': y_test_ri.values,
    'Predicted': best_ri_preds,
})
residuals_df['Residual'] = residuals_df['Actual'] - residuals_df['Predicted']

# Mean absolute residual per country, then show worst 25
country_residuals = residuals_df.groupby('Country Name').agg(
    Mean_Residual=('Residual', 'mean'),
    MAE=('Residual', lambda x: np.abs(x).mean()),
    N=('Residual', 'count'),
).reset_index().sort_values('MAE', ascending=False)

top_residuals = country_residuals.head(25).sort_values('Mean_Residual')

fig_res = go.Figure()
fig_res.add_trace(go.Bar(
    y=top_residuals['Country Name'],
    x=top_residuals['Mean_Residual'],
    orientation='h',
    marker_color=[STYLE['pos_color'] if r > 0 else STYLE['neg_color']
                  for r in top_residuals['Mean_Residual']],
    text=[f"{x:+.2f}" for x in top_residuals['Mean_Residual']],
    textposition='outside',
    textfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
))
fig_res.add_vline(x=0, line_dash='dash', line_color=STYLE['zero_line_color'])

fig_res.update_layout(
    **base_layout(height=STYLE['chart_height'], margin=dict(l=160, r=80, t=30, b=50)),
    title=styled_title(
        'Mean Prediction Error by Country',
        f'Top 25 by absolute error — positive = model underpredicts ECI',
    ),
    xaxis=styled_axis('Mean Residual (Actual − Predicted)'),
    yaxis=dict(tickfont=dict(size=STYLE['tick_size'],
                             family=STYLE['font_family'])),
)

save_html(fig_res, 'ml_residuals_by_country.html')

# ---------- D. SHAP ANALYSIS (RC + Interactions, LightGBM) ------------------

print("\n8. SHAP analysis...")

try:
    import shap

    shap_model = results_ri['LightGBM']['model']
    explainer = shap.TreeExplainer(shap_model)
    shap_values = explainer.shap_values(X_test_ri.astype(float))

    shap_df = pd.DataFrame({
        'Feature': [feature_names_clean.get(f, f) for f in features_rc_interactions],
        'Mean_SHAP': np.abs(shap_values).mean(axis=0),
    }).sort_values('Mean_SHAP', ascending=True)

    fig_shap = go.Figure()
    fig_shap.add_trace(go.Bar(
        y=shap_df['Feature'],
        x=shap_df['Mean_SHAP'],
        orientation='h',
        marker_color=STYLE['bar_color_alt'],
        text=[f"{x:.3f}" for x in shap_df['Mean_SHAP']],
        textposition='outside',
        textfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
    ))

    fig_shap.update_layout(
        **base_layout(margin=STYLE['margin_bar']),
        title=styled_title(
            'SHAP Feature Importance — RC + Interactions',
            'Mean |SHAP value| on test set, LightGBM',
        ),
        xaxis=styled_axis('Mean |SHAP Value|'),
        yaxis=dict(tickfont=dict(size=STYLE['tick_size'],
                                 family=STYLE['font_family'])),
    )

    save_html(fig_shap, 'ml_shap_importance.html')

except ImportError:
    print("   ⚠ SHAP not installed.")
except Exception as e:
    print(f"   ⚠ SHAP error: {e}")

# ============================================================================
# 12. SAVE RESULTS
# ============================================================================

print("\n9. Saving results...")

df_comp_rc.to_csv(os.path.join(output_dir, 'ml_comparison_resource_curse.csv'), index=False)
df_comp_ri.to_csv(os.path.join(output_dir, 'ml_comparison_rc_interactions.csv'), index=False)
df_comp_fu.to_csv(os.path.join(output_dir, 'ml_comparison_full.csv'), index=False)

ridge_coefs_rc.to_csv(os.path.join(output_dir, 'ml_ridge_coefficients_rc.csv'), index=False)
ridge_coefs_int.to_csv(os.path.join(output_dir, 'ml_ridge_coefficients_interactions.csv'), index=False)

importance_int.to_csv(os.path.join(output_dir, 'ml_feature_importance_interactions.csv'), index=False)
importance_rc.to_csv(os.path.join(output_dir, 'ml_feature_importance_rc.csv'), index=False)
importance_fu.to_csv(os.path.join(output_dir, 'ml_feature_importance_full.csv'), index=False)
residuals_df.to_csv(os.path.join(output_dir, 'ml_residuals.csv'), index=False)

print("   ✓ All results saved")

# ============================================================================
# COMPREHENSIVE RESULTS SUMMARY
# (copy-paste this output to update the HTML page)
# ============================================================================

print("\n" + "=" * 70)
print("ML COMPREHENSIVE RESULTS SUMMARY")
print("=" * 70)

# --- Data overview ---
print(f"\n{'—' * 50}")
print("DATA")
print(f"{'—' * 50}")
print(f"  Train period: ≤{TRAIN_END_YEAR}")
print(f"  Test period: {TRAIN_END_YEAR + 1}–2019")
print(f"  High resource countries: {n_high_res}")
print(f"  Other countries: {n_other}")
print(f"  RC Baseline — Train: {len(X_train_rc)}, Test: {len(X_test_rc)}")
print(f"  RC + Interactions — Train: {len(X_train_ri)}, Test: {len(X_test_ri)}")
print(f"  Full Structural — Train: {len(X_train_fu)}, Test: {len(X_test_fu)}")

# --- Full model comparison tables ---
for label, df_c in [("RC BASELINE", df_comp_rc),
                     ("RC + INTERACTIONS", df_comp_ri),
                     ("FULL STRUCTURAL", df_comp_fu)]:
    print(f"\n{'—' * 50}")
    print(f"{label}")
    print(f"{'—' * 50}")
    for _, row in df_c.iterrows():
        print(f"  {row['Model']:22s}  Train R²={row['Train R²']:.3f}  "
              f"Test R²={row['Test R²']:.3f}  CV R²={row['CV R²']:.3f}  "
              f"RMSE={row['Test RMSE']:.3f}  Overfit={row['Overfit']:.3f}")

# --- Best models ---
print(f"\n{'—' * 50}")
print("BEST MODELS")
print(f"{'—' * 50}")
baseline_r2 = df_comp_rc.iloc[0]['Test R²']
interaction_r2 = df_comp_ri.iloc[0]['Test R²']
full_r2 = df_comp_fu.iloc[0]['Test R²']
improvement = interaction_r2 - baseline_r2

print(f"  RC Baseline:      {df_comp_rc.iloc[0]['Model']:22s}  Test R²={baseline_r2:.3f}")
print(f"  RC + Interactions: {df_comp_ri.iloc[0]['Model']:22s}  Test R²={interaction_r2:.3f}")
print(f"  Full Structural:   {df_comp_fu.iloc[0]['Model']:22s}  Test R²={full_r2:.3f}")
print(f"  Interaction improvement over baseline: {improvement:+.3f}")

# --- Ridge coefficients (baseline) ---
print(f"\n{'—' * 50}")
print("RIDGE COEFFICIENTS — RC BASELINE")
print(f"{'—' * 50}")
for _, row in ridge_coefs_rc.iterrows():
    display = feature_names_clean.get(row['Feature'], row['Feature'])
    print(f"  {display:40s}  {row['Coefficient']:+.4f}")

# --- Ridge coefficients (interactions) ---
print(f"\n{'—' * 50}")
print("RIDGE COEFFICIENTS — RC + INTERACTIONS")
print(f"{'—' * 50}")
for _, row in ridge_coefs_int.iterrows():
    display = feature_names_clean.get(row['Feature'], row['Feature'])
    print(f"  {display:40s}  {row['Coefficient']:+.4f}")

# --- Interaction decomposition ---
print(f"\n{'—' * 50}")
print("INTERACTION DECOMPOSITION")
print(f"{'—' * 50}")
coef_dict = dict(zip(features_rc_interactions,
                      results_ri['Ridge']['coefficients']))

for res_feat, res_label in [('Oil_GDP_Pct', 'Oil'),
                             ('Natural Gas_GDP_Pct', 'Natural Gas'),
                             ('Coal_GDP_Pct', 'Coal'),
                             ('Metals_GDP_Pct', 'Metals')]:
    base = coef_dict.get(res_feat, 0)
    print(f"\n  {res_label}:")
    print(f"    Base effect (all countries):    {base:+.4f}")
    print(f"    Total for high-resource:        {base:+.4f}")

print(f"\n  High Resource dummy (intercept):   {coef_dict.get('High_Resource', 0):+.4f}")
hci_x = coef_dict.get('HCI_x_TotalResources', 0)
print(f"  HCI × Total Resources:             {hci_x:+.4f}")
if hci_x > 0:
    print("  → Human capital returns amplified in resource-rich contexts")
else:
    print("  → Human capital returns diminished in resource-rich contexts")

# --- Feature importance (all 3 specs) ---
print(f"\n{'—' * 50}")
print(f"FEATURE IMPORTANCE — RC BASELINE ({best_tree_rc})")
print(f"{'—' * 50}")
for _, row in importance_rc.sort_values('Importance', ascending=False).iterrows():
    print(f"  {row['Feature']:40s}  {row['Importance']:.3f}")

print(f"\n{'—' * 50}")
print(f"FEATURE IMPORTANCE — RC + INTERACTIONS ({best_tree_ri})")
print(f"{'—' * 50}")
for _, row in importance_int.sort_values('Importance', ascending=False).iterrows():
    print(f"  {row['Feature']:40s}  {row['Importance']:.3f}")

print(f"\n{'—' * 50}")
print(f"FEATURE IMPORTANCE — FULL STRUCTURAL ({best_tree_fu})")
print(f"{'—' * 50}")
for _, row in importance_fu.sort_values('Importance', ascending=False).iterrows():
    print(f"  {row['Feature']:40s}  {row['Importance']:.3f}")

# --- Worst predictions ---
print(f"\n{'—' * 50}")
print("LARGEST PREDICTION ERRORS (top 15 by MAE)")
print(f"{'—' * 50}")
for _, row in country_residuals.head(15).iterrows():
    direction = "underpredicted" if row['Mean_Residual'] > 0 else "overpredicted"
    print(f"  {row['Country Name']:30s}  mean error={row['Mean_Residual']:+.2f}  "
          f"MAE={row['MAE']:.2f}  ({direction})")

print(f"\n{'=' * 70}")
print("END OF ML SUMMARY")
print("=" * 70)

ML ANALYSIS: PREDICTING ECONOMIC COMPLEXITY
Panel Data with Temporal Split (Train: ≤2013, Test: 2014+)
Including High Resource Country Interactions

1. Loading data...
   Master data: 3150 rows, 126 countries
   Years: 1995 – 2019

2. Creating high resource country dummy and interactions...
   High resource countries: 54
   Other countries: 72
   Merged data: 3150 rows

3. Defining feature sets...

Checking feature availability...
   ✓ All features found

   RC Baseline: 8 features
   RC + Interactions: 10 features
   Full Structural: 17 features

4. Preparing train/test split...

   RC BASELINE: Train 2394, Test 756

   RC + INTERACTIONS: Train 2394, Test 756

   FULL STRUCTURAL: Train 2394, Test 756

5. Training models...
   Training RC Baseline...
   Training RC + Interactions...
   Training Full Structural...

6. Cross-validation (5-fold)...

RC BASELINE
            Model  Train R²  Test R²    CV R²  Test RMSE  Overfit
          XGBoost  0.964305 0.803927 0.925624   0.455613 0.1603

In [10]:
"""
PANEL FIXED EFFECTS REGRESSIONS - TWO APPROACHES
==================================================
Approach A: Full sample (126 countries), PanelOLS, 4 resource categories / GDP
            Country + year FE, entity-clustered SE
Approach B: High-resource subsample (54 countries), OLS + dummies, log-log spec
            SE clustered by country

Dependent variable:
  - Approach A: Economic Complexity Index (levels)
  - Approach B: log(ECI + shift) where shift ensures positivity

Specifications:
  A1: RC Contemporaneous
  A2: RC Lag-1
  A3: RC + Interactions Lag-1
  A4: Full Structural Lag-1
  A5: Dynamic Lag-1 (includes lagged ECI)

  B1: Log-log (country FE)
  B2: + Interactions (HCI x Production, GFCF x Production)
  B3: + Lagged ECI
  B4: Two-way FE (levels, not log-log)
  B5: Fully lagged L1 (log-log)
  B6: Fully lagged L1 (two-way FE, levels)
"""

import pandas as pd
import numpy as np
import statsmodels.api as sm
from linearmodels.panel import PanelOLS
import warnings
import os

warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION
# ============================================================================

input_file = "/Users/leoss/Desktop/GitHub/Capstone/MASTER/Master.csv"
production_file = "/Users/leoss/Desktop/GitHub/Capstone/MASTER/NaturalResource.csv"
output_dir = "/Users/leoss/Desktop/Portfolio/Website-/capstone_visualizations/individual_plots/ml"

os.makedirs(output_dir, exist_ok=True)

HIGH_RESOURCE_COUNTRIES = [
    'AGO', 'ARE', 'AZE', 'BFA', 'BHR', 'BOL', 'CHL', 'CIV', 'CMR',
    'COD', 'COG', 'DZA', 'ECU', 'EGY', 'ETH', 'GAB', 'GHA', 'GIN',
    'GNQ', 'IDN', 'IRN', 'IRQ', 'KAZ', 'KEN', 'KWT', 'LAO', 'LBR',
    'LBY', 'MDG', 'MLI', 'MMR', 'MNG', 'MOZ', 'MWI', 'MYS', 'NER',
    'NGA', 'OMN', 'PNG', 'QAT', 'RUS', 'RWA', 'SAU', 'TCD', 'TGO',
    'TTO', 'TZA', 'UGA', 'UZB', 'VEN', 'VNM', 'YEM', 'ZMB', 'ZWE',
]

print("=" * 70)
print("PANEL FIXED EFFECTS REGRESSIONS - TWO APPROACHES")
print("=" * 70)

# ============================================================================
# 1. LOAD AND PREPARE DATA
# ============================================================================

print("\n1. Loading data...")

df_master = pd.read_csv(input_file)
df_prod = pd.read_csv(production_file)


def categorize_resource(resource):
    if resource == 'Oil':
        return 'Oil'
    elif resource == 'Natural Gas':
        return 'Natural Gas'
    elif resource == 'Coal':
        return 'Coal'
    else:
        return 'Metals'


df_prod['Resource_Category'] = df_prod['Resource'].apply(categorize_resource)

prod_agg = df_prod.groupby(
    ['Country Name', 'Year', 'Resource_Category']
)['Production_TotalValue'].sum().reset_index()

prod_wide = prod_agg.pivot_table(
    index=['Country Name', 'Year'],
    columns='Resource_Category',
    values='Production_TotalValue',
    fill_value=0,
).reset_index()

df = df_master.merge(prod_wide, on=['Country Name', 'Year'], how='left')

for col in ['Oil', 'Natural Gas', 'Coal', 'Metals']:
    if col in df.columns:
        df[col] = df[col].fillna(0)

# --- GDP-normalized resource shares (Approach A) ---
df['GDP_total'] = (
    df['GDP per capita (constant prices, PPP)'] * df['Population']
)
for res in ['Oil', 'Natural Gas', 'Coal', 'Metals']:
    if res in df.columns:
        df[f'{res}_GDP_Pct'] = (df[res] / df['GDP_total']) * 100
        df[f'{res}_GDP_Pct'] = df[f'{res}_GDP_Pct'].replace(
            [np.inf, -np.inf], np.nan
        )

df['Total_Resources_GDP_Pct'] = df[[
    'Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct', 'Metals_GDP_Pct'
]].sum(axis=1)

# --- High resource dummy ---
df['High_Resource'] = df['Country Code'].isin(HIGH_RESOURCE_COUNTRIES).astype(int)

# --- Approach B: per-capita production value (already in master) ---
df['Total_Production_Value_Per_Capita'] = df['Total_Production_Value'] / df['Population']

n_full = df['Country Code'].nunique()
n_high = df[df['High_Resource'] == 1]['Country Code'].nunique()
print(f"   Full sample: {len(df)} rows, {n_full} countries")
print(f"   High-resource subsample: {n_high} countries")

# --- Log transforms for Approach B ---
eci_min = df['Economic Complexity Index'].min()
ECI_SHIFT = abs(eci_min) + 0.1 if eci_min <= 0 else 0
print(f"   ECI shift for log: {ECI_SHIFT:.3f}")

df['log_ECI'] = np.log(df['Economic Complexity Index'] + ECI_SHIFT)
df['log_HCI'] = np.log(df['Human capital index'].clip(lower=1e-6))
df['log_GFCF'] = np.log(
    df['Gross fixed capital formation, all, Constant prices, Percent of GDP'].clip(lower=1e-6)
)
df['log_Production_Value'] = np.log(
    df['Total_Production_Value_Per_Capita'].clip(lower=1e-6)
)

# --- Interaction terms (Approach A: HCI x total resources) ---
df['HCI_x_TotalResources'] = (
    df['Human capital index'] * df['Total_Resources_GDP_Pct']
)

# --- Interaction terms (Approach B: mean-centered log interactions) ---
df['HCI_x_ResourceProduction'] = (
    df['Human capital index'] * df['Total_Production_Value_Per_Capita']
)
df['GFCF_x_ResourceProduction'] = (
    df['Gross fixed capital formation, all, Constant prices, Percent of GDP']
    * df['Total_Production_Value_Per_Capita']
)

# ============================================================================
# 2. CREATE LAGGED VARIABLES
# ============================================================================

print("\n2. Creating lagged variables...")

df = df.sort_values(['Country Code', 'Year'])

lag_vars = [
    'Economic Complexity Index', 'log_ECI',
    'Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct', 'Metals_GDP_Pct',
    'Human capital index', 'Rule of law index', 'Property rights',
    'HCI_x_TotalResources', 'Total_Resources_GDP_Pct',
    'Manufacturing', 'Agriculture', 'Trade (% of GDP)',
    'Gross fixed capital formation, all, Constant prices, Percent of GDP',
    'Access to electricity (% of population)',
    'Urban population (% of total population)',
    'Political stability \u2014 estimate',
    'Domestic credit to private sector (% of GDP)',
    'Inflation, consumer prices (annual %)',
    # Approach B variables
    'log_HCI', 'log_GFCF', 'log_Production_Value',
    'Total_Production_Value_Per_Capita',
    'HCI_x_ResourceProduction', 'GFCF_x_ResourceProduction',
    'Political corruption index',
]

for var in lag_vars:
    if var in df.columns:
        df[f'{var}_L1'] = df.groupby('Country Code')[var].shift(1)

# Lagged ECI for dynamic spec
df['ECI_L1'] = df.groupby('Country Code')['Economic Complexity Index'].shift(1)
df['log_ECI_L1'] = df.groupby('Country Code')['log_ECI'].shift(1)

# Log interaction lags for Approach B
df['log_HCI_x_log_Prod_L1'] = df['log_HCI_L1'] * df['log_Production_Value_L1']
df['log_GFCF_x_log_Prod_L1'] = df['log_GFCF_L1'] * df['log_Production_Value_L1']

print("   Done.")

# ============================================================================
# 3. ESTIMATION FUNCTIONS
# ============================================================================


def run_panel_ols(panel_df, target, features, spec_name,
                  entity_effects=True, time_effects=True):
    """Approach A: PanelOLS with entity/time effects, clustered SE."""
    cols = [target] + features
    missing = [f for f in cols if f not in panel_df.columns]
    if missing:
        print(f"\n   WARNING {spec_name}: missing {missing}")
        return None

    reg = panel_df[cols].dropna()
    if len(reg) < 50:
        print(f"\n   WARNING {spec_name}: only {len(reg)} obs")
        return None

    y = reg[target]
    X = reg[features]

    model = PanelOLS(y, X,
                     entity_effects=entity_effects,
                     time_effects=time_effects,
                     check_rank=False)
    res = model.fit(cov_type='clustered', cluster_entity=True)

    n_countries = reg.index.get_level_values(0).nunique()
    print(f"   \u2713 {spec_name}: N={int(res.nobs)}, countries={n_countries}, "
          f"R\u00b2(within)={res.rsquared_within:.3f}")
    return res


def run_ols_dummies(df_sub, target, features, spec_name,
                    country_fe=True, year_fe=True,
                    cluster_by='Country Code'):
    """Approach B: OLS with manual country/year dummies, clustered SE."""
    cols = [target, 'Country Code', 'Year'] + features
    if cluster_by not in cols:
        cols.append(cluster_by)

    missing = [f for f in cols if f not in df_sub.columns]
    if missing:
        print(f"\n   WARNING {spec_name}: missing {missing}")
        return None

    reg = df_sub[cols].dropna().copy()
    if len(reg) < 50:
        print(f"\n   WARNING {spec_name}: only {len(reg)} obs")
        return None

    y = reg[target]
    X = reg[features].copy()

    if country_fe:
        cd = pd.get_dummies(
            reg['Country Code'], prefix='C', drop_first=True
        ).astype(float)
        X = X.join(cd)
    if year_fe:
        yd = pd.get_dummies(
            reg['Year'], prefix='Y', drop_first=True
        ).astype(float)
        X = X.join(yd)

    X = sm.add_constant(X)

    clusters = reg[cluster_by]
    model = sm.OLS(y, X)
    res = model.fit(cov_type='cluster', cov_kwds={'groups': clusters})

    n_c = reg['Country Code'].nunique()
    print(f"   \u2713 {spec_name}: N={int(res.nobs)}, countries={n_c}, "
          f"R\u00b2={res.rsquared:.3f}, adj R\u00b2={res.rsquared_adj:.3f}")
    return res


def print_coefs(res, features, use_panel=True):
    """Print coefficient table, handling both PanelOLS and OLS results."""
    rows = []
    se_vals = res.std_errors if use_panel else res.bse
    t_vals = res.tstats if use_panel else res.tvalues
    for f in features:
        if f in res.params.index:
            rows.append((f, res.params[f], se_vals[f], t_vals[f], res.pvalues[f]))

    rows.sort(key=lambda r: abs(r[1]), reverse=True)

    for raw, coef, se, t, p in rows:
        stars = '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.1 else ''
        print(f"  {raw:55s} {coef:+.4f}{stars:4s}  SE={se:.4f}  t={t:+.2f}  p={p:.3f}")


# ============================================================================
# 4. APPROACH A: FULL SAMPLE, PanelOLS
# ============================================================================

print("\n" + "=" * 70)
print("APPROACH A: FULL SAMPLE, PanelOLS")
print("Country + year FE, clustered SE at entity level")
print("=" * 70)

panel = df.set_index(['Country Code', 'Year'])

target_a = 'Economic Complexity Index'

specs_a = {
    'A1: RC Contemp.': [
        'Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct', 'Metals_GDP_Pct',
        'Human capital index', 'Rule of law index', 'Property rights',
    ],
    'A2: RC Lag-1': [
        'Oil_GDP_Pct_L1', 'Natural Gas_GDP_Pct_L1', 'Coal_GDP_Pct_L1',
        'Metals_GDP_Pct_L1',
        'Human capital index_L1', 'Rule of law index_L1', 'Property rights_L1',
    ],
    'A3: RC + Interact L1': [
        'Oil_GDP_Pct_L1', 'Natural Gas_GDP_Pct_L1', 'Coal_GDP_Pct_L1',
        'Metals_GDP_Pct_L1',
        'Human capital index_L1', 'HCI_x_TotalResources_L1',
        'Rule of law index_L1', 'Property rights_L1',
    ],
    'A4: Full Structural L1': [
        'Oil_GDP_Pct_L1', 'Natural Gas_GDP_Pct_L1', 'Coal_GDP_Pct_L1',
        'Metals_GDP_Pct_L1',
        'Manufacturing_L1', 'Agriculture_L1', 'Trade (% of GDP)_L1',
        'Gross fixed capital formation, all, Constant prices, Percent of GDP_L1',
        'Human capital index_L1',
        'Access to electricity (% of population)_L1',
        'Urban population (% of total population)_L1',
        'Rule of law index_L1', 'Property rights_L1',
        'Political stability \u2014 estimate_L1',
        'Domestic credit to private sector (% of GDP)_L1',
        'Inflation, consumer prices (annual %)_L1',
    ],
    'A5: Dynamic L1': [
        'ECI_L1',
        'Oil_GDP_Pct_L1', 'Natural Gas_GDP_Pct_L1', 'Coal_GDP_Pct_L1',
        'Metals_GDP_Pct_L1',
        'Human capital index_L1', 'Rule of law index_L1', 'Property rights_L1',
    ],
}

results_a = {}
for name, feats in specs_a.items():
    res = run_panel_ols(panel, target_a, feats, name)
    if res is not None:
        results_a[name] = (res, feats)

for name, (res, feats) in results_a.items():
    print(f"\n{'-' * 60}\n{name}\n{'-' * 60}")
    print_coefs(res, feats, use_panel=True)
    print(f"\n  R\u00b2(within)={res.rsquared_within:.3f}  "
          f"R\u00b2(between)={res.rsquared_between:.3f}  "
          f"R\u00b2(overall)={res.rsquared_overall:.3f}")

# ============================================================================
# 5. APPROACH B: HIGH-RESOURCE SUBSAMPLE, OLS + dummies, log-log
# ============================================================================

print("\n" + "=" * 70)
print("APPROACH B: HIGH-RESOURCE SUBSAMPLE (54 countries)")
print("OLS with manual country/year dummies, log-log spec")
print("SE clustered by country")
print("=" * 70)

df_hr = df[df['High_Resource'] == 1].copy()

target_b = 'log_ECI'

specs_b = {
    'B1: Log-log (country FE)': {
        'features': [
            'log_HCI', 'log_GFCF', 'log_Production_Value',
            'Political stability \u2014 estimate', 'Rule of law index',
            'Trade (% of GDP)',
        ],
        'country_fe': True, 'year_fe': False,
    },
    'B2: + Interactions': {
        'features': [
            'log_HCI', 'log_GFCF', 'log_Production_Value',
            'Political stability \u2014 estimate', 'Rule of law index',
            'Trade (% of GDP)',
            'HCI_x_ResourceProduction', 'GFCF_x_ResourceProduction',
        ],
        'country_fe': True, 'year_fe': False,
    },
    'B3: + Lagged ECI': {
        'features': [
            'log_HCI', 'log_GFCF', 'log_Production_Value',
            'Political stability \u2014 estimate', 'Rule of law index',
            'Trade (% of GDP)',
            'HCI_x_ResourceProduction', 'GFCF_x_ResourceProduction',
            'log_ECI_L1',
        ],
        'country_fe': True, 'year_fe': False,
    },
    'B4: Two-way FE (levels)': {
        'features': [
            'Human capital index', 'Total_Production_Value_Per_Capita',
            'Gross fixed capital formation, all, Constant prices, Percent of GDP',
            'Political corruption index', 'Political stability \u2014 estimate',
            'Rule of law index', 'Property rights', 'Trade (% of GDP)',
            'HCI_x_ResourceProduction', 'GFCF_x_ResourceProduction',
        ],
        'country_fe': True, 'year_fe': True,
    },
    'B5: Fully lagged L1 (log-log)': {
        'features': [
            'log_HCI_L1', 'log_GFCF_L1', 'log_Production_Value_L1',
            'Political stability \u2014 estimate_L1', 'Rule of law index_L1',
            'Trade (% of GDP)_L1',
            'log_HCI_x_log_Prod_L1', 'log_GFCF_x_log_Prod_L1',
        ],
        'country_fe': True, 'year_fe': False,
    },
    'B6: Fully lagged L1 (two-way FE, levels)': {
        'features': [
            'Human capital index_L1', 'Total_Production_Value_Per_Capita_L1',
            'Gross fixed capital formation, all, Constant prices, Percent of GDP_L1',
            'Political corruption index_L1', 'Political stability \u2014 estimate_L1',
            'Rule of law index_L1', 'Property rights_L1', 'Trade (% of GDP)_L1',
            'HCI_x_ResourceProduction_L1', 'GFCF_x_ResourceProduction_L1',
        ],
        'country_fe': True, 'year_fe': True,
    },
}

results_b = {}
for name, spec in specs_b.items():
    # B4 and B6 use levels, so target is log_ECI for log-log specs,
    # but for two-way FE levels specs we also use log_ECI for consistency
    # (the shift ensures all values are positive)
    tgt = target_b
    res = run_ols_dummies(
        df_hr, tgt, spec['features'], name,
        country_fe=spec['country_fe'], year_fe=spec['year_fe'],
    )
    if res is not None:
        results_b[name] = (res, spec['features'])

for name, (res, feats) in results_b.items():
    print(f"\n{'-' * 60}\n{name}\n{'-' * 60}")
    print_coefs(res, feats, use_panel=False)
    print(f"\n  R\u00b2={res.rsquared:.3f}  adj R\u00b2={res.rsquared_adj:.3f}  N={int(res.nobs)}")

# ============================================================================
# 6. FORMATTED COMPARISON TABLES
# ============================================================================


def print_formatted_table(results_dict, label, use_panel=True):
    """Print a formatted regression table across specifications."""
    print(f"\n{'=' * 70}")
    print(f"{label} - FORMATTED TABLE")
    print(f"{'=' * 70}")

    # Collect all unique features across specs
    all_feats = []
    seen = set()
    for _, (_, feats) in results_dict.items():
        for f in feats:
            if f not in seen:
                all_feats.append(f)
                seen.add(f)

    # Header
    spec_names = list(results_dict.keys())
    short_names = [s.split(': ')[1] for s in spec_names]
    header = f"{'Variable':55s}" + "".join(f"{s:>20s}" for s in short_names)
    print(header)
    print("-" * len(header))

    # Rows
    for feat in all_feats:
        row = f"{feat:55s}"
        for sn in spec_names:
            if sn in results_dict:
                res, _ = results_dict[sn]
                if feat in res.params.index:
                    coef = res.params[feat]
                    p = res.pvalues[feat]
                    se_attr = 'std_errors' if use_panel else 'bse'
                    se = getattr(res, se_attr)[feat]
                    stars = '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.1 else ''
                    row += f"{coef:>16.4f}{stars:4s}"
                else:
                    row += f"{'':>20s}"
            else:
                row += f"{'':>20s}"
        print(row)

        # SE row
        se_row = f"{'':55s}"
        for sn in spec_names:
            if sn in results_dict:
                res, _ = results_dict[sn]
                if feat in res.params.index:
                    se_attr = 'std_errors' if use_panel else 'bse'
                    se = getattr(res, se_attr)[feat]
                    se_row += f"{'(' + f'{se:.4f}' + ')':>20s}"
                else:
                    se_row += f"{'':>20s}"
            else:
                se_row += f"{'':>20s}"
        print(se_row)

    # Footer: fit statistics
    print("-" * len(header))
    if use_panel:
        r2_row = f"{'R2 (within)':55s}"
        for sn in spec_names:
            if sn in results_dict:
                r2_row += f"{results_dict[sn][0].rsquared_within:>20.3f}"
        print(r2_row)
    else:
        r2_row = f"{'R2':55s}"
        adj_row = f"{'Adj R2':55s}"
        for sn in spec_names:
            if sn in results_dict:
                r2_row += f"{results_dict[sn][0].rsquared:>20.3f}"
                adj_row += f"{results_dict[sn][0].rsquared_adj:>20.3f}"
        print(r2_row)
        print(adj_row)

    obs_row = f"{'Observations':55s}"
    for sn in spec_names:
        if sn in results_dict:
            obs_row += f"{int(results_dict[sn][0].nobs):>20d}"
    print(obs_row)


print_formatted_table(results_a, "APPROACH A", use_panel=True)
print_formatted_table(results_b, "APPROACH B", use_panel=False)

# ============================================================================
# 7. SAVE TABLES
# ============================================================================

print("\n" + "=" * 70)
print("SAVING")
print("=" * 70)


def save_table(results_dict, filename, use_panel=True):
    all_feats = []
    seen = set()
    for _, (_, feats) in results_dict.items():
        for f in feats:
            if f not in seen:
                all_feats.append(f)
                seen.add(f)
    rows = []
    for feat in all_feats:
        entry = {'Variable': feat}
        for s, (res, _) in results_dict.items():
            if feat in res.params.index:
                se_attr = 'std_errors' if use_panel else 'bse'
                entry[f'{s}_coef'] = res.params[feat]
                entry[f'{s}_se'] = getattr(res, se_attr)[feat]
                entry[f'{s}_pval'] = res.pvalues[feat]
        rows.append(entry)
    pd.DataFrame(rows).to_csv(os.path.join(output_dir, filename), index=False)
    print(f"   \u2713 {filename}")


save_table(results_a, 'fe_approach_a_table.csv', use_panel=True)
save_table(results_b, 'fe_approach_b_table.csv', use_panel=False)

print("\n" + "=" * 70)
print("DONE")
print("=" * 70)

PANEL FIXED EFFECTS REGRESSIONS - TWO APPROACHES

1. Loading data...
   Full sample: 3150 rows, 126 countries
   High-resource subsample: 54 countries
   ECI shift for log: 2.990

2. Creating lagged variables...
   Done.

APPROACH A: FULL SAMPLE, PanelOLS
Country + year FE, clustered SE at entity level
   ✓ A1: RC Contemp.: N=3150, countries=126, R²(within)=0.008
   ✓ A2: RC Lag-1: N=3024, countries=126, R²(within)=0.003
   ✓ A3: RC + Interact L1: N=3024, countries=126, R²(within)=0.001
   ✓ A4: Full Structural L1: N=3024, countries=126, R²(within)=0.045
   ✓ A5: Dynamic L1: N=3024, countries=126, R²(within)=0.314

------------------------------------------------------------
A1: RC Contemp.
------------------------------------------------------------
  Human capital index                                     +0.2216*     SE=0.1253  t=+1.77  p=0.077
  Property rights                                         -0.1472      SE=0.2890  t=-0.51  p=0.611
  Coal_GDP_Pct                           

In [11]:
# ============================================================================
# VISUALIZATIONS (append to regression script)
# ============================================================================

import plotly.graph_objects as go

FONT = "Helvetica Neue, Helvetica, Arial, sans-serif"
T_SZ, A_SZ, TK_SZ, L_SZ = 18, 13, 11, 11

def base_layout(title, height=600, width=950):
    return dict(
        title=dict(text=title, font=dict(size=T_SZ, family=FONT, color='#1a1a2e'), x=0.5),
        font=dict(family=FONT, size=TK_SZ, color='#4a4a6a'),
        template='plotly_white', height=height, width=width,
        margin=dict(l=60, r=40, t=80, b=60),
        plot_bgcolor='#fafbfc', paper_bgcolor='#ffffff',
        legend=dict(font=dict(size=L_SZ), bgcolor='rgba(255,255,255,0.9)',
                    bordercolor='#e0e0e0', borderwidth=1),
    )


# --- Clean variable names for display ---
LABEL_MAP = {
    'Oil_GDP_Pct': 'Oil (% GDP)', 'Oil_GDP_Pct_L1': 'Oil (% GDP)',
    'Natural Gas_GDP_Pct': 'Nat. Gas (% GDP)', 'Natural Gas_GDP_Pct_L1': 'Nat. Gas (% GDP)',
    'Coal_GDP_Pct': 'Coal (% GDP)', 'Coal_GDP_Pct_L1': 'Coal (% GDP)',
    'Metals_GDP_Pct': 'Metals (% GDP)', 'Metals_GDP_Pct_L1': 'Metals (% GDP)',
    'Human capital index': 'Human Capital', 'Human capital index_L1': 'Human Capital',
    'Rule of law index': 'Rule of Law', 'Rule of law index_L1': 'Rule of Law',
    'Property rights': 'Property Rights', 'Property rights_L1': 'Property Rights',
    'HCI_x_TotalResources_L1': 'HCI x Resources',
    'Manufacturing_L1': 'Manufacturing', 'Agriculture_L1': 'Agriculture',
    'Trade (% of GDP)_L1': 'Trade Openness',
    'Gross fixed capital formation, all, Constant prices, Percent of GDP_L1': 'Investment (GFCF)',
    'Access to electricity (% of population)_L1': 'Electricity Access',
    'Urban population (% of total population)_L1': 'Urbanization',
    'Political stability — estimate_L1': 'Pol. Stability',
    'Political stability — estimate': 'Pol. Stability',
    'Domestic credit to private sector (% of GDP)_L1': 'Private Credit',
    'Inflation, consumer prices (annual %)_L1': 'Inflation',
    'ECI_L1': 'Lagged ECI',
    'log_ECI_L1': 'Lagged log(ECI)',
    'log_HCI': 'log(HCI)', 'log_HCI_L1': 'log(HCI)',
    'log_GFCF': 'log(GFCF)', 'log_GFCF_L1': 'log(GFCF)',
    'log_Production_Value': 'log(Prod. Value)', 'log_Production_Value_L1': 'log(Prod. Value)',
    'Total_Production_Value_Per_Capita': 'Prod. Value p.c.',
    'Total_Production_Value_Per_Capita_L1': 'Prod. Value p.c.',
    'HCI_x_ResourceProduction': 'HCI x Prod.', 'HCI_x_ResourceProduction_L1': 'HCI x Prod.',
    'GFCF_x_ResourceProduction': 'GFCF x Prod.', 'GFCF_x_ResourceProduction_L1': 'GFCF x Prod.',
    'log_HCI_x_log_Prod_L1': 'log(HCI) x log(Prod)',
    'log_GFCF_x_log_Prod_L1': 'log(GFCF) x log(Prod)',
    'Political corruption index': 'Pol. Corruption',
    'Political corruption index_L1': 'Pol. Corruption',
    'Trade (% of GDP)': 'Trade Openness', 'Trade (% of GDP)_L1': 'Trade Openness',
    'Human capital index_L1': 'Human Capital',
    'Gross fixed capital formation, all, Constant prices, Percent of GDP': 'Investment (GFCF)',
}

def clean_name(raw):
    return LABEL_MAP.get(raw, raw)


def get_coef_data(res, feats, use_panel=True):
    """Extract coefficients, SEs, p-values from a result object."""
    se_attr = 'std_errors' if use_panel else 'bse'
    t_attr = 'tstats' if use_panel else 'tvalues'
    rows = []
    for f in feats:
        if f in res.params.index:
            rows.append({
                'var': f,
                'label': clean_name(f),
                'coef': res.params[f],
                'se': getattr(res, se_attr)[f],
                'p': res.pvalues[f],
            })
    return rows


# ============================================================================
# CHART 1: Approach A — coefficient plot (resource vars + human capital)
# ============================================================================

def make_coef_plot_a(results_a, output_dir):
    # Variables to compare across specs (use raw names that appear in results)
    target_vars_per_spec = {
        'A1: RC Contemp.': ['Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct',
                            'Metals_GDP_Pct', 'Human capital index'],
        'A2: RC Lag-1': ['Oil_GDP_Pct_L1', 'Natural Gas_GDP_Pct_L1', 'Coal_GDP_Pct_L1',
                         'Metals_GDP_Pct_L1', 'Human capital index_L1'],
        'A4: Full Structural L1': ['Oil_GDP_Pct_L1', 'Natural Gas_GDP_Pct_L1', 'Coal_GDP_Pct_L1',
                                    'Metals_GDP_Pct_L1', 'Human capital index_L1'],
        'A5: Dynamic L1': ['Oil_GDP_Pct_L1', 'Natural Gas_GDP_Pct_L1', 'Coal_GDP_Pct_L1',
                           'Metals_GDP_Pct_L1', 'Human capital index_L1'],
    }
    # Common display labels for the y-axis
    display_labels = ['Oil (% GDP)', 'Nat. Gas (% GDP)', 'Coal (% GDP)',
                      'Metals (% GDP)', 'Human Capital']
    colors = ['#2563eb', '#f59e0b', '#ef4444', '#8b5cf6']
    spec_names = list(target_vars_per_spec.keys())

    fig = go.Figure()
    n_vars = len(display_labels)
    n_specs = len(spec_names)
    offsets = [-(n_specs-1)/2 * 0.15 + i * 0.15 for i in range(n_specs)]

    for i, spec_key in enumerate(spec_names):
        if spec_key not in results_a:
            continue
        res, _ = results_a[spec_key]
        color = colors[i]
        raw_vars = target_vars_per_spec[spec_key]
        short_label = spec_key.split(': ')[1]

        yp, xv, xlo, xhi, texts = [], [], [], [], []
        for j, rv in enumerate(raw_vars):
            if rv in res.params.index:
                c = res.params[rv]
                se = res.std_errors[rv]
                yp.append(j + offsets[i])
                xv.append(c)
                xlo.append(c - 1.96 * se)
                xhi.append(c + 1.96 * se)
                texts.append(f'{short_label}<br>{display_labels[j]}<br>'
                             f'Coef: {c:.4f}<br>SE: {se:.4f}<br>'
                             f'95% CI: [{c-1.96*se:.4f}, {c+1.96*se:.4f}]')

        for k in range(len(yp)):
            fig.add_trace(go.Scatter(
                x=[xlo[k], xhi[k]], y=[yp[k], yp[k]],
                mode='lines', line=dict(color=color, width=2),
                showlegend=False, hoverinfo='skip'))
        fig.add_trace(go.Scatter(
            x=xv, y=yp, mode='markers',
            marker=dict(size=9, color=color, line=dict(width=1.5, color=color)),
            name=short_label,
            hovertemplate='%{text}<extra></extra>', text=texts))

    fig.add_vline(x=0, line=dict(color='#94a3b8', width=1.5, dash='dash'))
    fig.update_layout(**base_layout(
        'Approach A: Resource Variable Coefficients Across Specifications<br>'
        '<span style="font-size:12px;color:#6b7280">Full sample, PanelOLS, country + year FE, entity-clustered SE</span>',
        height=480, width=900))
    fig.update_yaxes(tickvals=list(range(n_vars)), ticktext=display_labels, tickfont=dict(size=12))
    fig.update_xaxes(title_text='Coefficient estimate (95% CI)', title_font_size=A_SZ, zeroline=False)
    fig.update_layout(legend=dict(x=0.98, y=0.02, xanchor='right', yanchor='bottom'))
    fig.write_html(os.path.join(output_dir, 'fe_coef_approach_a.html'), include_plotlyjs='cdn')
    print("   ✓ fe_coef_approach_a.html")


# ============================================================================
# CHART 2: Approach A — R² (within) bar chart
# ============================================================================

def make_r2_chart_a(results_a, output_dir):
    colors_list = ['#2563eb', '#f59e0b', '#10b981', '#ef4444', '#8b5cf6']
    specs, vals, cols = [], [], []
    for i, (name, (res, _)) in enumerate(results_a.items()):
        specs.append(name.split(': ')[1])
        vals.append(res.rsquared_within)
        cols.append(colors_list[i % len(colors_list)])

    fig = go.Figure(go.Bar(
        x=specs, y=vals,
        marker=dict(color=cols, line=dict(width=1, color='#e2e8f0')),
        text=[f'{v:.3f}' for v in vals], textposition='outside',
        textfont=dict(size=12, color='#374151'),
        hovertemplate='%{x}<br>R²(within) = %{y:.3f}<extra></extra>'))
    fig.update_layout(**base_layout(
        'Approach A: Within-R² Across Specifications<br>'
        '<span style="font-size:12px;color:#6b7280">Only the dynamic spec (lagged ECI) achieves substantial within-R²</span>',
        height=440, width=750))
    fig.update_yaxes(title_text='R² (within)', title_font_size=A_SZ,
                     range=[0, max(vals) * 1.25])
    fig.update_xaxes(title_text='Specification', title_font_size=A_SZ)
    fig.write_html(os.path.join(output_dir, 'fe_r2_approach_a.html'), include_plotlyjs='cdn')
    print("   ✓ fe_r2_approach_a.html")


# ============================================================================
# CHART 3: Approach B — coefficient plot (key variables)
# ============================================================================

def make_coef_plot_b(results_b, output_dir):
    # Map each spec to common "slots" so we can compare across specs
    # Slots: production value, human capital, investment, pol stability, rule of law
    slot_labels = ['log(Prod. Value)', 'log(HCI)', 'log(GFCF)', 'Pol. Stability', 'Rule of Law']
    slot_vars = {
        'B1: Log-log (country FE)': ['log_Production_Value', 'log_HCI', 'log_GFCF',
                                      'Political stability — estimate', 'Rule of law index'],
        'B2: + Interactions': ['log_Production_Value', 'log_HCI', 'log_GFCF',
                                'Political stability — estimate', 'Rule of law index'],
        'B3: + Lagged ECI': ['log_Production_Value', 'log_HCI', 'log_GFCF',
                              'Political stability — estimate', 'Rule of law index'],
        'B5: Fully lagged L1 (log-log)': ['log_Production_Value_L1', 'log_HCI_L1', 'log_GFCF_L1',
                                           'Political stability — estimate_L1', 'Rule of law index_L1'],
    }
    colors = ['#2563eb', '#f59e0b', '#10b981', '#8b5cf6']
    spec_keys = list(slot_vars.keys())
    n_vars = len(slot_labels)
    n_specs = len(spec_keys)
    offsets = [-(n_specs-1)/2 * 0.15 + i * 0.15 for i in range(n_specs)]

    fig = go.Figure()
    for i, sk in enumerate(spec_keys):
        if sk not in results_b:
            continue
        res, _ = results_b[sk]
        color = colors[i]
        short = sk.split(': ')[1]
        raw_vars = slot_vars[sk]
        se_vals = res.bse

        yp, xv, xlo, xhi, texts = [], [], [], [], []
        for j, rv in enumerate(raw_vars):
            if rv in res.params.index:
                c = res.params[rv]
                se = se_vals[rv]
                yp.append(j + offsets[i])
                xv.append(c)
                xlo.append(c - 1.96*se); xhi.append(c + 1.96*se)
                texts.append(f'{short}<br>{slot_labels[j]}<br>Coef: {c:.4f}<br>SE: {se:.4f}')

        for k in range(len(yp)):
            fig.add_trace(go.Scatter(
                x=[xlo[k], xhi[k]], y=[yp[k], yp[k]],
                mode='lines', line=dict(color=color, width=2),
                showlegend=False, hoverinfo='skip'))
        fig.add_trace(go.Scatter(
            x=xv, y=yp, mode='markers',
            marker=dict(size=9, color=color, line=dict(width=1.5, color=color)),
            name=short, hovertemplate='%{text}<extra></extra>', text=texts))

    fig.add_vline(x=0, line=dict(color='#94a3b8', width=1.5, dash='dash'))
    fig.update_layout(**base_layout(
        'Approach B: Coefficient Estimates, High-Resource Subsample<br>'
        '<span style="font-size:12px;color:#6b7280">54 countries, log-log spec, country FE, clustered SE</span>',
        height=450, width=900))
    fig.update_yaxes(tickvals=list(range(n_vars)), ticktext=slot_labels, tickfont=dict(size=12))
    fig.update_xaxes(title_text='Coefficient estimate (95% CI)', title_font_size=A_SZ, zeroline=False)
    fig.update_layout(legend=dict(x=0.98, y=0.02, xanchor='right', yanchor='bottom'))
    fig.write_html(os.path.join(output_dir, 'fe_coef_approach_b.html'), include_plotlyjs='cdn')
    print("   ✓ fe_coef_approach_b.html")


# ============================================================================
# CHART 4: Approach B — R² bar chart
# ============================================================================

def make_r2_chart_b(results_b, output_dir):
    specs, r2v, adjv = [], [], []
    for name, (res, _) in results_b.items():
        specs.append(name.split(': ')[1])
        r2v.append(res.rsquared)
        adjv.append(res.rsquared_adj)
    fig = go.Figure()
    fig.add_trace(go.Bar(x=specs, y=r2v, name='R²',
        marker=dict(color='#2563eb', line=dict(width=1, color='#e2e8f0')),
        text=[f'{v:.3f}' for v in r2v], textposition='outside', textfont=dict(size=11)))
    fig.add_trace(go.Bar(x=specs, y=adjv, name='Adj. R²',
        marker=dict(color='#93c5fd', line=dict(width=1, color='#e2e8f0')),
        text=[f'{v:.3f}' for v in adjv], textposition='outside', textfont=dict(size=11)))
    fig.update_layout(**base_layout(
        'Approach B: Model Fit Across Specifications<br>'
        '<span style="font-size:12px;color:#6b7280">High-resource subsample (54 countries), OLS with country dummies</span>',
        height=440, width=750), barmode='group')
    fig.update_yaxes(title_text='R²', title_font_size=A_SZ, range=[0, 0.75])
    fig.update_xaxes(title_text='Specification', title_font_size=A_SZ)
    fig.write_html(os.path.join(output_dir, 'fe_r2_approach_b.html'), include_plotlyjs='cdn')
    print("   ✓ fe_r2_approach_b.html")


# ============================================================================
# CHART 5: FE vs ML comparison
# ============================================================================

def make_fe_vs_ml(results_a, results_b, output_dir):
    """Compare FE within-R² with ML test R² (hardcoded ML since different script)."""
    # ML best test R² from ML output — update these if ML results change
    ml_best = {'RC Baseline': 0.804, 'RC + Interactions': 0.816, 'Full Structural': 0.891}

    # Map FE specs to comparable ML feature sets
    fe_a_map = {
        'RC Baseline': 'A1: RC Contemp.',
        'RC + Interactions': 'A3: RC + Interact L1',
        'Full Structural': 'A4: Full Structural L1',
    }
    fe_b_map = {
        'RC Baseline': 'B1: Log-log (country FE)',
        'RC + Interactions': 'B2: + Interactions',
    }

    cats = list(ml_best.keys())
    fe_a_vals = [results_a[fe_a_map[c]][0].rsquared_within if fe_a_map[c] in results_a else 0
                 for c in cats]
    fe_b_vals = [results_b[fe_b_map[c]][0].rsquared if c in fe_b_map and fe_b_map[c] in results_b
                 else None for c in cats]
    ml_vals = [ml_best[c] for c in cats]

    fig = go.Figure()
    fig.add_trace(go.Bar(x=cats, y=fe_a_vals, name='FE Approach A (within-R²)',
        marker=dict(color='#ef4444'),
        text=[f'{v:.3f}' for v in fe_a_vals], textposition='outside', textfont=dict(size=11)))

    # Approach B only has comparable specs for first two
    b_cats = [c for c, v in zip(cats, fe_b_vals) if v is not None]
    b_vals = [v for v in fe_b_vals if v is not None]
    fig.add_trace(go.Bar(x=b_cats, y=b_vals, name='FE Approach B (R²)',
        marker=dict(color='#f59e0b'),
        text=[f'{v:.3f}' for v in b_vals], textposition='outside', textfont=dict(size=11)))

    fig.add_trace(go.Bar(x=cats, y=ml_vals, name='ML (best test R²)',
        marker=dict(color='#2563eb'),
        text=[f'{v:.3f}' for v in ml_vals], textposition='outside', textfont=dict(size=11)))

    lay = base_layout(
        'Fixed Effects vs. Machine Learning: Explanatory Power<br>'
        '<span style="font-size:12px;color:#6b7280">Linear FE captures minimal within-country variation; '
        'tree models capture nonlinearities</span>',
        height=480, width=850)
    lay['barmode'] = 'group'
    lay['legend'] = dict(x=0.02, y=0.98, xanchor='left', yanchor='top',
                         font=dict(size=L_SZ), bgcolor='rgba(255,255,255,0.9)',
                         bordercolor='#e0e0e0', borderwidth=1)
    fig.update_layout(**lay)
    fig.update_yaxes(title_text='R²', title_font_size=A_SZ, range=[0, 1.05])
    fig.update_xaxes(title_text='Feature Set', title_font_size=A_SZ)
    fig.write_html(os.path.join(output_dir, 'fe_vs_ml_comparison.html'), include_plotlyjs='cdn')
    print("   ✓ fe_vs_ml_comparison.html")


# ============================================================================
# CHART 6: A4 Full Structural — all coefficients horizontal bar
# ============================================================================

def make_a4_full_coef(results_a, output_dir):
    spec_key = 'A4: Full Structural L1'
    if spec_key not in results_a:
        print("   ⚠ A4 not found, skipping")
        return
    res, feats = results_a[spec_key]
    rows = get_coef_data(res, feats, use_panel=True)
    rows.sort(key=lambda r: abs(r['coef']))

    names = [r['label'] for r in rows]
    coefs = [r['coef'] for r in rows]
    ses = [r['se'] for r in rows]
    pvals = [r['p'] for r in rows]
    colors = ['#2563eb' if p < 0.05 else ('#93c5fd' if p < 0.1 else '#cbd5e1') for p in pvals]

    fig = go.Figure(go.Bar(
        y=names, x=coefs, orientation='h',
        marker=dict(color=colors, line=dict(width=1, color='#e2e8f0')),
        error_x=dict(type='data', array=[1.96*s for s in ses], color='#94a3b8', thickness=1.5),
        hovertemplate='%{y}<br>Coef: %{x:.4f}<extra></extra>'))
    fig.add_vline(x=0, line=dict(color='#475569', width=1.5, dash='dash'))
    fig.add_annotation(x=0.98, y=0.12, xref='paper', yref='paper',
        text='<b>&#9632;</b> p<0.05  <b style="color:#93c5fd">&#9632;</b> p<0.10  '
             '<b style="color:#cbd5e1">&#9632;</b> n.s.',
        showarrow=False, font=dict(size=11, color='#374151'),
        bgcolor='rgba(255,255,255,0.9)', bordercolor='#e0e0e0', borderwidth=1,
        borderpad=6, xanchor='right')
    fig.update_layout(**base_layout(
        'A4: Full Structural Specification - All Coefficients<br>'
        '<span style="font-size:12px;color:#6b7280">PanelOLS, lagged IVs, country + year FE</span>',
        height=550, width=850))
    fig.update_xaxes(title_text='Coefficient (with 95% CI)', title_font_size=A_SZ)
    fig.update_yaxes(tickfont=dict(size=11))
    fig.update_layout(margin=dict(l=160, r=40, t=80, b=60))
    fig.write_html(os.path.join(output_dir, 'fe_coef_a4_full.html'), include_plotlyjs='cdn')
    print("   ✓ fe_coef_a4_full.html")


# ============================================================================
# RUN ALL CHARTS
# ============================================================================

print("\n" + "=" * 70)
print("GENERATING CHARTS")
print("=" * 70)

make_coef_plot_a(results_a, output_dir)
make_r2_chart_a(results_a, output_dir)
make_coef_plot_b(results_b, output_dir)
make_r2_chart_b(results_b, output_dir)
make_fe_vs_ml(results_a, results_b, output_dir)
make_a4_full_coef(results_a, output_dir)

print("\n" + "=" * 70)
print("ALL DONE")
print("=" * 70)


GENERATING CHARTS
   ✓ fe_coef_approach_a.html
   ✓ fe_r2_approach_a.html
   ✓ fe_coef_approach_b.html
   ✓ fe_r2_approach_b.html
   ✓ fe_vs_ml_comparison.html
   ✓ fe_coef_a4_full.html

ALL DONE


In [12]:
"""
ADDITIONAL APPENDIX VISUALIZATIONS
===================================
1. Coefficient Stability Plot — Ridge coefficients for key resource variables
   across the 3 ML specifications (RC Baseline, RC + Interactions, Full Structural)
2. ECI vs Resource Dependence time-series for 6 case-study countries
3. SHAP Dependence Scatter — Human Capital vs SHAP value, colored by High Resource

Uses same data pipeline and styling as cap_ml.ipynb.
Outputs: 3 interactive Plotly HTML files.
"""

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
import os
import sys

warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION — mirrors cap_ml.ipynb
# ============================================================================

input_file = "/Users/leoss/Desktop/GitHub/Capstone/MASTER/Master.csv"
production_file = "/Users/leoss/Desktop/GitHub/Capstone/MASTER/NaturalResource.csv"
output_dir = "/Users/leoss/Desktop/Portfolio/Website-/capstone_visualizations/individual_plots/ml"

os.makedirs(output_dir, exist_ok=True)

TRAIN_END_YEAR = 2013

HIGH_RESOURCE_COUNTRIES = [
    'AGO', 'ARE', 'AZE', 'BFA', 'BHR', 'BOL', 'CHL', 'CIV', 'CMR',
    'COD', 'COG', 'DZA', 'ECU', 'EGY', 'ETH', 'GAB', 'GHA', 'GIN',
    'GNQ', 'IDN', 'IRN', 'IRQ', 'KAZ', 'KEN', 'KWT', 'LAO', 'LBR',
    'LBY', 'MDG', 'MLI', 'MMR', 'MNG', 'MOZ', 'MWI', 'MYS', 'NER',
    'NGA', 'OMN', 'PNG', 'QAT', 'RUS', 'RWA', 'SAU', 'TCD', 'TGO',
    'TTO', 'TZA', 'UGA', 'UZB', 'VEN', 'VNM', 'YEM', 'ZMB', 'ZWE'
]

# Case study countries for time-series (diverse trajectories)
CASE_COUNTRIES = {
    'NOR': 'Norway',
    'NGA': 'Nigeria',
    'CHL': 'Chile',
    'ARE': 'UAE',
    'MYS': 'Malaysia',
    'BWA': 'Botswana',
}

# ============================================================================
# STYLE — same as cap_ml.ipynb
# ============================================================================

STYLE = {
    'font_family': 'IBM Plex Sans, -apple-system, BlinkMacSystemFont, sans-serif',
    'title_size': 18,
    'subtitle_size': 13,
    'axis_title_size': 13,
    'tick_size': 11,
    'legend_size': 12,
    'title_color': '#1a2744',

    'template': 'plotly_white',
    'bg_color': 'rgba(0,0,0,0)',
    'plot_bg': 'rgba(0,0,0,0)',
    'chart_height': 650,
    'chart_height_small': 500,
    'margin': dict(l=60, r=50, t=30, b=80),
    'margin_bar': dict(l=200, r=80, t=30, b=50),

    'pos_color': '#2e7d4a',
    'neg_color': '#c23a3a',

    'spec_colors': {
        'RC Baseline':       '#3d4f5f',
        'RC + Interactions': '#4a6fa5',
        'Full Structural':   '#c23a3a',
    },

    'bar_color': '#1a2744',
    'bar_color_alt': '#4a6fa5',

    'gridline_color': '#dde1e7',
    'zero_line_color': '#c9cfd6',
}

WRITE_CONFIG = {'displayModeBar': False}


def styled_title(main=None, sub=None):
    """Titles handled in HTML page, not in charts."""
    return dict(text='', x=0.5)


def base_layout(**overrides):
    layout = dict(
        template=STYLE['template'],
        font=dict(family=STYLE['font_family'], size=STYLE['tick_size'],
                  color='#4b5563'),
        paper_bgcolor=STYLE['bg_color'],
        plot_bgcolor=STYLE['plot_bg'],
        height=STYLE['chart_height'],
        margin=STYLE['margin'],
        hoverlabel=dict(
            bgcolor='white',
            bordercolor='#dde1e7',
            font=dict(family=STYLE['font_family'], size=13, color='#1a2744'),
        ),
    )
    layout.update(overrides)
    return layout


def styled_axis(title_text):
    return dict(
        title=dict(text=title_text,
                   font=dict(size=STYLE['axis_title_size'],
                             family=STYLE['font_family'])),
        tickfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
    )


def save_html(fig, filename):
    fig.write_html(
        os.path.join(output_dir, filename),
        config=WRITE_CONFIG,
        include_plotlyjs='cdn',
    )
    print(f"   Saved: {filename}")


# ============================================================================
# DATA LOADING — same pipeline as cap_ml.ipynb
# ============================================================================

def clean_name(s):
    return (s.replace('\u2014', '-').replace('(', '').replace(')', '')
             .replace('%', 'pct').replace(',', ''))


feature_names_clean = {
    'Oil_GDP_Pct': 'Oil (% GDP)',
    'Natural Gas_GDP_Pct': 'Natural Gas (% GDP)',
    'Coal_GDP_Pct': 'Coal (% GDP)',
    'Metals_GDP_Pct': 'Metals (% GDP)',
    'Human capital index': 'Human Capital',
    'Rule of law index': 'Rule of Law',
    'Property rights': 'Property Rights',
    'Political corruption index': 'Political Corruption',
    'Political stability - estimate': 'Political Stability',
    'Landlocked': 'Landlocked',
    'Manufacturing': 'Manufacturing (% GDP)',
    'Agriculture': 'Agriculture (% GDP)',
    'Trade pct of GDP': 'Trade Openness',
    'Gross fixed capital formation all Constant prices Percent of GDP': 'Investment (% GDP)',
    'Access to electricity pct of population': 'Electricity Access',
    'Urban population pct of total population': 'Urbanization',
    'Domestic credit to private sector pct of GDP': 'Private Credit',
    'Inflation consumer prices annual pct': 'Inflation',
    'High_Resource': 'High Resource Country',
    'Oil_GDP_Pct_x_HighRes': 'Oil x High Resource',
    'NatGas_GDP_Pct_x_HighRes': 'Nat Gas x High Resource',
    'Coal_GDP_Pct_x_HighRes': 'Coal x High Resource',
    'Metals_GDP_Pct_x_HighRes': 'Metals x High Resource',
    'Total_Resources_x_HighRes': 'Total Resources x High Resource',
    'HCI_x_TotalResources': 'Human Capital x Total Resources',
}


def categorize_resource(resource):
    if resource == 'Oil':
        return 'Oil'
    elif resource == 'Natural Gas':
        return 'Natural Gas'
    elif resource == 'Coal':
        return 'Coal'
    else:
        return 'Metals'


print("=" * 70)
print("GENERATING ADDITIONAL APPENDIX VISUALIZATIONS")
print("=" * 70)

print("\n1. Loading data...")
df_master = pd.read_csv(input_file)
df_master.columns = [clean_name(c) for c in df_master.columns]
df_prod = pd.read_csv(production_file)

df_prod['Resource_Category'] = df_prod['Resource'].apply(categorize_resource)

prod_agg = df_prod.groupby(
    ['Country Name', 'Year', 'Resource_Category']
)['Production_TotalValue'].sum().reset_index()

prod_wide = prod_agg.pivot_table(
    index=['Country Name', 'Year'],
    columns='Resource_Category',
    values='Production_TotalValue',
    fill_value=0,
).reset_index()

df = df_master.merge(prod_wide, on=['Country Name', 'Year'], how='left')

for col in ['Oil', 'Natural Gas', 'Coal', 'Metals']:
    if col in df.columns:
        df[col] = df[col].fillna(0)

df['GDP_total'] = df['GDP per capita constant prices PPP'] * df['Population']
for res in ['Oil', 'Natural Gas', 'Coal', 'Metals']:
    if res in df.columns:
        df[f'{res}_GDP_Pct'] = (df[res] / df['GDP_total']) * 100
        df[f'{res}_GDP_Pct'] = df[f'{res}_GDP_Pct'].replace([np.inf, -np.inf], np.nan)

df['Total_Resources_GDP_Pct'] = df[[
    'Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct', 'Metals_GDP_Pct'
]].sum(axis=1)

df['High_Resource'] = df['Country Code'].isin(HIGH_RESOURCE_COUNTRIES).astype(int)
df['HCI_x_TotalResources'] = (
    df['Human capital index'] * df['Total_Resources_GDP_Pct']
)

print(f"   Data loaded: {len(df)} rows, {df['Country Code'].nunique()} countries")

# ============================================================================
# FEATURE DEFINITIONS — same as cap_ml.ipynb
# ============================================================================

target = 'Economic Complexity Index'

features_resource_curse_raw = [
    'Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct', 'Metals_GDP_Pct',
    'Human capital index', 'Rule of law index', 'Property rights', 'Landlocked',
]

features_rc_interactions = [
    'Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct', 'Metals_GDP_Pct',
    'High_Resource', 'Human capital index', 'HCI_x_TotalResources',
    'Rule of law index', 'Property rights', 'Landlocked',
]

features_full_raw = [
    'Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct', 'Metals_GDP_Pct',
    'Manufacturing', 'Agriculture', 'Trade (% of GDP)',
    'Gross fixed capital formation, all, Constant prices, Percent of GDP',
    'Human capital index', 'Access to electricity (% of population)',
    'Urban population (% of total population)', 'Rule of law index',
    'Property rights', 'Political stability \u2014 estimate',
    'Domestic credit to private sector (% of GDP)',
    'Inflation, consumer prices (annual %)', 'Landlocked',
]

features_resource_curse = [clean_name(f) for f in features_resource_curse_raw]
features_full = [clean_name(f) for f in features_full_raw]

# ============================================================================
# TRAIN / TEST SPLIT + RIDGE MODELS (needed for chart 1 and 3)
# ============================================================================

print("\n2. Training Ridge models across specifications...")


def prepare_split(df_src, features, label):
    d = df_src[['Country Code', 'Country Name', 'Year', target] + features].dropna()
    train = d[d['Year'] <= TRAIN_END_YEAR]
    test = d[d['Year'] > TRAIN_END_YEAR]
    X_tr, y_tr = train[features], train[target]
    X_te, y_te = test[features], test[target]
    scaler = StandardScaler()
    X_tr_s = scaler.fit_transform(X_tr)
    X_te_s = scaler.transform(X_te)
    print(f"   {label}: Train {len(train)}, Test {len(test)}")
    return X_tr, X_te, y_tr, y_te, X_tr_s, X_te_s, train, test, scaler


(X_train_rc, X_test_rc, y_train_rc, y_test_rc,
 X_train_rc_s, X_test_rc_s, train_rc, test_rc, scaler_rc) = \
    prepare_split(df, features_resource_curse, "RC BASELINE")

(X_train_ri, X_test_ri, y_train_ri, y_test_ri,
 X_train_ri_s, X_test_ri_s, train_ri, test_ri, scaler_ri) = \
    prepare_split(df, features_rc_interactions, "RC + INTERACTIONS")

(X_train_fu, X_test_fu, y_train_fu, y_test_fu,
 X_train_fu_s, X_test_fu_s, train_fu, test_fu, scaler_fu) = \
    prepare_split(df, features_full, "FULL STRUCTURAL")

# Train Ridge models
ridge_rc = Ridge(alpha=1.0, random_state=42).fit(X_train_rc_s, y_train_rc)
ridge_ri = Ridge(alpha=1.0, random_state=42).fit(X_train_ri_s, y_train_ri)
ridge_fu = Ridge(alpha=1.0, random_state=42).fit(X_train_fu_s, y_train_fu)

# ============================================================================
# CHART 1: COEFFICIENT STABILITY — ML Ridge across 3 specifications
# ============================================================================

print("\n3. Creating coefficient stability plot...")

# Track the 4 resource variables across specs
# Each spec has different feature lists, so map by semantic meaning
resource_vars = ['Oil (% GDP)', 'Natural Gas (% GDP)', 'Coal (% GDP)', 'Metals (% GDP)']
spec_names = ['RC Baseline', 'RC + Interactions', 'Full Structural']

# Build coefficient + bootstrap CI for each
def bootstrap_ridge_ci(X_train_s, y_train, alpha=1.0, n_boot=500, seed=42):
    """Bootstrap confidence intervals for Ridge coefficients."""
    rng = np.random.RandomState(seed)
    n = len(y_train)
    coefs = []
    for _ in range(n_boot):
        idx = rng.choice(n, n, replace=True)
        model = Ridge(alpha=alpha, random_state=42)
        model.fit(X_train_s[idx], y_train.values[idx])
        coefs.append(model.coef_)
    coefs = np.array(coefs)
    lower = np.percentile(coefs, 2.5, axis=0)
    upper = np.percentile(coefs, 97.5, axis=0)
    point = Ridge(alpha=alpha, random_state=42).fit(X_train_s, y_train).coef_
    return point, lower, upper


print("   Bootstrapping CIs for RC Baseline...")
coef_rc, lo_rc, hi_rc = bootstrap_ridge_ci(X_train_rc_s, y_train_rc)
print("   Bootstrapping CIs for RC + Interactions...")
coef_ri, lo_ri, hi_ri = bootstrap_ridge_ci(X_train_ri_s, y_train_ri)
print("   Bootstrapping CIs for Full Structural...")
coef_fu, lo_fu, hi_fu = bootstrap_ridge_ci(X_train_fu_s, y_train_fu)

# Map indices for resource vars in each feature list
def get_resource_indices(features):
    """Return indices of the 4 resource GDP pct variables."""
    targets = ['Oil_GDP_Pct', 'Natural Gas_GDP_Pct', 'Coal_GDP_Pct', 'Metals_GDP_Pct']
    return [features.index(t) for t in targets]


idx_rc = get_resource_indices(features_resource_curse)
idx_ri = get_resource_indices(features_rc_interactions)
idx_fu = get_resource_indices(features_full)

# Assemble data
stability_data = []
for i, var_label in enumerate(resource_vars):
    for spec, coef, lo, hi, idx_list in [
        ('RC Baseline', coef_rc, lo_rc, hi_rc, idx_rc),
        ('RC + Interactions', coef_ri, lo_ri, hi_ri, idx_ri),
        ('Full Structural', coef_fu, lo_fu, hi_fu, idx_fu),
    ]:
        j = idx_list[i]
        stability_data.append({
            'Variable': var_label,
            'Specification': spec,
            'Coefficient': coef[j],
            'CI_Lower': lo[j],
            'CI_Upper': hi[j],
        })

df_stab = pd.DataFrame(stability_data)

# Create the plot: one subplot per resource variable, specs on x-axis
fig_stab = make_subplots(
    rows=1, cols=4,
    subplot_titles=resource_vars,
    shared_yaxes=True,
    horizontal_spacing=0.06,
)

for col_idx, var in enumerate(resource_vars, 1):
    subset = df_stab[df_stab['Variable'] == var]
    for _, row in subset.iterrows():
        color = STYLE['spec_colors'][row['Specification']]
        fig_stab.add_trace(go.Scatter(
            x=[row['Specification']],
            y=[row['Coefficient']],
            error_y=dict(
                type='data',
                symmetric=False,
                array=[row['CI_Upper'] - row['Coefficient']],
                arrayminus=[row['Coefficient'] - row['CI_Lower']],
                color=color,
                thickness=1.5,
                width=6,
            ),
            mode='markers',
            marker=dict(size=10, color=color, symbol='circle'),
            showlegend=False,
            hovertemplate=(
                f"<b>{var}</b><br>"
                f"{row['Specification']}<br>"
                f"Coef: {row['Coefficient']:.4f}<br>"
                f"95% CI: [{row['CI_Lower']:.4f}, {row['CI_Upper']:.4f}]"
                "<extra></extra>"
            ),
        ), row=1, col=col_idx)

    # Zero line
    fig_stab.add_hline(
        y=0, line_dash='dash', line_color=STYLE['zero_line_color'],
        line_width=1, row=1, col=col_idx
    )

fig_stab.update_layout(
    **base_layout(
        height=420,
        margin=dict(l=60, r=30, t=50, b=100),
    ),
)

# Style x-axes
for i in range(1, 5):
    fig_stab.update_xaxes(
        tickangle=-35,
        tickfont=dict(size=10, family=STYLE['font_family']),
        row=1, col=i,
    )

fig_stab.update_yaxes(
    title=dict(text='Standardized Coefficient',
               font=dict(size=STYLE['axis_title_size'],
                         family=STYLE['font_family'])),
    tickfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
    row=1, col=1,
)

# Style subplot titles
for ann in fig_stab['layout']['annotations']:
    ann['font'] = dict(size=12, family=STYLE['font_family'], color=STYLE['title_color'])

save_html(fig_stab, 'ml_coefficient_stability.html')


# ============================================================================
# CHART 2: ECI vs RESOURCE DEPENDENCE — time-series for case countries
# ============================================================================

print("\n4. Creating ECI vs resource dependence time-series...")

# Prepare data
case_codes = list(CASE_COUNTRIES.keys())
df_case = df[df['Country Code'].isin(case_codes)].copy()
df_case = df_case[['Country Code', 'Country Name', 'Year',
                    'Economic Complexity Index', 'Total_Resources_GDP_Pct']].dropna()
df_case = df_case.sort_values(['Country Code', 'Year'])

# Color palette for 6 countries
case_colors = {
    'NOR': '#2e7d4a',   # green
    'NGA': '#c23a3a',   # red
    'CHL': '#4a6fa5',   # blue
    'ARE': '#d4a017',   # gold
    'MYS': '#7c3aed',   # purple
    'BWA': '#e07b39',   # orange
}

fig_ts = make_subplots(
    rows=2, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.08,
    row_heights=[0.5, 0.5],
)

for code, name in CASE_COUNTRIES.items():
    cdata = df_case[df_case['Country Code'] == code]
    color = case_colors[code]

    # ECI line (top panel)
    fig_ts.add_trace(go.Scatter(
        x=cdata['Year'],
        y=cdata['Economic Complexity Index'],
        mode='lines+markers',
        name=name,
        line=dict(color=color, width=2),
        marker=dict(size=4, color=color),
        legendgroup=name,
        hovertemplate=f"<b>{name}</b> (%{{x}})<br>ECI: %{{y:.2f}}<extra></extra>",
    ), row=1, col=1)

    # Resource dependence line (bottom panel)
    fig_ts.add_trace(go.Scatter(
        x=cdata['Year'],
        y=cdata['Total_Resources_GDP_Pct'],
        mode='lines+markers',
        name=name,
        line=dict(color=color, width=2),
        marker=dict(size=4, color=color),
        legendgroup=name,
        showlegend=False,
        hovertemplate=f"<b>{name}</b> (%{{x}})<br>Resources/GDP: %{{y:.1f}}%<extra></extra>",
    ), row=2, col=1)

fig_ts.update_layout(
    **base_layout(
        height=580,
        margin=dict(l=65, r=30, t=20, b=50),
    ),
    legend=dict(
        orientation='h', yanchor='top', y=-0.10, xanchor='center', x=0.5,
        font=dict(size=STYLE['legend_size'], family=STYLE['font_family']),
    ),
)

fig_ts.update_yaxes(
    title=dict(text='Economic Complexity Index',
               font=dict(size=STYLE['axis_title_size'],
                         family=STYLE['font_family'])),
    tickfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
    row=1, col=1,
)
fig_ts.update_yaxes(
    title=dict(text='Resource Production (% GDP)',
               font=dict(size=STYLE['axis_title_size'],
                         family=STYLE['font_family'])),
    tickfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
    row=2, col=1,
)
fig_ts.update_xaxes(
    tickfont=dict(size=STYLE['tick_size'], family=STYLE['font_family']),
    row=2, col=1,
)

save_html(fig_ts, 'eci_vs_resources_timeseries.html')


# ============================================================================
# CHART 3: SHAP DEPENDENCE — Human Capital vs SHAP value (colored High Resource)
# ============================================================================

print("\n5. Creating SHAP dependence plots...")

try:
    import shap
    import lightgbm as lgb

    # Train LightGBM on RC + Interactions (same as cap_ml.ipynb)
    lgbm_model = lgb.LGBMRegressor(
        n_estimators=200, learning_rate=0.05, max_depth=5,
        num_leaves=31, min_child_samples=10, subsample=0.8,
        colsample_bytree=0.8, random_state=42, verbose=-1
    )
    lgbm_model.fit(X_train_ri, y_train_ri)

    explainer = shap.TreeExplainer(lgbm_model)
    shap_values = explainer.shap_values(X_test_ri.astype(float))

    # Human Capital index in features_rc_interactions
    hci_idx = features_rc_interactions.index('Human capital index')
    hci_values = X_test_ri['Human capital index'].values
    hci_shap = shap_values[:, hci_idx]
    high_res_test = test_ri['Country Code'].isin(HIGH_RESOURCE_COUNTRIES).values

    fig_shap_dep = go.Figure()

    # Non-high-resource countries
    mask_other = ~high_res_test
    fig_shap_dep.add_trace(go.Scatter(
        x=hci_values[mask_other],
        y=hci_shap[mask_other],
        mode='markers',
        name='Other Countries',
        marker=dict(
            size=6, color=STYLE['bar_color_alt'], opacity=0.5,
            line=dict(width=0.3, color='white'),
        ),
        text=test_ri['Country Name'].values[mask_other],
        customdata=test_ri['Year'].values[mask_other],
        hovertemplate=(
            '<b>%{text}</b> (%{customdata})<br>'
            'Human Capital: %{x:.2f}<br>'
            'SHAP Value: %{y:.3f}<extra></extra>'
        ),
    ))

    # High-resource countries
    mask_hr = high_res_test
    fig_shap_dep.add_trace(go.Scatter(
        x=hci_values[mask_hr],
        y=hci_shap[mask_hr],
        mode='markers',
        name='High Resource Countries',
        marker=dict(
            size=7, color=STYLE['neg_color'], opacity=0.7,
            line=dict(width=0.3, color='white'),
            symbol='diamond',
        ),
        text=test_ri['Country Name'].values[mask_hr],
        customdata=test_ri['Year'].values[mask_hr],
        hovertemplate=(
            '<b>%{text}</b> (%{customdata})<br>'
            'Human Capital: %{x:.2f}<br>'
            'SHAP Value: %{y:.3f}<extra></extra>'
        ),
    ))

    # Zero line
    fig_shap_dep.add_hline(y=0, line_dash='dash',
                           line_color=STYLE['zero_line_color'], line_width=1)

    fig_shap_dep.update_layout(
        **base_layout(
            height=520,
            margin=dict(l=65, r=30, t=20, b=70),
        ),
        xaxis=styled_axis('Human Capital Index'),
        yaxis=styled_axis('SHAP Value (impact on predicted ECI)'),
        legend=dict(
            orientation='h', yanchor='top', y=-0.12, xanchor='center', x=0.5,
            font=dict(size=STYLE['legend_size'], family=STYLE['font_family']),
        ),
    )

    save_html(fig_shap_dep, 'ml_shap_dependence_hci.html')

    # --- Also do one for Total Resources if useful ---
    # Oil GDP Pct
    oil_idx = features_rc_interactions.index('Oil_GDP_Pct')
    oil_values = X_test_ri['Oil_GDP_Pct'].values
    oil_shap = shap_values[:, oil_idx]

    fig_shap_oil = go.Figure()

    fig_shap_oil.add_trace(go.Scatter(
        x=oil_values[mask_other],
        y=oil_shap[mask_other],
        mode='markers',
        name='Other Countries',
        marker=dict(
            size=6, color=STYLE['bar_color_alt'], opacity=0.5,
            line=dict(width=0.3, color='white'),
        ),
        text=test_ri['Country Name'].values[mask_other],
        customdata=test_ri['Year'].values[mask_other],
        hovertemplate=(
            '<b>%{text}</b> (%{customdata})<br>'
            'Oil (% GDP): %{x:.2f}<br>'
            'SHAP Value: %{y:.3f}<extra></extra>'
        ),
    ))

    fig_shap_oil.add_trace(go.Scatter(
        x=oil_values[mask_hr],
        y=oil_shap[mask_hr],
        mode='markers',
        name='High Resource Countries',
        marker=dict(
            size=7, color=STYLE['neg_color'], opacity=0.7,
            line=dict(width=0.3, color='white'),
            symbol='diamond',
        ),
        text=test_ri['Country Name'].values[mask_hr],
        customdata=test_ri['Year'].values[mask_hr],
        hovertemplate=(
            '<b>%{text}</b> (%{customdata})<br>'
            'Oil (% GDP): %{x:.2f}<br>'
            'SHAP Value: %{y:.3f}<extra></extra>'
        ),
    ))

    fig_shap_oil.add_hline(y=0, line_dash='dash',
                           line_color=STYLE['zero_line_color'], line_width=1)

    fig_shap_oil.update_layout(
        **base_layout(
            height=520,
            margin=dict(l=65, r=30, t=20, b=70),
        ),
        xaxis=styled_axis('Oil Production (% GDP)'),
        yaxis=styled_axis('SHAP Value (impact on predicted ECI)'),
        legend=dict(
            orientation='h', yanchor='top', y=-0.12, xanchor='center', x=0.5,
            font=dict(size=STYLE['legend_size'], family=STYLE['font_family']),
        ),
    )

    save_html(fig_shap_oil, 'ml_shap_dependence_oil.html')

except ImportError:
    print("   WARNING: shap or lightgbm not installed. Skipping SHAP dependence plots.")
    print("   Install with: pip install shap lightgbm")
except Exception as e:
    print(f"   WARNING: SHAP error: {e}")


# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "=" * 70)
print("OUTPUTS GENERATED")
print("=" * 70)
print(f"  1. ml_coefficient_stability.html    — Ridge coef stability across 3 ML specs")
print(f"  2. eci_vs_resources_timeseries.html  — ECI vs resource dependence for 6 countries")
print(f"  3. ml_shap_dependence_hci.html       — SHAP dependence: Human Capital (if shap installed)")
print(f"  4. ml_shap_dependence_oil.html       — SHAP dependence: Oil % GDP (if shap installed)")
print(f"\n  Output directory: {output_dir}")
print("=" * 70)

GENERATING ADDITIONAL APPENDIX VISUALIZATIONS

1. Loading data...
   Data loaded: 3150 rows, 126 countries

2. Training Ridge models across specifications...
   RC BASELINE: Train 2394, Test 756
   RC + INTERACTIONS: Train 2394, Test 756
   FULL STRUCTURAL: Train 2394, Test 756

3. Creating coefficient stability plot...
   Bootstrapping CIs for RC Baseline...
   Bootstrapping CIs for RC + Interactions...
   Bootstrapping CIs for Full Structural...
   Saved: ml_coefficient_stability.html

4. Creating ECI vs resource dependence time-series...
   Saved: eci_vs_resources_timeseries.html

5. Creating SHAP dependence plots...
   Saved: ml_shap_dependence_hci.html
   Saved: ml_shap_dependence_oil.html

OUTPUTS GENERATED
  1. ml_coefficient_stability.html    — Ridge coef stability across 3 ML specs
  2. eci_vs_resources_timeseries.html  — ECI vs resource dependence for 6 countries
  3. ml_shap_dependence_hci.html       — SHAP dependence: Human Capital (if shap installed)
  4. ml_shap_dependenc