In [None]:
!pip -q install shap plotly==5.24.1

import os, io, warnings, numpy as np, pandas as pd, matplotlib.pyplot as plt, seaborn as sns, plotly.express as px
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import shap
from google.colab import files, drive
import requests

warnings.filterwarnings("ignore")
plt.rcParams["figure.figsize"] = (10, 5)
plt.rcParams["axes.grid"] = True
sns.set_theme(style="whitegrid")

CFG = {
    "target_col": "Crop_Yield_MT_per_HA",
    "num_cols_guess": [
        "Average_Temperature_C","Total_Precipitation_mm","CO2_Emissions_MT",
        "Crop_Yield_MT_per_HA","Extreme_Weather_Events","Irrigation_Access_%",
        "Pesticide_Use_KG_per_HA","Fertilizer_Use_KG_per_HA",
        "Soil_Health_Index","Economic_Impact_Million_USD"
    ],
    "cat_cols_guess": ["Country","Region","Crop_Type","Adaptation_Strategies","Year"],
    "pairplot_sample": 1200,
    "random_state": 42,
    "output_dir": "/content/outputs"
}
os.makedirs(CFG["output_dir"], exist_ok=True)

def load_data_colab():
    print("1) Upload CSV  2) Google Drive  3) URL")
    choice = input("Choose [1/2/3] (default=1): ").strip() or "1"
    if choice == "1":
        uploaded = files.upload()
        fname = list(uploaded.keys())[0]
        return pd.read_csv(io.BytesIO(uploaded[fname]))
    if choice == "2":
        drive.mount('/content/drive', force_remount=True)
        path = input("CSV path: ").strip()
        return pd.read_csv(path)
    if choice == "3":
        url = input("Direct CSV URL: ").strip()
        r = requests.get(url); r.raise_for_status()
        return pd.read_csv(io.StringIO(r.text))
    raise ValueError("Invalid choice")

def robust_numeric_cast(df, cols):
    for c in cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")
    return df

def clean_dataframe(df):
    df.columns = [c.strip().replace(" ", "_").replace("%","_pct").replace("-","_") for c in df.columns]
    num_cols = [c for c in CFG["num_cols_guess"] if c in df.columns]
    cat_cols = [c for c in CFG["cat_cols_guess"] if c in df.columns]
    df = robust_numeric_cast(df, num_cols)
    if "Year" in df.columns:
        df["Year"] = pd.to_numeric(df["Year"], errors="coerce").astype("Int64")
    df.replace([np.inf, -np.inf], np.nan, inplace=True)
    ycol = CFG["target_col"]
    if ycol in df.columns:
        df = df[~df[ycol].isna()]
    for c in num_cols:
        if c in df.columns:
            Q1, Q3 = df[c].quantile(0.25), df[c].quantile(0.75)
            IQR = Q3 - Q1
            df[c] = df[c].clip(lower=Q1-3*IQR, upper=Q3+3*IQR)
    for c in num_cols:
        if c in df.columns:
            df[c] = df[c].fillna(df[c].median())
    for c in cat_cols:
        if c in df.columns:
            df[c] = df[c].astype("string").fillna("Unknown")
    return df

def quick_glance(df):
    print(df.shape)
    print(df.dtypes)
    print(df.isna().sum().sort_values(ascending=False).head(20))
    display(df.describe().T)

def plot_summaries(df):
    if "Region" in df.columns and "Average_Temperature_C" in df.columns:
        plt.figure(figsize=(12,5)); sns.barplot(x="Region", y="Average_Temperature_C", data=df, estimator="mean", errorbar=None)
        plt.title("Average Temperature by Region"); plt.xticks(rotation=45, ha="right"); plt.tight_layout()
        plt.savefig(f'{CFG["output_dir"]}/avg_temp_by_region.png', dpi=160); plt.show()
    if "Year" in df.columns and CFG["target_col"] in df.columns:
        plt.figure(figsize=(12,5)); sns.lineplot(x="Year", y=CFG["target_col"], data=df, errorbar=None)
        plt.title("Crop Yield Over Years"); plt.tight_layout()
        plt.savefig(f'{CFG["output_dir"]}/yield_over_years.png', dpi=160); plt.show()
    if "Crop_Type" in df.columns:
        plt.figure(figsize=(7,7)); df["Crop_Type"].value_counts().plot.pie(autopct="%.1f%%", startangle=140)
        plt.title("Crop Type Share"); plt.ylabel(""); plt.tight_layout()
        plt.savefig(f'{CFG["output_dir"]}/crop_type_pie.png', dpi=160); plt.show()
    if "CO2_Emissions_MT" in df.columns:
        plt.figure(figsize=(10,5)); sns.histplot(df["CO2_Emissions_MT"], bins=40, kde=True)
        plt.title("CO2 Emissions Distribution"); plt.tight_layout()
        plt.savefig(f'{CFG["output_dir"]}/co2_hist.png', dpi=160); plt.show()
    if "Country" in df.columns and CFG["target_col"] in df.columns:
        top_countries = df["Country"].value_counts().head(20).index
        plt.figure(figsize=(12,5)); sns.boxplot(x="Country", y=CFG["target_col"], data=df[df["Country"].isin(top_countries)])
        plt.title("Crop Yield by Country (Top 20)"); plt.xticks(rotation=60, ha="right"); plt.tight_layout()
        plt.savefig(f'{CFG["output_dir"]}/yield_by_country.png', dpi=160); plt.show()
    num_cols = [c for c in df.select_dtypes(include=[np.number]).columns if df[c].nunique() > 5]
    if len(num_cols) >= 2:
        corr = df[num_cols].corr()
        plt.figure(figsize=(12,10)); sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f")
        plt.title("Correlation Matrix"); plt.tight_layout()
        plt.savefig(f'{CFG["output_dir"]}/corr_heatmap.png', dpi=160); plt.show()
    if "CO2_Emissions_MT" in df.columns and CFG["target_col"] in df.columns:
        plt.figure(figsize=(10,5)); sns.scatterplot(x="CO2_Emissions_MT", y=CFG["target_col"], data=df, alpha=0.6)
        plt.title("CO2 vs Yield"); plt.tight_layout()
        plt.savefig(f'{CFG["output_dir"]}/co2_vs_yield.png', dpi=160); plt.show()
    core_cols = ["Average_Temperature_C","Total_Precipitation_mm","CO2_Emissions_MT",
                 CFG["target_col"],"Extreme_Weather_Events","Irrigation_Access__pct",
                 "Pesticide_Use_KG_per_HA","Fertilizer_Use_KG_per_HA","Soil_Health_Index"]
    core_cols = [c for c in core_cols if c in df.columns]
    if len(core_cols) >= 3:
        _df = df[core_cols]
        if len(_df) > CFG["pairplot_sample"]:
            _df = _df.sample(CFG["pairplot_sample"], random_state=CFG["random_state"])
        g = sns.pairplot(_df, diag_kind="hist", corner=True)
        g.fig.suptitle("Pairwise (sampled)", y=1.02); plt.tight_layout()
        plt.savefig(f'{CFG["output_dir"]}/pairplot.png', dpi=160); plt.show()
    if "Region" in df.columns and "Economic_Impact_Million_USD" in df.columns:
        plt.figure(figsize=(12,5)); sns.boxplot(x="Region", y="Economic_Impact_Million_USD", data=df)
        plt.title("Economic Impact by Region"); plt.xticks(rotation=45, ha="right"); plt.tight_layout()
        plt.savefig(f'{CFG["output_dir"]}/economic_impact_by_region.png', dpi=160); plt.show()

def quick_modeling(df):
    ycol = CFG["target_col"]
    if ycol not in df.columns:
        print("Target not found"); return None
    categorical_feats = [c for c in df.select_dtypes(exclude=[np.number]).columns]
    cat_for_ohe = [c for c in categorical_feats if df[c].nunique() <= 30]
    data_enc = df.copy()
    if cat_for_ohe:
        data_enc = pd.get_dummies(data_enc, columns=cat_for_ohe, drop_first=True)
    drop_cols = [c for c in data_enc.columns if data_enc[c].dtype == "object" and c != ycol]
    data_enc = data_enc.drop(columns=drop_cols)
    X = data_enc.drop(columns=[ycol]); y = data_enc[ycol].astype(float)
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=CFG["random_state"])
    model = RandomForestRegressor(n_estimators=400, min_samples_leaf=2, random_state=CFG["random_state"], n_jobs=-1)
    model.fit(X_train, y_train); preds = model.predict(X_val)
    r2 = r2_score(y_val, preds); mae = mean_absolute_error(y_val, preds); rmse = mean_squared_error(y_val, preds, squared=False)
    print(f"R2: {r2:.3f}  MAE: {mae:.3f}  RMSE: {rmse:.3f}")
    fi = pd.Series(model.feature_importances_, index=X_train.columns).sort_values(ascending=False)[:25]
    plt.figure(figsize=(10,8)); sns.barplot(x=fi.values, y=fi.index); plt.title("Feature Importances"); plt.tight_layout()
    plt.savefig(f'{CFG["output_dir"]}/feature_importance_rf.png', dpi=160); plt.show()
    return model, X_val

def explain_with_shap(model, X_sample, max_points=400):
    try:
        if len(X_sample) > max_points:
            X_sample = X_sample.sample(max_points, random_state=CFG["random_state"])
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(X_sample)
        shap.summary_plot(shap_values, X_sample, plot_type="bar", show=False, max_display=20)
        plt.title("SHAP Summary (bar)"); plt.tight_layout()
        plt.savefig(f'{CFG["output_dir"]}/shap_summary_bar.png', dpi=160); plt.show()
        shap.summary_plot(shap_values, X_sample, show=False, max_display=20)
        plt.title("SHAP Summary (beeswarm)"); plt.tight_layout()
        plt.savefig(f'{CFG["output_dir"]}/shap_summary_beeswarm.png', dpi=160); plt.show()
    except Exception as e:
        print("SHAP skipped:", e)

def plotly_extras(df):
    if "Year" in df.columns and CFG["target_col"] in df.columns and "Region" in df.columns:
        fig = px.line(df.sort_values("Year"), x="Year", y=CFG["target_col"], color="Region", title="Yield over Years by Region"); fig.show()
    if "CO2_Emissions_MT" in df.columns and CFG["target_col"] in df.columns and "Country" in df.columns:
        fig = px.scatter(df, x="CO2_Emissions_MT", y=CFG["target_col"], color="Country", opacity=0.6, trendline="ols", title="CO2 vs Yield (by Country)"); fig.show()

df_raw = load_data_colab()
df = clean_dataframe(df_raw.copy())
print("Columns:", list(df.columns))
quick_glance(df)
plot_summaries(df)
out = quick_modeling(df)
if out is not None:
    model, X_val = out
    explain_with_shap(model, X_val)
plotly_extras(df)
print(f"Saved to {CFG['output_dir']}")
