In [None]:
"""
IMPROVED CLUSTERING ANALYSIS
- Train Once, Predict Always (temporal consistency)
- Automated Cluster Profiling (radar charts)
- Country Movement Tracking Over Time
"""

import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import os
import plotly.express as px


# ============================================================================
# CONFIGURATION - UPDATE THESE PATHS
# ============================================================================

input_file = "/Users/leoss/Desktop/GitHub/Capstone/MASTER/Master.csv"
production_file = "/Users/leoss/Desktop/GitHub/Capstone/MASTER/NaturalResource.csv"
output_dir = "/Users/leoss/Desktop/Portfolio/Website-/capstone_visualizations/individual_plots"

os.makedirs(output_dir, exist_ok=True)

print("="*70)
print("IMPROVED CLUSTERING ANALYSIS")
print("Train Once, Predict Always + Temporal Tracking")
print("="*70)

# ============================================================================
# 1. DATA PREPARATION (same as your original)
# ============================================================================

print("\n1. Loading and preparing data...")

df_prod = pd.read_csv(production_file)
df_master = pd.read_csv(input_file)

# Categorize resources
def categorize_resource(resource):
    if resource == 'Oil': return 'Oil'
    elif resource == 'Natural Gas': return 'Natural Gas'
    elif resource == 'Coal': return 'Coal'
    else: return 'Metals'

df_prod['Resource_Category'] = df_prod['Resource'].apply(categorize_resource)

# Aggregate by category
prod_agg = df_prod.groupby(['Country Name', 'Year', 'Resource_Category'])['Production_TotalValue'].sum().reset_index()

prod_wide = prod_agg.pivot_table(
    index=['Country Name', 'Year'], 
    columns='Resource_Category', 
    values='Production_TotalValue', 
    fill_value=0
).reset_index()

resource_cols = [c for c in prod_wide.columns if c not in ['Country Name', 'Year']]
prod_wide['Total'] = prod_wide[resource_cols].sum(axis=1)

# Merge with master
master_data = df_master[[
    'Country Name', 'Year', 'Country Code', 'Population',
    'GDP per capita (constant prices, PPP)', 
    'Economic Complexity Index', 'Human capital index', 'Manufacturing'
]].copy()

map_data = prod_wide.merge(master_data, on=['Country Name', 'Year'], how='inner')

# Calculate derived metrics
map_data['GDP_total'] = map_data['GDP per capita (constant prices, PPP)'] * map_data['Population']

for res in ['Total', 'Oil', 'Natural Gas', 'Coal', 'Metals']:
    if res in map_data.columns:
        map_data[f'{res}_GDP_Norm'] = (map_data[res] / map_data['GDP_total']) * 100

print(f"   Data loaded: {len(map_data)} country-years")
print(f"   Countries: {map_data['Country Code'].nunique()}")
print(f"   Years: {map_data['Year'].min()} - {map_data['Year'].max()}")

# ============================================================================
# 2. TRAIN ONCE ON REFERENCE YEAR (2019)
# ============================================================================

print("\n2. Training clustering model on reference year (2019)...")

REFERENCE_YEAR = 2019
N_CLUSTERS = 6

feature_cols = [
    'Metals_GDP_Norm', 
    'Oil_GDP_Norm', 
    'Natural Gas_GDP_Norm', 
    'Coal_GDP_Norm',
    'Economic Complexity Index', 
    'Human capital index'
]

# Get reference year data
ref_data = map_data[map_data['Year'] == REFERENCE_YEAR].copy()
ref_clean = ref_data[['Country Code', 'Country Name'] + feature_cols].dropna()

print(f"   Reference year countries: {len(ref_clean)}")

# Fit scaler on reference year ONLY
scaler_ref = StandardScaler()
X_ref_scaled = scaler_ref.fit_transform(ref_clean[feature_cols])

# Fit KMeans on reference year ONLY
kmeans_ref = KMeans(
    n_clusters=N_CLUSTERS, 
    random_state=42, 
    n_init=50,  # More initializations for stability
    max_iter=500
)
kmeans_ref.fit(X_ref_scaled)

# Assign clusters to reference year
ref_clean['Cluster'] = kmeans_ref.labels_

# Calculate silhouette score for reference year
sil_score = silhouette_score(X_ref_scaled, kmeans_ref.labels_)
print(f"   Silhouette score (2019): {sil_score:.3f}")

# ============================================================================
# 3. ANALYZE CLUSTER CENTROIDS TO ASSIGN MEANINGFUL NAMES
# ============================================================================

print("\n3. Analyzing cluster profiles...")

# Get centroids in original scale
centroids_scaled = kmeans_ref.cluster_centers_
centroids_original = scaler_ref.inverse_transform(centroids_scaled)
centroids_df = pd.DataFrame(centroids_original, columns=feature_cols)
centroids_df['Cluster'] = range(N_CLUSTERS)

print("\n   Cluster Centroids (original scale):")
print(centroids_df.round(3).to_string())

# Automatic naming based on centroid characteristics
def auto_name_cluster(row):
    """Assign cluster name based on centroid values"""
    eci = row['Economic Complexity Index']
    hci = row['Human capital index']
    oil = row['Oil_GDP_Norm']
    gas = row['Natural Gas_GDP_Norm']
    metals = row['Metals_GDP_Norm']
    coal = row['Coal_GDP_Norm']
    
    total_hydrocarbons = oil + gas
    total_resources = oil + gas + metals + coal
    
    # Decision tree for naming
    if eci > 0.5 and hci > 2.5:
        return 'Advanced Economies'
    elif total_hydrocarbons > 15:  # Heavy oil/gas dependence
        return 'Petrostates'
    elif metals > 5:  # Heavy mining dependence
        return 'Mining-Dependent'
    elif eci > -0.5 and total_resources < 10:
        return 'Emerging Markets'
    else:
        return 'Low-Income Resource Producers'

cluster_names = {}
for idx, row in centroids_df.iterrows():
    cluster_names[int(row['Cluster'])] = auto_name_cluster(row)

# Check for duplicate names and adjust
name_counts = {}
for c, name in cluster_names.items():
    if name in name_counts:
        name_counts[name] += 1
        cluster_names[c] = f"{name} {name_counts[name]}"
    else:
        name_counts[name] = 1
cluster_names = {
    0: 'Advanced Economies',                  
    1: 'Low-Income Diversified',       
    2: 'Petrostates',            
    3: 'Coal & Metals Outlier',     
    4: 'Mining-Dependent',                
    5: 'Wealthy Hydrocarbon Exporters'    
}
print("\n   Cluster Names (auto-assigned):")
for c, name in sorted(cluster_names.items()):
    count = (ref_clean['Cluster'] == c).sum()
    print(f"   Cluster {c}: {name} ({count} countries)")

# Verify with sample countries
print("\n   Sample countries per cluster:")
for c in range(N_CLUSTERS):
    countries = ref_clean[ref_clean['Cluster'] == c]['Country Name'].head(N_CLUSTERS).tolist()
    print(f"   Cluster {c} ({cluster_names[c]}): {', '.join(countries)}")

# ============================================================================
# 4. PREDICT CLUSTERS FOR ALL YEARS
# ============================================================================

print("\n4. Applying model to all years...")

all_years_results = []

for year in sorted(map_data['Year'].unique()):
    year_data = map_data[map_data['Year'] == year].copy()
    year_clean = year_data[['Country Code', 'Country Name', 'Year'] + feature_cols].dropna()
    
    if len(year_clean) == 0:
        continue
    
    # Transform using REFERENCE scaler (not fit!)
    X_year_scaled = scaler_ref.transform(year_clean[feature_cols])
    
    # Predict using REFERENCE model
    year_clean['Cluster'] = kmeans_ref.predict(X_year_scaled)
    year_clean['Cluster_Name'] = year_clean['Cluster'].map(cluster_names)
    
    all_years_results.append(year_clean)
    
    # Print cluster distribution for select years
    if year in [1995, 2000, 2005, 2010, 2015, 2019]:
        dist = year_clean['Cluster'].value_counts().sort_index()
        print(f"   {year}: {dict(dist)}")

# Combine all years
temporal_data = pd.concat(all_years_results, ignore_index=True)
print(f"\n   Total observations: {len(temporal_data)}")

# ============================================================================
# 5. RADAR CHART - CLUSTER PROFILES
# ============================================================================

print("\n5. Creating cluster profile radar chart...")

# Normalize features to 0-1 scale for radar visualization
profile_data = ref_clean.groupby('Cluster')[feature_cols].mean()

# Min-max normalize each feature
profile_normalized = profile_data.copy()
for col in feature_cols:
    min_val = profile_data[col].min()
    max_val = profile_data[col].max()
    if max_val > min_val:
        profile_normalized[col] = (profile_data[col] - min_val) / (max_val - min_val)
    else:
        profile_normalized[col] = 0.5

# Create radar chart
fig_radar = go.Figure()

colors = ['#3b82f6', '#22c55e', '#ef4444', '#a855f7', '#f59e0b', '#06b6d4']  # Added cyan
feature_labels = [
    'Metals<br>(% GDP)', 
    'Oil<br>(% GDP)', 
    'Natural Gas<br>(% GDP)', 
    'Coal<br>(% GDP)',
    'Economic<br>Complexity', 
    'Human<br>Capital'
]

for cluster_id in range(N_CLUSTERS):
    values = profile_normalized.loc[cluster_id].tolist()
    values.append(values[0])  # Close the polygon
    
    fig_radar.add_trace(go.Scatterpolar(
        r=values,
        theta=feature_labels + [feature_labels[0]],
        fill='toself',
        fillcolor=f'rgba{tuple(list(int(colors[cluster_id].lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) + [0.2])}',
        line=dict(color=colors[cluster_id], width=2),
        name=cluster_names[cluster_id]
    ))

fig_radar.update_layout(
    polar=dict(
        radialaxis=dict(
            visible=True,
            range=[0, 1],
            tickvals=[0, 0.25, 0.5, 0.75, 1],
            ticktext=['Low', '', 'Med', '', 'High']
        )
    ),
    title=dict(
        text='Cluster Profiles: Resource Dependency vs Development',
        x=0.5,
        font=dict(size=16)
    ),
    legend=dict(
        orientation='h',
        yanchor='bottom',
        y=-0.2,
        xanchor='center',
        x=0.5
    ),
    height=600,
    margin=dict(t=80, b=100)
)

fig_radar.write_html(os.path.join(output_dir, 'cluster_radar_profiles.html'))
print("   ✓ Radar chart saved")

# ============================================================================
# 6. TEMPORAL MOVEMENT ANALYSIS
# ============================================================================

print("\n6. Analyzing country movements over time...")

# Get first and last year for each country
first_last = temporal_data.groupby('Country Code').agg({
    'Year': ['min', 'max'],
    'Country Name': 'first'
}).reset_index()
first_last.columns = ['Country Code', 'First_Year', 'Last_Year', 'Country Name']

# Get cluster at first and last observation
movements = []
for _, row in first_last.iterrows():
    country = row['Country Code']
    
    first_obs = temporal_data[(temporal_data['Country Code'] == country) & 
                              (temporal_data['Year'] == row['First_Year'])]
    last_obs = temporal_data[(temporal_data['Country Code'] == country) & 
                             (temporal_data['Year'] == row['Last_Year'])]
    
    if len(first_obs) > 0 and len(last_obs) > 0:
        movements.append({
            'Country Code': country,
            'Country Name': row['Country Name'],
            'First_Year': row['First_Year'],
            'Last_Year': row['Last_Year'],
            'Initial_Cluster': first_obs.iloc[0]['Cluster'],
            'Final_Cluster': last_obs.iloc[0]['Cluster'],
            'Initial_Cluster_Name': cluster_names[first_obs.iloc[0]['Cluster']],
            'Final_Cluster_Name': cluster_names[last_obs.iloc[0]['Cluster']]
        })

movements_df = pd.DataFrame(movements)
movements_df['Changed'] = movements_df['Initial_Cluster'] != movements_df['Final_Cluster']

print(f"\n   Countries that changed clusters: {movements_df['Changed'].sum()} / {len(movements_df)}")

# Show interesting transitions
print("\n   Notable Transitions:")
changers = movements_df[movements_df['Changed']].copy()
for _, row in changers.head(15).iterrows():
    print(f"   {row['Country Name']}: {row['Initial_Cluster_Name']} → {row['Final_Cluster_Name']}")

# ============================================================================
# 7. SANKEY DIAGRAM - CLUSTER FLOWS
# ============================================================================

print("\n7. Creating Sankey diagram of cluster movements...")

# Create transition matrix
transition_matrix = pd.crosstab(
    movements_df['Initial_Cluster_Name'], 
    movements_df['Final_Cluster_Name']
)

print("\n   Transition Matrix:")
print(transition_matrix)

# Prepare Sankey data
all_clusters = list(cluster_names.values())
source_labels = [f"{c} (Initial)" for c in all_clusters]
target_labels = [f"{c} (Final)" for c in all_clusters]
all_labels = source_labels + target_labels

source_indices = []
target_indices = []
values = []
link_colors = []

color_map = {name: colors[i] for i, name in enumerate(cluster_names.values())}

for i, src in enumerate(all_clusters):
    for j, tgt in enumerate(all_clusters):
        if src in transition_matrix.index and tgt in transition_matrix.columns:
            val = transition_matrix.loc[src, tgt]
            if val > 0:
                source_indices.append(i)
                target_indices.append(len(all_clusters) + j)
                values.append(val)
                # Color by source cluster, lighter for transitions
                base_color = color_map[src]
                alpha = 0.8 if src == tgt else 0.4
                rgb = tuple(int(base_color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
                link_colors.append(f'rgba({rgb[0]},{rgb[1]},{rgb[2]},{alpha})')

fig_sankey = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color='black', width=0.5),
        label=all_labels,
        color=[color_map.get(l.replace(' (Initial)', '').replace(' (Final)', ''), '#999') 
               for l in all_labels]
    ),
    link=dict(
        source=source_indices,
        target=target_indices,
        value=values,
        color=link_colors
    )
)])

fig_sankey.update_layout(
    title=dict(
        text=f'Country Cluster Transitions ({movements_df["First_Year"].min()}-{movements_df["Last_Year"].max()})',
        x=0.5,
        font=dict(size=16)
    ),
    height=600,
    font=dict(size=12)
)

fig_sankey.write_html(os.path.join(output_dir, 'cluster_sankey_transitions.html'))
print("   ✓ Sankey diagram saved")

# ============================================================================
# 8. ANIMATED CLUSTER MAP OVER TIME
# ============================================================================

print("\n8. Creating animated cluster map...")

# Prepare data for animation
temporal_data['Cluster_Name'] = temporal_data['Cluster'].map(cluster_names)

# Create frames for each year
years = sorted(temporal_data['Year'].unique())
frames = []

for year in years:
    year_data = temporal_data[temporal_data['Year'] == year]
    frame_data = []
    
    for cluster_id in range(N_CLUSTERS):
        df_cluster = year_data[year_data['Cluster'] == cluster_id]
        frame_data.append(go.Choropleth(
            locations=df_cluster['Country Code'],
            z=[cluster_id] * len(df_cluster),
            text=df_cluster['Country Name'],
            colorscale=[[0, colors[cluster_id]], [1, colors[cluster_id]]],
            zmin=0, zmax=N_CLUSTERS-1,
            showscale=False,
            marker=dict(line=dict(color='white', width=0.5)),
            hovertemplate='<b>%{text}</b><br>' + cluster_names[cluster_id] + '<extra></extra>'
        ))
    
    frames.append(go.Frame(data=frame_data, name=str(year)))

# Initial figure (first year)
initial_year = years[0]
initial_data = temporal_data[temporal_data['Year'] == initial_year]

fig_animated = go.Figure()

for cluster_id in range(N_CLUSTERS):
    df_cluster = initial_data[initial_data['Cluster'] == cluster_id]
    fig_animated.add_trace(go.Choropleth(
        locations=df_cluster['Country Code'],
        z=[cluster_id] * len(df_cluster),
        text=df_cluster['Country Name'],
        name=cluster_names[cluster_id],
        colorscale=[[0, colors[cluster_id]], [1, colors[cluster_id]]],
        zmin=0, zmax=N_CLUSTERS-1,
        showscale=False,
        marker=dict(line=dict(color='white', width=0.5)),
        hovertemplate='<b>%{text}</b><br>' + cluster_names[cluster_id] + '<extra></extra>'
    ))

fig_animated.frames = frames

# Add animation controls
fig_animated.update_layout(
    title=dict(
        text='Cluster Evolution Over Time',
        x=0.5,
        font=dict(size=16)
    ),
    geo=dict(
        showframe=False,
        showcoastlines=True,
        coastlinecolor='#d1d5db',
        projection_type='natural earth',
        bgcolor='rgba(0,0,0,0)',
        landcolor='#f3f4f6',
        countrycolor='#d1d5db'
    ),
    legend=dict(
        title='Cluster',
        x=0.99, y=0.99,
        xanchor='right', yanchor='top',
        bgcolor='rgba(255,255,255,0.9)'
    ),
    updatemenus=[{
        'type': 'buttons',
        'showactive': False,
        'y': 0,
        'x': 0.1,
        'xanchor': 'right',
        'yanchor': 'top',
        'buttons': [
            {
                'label': '▶ Play',
                'method': 'animate',
                'args': [None, {
                    'frame': {'duration': 500, 'redraw': True},
                    'fromcurrent': True,
                    'transition': {'duration': 200}
                }]
            },
            {
                'label': '⏸ Pause',
                'method': 'animate',
                'args': [[None], {
                    'frame': {'duration': 0, 'redraw': False},
                    'mode': 'immediate',
                    'transition': {'duration': 0}
                }]
            }
        ]
    }],
    sliders=[{
        'active': 0,
        'yanchor': 'top',
        'xanchor': 'left',
        'currentvalue': {
            'prefix': 'Year: ',
            'visible': True,
            'xanchor': 'center',
            'font': {'size': 16}
        },
        'pad': {'b': 10, 't': 50},
        'len': 0.9,
        'x': 0.05,
        'y': 0,
        'steps': [
            {
                'args': [[str(year)], {
                    'frame': {'duration': 0, 'redraw': True},
                    'mode': 'immediate',
                    'transition': {'duration': 0}
                }],
                'label': str(year),
                'method': 'animate'
            }
            for year in years
        ]
    }],
    height=650,
    margin=dict(l=10, r=10, t=60, b=80)
)

fig_animated.write_html(os.path.join(output_dir, 'cluster_map_animated.html'))
print("   ✓ Animated map saved")

# ============================================================================
# 9. COUNTRY TRAJECTORY PLOT (SELECT COUNTRIES)
# ============================================================================

print("\n9. Creating country trajectory visualization...")

# Select interesting countries to track
interesting_countries = [
    'CHN',  # China - emerging market
    'SAU',  # Saudi Arabia - petrostate
    'NOR',  # Norway - advanced but oil
    'KOR',  # South Korea - advanced
    'BRA',  # Brazil - emerging
    'NGA',  # Nigeria - oil dependent
    'CHL',  # Chile - mining
    'IND',  # India - emerging
    'DEU',  # Germany - advanced
    'ZAF'   # South Africa - mining
]

# Filter to countries we have data for
available_countries = temporal_data['Country Code'].unique()
plot_countries = [c for c in interesting_countries if c in available_countries]

# Get trajectories
trajectory_data = temporal_data[temporal_data['Country Code'].isin(plot_countries)].copy()

fig_trajectory = go.Figure()

country_colors = px.colors.qualitative.D3[:len(plot_countries)]

for i, country in enumerate(plot_countries):
    country_data = trajectory_data[trajectory_data['Country Code'] == country].sort_values('Year')
    country_name = country_data['Country Name'].iloc[0]
    
    fig_trajectory.add_trace(go.Scatter(
        x=country_data['Year'],
        y=country_data['Cluster'],
        mode='lines+markers',
        name=country_name,
        line=dict(color=country_colors[i], width=2),
        marker=dict(size=8),
        hovertemplate=f'<b>{country_name}</b><br>Year: %{{x}}<br>Cluster: %{{customdata}}<extra></extra>',
        customdata=country_data['Cluster_Name']
    ))

# Add cluster name annotations on y-axis
fig_trajectory.update_layout(
    title=dict(
        text='Country Cluster Trajectories Over Time',
        x=0.5,
        font=dict(size=16)
    ),
    xaxis_title='Year',
    yaxis=dict(
        title='Cluster',
        tickmode='array',
        tickvals=list(range(N_CLUSTERS)),
        ticktext=[f"{i}: {cluster_names[i]}" for i in range(N_CLUSTERS)]
    ),
    legend=dict(
        orientation='h',
        yanchor='bottom',
        y=-0.25,
        xanchor='center',
        x=0.5
    ),
    height=600,
    margin=dict(b=120),
    template='plotly_white',
    hovermode='x unified'
)

fig_trajectory.write_html(os.path.join(output_dir, 'cluster_trajectories.html'))
print("   ✓ Trajectory plot saved")

# ============================================================================
# 10. SAVE DATA FILES
# ============================================================================

print("\n10. Saving data files...")

# Save cluster assignments for all years
temporal_data.to_csv(os.path.join(output_dir, 'cluster_assignments_all_years.csv'), index=False)
print("   ✓ All years cluster assignments saved")

# Save movement summary
movements_df.to_csv(os.path.join(output_dir, 'cluster_movements_summary.csv'), index=False)
print("   ✓ Movement summary saved")

# Save centroids
centroids_df['Cluster_Name'] = centroids_df['Cluster'].map(cluster_names)
centroids_df.to_csv(os.path.join(output_dir, 'cluster_centroids.csv'), index=False)
print("   ✓ Centroids saved")

# Save transition matrix
transition_matrix.to_csv(os.path.join(output_dir, 'cluster_transition_matrix.csv'))
print("   ✓ Transition matrix saved")

# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "="*70)
print("✅ IMPROVED CLUSTERING ANALYSIS COMPLETE!")
print("="*70)

print(f"\nMethodology:")
print(f"  • Reference year: {REFERENCE_YEAR}")
print(f"  • Clusters: {N_CLUSTERS}")
print(f"  • Silhouette score: {sil_score:.3f}")
print(f"  • Approach: Train Once, Predict Always")

print(f"\nCluster Definitions (frozen from {REFERENCE_YEAR}):")
for c, name in sorted(cluster_names.items()):
    count = (ref_clean['Cluster'] == c).sum()
    print(f"  Cluster {c}: {name} ({count} countries)")

print(f"\nTemporal Coverage:")
print(f"  • Years: {temporal_data['Year'].min()} - {temporal_data['Year'].max()}")
print(f"  • Countries tracked: {temporal_data['Country Code'].nunique()}")
print(f"  • Countries that changed: {movements_df['Changed'].sum()}")

print(f"\nOutputs saved to: {output_dir}")
print(f"  1. cluster_radar_profiles.html - Cluster characteristic profiles")
print(f"  2. cluster_sankey_transitions.html - Flow diagram of movements")
print(f"  3. cluster_map_animated.html - Animated map over time")
print(f"  4. cluster_trajectories.html - Individual country paths")
print(f"  5. cluster_assignments_all_years.csv - Full temporal data")
print(f"  6. cluster_movements_summary.csv - Which countries changed")
print(f"  7. cluster_centroids.csv - Cluster definitions")
print(f"  8. cluster_transition_matrix.csv - Movement counts")

print("="*70)