# 06 — Scenario Analysis

**Purpose:** Train a chosen model and run what‑if scenarios in two modes (local t, t+1), **recomputing logs/squares and interactions** so deltas reflect baseline heterogeneity.

**Inputs:** `./data/processed/countries_features.csv`

**Outputs (CSV only):**
- Local: `./reports/scenarios/scenario_<name>.csv`
- t+1:   `./reports/scenarios/scenario_<name>_tplus1.csv`

**Assumptions:** Countries only; same feature set as 04.

## 1) Setup & load

In [None]:
from pathlib import Path
import numpy as np, pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import Ridge
from sklearn.preprocessing import StandardScaler

INP = Path("../data/processed/countries_features.csv")
OUT = Path("./reports/scenarios"); OUT.mkdir(parents=True, exist_ok=True)

df = pd.read_csv(INP)
if "year" in df.columns:
    df["year"] = pd.to_numeric(df["year"], errors="coerce").astype("Int64")

TARGET = "cereal_yield"
id_cols = [c for c in ["Country Name","Country Code"] if c in df.columns]

numeric_like = df.select_dtypes(include=[np.number, "Int64", "Float64"]).columns.tolist()
exclude = set([TARGET, "log_cereal_yield", "year"] + id_cols)
exclude |= {c for c in numeric_like if c.startswith("lag0_") or c.endswith("_future")}
FEATS = [c for c in numeric_like if c not in exclude]

print("n_features:", len(FEATS))

mask = df[TARGET].notna()
df_train = df.loc[mask].copy()

MODEL = "ridge"  # default; set to "rf" for richer heterogeneity
if MODEL == "rf":
    pipe = Pipeline([("imputer", SimpleImputer(strategy="median")),
                     ("rf", RandomForestRegressor(n_estimators=300, random_state=42, n_jobs=-1))])
else:
    pipe = Pipeline([("imputer", SimpleImputer(strategy="median")),
                     ("scaler", StandardScaler(with_mean=True, with_std=True)),
                     ("ridge", Ridge(alpha=1.0, solver="auto"))])

pipe.fit(df_train[FEATS], df_train[TARGET])
print(f"Trained baseline {MODEL.upper()} on labeled history.")

## 2) Scenario settings

In [None]:
SCENARIO_NAME = "plus0p5C_plus5pct_precip_plus10pct_fert"
SELECT_COUNTRIES = []  # empty = all

DELTAS = {
    "temp_anomaly": {"abs": 0.5},
    "precipitation": {"pct": 0.05},
    "fertilizer_use": {"pct": 0.10},
}
print("Scenario:", SCENARIO_NAME); print("Deltas:", DELTAS)

## 3) Helpers for transforms & interactions

In [None]:
import numpy as np, pandas as pd

def apply_change(series, change):
    s = series.astype(float).copy()
    if 'abs' in change and change['abs'] is not None:
        s = s + float(change['abs'])
    if 'pct' in change and change['pct'] is not None:
        s = s * (1.0 + float(change['pct']))
    return s

def recompute_transforms(X_base, X_scn, changed_cols):
    def has(col): return col in X_base.columns

    # Logs & squares
    for col in changed_cols:
        if has(f"log_{col}"):
            X_base[f"log_{col}"] = np.log1p(np.clip(X_base[col].astype(float), 0, None))
            X_scn[f"log_{col}"]  = np.log1p(np.clip(X_scn[col].astype(float), 0, None))
        if has(f"{col}_sq"):
            X_base[f"{col}_sq"] = X_base[col].astype(float)**2
            X_scn[f"{col}_sq"]  = X_scn[col].astype(float)**2

    # Interactions
    inter_defs = [
        ("temp_anomaly", "precipitation", "tempXprecip"),
        ("temp_anomaly", "fertilizer_use", "tempXfertilizer"),
        ("precipitation", "fertilizer_use", "precipXfertilizer"),
    ]
    for a, b, name in inter_defs:
        if name in X_base.columns and a in X_base.columns and b in X_base.columns:
            X_base[name] = X_base[a].astype(float) * X_base[b].astype(float)
            X_scn[name]  = X_scn[a].astype(float) * X_scn[b].astype(float)

    return X_base, X_scn

## 4) Local sensitivity (t)

In [None]:
df_latest = df.copy()
if SELECT_COUNTRIES:
    mask_sel = df_latest['Country Code'].isin(SELECT_COUNTRIES) | df_latest['Country Name'].isin(SELECT_COUNTRIES)
    df_latest = df_latest.loc[mask_sel].copy()

df_latest = (df_latest
             .sort_values(['Country Name','year'])
             .groupby('Country Name', as_index=False)
             .tail(1)
             .reset_index(drop=True))

X_base = df_latest[FEATS].copy()
X_scn = X_base.copy()

for col, ch in DELTAS.items():
    if col in X_scn.columns:
        X_scn[col] = apply_change(X_scn[col], ch)

X_base, X_scn = recompute_transforms(X_base, X_scn, changed_cols=list(DELTAS.keys()))

meta_cols = [c for c in ['Country Name','Country Code','year'] if c in df_latest.columns]
meta = df_latest[meta_cols].copy()

y_base = pipe.predict(X_base); y_scn = pipe.predict(X_scn)
res_local = meta.copy()
res_local['y_pred_baseline'] = y_base.astype(float)
res_local['y_pred_scenario'] = y_scn.astype(float)
res_local['delta_abs'] = (res_local['y_pred_scenario'] - res_local['y_pred_baseline']).astype(float)
res_local['delta_pct'] = np.where(res_local['y_pred_baseline'] != 0, res_local['delta_abs']/res_local['y_pred_baseline'], np.nan)

out_local = OUT / f'scenario_{SCENARIO_NAME}_{MODEL}.csv'
res_local.to_csv(out_local, index=False)
print('Saved (local):', out_local)

## 5) Next‑year (t+1)

In [None]:
def build_tplus1_matrices(df, FEATS, DELTAS, select=None, roll_window=3):
    import numpy as np, pandas as pd
    grpkey = "Country Name" if "Country Name" in df.columns else "Country Code"

    def bases_with_suffix(sfx): return sorted({c[:-len(sfx)] for c in FEATS if c.endswith(sfx)})
    lag1_bases  = bases_with_suffix("_lag1")
    roll3_bases = bases_with_suffix("_roll3")
    lag1_sq_bases = [c.replace("_lag1_sq","") for c in FEATS if c.endswith("_lag1_sq")]
    log_raws = sorted({c.replace("log_","") for c in FEATS if c.startswith("log_")})

    inter_defs = [
        ("temp_anomaly", "precipitation", "tempXprecip"),
        ("temp_anomaly", "fertilizer_use", "tempXfertilizer"),
        ("precipitation", "fertilizer_use", "precipXfertilizer"),
    ]
    inter_defs = [(a,b,n) for a,b,n in inter_defs if n in FEATS]

    data = df.copy()
    if select:
        data = data[(data.get("Country Code","").isin(select)) | (data.get("Country Name","").isin(select))].copy()

    rows_meta, rows_base, rows_scn = [], [], []
    for g, gdf in data.sort_values([grpkey, "year"]).groupby(grpkey):
        gd = gdf.copy().sort_values("year")
        if gd.empty: continue
        last = gd.iloc[-1]
        year_t = int(last["year"]) if "year" in gd.columns and not pd.isna(last["year"]) else None
        meta = {k: last.get(k, None) for k in ["Country Name","Country Code"]}
        meta["year"] = None if year_t is None else year_t + 1

        x_base = last.reindex(FEATS).to_dict()

        for b in lag1_bases:
            if b in gd.columns:
                x_base[f"{b}_lag1"] = gd[b].dropna().iloc[-1] if gd[b].notna().any() else np.nan
        for b in roll3_bases:
            if b in gd.columns:
                vals = gd[b].dropna().tail(roll_window).to_numpy()
                x_base[f"{b}_roll3"] = float(np.mean(vals)) if len(vals) > 0 else np.nan
        for base in lag1_sq_bases:
            key = f"{base}_lag1"; sqk = f"{base}_lag1_sq"
            if key in x_base:
                v = x_base[key]; x_base[sqk] = np.nan if pd.isna(v) else float(v)**2

        x_scn = dict(x_base)
        def apply_change(v, ch):
            if pd.isna(v): return v
            v = float(v)
            if "abs" in ch and ch["abs"] is not None: v += float(ch["abs"])
            if "pct" in ch and ch["pct"] is not None: v *= (1.0 + float(ch["pct"]))
            return v
        for col, ch in DELTAS.items():
            if col in x_scn:
                x_scn[col] = apply_change(x_scn[col], ch)

        # Recompute logs/squares where present
        for raw in log_raws:
            logk = f"log_{raw}"
            if raw in x_scn and logk in FEATS:
                vb = x_base.get(raw, np.nan); vs = x_scn.get(raw, np.nan)
                x_base[logk] = np.nan if pd.isna(vb) or vb < 0 else np.log1p(vb)
                x_scn[logk]  = np.nan if pd.isna(vs) or vs < 0 else np.log1p(vs)

        for raw in set([r for r in ["temp_anomaly","precipitation","fertilizer_use"] if f"{r}_sq" in FEATS]):
            if raw in x_scn:
                try:
                    x_base[f"{raw}_sq"] = float(x_base.get(raw, np.nan))**2 if x_base.get(raw) is not None else np.nan
                    x_scn[f"{raw}_sq"]  = float(x_scn.get(raw, np.nan))**2  if x_scn.get(raw)  is not None else np.nan
                except Exception:
                    x_base[f"{raw}_sq"] = np.nan; x_scn[f"{raw}_sq"] = np.nan

        # Interactions present in FEATS
        for a, b, name in inter_defs:
            avb = x_base.get(a, np.nan); bvb = x_base.get(b, np.nan)
            avs = x_scn.get(a, np.nan); bvs = x_scn.get(b, np.nan)
            x_base[name] = np.nan if (pd.isna(avb) or pd.isna(bvb)) else float(avb) * float(bvb)
            x_scn[name]  = np.nan if (pd.isna(avs) or pd.isna(bvs)) else float(avs) * float(bvs)

        x_base = {k: x_base.get(k, np.nan) for k in FEATS}
        x_scn  = {k: x_scn.get(k,  np.nan) for k in FEATS}

        rows_meta.append(meta); rows_base.append(x_base); rows_scn.append(x_scn)

    meta_df = pd.DataFrame(rows_meta); Xb = pd.DataFrame(rows_base); Xs = pd.DataFrame(rows_scn)
    return meta_df, Xb, Xs

meta_t1, Xb_t1, Xs_t1 = build_tplus1_matrices(df=df, FEATS=FEATS, DELTAS=DELTAS, select=SELECT_COUNTRIES, roll_window=3)

y_base_t1 = pipe.predict(Xb_t1); y_scn_t1 = pipe.predict(Xs_t1)
res_t1 = meta_t1.copy()
res_t1['y_pred_baseline'] = y_base_t1.astype(float)
res_t1['y_pred_scenario'] = y_scn_t1.astype(float)
res_t1['delta_abs'] = (res_t1['y_pred_scenario'] - res_t1['y_pred_baseline']).astype(float)
res_t1['delta_pct'] = np.where(res_t1['y_pred_baseline'] != 0, res_t1['delta_abs']/res_t1['y_pred_baseline'], np.nan)

from pathlib import Path
OUT = Path("./reports/scenarios"); OUT.mkdir(parents=True, exist_ok=True)
out_t1 = OUT / f'scenario_{SCENARIO_NAME}_tplus1_{MODEL}.csv'
res_t1.to_csv(out_t1, index=False)
print('Saved (t+1):', out_t1)

## 6) Environment

In [None]:
import sys, platform, numpy, pandas, sklearn
print("Python:", sys.version.split()[0])
print("Platform:", platform.platform())
print("NumPy:", numpy.__version__)
print("Pandas:", pandas.__version__)
print("scikit-learn:", sklearn.__version__)