In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import statsmodels.api as sm
import statsmodels.formula.api as smf
import scipy.stats as st

from pkg import detrend_group


In [2]:
yields = pd.read_csv("./data/yield_comparison.csv")
yields= yields.loc[:, ~yields.columns.str.contains("_lag|_lead|fao_idx|gridcells", regex=True)]

sm_tmax = pd.read_csv("./data/sm_tmax.csv")
yields_clim = yields.merge(sm_tmax, how="left", on = ["year", "cropname", "country"])

yields_clim = detrend_group(yields_clim, "sm_og", "sm_dt")
yields_clim = detrend_group(yields_clim, "tmax_og", "tmax_dt")

yields_clim = yields_clim[yields_clim.notna()]

In [3]:
def safe_regress(data, formula=None, yvar=None, xvars=None, ci=None):
    if formula:
        yvar = formula.split("~")[0].strip()
        xvars = formula.split("~")[1].strip().split("+")
        xvars = [x.strip() for x in xvars]


    if formula is None and yvar is None or xvars is None:
        raise ValueError("Provide either formula or yvar + xvars")

    n_obs = len(data.dropna(subset=[yvar] + xvars))
    n_preds = len(xvars) + 1  # +1 for intercept

    if n_obs <= n_preds:
        return pd.DataFrame([{"r2": None, "adj_r2": None}])

    try:
        if formula:
            model = smf.ols(formula=formula, data=data).fit()
        else:
            formula = f"{yvar} ~ {' + '.join(xvars)}"
            model = smf.ols(formula=formula, data=data).fit()

        out = {"r2": model.rsquared,
            "adj_r2": model.rsquared_adj,
            "ftest_pval": model.f_pvalue}
        
        out.update({f"coef_{k}": v for k, v in model.params.items()})
        out.update({f"pval_{k}": v for k, v in model.pvalues.items()})
        
        if ci is not None:
            ci = model.conf_int()
            for k in model.params.index:
                out[f"cilow_{k}"] = ci.loc[k, 0]
                out[f"cihigh_{k}"] = ci.loc[k, 1]

        return pd.DataFrame([out])

    except Exception as e:
        return pd.DataFrame([{"r2": None, "adj_r2": None, "error": str(e)}])


In [None]:
# test = yields_clim[["country", "cropname", "year", "sm_dt", "tmax_dt", "yield_log_dt", "csif_log_dt"]].sample(10000).dropna(how="any")
# counts = test["country"].value_counts()
# test = test[test["country"].isin(counts[counts > 10].index)].reset_index(drop=True)

# res = test.groupby(['cropname', 'country']).apply(
#     lambda group: safe_regress(group, formula= "yield_log_dt ~ sm_dt + tmax_dt")
#     ).reset_index(level=[0,1])
# res = res.iloc[:, ~res.columns.str.contains("Intercept", regex=True)]
# res.drop({"country"}, axis=1).groupby("cropname").quantile([0.25, 0.5, 0.75])



In [4]:
counts = yields_clim[["country", "cropname"]].value_counts()
counts_idx = counts[counts>10].index
yields_clim10 = yields_clim.set_index(['country', 'cropname'])
yields_clim10 =  yields_clim10.loc[counts_idx].reset_index()

In [5]:
res_surv = yields_clim10.groupby(['cropname', 'country']).apply(
    lambda group: safe_regress(group, formula= "yield_log_dt ~ sm_dt + tmax_dt")
    ).reset_index(level=[0,1])
res_surv = res_surv.iloc[:, ~res_surv.columns.str.contains("Intercept", regex=True)]
res_surv['model']="Survey"

In [6]:
res_sat = yields_clim10.groupby(['cropname', 'country']).apply(
    lambda group: safe_regress(group, formula= "csif_log_dt ~ sm_dt + tmax_dt")
    ).reset_index(level=[0,1])
res_sat = res_sat.iloc[:, ~res_sat.columns.str.contains("Intercept", regex=True)]
res_sat['model'] = "Satellite"

In [114]:
res_comb= pd.concat([res_sat, res_surv]).reset_index(drop=True)

def pval_is_sig(data, column, threshold = 0.05):
    data[f'{column}_pass'] = np.where(data[column] <0.05, True, False )
    return data

[pval_is_sig(res_comb, col) for col in ['pval_sm_dt', 'pval_tmax_dt', 'ftest_pval']]
res_comb= res_comb.dropna(how="any")
res_comb.to_csv("./data/yields_sm_tmax_reg.csv", index=False)

In [96]:
wanted = res_comb[res_comb['cropname'].isin(["Maize", "Sorghum", "Wheat", "Potatoes", "Cassava"]) ]

nice_tab1 = wanted.groupby(['cropname', 'model'])[['r2', 'adj_r2', 'coef_sm_dt', 'coef_tmax_dt']].quantile([0.25, 0.5, 0.75]).round(2).reset_index()
nice_tab2= wanted.groupby(['cropname', 'model'])[['pval_sm_dt_pass', 'pval_tmax_dt_pass', 'ftest_pval_pass']].agg(lambda x: str(round((sum(x)/ len(x)*100),1)))
nice_tab2 = nice_tab2 + "%"

In [97]:
def make_pretty_tab_multi(df, cols):
    out = {}
    for col in cols:
        vals = {'low': df.loc[df['level_2']==0.25, col].item(),
            'med': df.loc[df['level_2']==0.5,  col].item(),
            'high': df.loc[df['level_2']==0.75, col].item()}
        out[col] = f"{vals['med']} ({vals['low']}, {vals['high']})"
    return pd.Series(out)

cols = ["adj_r2", "r2", "coef_sm_dt", "coef_tmax_dt"]

nice_tab1 = nice_tab1.groupby(['cropname', 'model']).apply(lambda x: make_pretty_tab_multi(x, cols))


In [98]:
nice_tab1

Unnamed: 0_level_0,Unnamed: 1_level_0,adj_r2,r2,coef_sm_dt,coef_tmax_dt
cropname,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Cassava,Satellite,"0.13 (-0.02, 0.29)","0.21 (0.09, 0.36)","1.04 (0.35, 2.68)","-0.0 (-0.01, 0.01)"
Cassava,Survey,"0.01 (-0.04, 0.13)","0.11 (0.06, 0.21)","0.24 (-0.51, 2.11)","-0.0 (-0.05, 0.04)"
Maize,Satellite,"0.32 (0.11, 0.51)","0.38 (0.21, 0.55)","2.54 (1.36, 5.27)","-0.01 (-0.03, 0.01)"
Maize,Survey,"0.07 (-0.05, 0.21)","0.16 (0.05, 0.28)","1.65 (-0.83, 7.11)","-0.02 (-0.04, 0.03)"
Potatoes,Satellite,"0.18 (0.01, 0.44)","0.26 (0.1, 0.49)","3.13 (1.25, 6.0)","0.01 (-0.01, 0.02)"
Potatoes,Survey,"-0.01 (-0.07, 0.09)","0.09 (0.03, 0.18)","0.63 (-0.86, 2.93)","-0.0 (-0.02, 0.02)"
Sorghum,Satellite,"0.3 (0.13, 0.56)","0.36 (0.2, 0.6)","2.96 (1.65, 6.42)","-0.0 (-0.03, 0.02)"
Sorghum,Survey,"0.05 (-0.03, 0.14)","0.13 (0.06, 0.23)","1.66 (-1.57, 5.65)","-0.01 (-0.06, 0.04)"
Wheat,Satellite,"0.23 (0.08, 0.49)","0.3 (0.16, 0.53)","3.26 (1.25, 7.01)","0.01 (-0.01, 0.03)"
Wheat,Survey,"0.05 (-0.02, 0.19)","0.14 (0.07, 0.27)","0.92 (-1.7, 6.0)","-0.01 (-0.05, 0.01)"


In [99]:
nice_tab2

Unnamed: 0_level_0,Unnamed: 1_level_0,pval_sm_dt_pass,pval_tmax_dt_pass,ftest_pval_pass
cropname,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Cassava,Satellite,33.8%,13.8%,38.8%
Cassava,Survey,11.2%,15.0%,16.2%
Maize,Satellite,52.3%,17.4%,68.5%
Maize,Survey,18.8%,9.4%,26.8%
Potatoes,Satellite,44.8%,13.8%,49.7%
Potatoes,Survey,11.0%,5.5%,9.7%
Sorghum,Satellite,60.6%,17.3%,68.3%
Sorghum,Survey,17.3%,7.7%,19.2%
Wheat,Satellite,51.7%,17.5%,55.0%
Wheat,Survey,20.8%,10.8%,26.7%


In [110]:
nice_table =  pd.merge(nice_tab1, nice_tab2, left_on=(['cropname', "model"]), right_on=(['cropname', 'model']))
nice_table = nice_table.iloc[: , [1,0, 6, 2,4, 3,5]].loc[order,:]
nice_table.columns=['R2', "Adj R2", "F-test pass %", "SM coefficient", "SM p-val pass", "Tmax coefficient", "Tmax p-val pass"]
nice_table


Unnamed: 0_level_0,Unnamed: 1_level_0,R2,Adj R2,F-test pass %,SM coefficient,SM p-val pass,Tmax coefficient,Tmax p-val pass
cropname,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Maize,Satellite,"0.38 (0.21, 0.55)","0.32 (0.11, 0.51)",68.5%,"2.54 (1.36, 5.27)",52.3%,"-0.01 (-0.03, 0.01)",17.4%
Maize,Survey,"0.16 (0.05, 0.28)","0.07 (-0.05, 0.21)",26.8%,"1.65 (-0.83, 7.11)",18.8%,"-0.02 (-0.04, 0.03)",9.4%
Sorghum,Satellite,"0.36 (0.2, 0.6)","0.3 (0.13, 0.56)",68.3%,"2.96 (1.65, 6.42)",60.6%,"-0.0 (-0.03, 0.02)",17.3%
Sorghum,Survey,"0.13 (0.06, 0.23)","0.05 (-0.03, 0.14)",19.2%,"1.66 (-1.57, 5.65)",17.3%,"-0.01 (-0.06, 0.04)",7.7%
Wheat,Satellite,"0.3 (0.16, 0.53)","0.23 (0.08, 0.49)",55.0%,"3.26 (1.25, 7.01)",51.7%,"0.01 (-0.01, 0.03)",17.5%
Wheat,Survey,"0.14 (0.07, 0.27)","0.05 (-0.02, 0.19)",26.7%,"0.92 (-1.7, 6.0)",20.8%,"-0.01 (-0.05, 0.01)",10.8%
Cassava,Satellite,"0.21 (0.09, 0.36)","0.13 (-0.02, 0.29)",38.8%,"1.04 (0.35, 2.68)",33.8%,"-0.0 (-0.01, 0.01)",13.8%
Cassava,Survey,"0.11 (0.06, 0.21)","0.01 (-0.04, 0.13)",16.2%,"0.24 (-0.51, 2.11)",11.2%,"-0.0 (-0.05, 0.04)",15.0%
Potatoes,Satellite,"0.26 (0.1, 0.49)","0.18 (0.01, 0.44)",49.7%,"3.13 (1.25, 6.0)",44.8%,"0.01 (-0.01, 0.02)",13.8%
Potatoes,Survey,"0.09 (0.03, 0.18)","-0.01 (-0.07, 0.09)",9.7%,"0.63 (-0.86, 2.93)",11.0%,"-0.0 (-0.02, 0.02)",5.5%


In [111]:
nice_table.to_latex

<bound method NDFrame.to_latex of                                    R2               Adj R2 F-test pass %  \
cropname model                                                             
Maize    Satellite  0.38 (0.21, 0.55)    0.32 (0.11, 0.51)         68.5%   
         Survey     0.16 (0.05, 0.28)   0.07 (-0.05, 0.21)         26.8%   
Sorghum  Satellite    0.36 (0.2, 0.6)     0.3 (0.13, 0.56)         68.3%   
         Survey     0.13 (0.06, 0.23)   0.05 (-0.03, 0.14)         19.2%   
Wheat    Satellite   0.3 (0.16, 0.53)    0.23 (0.08, 0.49)         55.0%   
         Survey     0.14 (0.07, 0.27)   0.05 (-0.02, 0.19)         26.7%   
Cassava  Satellite  0.21 (0.09, 0.36)   0.13 (-0.02, 0.29)         38.8%   
         Survey     0.11 (0.06, 0.21)   0.01 (-0.04, 0.13)         16.2%   
Potatoes Satellite   0.26 (0.1, 0.49)    0.18 (0.01, 0.44)         49.7%   
         Survey     0.09 (0.03, 0.18)  -0.01 (-0.07, 0.09)          9.7%   

                        SM coefficient SM p-val pass 