In [13]:
"""
COMPREHENSIVE NATURAL RESOURCE & CLUSTERING ANALYSIS
- Production Maps (interactive)
- Train Once, Predict Always Clustering
- PCA Visualization (post-clustering)
- Temporal Tracking
"""

import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.decomposition import PCA
import plotly.express as px
import os

# ============================================================================
# CONFIGURATION
# ============================================================================

input_file = "/Users/leoss/Desktop/GitHub/Capstone/MASTER/Master.csv"
production_file = "/Users/leoss/Desktop/GitHub/Capstone/MASTER/NaturalResource.csv"
output_dir = "/Users/leoss/Desktop/Portfolio/Website-/capstone_visualizations/individual_plots/cluster"

os.makedirs(output_dir, exist_ok=True)

print("="*70)
print("COMPREHENSIVE NATURAL RESOURCE ANALYSIS")
print("Production Maps + Clustering + PCA")
print("="*70)

# ============================================================================
# 1. DATA PREPARATION
# ============================================================================

print("\n1. Loading and preparing data...")

df_prod = pd.read_csv(production_file)
df_master = pd.read_csv(input_file)

print(f"   Production data: {len(df_prod)} rows")
print(f"   Master data: {len(df_master)} rows")

# Categorize resources
def categorize_resource(resource):
    if resource == 'Oil': return 'Oil'
    elif resource == 'Natural Gas': return 'Natural Gas'
    elif resource == 'Coal': return 'Coal'
    else: return 'Metals'

df_prod['Resource_Category'] = df_prod['Resource'].apply(categorize_resource)

# Aggregate by category
prod_agg = df_prod.groupby(['Country Name', 'Year', 'Resource_Category'])['Production_TotalValue'].sum().reset_index()

prod_wide = prod_agg.pivot_table(
    index=['Country Name', 'Year'], 
    columns='Resource_Category', 
    values='Production_TotalValue', 
    fill_value=0
).reset_index()

resource_cols = [c for c in prod_wide.columns if c not in ['Country Name', 'Year']]
prod_wide['Total'] = prod_wide[resource_cols].sum(axis=1)

# Merge with master
master_data = df_master.copy()

map_data = prod_wide.merge(master_data, on=['Country Name', 'Year'], how='inner')

# Calculate derived metrics
map_data['GDP_total'] = map_data['GDP per capita (constant prices, PPP)'] * map_data['Population']

for res in ['Total', 'Oil', 'Natural Gas', 'Coal', 'Metals']:
    if res in map_data.columns:
        map_data[f'{res}_Per_Capita'] = map_data[res] / map_data['Population']
        map_data[f'{res}_GDP_Norm'] = (map_data[res] / map_data['GDP_total']) * 100

print(f"   Merged: {len(map_data)} country-years, {map_data['Country Code'].nunique()} countries")
print(f"   Years: {map_data['Year'].min()} - {map_data['Year'].max()}")

# ============================================================================
# 2. PRODUCTION MAP WITH SYNCED DROPDOWNS
# ============================================================================

print("\n2. Creating production map...")

initial_data = map_data[map_data['Year'] == 2019]
traces = []

for resource in ['Total', 'Oil', 'Natural Gas', 'Coal', 'Metals']:
    for suffix, norm_type, units in [('', 'absolute', 'USD'), 
                                     ('_Per_Capita', 'per_capita', 'USD/person'), 
                                     ('_GDP_Norm', 'gdp_norm', '% GDP')]:
        col = f'{resource}{suffix}'
        z = initial_data[col].fillna(0)
        
        if norm_type == 'absolute':
            hover = [f"${v/1e9:.2f}B" if v >= 1e9 else f"${v/1e6:.1f}M" if v >= 1e6 else f"${v:,.0f}" for v in z]
        elif norm_type == 'per_capita':
            hover = [f"${v:,.0f}" for v in z]
        else:
            hover = [f"{v:.2f}%" for v in z]
        
        traces.append(go.Choropleth(
            locations=initial_data['Country Code'],
            z=z,
            text=initial_data['Country Name'],
            customdata=hover,
            colorscale='YlOrRd',
            marker=dict(line=dict(color='#999999', width=0.5)),
            colorbar=dict(title=units, len=0.7),
            hovertemplate=f'<b>%{{text}}</b><br>{resource}: %{{customdata}}<extra></extra>',
            visible=False
        ))

traces[0].visible = True

# Create slider steps
slider_steps = []
for year in sorted(map_data['Year'].unique()):
    year_data = map_data[map_data['Year'] == year]
    z_list, hover_list = [], []
    
    for resource in ['Total', 'Oil', 'Natural Gas', 'Coal', 'Metals']:
        for suffix, norm_type in [('', 'absolute'), ('_Per_Capita', 'per_capita'), ('_GDP_Norm', 'gdp_norm')]:
            col = f'{resource}{suffix}'
            z = year_data[col].fillna(0)
            
            if norm_type == 'absolute':
                hover = [f"${v/1e9:.2f}B" if v >= 1e9 else f"${v/1e6:.1f}M" if v >= 1e6 else f"${v:,.0f}" for v in z]
            elif norm_type == 'per_capita':
                hover = [f"${v:,.0f}" for v in z]
            else:
                hover = [f"{v:.2f}%" for v in z]
            
            z_list.append(z.tolist())
            hover_list.append(hover)
    
    slider_steps.append({
        'method': 'restyle',
        'args': [{
            'z': z_list,
            'customdata': hover_list,
            'locations': [year_data['Country Code'].tolist()] * 15,
            'text': [year_data['Country Name'].tolist()] * 15
        }],
        'label': str(year)
    })

fig_prod = go.Figure(data=traces)

fig_prod.update_layout(
    sliders=[{
        'active': len(slider_steps) - 1,
        'yanchor': 'top',
        'xanchor': 'left',
        'currentvalue': {
            'prefix': 'Year: ',
            'visible': True,
            'xanchor': 'center',
            'font': {'size': 18, 'color': '#002A54'}
        },
        'pad': {'b': 10, 't': 50},
        'len': 0.9,
        'x': 0.05,
        'y': 0,
        'steps': slider_steps,
        'transition': {'duration': 0}
    }],
    title={
        'text': "Natural Resource Production",
        'x': 0.5,
        'font': {'size': 22, 'color': '#002A54'}
    },
    geo=dict(
        showframe=False,
        showcoastlines=True,
        coastlinecolor='#aaaaaa',
        projection_type='natural earth',
        bgcolor='#e3f2fd',
        showland=True,
        landcolor='#fafafa',
        showcountries=True,
        countrycolor='#999999',
        countrywidth=0.5
    ),
    height=700,
    margin={"r":50,"t":120,"l":50,"b":120}
)

# Add custom controls
fig_html = fig_prod.to_html(include_plotlyjs='cdn', config={'displayModeBar': False})

controls_html = """
<div style="position: fixed; top: 20px; left: 20px; z-index: 1000; background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); border: 2px solid #002A54;">
    <label style="font-weight: 600; color: #002A54; margin-right: 10px;">Resource:</label>
    <select id="resourceSelect" style="padding: 8px; border: 2px solid #002A54; border-radius: 4px; font-size: 14px;">
        <option value="0">Total</option>
        <option value="1">Oil</option>
        <option value="2">Natural Gas</option>
        <option value="3">Coal</option>
        <option value="4">Metals</option>
    </select>
</div>
<div style="position: fixed; top: 20px; right: 20px; z-index: 1000; background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); border: 2px solid #E30613;">
    <label style="font-weight: 600; color: #E30613; margin-right: 10px;">View:</label>
    <select id="normSelect" style="padding: 8px; border: 2px solid #E30613; border-radius: 4px; font-size: 14px;">
        <option value="0">Absolute</option>
        <option value="1">Per Capita</option>
        <option value="2">% of GDP</option>
    </select>
</div>
<script>
let currentResource = 0;
let currentNorm = 0;

function updateMap() {
    const vis = Array(15).fill(false);
    vis[currentResource * 3 + currentNorm] = true;
    const plotDiv = document.getElementsByClassName('plotly-graph-div')[0];
    if (plotDiv) {
        Plotly.restyle(plotDiv, {visible: vis});
    }
}

setTimeout(function() {
    document.getElementById('resourceSelect').addEventListener('change', function() {
        currentResource = parseInt(this.value);
        updateMap();
    });
    
    document.getElementById('normSelect').addEventListener('change', function() {
        currentNorm = parseInt(this.value);
        updateMap();
    });
}, 100);
</script>
"""

fig_html = fig_html.replace('<body>', '<body>' + controls_html)

with open(os.path.join(output_dir, 'map_production.html'), 'w') as f:
    f.write(fig_html)

print("   ✓ Production map saved")

# ============================================================================
# 3. CLUSTERING - TRAIN ONCE ON REFERENCE YEAR
# ============================================================================

print("\n3. Training clustering model on reference year (2019)...")

REFERENCE_YEAR = 2019
N_CLUSTERS = 6

feature_cols = [
    'Metals_GDP_Norm', 
    'Oil_GDP_Norm', 
    'Natural Gas_GDP_Norm', 
    'Coal_GDP_Norm',
    'Economic Complexity Index', 
    'Human capital index'
]

ref_data = map_data[map_data['Year'] == REFERENCE_YEAR].copy()
ref_clean = ref_data[['Country Code', 'Country Name'] + feature_cols].dropna()

print(f"   Reference year countries: {len(ref_clean)}")

# Fit scaler and KMeans on reference year
scaler_ref = StandardScaler()
X_ref_scaled = scaler_ref.fit_transform(ref_clean[feature_cols])

kmeans_ref = KMeans(
    n_clusters=N_CLUSTERS, 
    random_state=42, 
    n_init=50,
    max_iter=500
)
kmeans_ref.fit(X_ref_scaled)

ref_clean['Cluster'] = kmeans_ref.labels_

sil_score = silhouette_score(X_ref_scaled, kmeans_ref.labels_)
print(f"   Silhouette score (2019): {sil_score:.3f}")

# ============================================================================
# 4. CLUSTER NAMING
# ============================================================================

print("\n4. Analyzing cluster profiles...")

centroids_scaled = kmeans_ref.cluster_centers_
centroids_original = scaler_ref.inverse_transform(centroids_scaled)
centroids_df = pd.DataFrame(centroids_original, columns=feature_cols)
centroids_df['Cluster'] = range(N_CLUSTERS)

print("\n   Cluster Centroids (original scale):")
print(centroids_df.round(3).to_string())

# Fixed cluster names based on centroid analysis
cluster_names = {
    0: 'Petrostates',                     
    1: 'Low-Income Diversified',          
    2: 'Advanced Economies',              
    3: 'Coal & Metals Outlier',           
    4: 'Mining-Dependent',                
    5: 'Wealthy Hydrocarbon Exporters'    
}

print("\n   Cluster Names:")
for c, name in sorted(cluster_names.items()):
    count = (ref_clean['Cluster'] == c).sum()
    print(f"   Cluster {c}: {name} ({count} countries)")

print("\n   Sample countries per cluster:")
for c in range(N_CLUSTERS):
    countries = ref_clean[ref_clean['Cluster'] == c]['Country Name'].head(5).tolist()
    print(f"   Cluster {c} ({cluster_names[c]}): {', '.join(countries)}")

# ============================================================================
# 5. PCA ANALYSIS (POST-CLUSTERING)
# ============================================================================

print("\n5. Performing PCA analysis...")

pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_ref_scaled)

ref_clean['PC1'] = X_pca[:, 0]
ref_clean['PC2'] = X_pca[:, 1]
ref_clean['Cluster_Name'] = ref_clean['Cluster'].map(cluster_names)

print(f"   Variance explained: PC1={pca.explained_variance_ratio_[0]:.1%}, PC2={pca.explained_variance_ratio_[1]:.1%}")
print(f"   Total variance explained: {pca.explained_variance_ratio_.sum():.1%}")

# PCA Loadings
loadings = pd.DataFrame(
    pca.components_.T,
    columns=['PC1', 'PC2'],
    index=feature_cols
)
print("\n   PCA Loadings:")
print(loadings.round(3).to_string())

# ============================================================================
# 6. PCA SCATTER PLOT (COLORED BY CLUSTER)
# ============================================================================

# ============================================================================
# 6. PCA SCATTER PLOT (COLORED BY CLUSTER)
# ============================================================================

print("\n6. Creating PCA scatter plot...")

colors = ['#ef4444', '#22c55e', '#3b82f6', '#a855f7', '#f59e0b', '#06b6d4']

fig_pca = go.Figure()

for cluster_id in range(N_CLUSTERS):
    cluster_data = ref_clean[ref_clean['Cluster'] == cluster_id]
    
    fig_pca.add_trace(go.Scatter(
        x=cluster_data['PC1'],
        y=cluster_data['PC2'],
        mode='markers',  # CHANGED: markers only, no text
        name=cluster_names[cluster_id],
        marker=dict(
            size=12,
            color=colors[cluster_id],
            line=dict(width=1, color='white')
        ),
        hovertemplate='<b>%{customdata}</b><br>PC1: %{x:.2f}<br>PC2: %{y:.2f}<extra>' + cluster_names[cluster_id] + '</extra>',
        customdata=cluster_data['Country Name']
    ))

# Add loading vectors
scale_factor = 3
label_map = {
    'Metals_GDP_Norm': 'Metals',
    'Oil_GDP_Norm': 'Oil',
    'Natural Gas_GDP_Norm': 'Gas',
    'Coal_GDP_Norm': 'Coal',
    'Economic Complexity Index': 'ECI',
    'Human capital index': 'HCI'
}

for i, feature in enumerate(feature_cols):
    # Arrow
    fig_pca.add_annotation(
        x=loadings.loc[feature, 'PC1'] * scale_factor,
        y=loadings.loc[feature, 'PC2'] * scale_factor,
        ax=0, ay=0,
        xref='x', yref='y',
        axref='x', ayref='y',
        showarrow=True,
        arrowhead=2,
        arrowsize=1,
        arrowwidth=2,
        arrowcolor='black'  # CHANGED: black arrows
    )
    
    # Feature labels - BLACK with white background
    fig_pca.add_annotation(
        x=loadings.loc[feature, 'PC1'] * scale_factor * 1.15,
        y=loadings.loc[feature, 'PC2'] * scale_factor * 1.15,
        text=f"<b>{label_map.get(feature, feature)}</b>",  # CHANGED: bold
        showarrow=False,
        font=dict(size=11, color='black', family='Arial Black'),  # CHANGED: black
        bgcolor='rgba(255,255,255,0.85)',  # ADDED: white background
        bordercolor='black',  # ADDED: border
        borderwidth=1,
        borderpad=3
    )

fig_pca.update_layout(
    title=dict(
        text=f'PCA: Countries by Resource & Development Profile<br><sup>PC1 ({pca.explained_variance_ratio_[0]:.0%}) vs PC2 ({pca.explained_variance_ratio_[1]:.0%}) - Arrows show feature loadings</sup>',
        x=0.5,
        font=dict(size=16)
    ),
    xaxis_title=f'PC1: Economic Development ({pca.explained_variance_ratio_[0]:.0%} variance)',
    yaxis_title=f'PC2: Hydrocarbon Intensity ({pca.explained_variance_ratio_[1]:.0%} variance)',
    legend=dict(
        title='Cluster',
        orientation='h',
        yanchor='bottom',
        y=-0.2,
        xanchor='center',
        x=0.5
    ),
    height=700,
    template='plotly_white',
    margin=dict(b=120)
)

fig_pca.write_html(os.path.join(output_dir, 'pca_scatter_clusters.html'))
print("   ✓ PCA scatter plot saved")
# ============================================================================
# 7. PCA LOADINGS BAR CHARTS
# ============================================================================

print("\n7. Creating PCA loadings charts...")

feature_labels_clean = {
    'Metals_GDP_Norm': 'Metals (% GDP)',
    'Oil_GDP_Norm': 'Oil (% GDP)',
    'Natural Gas_GDP_Norm': 'Natural Gas (% GDP)',
    'Coal_GDP_Norm': 'Coal (% GDP)',
    'Economic Complexity Index': 'Economic Complexity',
    'Human capital index': 'Human Capital'
}

# PC1 Loadings
loadings_pc1 = loadings['PC1'].sort_values()
fig_pc1 = go.Figure()
fig_pc1.add_trace(go.Bar(
    y=[feature_labels_clean.get(f, f) for f in loadings_pc1.index],
    x=loadings_pc1.values,
    orientation='h',
    marker_color=['#22c55e' if x > 0 else '#ef4444' for x in loadings_pc1],
    text=[f"{x:.3f}" for x in loadings_pc1.values],
    textposition='outside'
))
fig_pc1.add_vline(x=0, line_dash='dash', line_color='black')
fig_pc1.update_layout(
    title=dict(text=f"PC1 Loadings ({pca.explained_variance_ratio_[0]:.1%} variance)<br><sup>Positive = Higher Development</sup>", x=0.5),
    xaxis_title="Loading",
    height=450,
    template='plotly_white',
    margin=dict(l=150)
)
fig_pc1.write_html(os.path.join(output_dir, 'pca_loadings_pc1.html'))

# PC2 Loadings
loadings_pc2 = loadings['PC2'].sort_values()
fig_pc2 = go.Figure()
fig_pc2.add_trace(go.Bar(
    y=[feature_labels_clean.get(f, f) for f in loadings_pc2.index],
    x=loadings_pc2.values,
    orientation='h',
    marker_color=['#22c55e' if x > 0 else '#ef4444' for x in loadings_pc2],
    text=[f"{x:.3f}" for x in loadings_pc2.values],
    textposition='outside'
))
fig_pc2.add_vline(x=0, line_dash='dash', line_color='black')
fig_pc2.update_layout(
    title=dict(text=f"PC2 Loadings ({pca.explained_variance_ratio_[1]:.1%} variance)<br><sup>Positive = Higher Hydrocarbon Intensity</sup>", x=0.5),
    xaxis_title="Loading",
    height=450,
    template='plotly_white',
    margin=dict(l=150)
)
fig_pc2.write_html(os.path.join(output_dir, 'pca_loadings_pc2.html'))

print("   ✓ PCA loadings charts saved")

# ============================================================================
# 8. APPLY CLUSTERING TO ALL YEARS
# ============================================================================

print("\n8. Applying model to all years...")

all_years_results = []

for year in sorted(map_data['Year'].unique()):
    year_data = map_data[map_data['Year'] == year].copy()
    year_clean = year_data[['Country Code', 'Country Name', 'Year'] + feature_cols].dropna()
    
    if len(year_clean) == 0:
        continue
    
    X_year_scaled = scaler_ref.transform(year_clean[feature_cols])
    year_clean['Cluster'] = kmeans_ref.predict(X_year_scaled)
    year_clean['Cluster_Name'] = year_clean['Cluster'].map(cluster_names)
    
    # Also compute PCA coordinates for each year
    X_year_pca = pca.transform(X_year_scaled)
    year_clean['PC1'] = X_year_pca[:, 0]
    year_clean['PC2'] = X_year_pca[:, 1]
    
    all_years_results.append(year_clean)
    
    if year in [1995, 2000, 2005, 2010, 2015, 2019]:
        dist = year_clean['Cluster'].value_counts().sort_index()
        print(f"   {year}: {dict(dist)}")

temporal_data = pd.concat(all_years_results, ignore_index=True)
print(f"\n   Total observations: {len(temporal_data)}")


# ============================================================================
# 10. TEMPORAL MOVEMENT ANALYSIS
# ============================================================================

print("\n10. Analyzing country movements over time...")

first_last = temporal_data.groupby('Country Code').agg({
    'Year': ['min', 'max'],
    'Country Name': 'first'
}).reset_index()
first_last.columns = ['Country Code', 'First_Year', 'Last_Year', 'Country Name']

movements = []
for _, row in first_last.iterrows():
    country = row['Country Code']
    
    first_obs = temporal_data[(temporal_data['Country Code'] == country) & 
                              (temporal_data['Year'] == row['First_Year'])]
    last_obs = temporal_data[(temporal_data['Country Code'] == country) & 
                             (temporal_data['Year'] == row['Last_Year'])]
    
    if len(first_obs) > 0 and len(last_obs) > 0:
        movements.append({
            'Country Code': country,
            'Country Name': row['Country Name'],
            'First_Year': row['First_Year'],
            'Last_Year': row['Last_Year'],
            'Initial_Cluster': first_obs.iloc[0]['Cluster'],
            'Final_Cluster': last_obs.iloc[0]['Cluster'],
            'Initial_Cluster_Name': cluster_names[first_obs.iloc[0]['Cluster']],
            'Final_Cluster_Name': cluster_names[last_obs.iloc[0]['Cluster']]
        })

movements_df = pd.DataFrame(movements)
movements_df['Changed'] = movements_df['Initial_Cluster'] != movements_df['Final_Cluster']

print(f"\n   Countries that changed clusters: {movements_df['Changed'].sum()} / {len(movements_df)}")

print("\n   Notable Transitions:")
changers = movements_df[movements_df['Changed']].copy()
for _, row in changers.head(15).iterrows():
    print(f"   {row['Country Name']}: {row['Initial_Cluster_Name']} → {row['Final_Cluster_Name']}")

# ============================================================================
# 11. SANKEY DIAGRAM
# ============================================================================

print("\n11. Creating Sankey diagram...")

transition_matrix = pd.crosstab(
    movements_df['Initial_Cluster_Name'], 
    movements_df['Final_Cluster_Name']
)

print("\n   Transition Matrix:")
print(transition_matrix)

all_clusters = list(cluster_names.values())
source_labels = [f"{c} (Initial)" for c in all_clusters]
target_labels = [f"{c} (Final)" for c in all_clusters]
all_labels = source_labels + target_labels

source_indices = []
target_indices = []
values = []
link_colors = []

color_map = {name: colors[i] for i, name in enumerate(cluster_names.values())}

for i, src in enumerate(all_clusters):
    for j, tgt in enumerate(all_clusters):
        if src in transition_matrix.index and tgt in transition_matrix.columns:
            val = transition_matrix.loc[src, tgt]
            if val > 0:
                source_indices.append(i)
                target_indices.append(len(all_clusters) + j)
                values.append(val)
                base_color = color_map[src]
                alpha = 0.8 if src == tgt else 0.4
                rgb = tuple(int(base_color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
                link_colors.append(f'rgba({rgb[0]},{rgb[1]},{rgb[2]},{alpha})')

fig_sankey = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color='black', width=0.5),
        label=all_labels,
        color=[color_map.get(l.replace(' (Initial)', '').replace(' (Final)', ''), '#999') 
               for l in all_labels]
    ),
    link=dict(
        source=source_indices,
        target=target_indices,
        value=values,
        color=link_colors
    )
)])

fig_sankey.update_layout(
    title=dict(
        text=f'Country Cluster Transitions ({movements_df["First_Year"].min()}-{movements_df["Last_Year"].max()})',
        x=0.5,
        font=dict(size=16)
    ),
    height=600,
    font=dict(size=12)
)

fig_sankey.write_html(os.path.join(output_dir, 'cluster_sankey_transitions.html'))
print("   ✓ Sankey diagram saved")

# ============================================================================
# 12. CLUSTER MAP
# ============================================================================

print("\n12. Creating cluster map...")

ref_clean_with_names = ref_clean.copy()
ref_clean_with_names['Cluster_Name'] = ref_clean_with_names['Cluster'].map(cluster_names)

fig_cluster_map = go.Figure()

for cluster_id in range(N_CLUSTERS):
    df_cluster = ref_clean_with_names[ref_clean_with_names['Cluster'] == cluster_id]
    
    fig_cluster_map.add_trace(go.Choropleth(
        locations=df_cluster['Country Code'],
        z=[cluster_id] * len(df_cluster),
        text=df_cluster['Country Name'],
        name=cluster_names[cluster_id],
        colorscale=[[0, colors[cluster_id]], [1, colors[cluster_id]]],
        zmin=0, zmax=N_CLUSTERS-1,
        showscale=False,
        marker=dict(line=dict(color='white', width=0.5)),
        hovertemplate='<b>%{text}</b><br>' + cluster_names[cluster_id] + '<extra></extra>'
    ))

fig_cluster_map.update_layout(
    title=dict(text=f'Country Clusters ({REFERENCE_YEAR})', x=0.5, font=dict(size=16)),
    geo=dict(
        showframe=False,
        showcoastlines=True,
        coastlinecolor='#d1d5db',
        projection_type='natural earth',
        bgcolor='rgba(0,0,0,0)',
        landcolor='#f3f4f6',
        countrycolor='#d1d5db'
    ),
    legend=dict(title='Cluster', x=0.99, y=0.99, xanchor='right', yanchor='top', bgcolor='rgba(255,255,255,0.9)'),
    height=550,
    margin=dict(l=10, r=10, t=60, b=10)
)

fig_cluster_map.write_html(os.path.join(output_dir, 'map_clusters.html'))
print("   ✓ Cluster map saved")

# ============================================================================
# 13. ANIMATED CLUSTER MAP
# ============================================================================

print("\n13. Creating animated cluster map...")

years = sorted(temporal_data['Year'].unique())
frames = []

for year in years:
    year_data = temporal_data[temporal_data['Year'] == year]
    frame_data = []
    
    for cluster_id in range(N_CLUSTERS):
        df_cluster = year_data[year_data['Cluster'] == cluster_id]
        frame_data.append(go.Choropleth(
            locations=df_cluster['Country Code'],
            z=[cluster_id] * len(df_cluster),
            text=df_cluster['Country Name'],
            colorscale=[[0, colors[cluster_id]], [1, colors[cluster_id]]],
            zmin=0, zmax=N_CLUSTERS-1,
            showscale=False,
            marker=dict(line=dict(color='white', width=0.5)),
            hovertemplate='<b>%{text}</b><br>' + cluster_names[cluster_id] + '<extra></extra>'
        ))
    
    frames.append(go.Frame(data=frame_data, name=str(year)))

initial_year = years[0]
initial_data_anim = temporal_data[temporal_data['Year'] == initial_year]

fig_animated = go.Figure()

for cluster_id in range(N_CLUSTERS):
    df_cluster = initial_data_anim[initial_data_anim['Cluster'] == cluster_id]
    fig_animated.add_trace(go.Choropleth(
        locations=df_cluster['Country Code'],
        z=[cluster_id] * len(df_cluster),
        text=df_cluster['Country Name'],
        name=cluster_names[cluster_id],
        colorscale=[[0, colors[cluster_id]], [1, colors[cluster_id]]],
        zmin=0, zmax=N_CLUSTERS-1,
        showscale=False,
        marker=dict(line=dict(color='white', width=0.5)),
        hovertemplate='<b>%{text}</b><br>' + cluster_names[cluster_id] + '<extra></extra>'
    ))

fig_animated.frames = frames

fig_animated.update_layout(
    title=dict(text='Cluster Evolution Over Time', x=0.5, font=dict(size=16)),
    geo=dict(
        showframe=False,
        showcoastlines=True,
        coastlinecolor='#d1d5db',
        projection_type='natural earth',
        bgcolor='rgba(0,0,0,0)',
        landcolor='#f3f4f6',
        countrycolor='#d1d5db'
    ),
    legend=dict(title='Cluster', x=0.99, y=0.99, xanchor='right', yanchor='top', bgcolor='rgba(255,255,255,0.9)'),
    updatemenus=[{
        'type': 'buttons',
        'showactive': False,
        'y': 0,
        'x': 0.1,
        'xanchor': 'right',
        'yanchor': 'top',
        'buttons': [
            {'label': '▶ Play', 'method': 'animate', 'args': [None, {'frame': {'duration': 500, 'redraw': True}, 'fromcurrent': True, 'transition': {'duration': 200}}]},
            {'label': '⏸ Pause', 'method': 'animate', 'args': [[None], {'frame': {'duration': 0, 'redraw': False}, 'mode': 'immediate', 'transition': {'duration': 0}}]}
        ]
    }],
    sliders=[{
        'active': 0,
        'yanchor': 'top',
        'xanchor': 'left',
        'currentvalue': {'prefix': 'Year: ', 'visible': True, 'xanchor': 'center', 'font': {'size': 16}},
        'pad': {'b': 10, 't': 50},
        'len': 0.9,
        'x': 0.05,
        'y': 0,
        'steps': [{'args': [[str(year)], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate', 'transition': {'duration': 0}}], 'label': str(year), 'method': 'animate'} for year in years]
    }],
    height=650,
    margin=dict(l=10, r=10, t=60, b=80)
)

fig_animated.write_html(os.path.join(output_dir, 'cluster_map_animated.html'))
print("   ✓ Animated map saved")

# ============================================================================
# 15. SAVE DATA FILES
# ============================================================================

print("\n15. Saving data files...")

temporal_data.to_csv(os.path.join(output_dir, 'cluster_assignments_all_years.csv'), index=False)
movements_df.to_csv(os.path.join(output_dir, 'cluster_movements_summary.csv'), index=False)
centroids_df['Cluster_Name'] = centroids_df['Cluster'].map(cluster_names)
centroids_df.to_csv(os.path.join(output_dir, 'cluster_centroids.csv'), index=False)
transition_matrix.to_csv(os.path.join(output_dir, 'cluster_transition_matrix.csv'))
loadings.to_csv(os.path.join(output_dir, 'pca_loadings.csv'))

print("   ✓ All data files saved")

# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "="*70)
print("✅ COMPREHENSIVE ANALYSIS COMPLETE!")
print("="*70)

print(f"\nMethodology:")
print(f"  • Reference year: {REFERENCE_YEAR}")
print(f"  • Clusters: {N_CLUSTERS}")
print(f"  • Silhouette score: {sil_score:.3f}")
print(f"  • PCA variance explained: {pca.explained_variance_ratio_.sum():.1%}")
print(f"  • Approach: Train Once, Predict Always")

print(f"\nCluster Definitions (frozen from {REFERENCE_YEAR}):")
for c, name in sorted(cluster_names.items()):
    count = (ref_clean['Cluster'] == c).sum()
    print(f"  Cluster {c}: {name} ({count} countries)")

print(f"\nTemporal Coverage:")
print(f"  • Years: {temporal_data['Year'].min()} - {temporal_data['Year'].max()}")
print(f"  • Countries tracked: {temporal_data['Country Code'].nunique()}")
print(f"  • Countries that changed: {movements_df['Changed'].sum()}")

print(f"\nOutputs saved to: {output_dir}")
print("""
  MAPS:
  1. map_production.html - Interactive resource production map
  2. map_clusters.html - Static cluster map (2019)
  3. cluster_map_animated.html - Animated cluster evolution
  
  PCA:
  4. pca_scatter_clusters.html - PCA scatter with cluster colors
  5. pca_loadings_pc1.html - PC1 loadings
  6. pca_loadings_pc2.html - PC2 loadings
  
  CLUSTER ANALYSIS:
  7. cluster_radar_profiles.html - Cluster characteristic profiles
  8. cluster_sankey_transitions.html - Flow diagram of movements
  9. cluster_trajectories.html - Individual country paths
  
  DATA:
  10. cluster_assignments_all_years.csv
  11. cluster_movements_summary.csv
  12. cluster_centroids.csv
  13. cluster_transition_matrix.csv
  14. pca_loadings.csv
""")
print("="*70)

COMPREHENSIVE NATURAL RESOURCE ANALYSIS
Production Maps + Clustering + PCA

1. Loading and preparing data...
   Production data: 17166 rows
   Master data: 3150 rows
   Merged: 3069 country-years, 126 countries
   Years: 1995 - 2019

2. Creating production map...
   ✓ Production map saved

3. Training clustering model on reference year (2019)...
   Reference year countries: 125
   Silhouette score (2019): 0.344

4. Analyzing cluster profiles...

   Cluster Centroids (original scale):
   Metals_GDP_Norm  Oil_GDP_Norm  Natural Gas_GDP_Norm  Coal_GDP_Norm  Economic Complexity Index  Human capital index  Cluster
0            0.121         8.277                 0.493          0.001                     -0.729                2.497        0
1            0.526         0.404                 0.119          0.094                     -0.786                2.123        1
2            0.376         0.301                 0.122          0.195                      0.789                3.225        2
3  

In [25]:
"""
RESOURCE DIVERSITY & INTENSITY — TWO-FEATURE APPROACH
======================================================

Feature 1: Resource Diversity (Shannon entropy on category-level domestic shares)
Feature 2: Resource Intensity (Total production value / GDP)

Categories (5):
  - Fossil fuels: Oil, Natural Gas, Coal
  - Base metals: Copper, Lead, Zinc, Tin, Nickel, Aluminium
  - Precious metals: Gold, Silver
  - Battery/strategic metals: Lithium, Cobalt, Rare Earth, Vanadium
  - Industrial minerals: Bauxite, Manganese, Magnesium compounds, Cadmium, Natural Graphite

Threshold: Only count resources where country holds >0.5% global market share
"""

import pandas as pd
import numpy as np
import plotly.graph_objects as go
import os

# ============================================================================
# CONFIGURATION
# ============================================================================

input_file = "/Users/leoss/Desktop/GitHub/Capstone/MASTER/Master.csv"
production_file = "/Users/leoss/Desktop/GitHub/Capstone/MASTER/NaturalResource.csv"
output_dir = "/Users/leoss/Desktop/Portfolio/Website-/capstone_visualizations/individual_plots/cluster"

os.makedirs(output_dir, exist_ok=True)

REFERENCE_YEAR = 2019
MARKET_SHARE_THRESHOLD = 0.5  # % — ignore resources where country has <0.5% global share

CATEGORY_MAP = {
    'Oil': 'Fossil fuels',
    'Natural Gas': 'Fossil fuels',
    'Coal': 'Fossil fuels',
    'Copper': 'Base metals',
    'Lead': 'Base metals',
    'Zinc': 'Base metals',
    'Tin': 'Base metals',
    'Nickel': 'Base metals',
    'Aluminium': 'Base metals',
    'Gold': 'Precious metals',
    'Silver': 'Precious metals',
    'Lithium': 'Battery/strategic',
    'Cobalt': 'Battery/strategic',
    'Rare Earth': 'Battery/strategic',
    'Vanadium': 'Battery/strategic',
    'Bauxite': 'Industrial minerals',
    'Manganese': 'Industrial minerals',
    'Magnesium compounds': 'Industrial minerals',
    'Cadmium': 'Industrial minerals',
    'Natural Graphite': 'Industrial minerals',
}

CATEGORY_ORDER = ['Fossil fuels', 'Base metals', 'Precious metals',
                  'Battery/strategic', 'Industrial minerals']

print("=" * 70)
print("RESOURCE DIVERSITY & INTENSITY INDEX")
print("=" * 70)

# ============================================================================
# 1. LOAD DATA
# ============================================================================

print("\n1. Loading data...")

df_prod = pd.read_csv(production_file)
df_master = pd.read_csv(input_file)

df_year = df_prod[df_prod['Year'] == REFERENCE_YEAR].copy()

print(f"   Resources in data: {df_year['Resource'].nunique()}")
print(f"   Countries in data: {df_year['Country Name'].nunique()}")

# ============================================================================
# 2. COMPUTE GLOBAL SHARES & APPLY THRESHOLD
# ============================================================================

print(f"\n2. Computing market shares (threshold: >{MARKET_SHARE_THRESHOLD}%)...")

# Global production per resource
global_prod = df_year.groupby('Resource')['Production_TotalValue'].sum().reset_index()
global_prod.columns = ['Resource', 'Global_Production']

# Country-resource production
country_res = df_year.groupby(['Country Name', 'Resource'])['Production_TotalValue'].sum().reset_index()
country_res = country_res.merge(global_prod, on='Resource')

# Market share
country_res['Market_Share_Pct'] = (
    country_res['Production_TotalValue'] / country_res['Global_Production']
) * 100
country_res['Market_Share_Pct'] = country_res['Market_Share_Pct'].replace(
    [np.inf, -np.inf], 0
).fillna(0)

n_before = len(country_res)
country_res = country_res[country_res['Market_Share_Pct'] >= MARKET_SHARE_THRESHOLD].copy()
print(f"   Pairs before threshold: {n_before}")
print(f"   Pairs after threshold:  {len(country_res)}")

# Assign categories
country_res['Category'] = country_res['Resource'].map(CATEGORY_MAP)
unmapped = country_res[country_res['Category'].isna()]['Resource'].unique()
if len(unmapped) > 0:
    print(f"   WARNING: Unmapped resources: {unmapped}")
country_res = country_res.dropna(subset=['Category'])

# ============================================================================
# 3. FEATURE 1 — RESOURCE DIVERSITY (Shannon Entropy at category level)
# ============================================================================

print("\n3. Computing Resource Diversity (Shannon entropy)...")

# Aggregate production value by country × category
country_cat = country_res.groupby(['Country Name', 'Category'])['Production_TotalValue'].sum().reset_index()

# Compute each category's share of the country's total (filtered) production
country_totals = country_cat.groupby('Country Name')['Production_TotalValue'].sum().reset_index()
country_totals.columns = ['Country Name', 'Country_Total']
country_cat = country_cat.merge(country_totals, on='Country Name')

country_cat['Domestic_Share'] = (
    country_cat['Production_TotalValue'] / country_cat['Country_Total']
)

# Shannon entropy: H = -Σ p_i * ln(p_i)
# Max possible = ln(5) ≈ 1.609 if equally spread across 5 categories
country_cat['Entropy_Component'] = -country_cat['Domestic_Share'] * np.log(
    country_cat['Domestic_Share']
)

diversity = country_cat.groupby('Country Name').agg(
    Shannon_Entropy=('Entropy_Component', 'sum'),
    N_Categories=('Category', 'nunique')
).reset_index()

# Normalized entropy (0-1): divide by ln(5)
MAX_ENTROPY = np.log(len(CATEGORY_ORDER))
diversity['Diversity_Normalized'] = diversity['Shannon_Entropy'] / MAX_ENTROPY

print(f"   Max possible entropy: {MAX_ENTROPY:.3f} (= ln({len(CATEGORY_ORDER)}))")
print(f"   Countries with scores: {len(diversity)}")

# Also keep category-level shares for the decomposition chart
cat_pivot = country_cat.pivot_table(
    index='Country Name', columns='Category',
    values='Domestic_Share', fill_value=0
).reset_index()

diversity = diversity.merge(cat_pivot, on='Country Name', how='left')

# ============================================================================
# 4. FEATURE 2 — RESOURCE INTENSITY (Production Value / GDP)
# ============================================================================

print("\n4. Computing Resource Intensity (production / GDP)...")

# Back out GDP same way as clustering pipeline: GDP per capita (PPP) × Population
gdp_data = df_master[df_master['Year'] == REFERENCE_YEAR][
    ['Country Name', 'GDP per capita (constant prices, PPP)', 'Population']
].copy()
gdp_data['GDP'] = (
    pd.to_numeric(gdp_data['GDP per capita (constant prices, PPP)'], errors='coerce') *
    pd.to_numeric(gdp_data['Population'], errors='coerce')
)
gdp_data = gdp_data.dropna(subset=['GDP'])
gdp_data = gdp_data[gdp_data['GDP'] > 0]
gdp_data = gdp_data[['Country Name', 'GDP']]

print(f"   GDP computed as: GDP per capita (PPP) × Population")
print(f"   Countries with GDP data: {len(gdp_data)}")

# Total resource production value per country (after threshold)
country_total_prod = country_res.groupby('Country Name')['Production_TotalValue'].sum().reset_index()
country_total_prod.columns = ['Country Name', 'Total_Resource_Value']

intensity = country_total_prod.merge(gdp_data, on='Country Name', how='inner')
intensity['Resource_Intensity_Pct'] = (
    intensity['Total_Resource_Value'] / intensity['GDP']
) * 100

print(f"   Countries with intensity: {len(intensity)}")

# ============================================================================
# 5. MERGE FEATURES
# ============================================================================

print("\n5. Merging features...")

result = diversity.merge(
    intensity[['Country Name', 'Total_Resource_Value', 'GDP', 'Resource_Intensity_Pct']],
    on='Country Name', how='outer'
)

# Add country codes
master_codes = df_master[['Country Name', 'Country Code']].drop_duplicates()
result = result.merge(master_codes, on='Country Name', how='left')

# Clean
result = result[result['Country Name'] != '0']

# Fill NaN for countries that passed threshold in one measure but not other
result['Shannon_Entropy'] = result['Shannon_Entropy'].fillna(0)
result['Diversity_Normalized'] = result['Diversity_Normalized'].fillna(0)
result['Resource_Intensity_Pct'] = result['Resource_Intensity_Pct'].fillna(0)
result['N_Categories'] = result['N_Categories'].fillna(0).astype(int)

print(f"   Final dataset: {len(result)} countries")
print(f"\n   Diversity range:  {result['Diversity_Normalized'].min():.3f} - {result['Diversity_Normalized'].max():.3f}")
print(f"   Intensity range:  {result['Resource_Intensity_Pct'].min():.2f}% - {result['Resource_Intensity_Pct'].max():.2f}%")

# Top 10 by diversity
print(f"\n   Top 10 by Diversity (normalized entropy):")
top_div = result.nlargest(10, 'Diversity_Normalized')
for i, (_, row) in enumerate(top_div.iterrows(), 1):
    print(f"      {i:2d}. {row['Country Name']}: {row['Diversity_Normalized']:.3f} ({row['N_Categories']} categories)")

# Top 10 by intensity
print(f"\n   Top 10 by Resource Intensity (% of GDP):")
top_int = result.nlargest(10, 'Resource_Intensity_Pct')
for i, (_, row) in enumerate(top_int.iterrows(), 1):
    print(f"      {i:2d}. {row['Country Name']}: {row['Resource_Intensity_Pct']:.1f}%")

# ============================================================================
# 6. CHOROPLETH — DIVERSITY
# ============================================================================

print("\n6. Creating diversity map...")

fig_div = go.Figure()

fig_div.add_trace(go.Choropleth(
    locations=result['Country Code'],
    z=result['Diversity_Normalized'],
    text=result['Country Name'],
    customdata=np.stack([
        result['Shannon_Entropy'],
        result['N_Categories'],
        result['Diversity_Normalized']
    ], axis=-1),
    colorscale='YlGnBu',
    marker=dict(line=dict(color='white', width=0.5)),
    zmin=0, zmax=1,
    colorbar=dict(
        title=dict(text='Normalized<br>Entropy', font=dict(size=12)),
        len=0.7, thickness=15
    ),
    hovertemplate=(
        '<b>%{text}</b><br>'
        'Diversity: %{customdata[2]:.3f}<br>'
        'Categories: %{customdata[1]:.0f}/5<br>'
        'Raw entropy: %{customdata[0]:.3f}'
        '<extra></extra>'
    )
))

fig_div.update_layout(
    title=dict(
        text=(f'Resource Diversity Index ({REFERENCE_YEAR})<br>'
              f'<sup>Shannon entropy across 5 resource categories '
              f'(threshold: >{MARKET_SHARE_THRESHOLD}% global market share)</sup>'),
        x=0.5, font=dict(size=18, color='#1f2937')
    ),
    geo=dict(
        showframe=False, showcoastlines=True,
        coastlinecolor='#d1d5db',
        projection_type='natural earth',
        bgcolor='rgba(0,0,0,0)',
        landcolor='#f3f4f6',
        countrycolor='#e5e7eb', countrywidth=0.3
    ),
    height=600,
    margin=dict(l=10, r=10, t=80, b=10),
    paper_bgcolor='rgba(0,0,0,0)'
)

fig_div.write_html(
    os.path.join(output_dir, 'diversity_map.html'),
    config={'displayModeBar': True, 'displaylogo': False},
    include_plotlyjs='cdn'
)
print("   ✓ diversity_map.html")

# ============================================================================
# 7. CHOROPLETH — INTENSITY
# ============================================================================

print("\n7. Creating intensity map...")

# Log scale for intensity (huge range expected)
result['Intensity_Log'] = np.log1p(result['Resource_Intensity_Pct'])

fig_int = go.Figure()

fig_int.add_trace(go.Choropleth(
    locations=result['Country Code'],
    z=result['Intensity_Log'],
    text=result['Country Name'],
    customdata=result['Resource_Intensity_Pct'],
    colorscale='YlOrRd',
    marker=dict(line=dict(color='white', width=0.5)),
    colorbar=dict(
        title=dict(text='Intensity<br>(log scale)', font=dict(size=12)),
        len=0.7, thickness=15
    ),
    hovertemplate=(
        '<b>%{text}</b><br>'
        'Resource production: %{customdata:.1f}% of GDP'
        '<extra></extra>'
    )
))

fig_int.update_layout(
    title=dict(
        text=(f'Resource Intensity ({REFERENCE_YEAR})<br>'
              f'<sup>Total resource production value as % of GDP '
              f'(log scale for visibility)</sup>'),
        x=0.5, font=dict(size=18, color='#1f2937')
    ),
    geo=dict(
        showframe=False, showcoastlines=True,
        coastlinecolor='#d1d5db',
        projection_type='natural earth',
        bgcolor='rgba(0,0,0,0)',
        landcolor='#f3f4f6',
        countrycolor='#e5e7eb', countrywidth=0.3
    ),
    height=600,
    margin=dict(l=10, r=10, t=80, b=10),
    paper_bgcolor='rgba(0,0,0,0)'
)

fig_int.write_html(
    os.path.join(output_dir, 'intensity_map.html'),
    config={'displayModeBar': True, 'displaylogo': False},
    include_plotlyjs='cdn'
)
print("   ✓ intensity_map.html")

# ============================================================================
# 8. DECOMPOSED BAR CHART — TOP 20 BY DIVERSITY
# ============================================================================

print("\n8. Creating decomposition chart...")

top20 = result.nlargest(20, 'Diversity_Normalized')

fig_bar = go.Figure()

colors = {
    'Fossil fuels': '#3b82f6',
    'Base metals': '#6b7280',
    'Precious metals': '#f59e0b',
    'Battery/strategic': '#10b981',
    'Industrial minerals': '#ef4444'
}

for cat in CATEGORY_ORDER:
    if cat in top20.columns:
        fig_bar.add_trace(go.Bar(
            name=cat,
            y=top20['Country Name'],
            x=top20[cat],
            orientation='h',
            marker_color=colors[cat],
            hovertemplate=(
                f'<b>%{{y}}</b><br>{cat}: %{{x:.1%}}'
                f'<extra></extra>'
            )
        ))

fig_bar.update_layout(
    barmode='stack',
    title=dict(
        text=(f'Resource Portfolio Composition ({REFERENCE_YEAR})<br>'
              f'<sup>Top 20 most diversified countries — '
              f'share of domestic production by category</sup>'),
        x=0.5, font=dict(size=16)
    ),
    xaxis=dict(title='Share of domestic resource production', tickformat='.0%'),
    yaxis=dict(categoryorder='total ascending'),
    legend=dict(
        orientation='h', yanchor='bottom', y=1.02,
        xanchor='center', x=0.5
    ),
    height=700,
    margin=dict(l=150),
    template='plotly_white'
)

fig_bar.write_html(
    os.path.join(output_dir, 'diversity_decomposed.html'),
    config={'displayModeBar': True, 'displaylogo': False},
    include_plotlyjs='cdn'
)
print("   ✓ diversity_decomposed.html")

# ============================================================================
# 9. SCATTER — DIVERSITY vs INTENSITY
# ============================================================================

print("\n9. Creating diversity vs intensity scatter...")

fig_scatter = go.Figure()

fig_scatter.add_trace(go.Scatter(
    x=result['Diversity_Normalized'],
    y=result['Resource_Intensity_Pct'],
    mode='markers+text',
    text=result['Country Code'],
    textposition='top center',
    textfont=dict(size=8),
    marker=dict(
        size=8, color=result['N_Categories'],
        colorscale='Viridis', showscale=True,
        colorbar=dict(title='N categories', len=0.7, thickness=15),
        line=dict(width=0.5, color='white')
    ),
    customdata=np.stack([
        result['Country Name'],
        result['N_Categories'],
        result['Resource_Intensity_Pct']
    ], axis=-1),
    hovertemplate=(
        '<b>%{customdata[0]}</b><br>'
        'Diversity: %{x:.3f}<br>'
        'Intensity: %{customdata[2]}% of GDP<br>'
        'Categories: %{customdata[1]}/5'
        '<extra></extra>'
    )
))

# Log y-axis for readability
fig_scatter.update_yaxes(type='log', title='Resource Intensity (% of GDP, log scale)')
fig_scatter.update_xaxes(title='Resource Diversity (normalized entropy, 0–1)')

fig_scatter.update_layout(
    title=dict(
        text=(f'Resource Diversity vs Intensity ({REFERENCE_YEAR})<br>'
              f'<sup>Each dot = one country. Color = number of active categories.</sup>'),
        x=0.5, font=dict(size=16)
    ),
    height=650,
    template='plotly_white',
    margin=dict(l=80, r=30, t=80, b=60)
)

fig_scatter.write_html(
    os.path.join(output_dir, 'diversity_vs_intensity.html'),
    config={'displayModeBar': True, 'displaylogo': False},
    include_plotlyjs='cdn'
)
print("   ✓ diversity_vs_intensity.html")

# ============================================================================
# 10. SAVE
# ============================================================================

print("\n10. Saving data...")

output_cols = ['Country Name', 'Country Code', 'N_Categories',
               'Shannon_Entropy', 'Diversity_Normalized',
               'Total_Resource_Value', 'GDP', 'Resource_Intensity_Pct']
# Add category shares if present
for cat in CATEGORY_ORDER:
    if cat in result.columns:
        output_cols.append(cat)

result[output_cols].to_csv(
    os.path.join(output_dir, 'diversity_intensity_scores.csv'), index=False
)
print("   ✓ diversity_intensity_scores.csv")

# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "=" * 70)
print("COMPLETE")
print("=" * 70)

print(f"""
METHODOLOGY:
  Feature 1 — Resource Diversity
    • Group 20 resources into 5 categories
    • Filter: only count resources where country holds >{MARKET_SHARE_THRESHOLD}% global share
    • Compute each category's share of the country's total production
    • Shannon entropy: H = -Σ p_i · ln(p_i)
    • Normalize to [0, 1] by dividing by ln(5)
    
  Feature 2 — Resource Intensity
    • Sum production value of all resources passing the threshold
    • Divide by GDP
    
  Threshold filters out negligible production so that a country
  producing trace amounts of many minerals doesn't appear diversified.

OUTPUTS:
  1. diversity_map.html          — Choropleth of normalized entropy
  2. intensity_map.html          — Choropleth of production/GDP
  3. diversity_decomposed.html   — Stacked bar: portfolio composition
  4. diversity_vs_intensity.html — Scatter: diversity × intensity
  5. diversity_intensity_scores.csv
""")
print("=" * 70)

RESOURCE DIVERSITY & INTENSITY INDEX

1. Loading data...
   Resources in data: 20
   Countries in data: 126

2. Computing market shares (threshold: >0.5%)...
   Pairs before threshold: 641
   Pairs after threshold:  311

3. Computing Resource Diversity (Shannon entropy)...
   Max possible entropy: 1.609 (= ln(5))
   Countries with scores: 84

4. Computing Resource Intensity (production / GDP)...
   GDP computed as: GDP per capita (PPP) × Population
   Countries with GDP data: 126
   Countries with intensity: 83

5. Merging features...
   Final dataset: 83 countries

   Diversity range:  0.000 - 0.772
   Intensity range:  0.00% - 15.39%

   Top 10 by Diversity (normalized entropy):
       1. South Africa: 0.772 (5 categories)
       2. Australia: 0.766 (5 categories)
       3. Morocco: 0.673 (3 categories)
       4. Turkiye: 0.655 (4 categories)
       5. Congo, Dem. Rep.: 0.566 (3 categories)
       6. Uzbekistan: 0.562 (4 categories)
       7. China: 0.554 (5 categories)
       8. Mex