# Results Data Cleanup

## Setup

In [1]:
%load_ext lab_black

In [2]:
import os
import re
import sys
import glob
import numpy as np
import pandas as pd

from pyhere import here
from scipy.stats import pearsonr
from sklearn.metrics import r2_score


sys.path.append(str(here("code", "3_task_modeling")))

import task_modeling_utils

In [3]:
def get_group_R2(df, observed="demean_log_yield", predicted="demean_oos_prediction"):
    return r2_score(df[observed], df[predicted])


def get_group_r2(df, observed="demean_log_yield", predicted="demean_oos_prediction"):
    return pearsonr(df[observed], df[predicted])[0]


def extract_float_from_string(s):
    if isinstance(s, str):
        match = re.search(r"\((.*),\)", s)
        return float(match.group(1)) if match else s
    else:
        return s

## Global Experiment - General Model

### Model Selection

In [4]:
model_selection_file_pattern = str(
    here("data", "results", "2_sensor_10-splits_2023-05-*_anom-False.csv")
)
model_selection_files = glob.glob(pathname=model_selection_file_pattern)
sorted(model_selection_files)

['/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-05-11_1_rcf_climate-False_anom-False.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-05-11_2_rcf_climate-False_anom-False.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-05-11_3_rcf_climate-False_anom-False.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-05-11_4_rcf_climate-False_anom-False.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-05-11_5_rcf_climate-False_anom-False.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-05-11_6_rcf_climate-False_anom-False.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-05-11_7_rcf_climate-False_anom-False.csv',
 '/home/cmolitor/cro

In [5]:
model_selection_results = task_modeling_utils.merge_files(model_selection_files)
model_selection_results["anomaly"] = False
model_selection_results["variables"] = "rcf"
model_selection_results

Unnamed: 0,split,random_state,country,year_range,satellite_1,bands_1,num_features_1,points_1,month_range_1,limit_months_1,...,test_r,test_r2,demean_cv_R2,demean_cv_r,demean_cv_r2,demean_test_R2,demean_test_r,demean_test_r2,anomaly,variables
0,0,670487,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,1-12,False,...,0.875606,0.766687,0.117839,0.519227,0.269597,0.293401,0.561290,0.315046,False,rcf
1,1,116739,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,1-12,False,...,0.904602,0.818305,0.132072,0.494040,0.244076,0.511307,0.722383,0.521837,False,rcf
2,2,26225,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,1-12,False,...,0.932698,0.869926,0.125536,0.452668,0.204908,0.421614,0.668757,0.447236,False,rcf
3,3,777572,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,1-12,False,...,0.913001,0.833571,0.146011,0.523675,0.274235,0.131849,0.564733,0.318923,False,rcf
4,4,288389,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,1-12,False,...,0.919801,0.846033,-0.002118,0.398073,0.158462,0.610579,0.784178,0.614934,False,rcf
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1135,5,256787,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,4-9,True,...,0.841214,0.707641,-0.131221,0.466321,0.217455,-0.016187,0.292059,0.085298,False,rcf
1136,6,234053,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,4-9,True,...,0.915579,0.838286,0.144304,0.480349,0.230735,0.399199,0.667107,0.445031,False,rcf
1137,7,146316,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,4-9,True,...,0.912913,0.833411,0.065448,0.436761,0.190760,0.366838,0.608689,0.370502,False,rcf
1138,8,772246,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,4-9,True,...,0.942299,0.887927,0.160958,0.485971,0.236168,0.513391,0.718095,0.515660,False,rcf


In [6]:
model_selection_results.val_R2.describe()

count    1140.000000
mean        0.779307
std         0.026716
min         0.667332
25%         0.764163
50%         0.783825
75%         0.796886
max         0.839085
Name: val_R2, dtype: float64

In [7]:
model_selection_group_cols = [
    "variables",
    "anomaly",
    "country",
    "year_range",
    "satellite_1",
    "bands_1",
    "num_features_1",
    "points_1",
    "month_range_1",
    "limit_months_1",
    "crop_mask_1",
    "weighted_avg_1",
    "satellite_2",
    "bands_2",
    "num_features_2",
    "points_2",
    "month_range_2",
    "limit_months_2",
    "crop_mask_2",
    "weighted_avg_2",
    "hot_encode",
]

model_selection_results_summary = model_selection_results.groupby(
    model_selection_group_cols, as_index=False
).agg(
    {
        "val_R2": "mean",
        "val_r2": "mean",
        "test_R2": "mean",
        "test_r2": "mean",
    }
)
model_selection_results_summary = model_selection_results_summary.sort_values(
    "val_R2", ascending=False
)

### Top Model Results 

In [8]:
top_model = model_selection_results_summary.iloc[0:1, :]
top_model_dict = top_model.to_dict(orient="records")[0]
top_model_dict

{'variables': 'rcf',
 'anomaly': False,
 'country': 'ZMB',
 'year_range': '2016-2021',
 'satellite_1': 'landsat-c2-l2',
 'bands_1': 'r-g-b-nir-swir16-swir22',
 'num_features_1': 1024,
 'points_1': 20,
 'month_range_1': '4-9',
 'limit_months_1': True,
 'crop_mask_1': True,
 'weighted_avg_1': False,
 'satellite_2': 'sentinel-2-l2a',
 'bands_2': '2-3-4-8',
 'num_features_2': 1000,
 'points_2': 15,
 'month_range_2': '1-12',
 'limit_months_2': False,
 'crop_mask_2': True,
 'weighted_avg_2': False,
 'hot_encode': True,
 'val_R2': 0.8002404610219435,
 'val_r2': 0.8017879820931079,
 'test_R2': 0.8258186120850981,
 'test_r2': 0.8325650503230417}

In [9]:
# # Define the keys to remove
# keys_to_remove = ["val_R2", "val_r2", "test_R2", "test_r2"]

# # Remove keys
# for key in keys_to_remove:
#     if key in top_model_dict:
#         del top_model_dict[key]

# top_model_splits = model_selection_results.copy()

# for key, value in top_model_dict.items():
#     if key in model_selection_results.columns:
#         top_model_splits = top_model_splits[top_model_splits[key] == value]
# top_model_splits

### Top Model Splits

In [10]:
top_model_file_pattern = str(
    here("data", "results", "2_sensor_top-mod_10-splits_*_*.csv")
)
top_model_files = glob.glob(pathname=top_model_file_pattern)
sorted(top_model_files)

['/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_top-mod_10-splits_2023-07-05_rcf_climate-False_anom-True.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_top-mod_10-splits_2023-07-05_rcf_climate-True_anom-False.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_top-mod_10-splits_2023-07-06_rcf_climate-False_anom-False.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_top-mod_10-splits_2023-07-06_rcf_climate-True_anom-True.csv']

In [11]:
top_model_results = task_modeling_utils.merge_files(top_model_files)
# top_model_results["variables"] = top_model_results["variables"].apply(
#     lambda x: "rcf_" + x if "rcf" not in x else x
# )

top_model_results.head(10)

Unnamed: 0,split,random_state,variables,anomaly,country,year_range,satellite_1,bands_1,num_features_1,points_1,...,train_r2,test_R2,test_r,test_r2,demean_cv_R2,demean_cv_r,demean_cv_r2,demean_test_R2,demean_test_r,demean_test_r2
0,0,670487,rcf,True,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,...,0.975638,0.454223,0.73888,0.545944,,,,,,
1,1,116739,rcf,True,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,...,0.987626,0.646691,0.814142,0.662827,,,,,,
2,2,26225,rcf,True,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,...,0.972669,0.699289,0.8364,0.699565,,,,,,
3,3,777572,rcf,True,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,...,0.997239,0.65321,0.820834,0.673769,,,,,,
4,4,288389,rcf,True,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,...,0.95634,0.529977,0.74721,0.558323,,,,,,
5,5,256787,rcf,True,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,...,0.986114,0.555285,0.748132,0.559702,,,,,,
6,6,234053,rcf,True,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,...,0.974025,0.514385,0.747043,0.558073,,,,,,
7,7,146316,rcf,True,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,...,0.988091,0.381815,0.642125,0.412324,,,,,,
8,8,772246,rcf,True,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,...,0.987443,0.68844,0.830507,0.689742,,,,,,
9,9,107473,rcf,True,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,...,0.986473,0.708049,0.843981,0.712304,,,,,,


In [12]:
top_model_results.loc[
    (top_model_results["variables"] == "rcf") & (top_model_results.anomaly == False)
].iloc[:, 25:30]

Unnamed: 0,test_n,best_reg_param,mean_of_val_R2,val_R2,val_r
30,84,"[[10.0, 10.0, 0.009999999999999998]]","[[0.6648244030124295, 0.7101271931484556, 0.78...",0.792837,0.890883
31,84,"[[10.0, 10.0, 0.0009999999999999998]]","[[0.703574744293908, 0.7455820353086114, 0.795...",0.807981,0.899629
32,84,"[[100.0, 10.0, 0.09999999999999999]]","[[0.7351356731429453, 0.7583377050863769, 0.79...",0.797153,0.893292
33,84,"[[100.0, 10.0, 0.009999999999999998]]","[[0.6880542403771379, 0.7352708239111512, 0.79...",0.796673,0.892972
34,84,"[[10.0, 10.0, 0.009999999999999998]]","[[0.7286959537558308, 0.7641489361690736, 0.82...",0.825819,0.908855
35,84,"[[100.0, 10.0, 0.009999999999999998]]","[[0.7331774863223968, 0.7559086275465114, 0.81...",0.818719,0.905189
36,84,"[[10.0, 10.0, 0.009999999999999998]]","[[0.6964254355989031, 0.7361124201081983, 0.79...",0.803079,0.896493
37,84,"[[100.0, 10.0, 0.09999999999999999]]","[[0.7355409474919525, 0.7710290473858741, 0.81...",0.813752,0.902505
38,84,"[[10.0, 10.0, 0.009999999999999998]]","[[0.724555150847279, 0.7475335680402925, 0.781...",0.783923,0.886083
39,84,"[[100.0, 10.0, 0.009999999999999998]]","[[0.6889778140977343, 0.7415109445765998, 0.78...",0.798971,0.894057


In [13]:
round(0.009999999999999998, 4), round(0.0009999999999, 4)

(0.01, 0.001)

In [14]:
top_model_group_cols = [
    "variables",
    "anomaly",
]

top_model_results_summary = top_model_results.groupby(
    top_model_group_cols, as_index=False
).agg(
    {
        "val_R2": "mean",
        "val_r2": "mean",
        "test_R2": "mean",
        "test_r2": "mean",
        "demean_cv_R2": "mean",
        "demean_cv_r2": "mean",
        "demean_test_R2": "mean",
        "demean_test_r2": "mean",
        # "split": "count",
    }
)
top_model_results_summary = top_model_results_summary.sort_values(
    "val_R2", ascending=False
)
top_model_results_summary

Unnamed: 0,variables,anomaly,val_R2,val_r2,test_R2,test_r2,demean_cv_R2,demean_cv_r2,demean_test_R2,demean_test_r2
4,rcf_ndvi_tmp,False,0.820895,0.821212,0.846339,0.852371,0.249587,0.303363,0.465455,0.489834
2,rcf_ndvi,False,0.813084,0.813852,0.837648,0.844072,0.211704,0.265298,0.449108,0.478338
0,rcf,False,0.803891,0.804645,0.831973,0.837512,0.172972,0.233819,0.422436,0.452334
5,rcf_ndvi_tmp,True,0.567,0.590313,0.581847,0.608774,,,,
3,rcf_ndvi,True,0.558335,0.581892,0.585674,0.610538,,,,
1,rcf,True,0.537469,0.567692,0.583136,0.607257,,,,


In [15]:
top_model_group_cols = [
    "variables",
    "anomaly",
]

top_model_results_summary = top_model_results.groupby(
    top_model_group_cols, as_index=False
).agg(
    {
        "val_R2": "sem",
        "val_r2": "sem",
        "test_R2": "sem",
        "test_r2": "sem",
        "split": "count",
    }
)
top_model_results_summary = top_model_results_summary.sort_values(
    "val_R2", ascending=False
)
top_model_results_summary

Unnamed: 0,variables,anomaly,val_R2,val_r2,test_R2,test_r2,split
5,rcf_ndvi_tmp,True,0.012316,0.009095,0.041445,0.035144,10
3,rcf_ndvi,True,0.01228,0.008519,0.043679,0.037243,10
1,rcf,True,0.011523,0.008428,0.035685,0.030314,10
4,rcf_ndvi_tmp,False,0.004607,0.004645,0.010258,0.010288,10
0,rcf,False,0.004025,0.003972,0.010714,0.010837,10
2,rcf_ndvi,False,0.00381,0.003715,0.010226,0.01033,10


### Out of Sample Predictions

In [48]:
oos_prediction_file_pattern = str(
    here(
        "data", "results", "2_sensor_top-mod_oos_predictions_10-splits_*_anom-False.csv"
    )
)
oos_prediction_files = glob.glob(pathname=oos_prediction_file_pattern)
sorted(oos_prediction_files)

['/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_top-mod_oos_predictions_10-splits_2023-07-05_rcf_climate-True_anom-False.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_top-mod_oos_predictions_10-splits_2023-07-06_rcf_climate-False_anom-False.csv']

In [49]:
oos_prediction = task_modeling_utils.merge_files(oos_prediction_files)
# oos_prediction

In [50]:
oos_prediction_demean_groups = [
    # "data_fold",
    "district",
    # "val_fold",
    "split",
    "random_state",
    "variables",
]
oos_prediction = task_modeling_utils.demean_by_group(
    df=oos_prediction,
    observed="log_yield",
    predicted="oos_prediction",
    group=oos_prediction_demean_groups,
)
oos_prediction

Unnamed: 0,data_fold,district,year,yield_mt,log_yield,demean_log_yield,oos_prediction,val_fold,split,random_state,variables,anomaly,hot_encode,demean_oos_prediction
0,train,Kaoma,2019,0.523864,0.182946,-0.192782,0.279597,1,0,670487,rcf_ndvi,False,True,-0.083021
1,train,Kaoma,2017,1.540420,0.404906,0.029177,0.368818,1,0,670487,rcf_ndvi,False,True,0.006200
2,train,Sinazongwe,2020,0.589120,0.201157,-0.021142,0.205443,1,0,670487,rcf_ndvi,False,True,-0.040686
3,train,Kasempa,2020,3.349064,0.638396,0.077697,0.556458,1,0,670487,rcf_ndvi,False,True,0.005091
4,train,Shangombo,2020,0.741213,0.240852,0.071072,0.167303,1,0,670487,rcf_ndvi,False,True,-0.025521
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12595,test,Chadiza,2018,1.299279,0.361592,-0.142270,0.514470,6,9,107473,rcf,False,True,-0.009694
12596,test,Mkushi,2020,2.819023,0.581952,-0.022751,0.605807,6,9,107473,rcf,False,True,-0.004002
12597,test,Kitwe,2020,2.771743,0.576542,0.027490,0.562126,6,9,107473,rcf,False,True,0.020002
12598,test,Nchelenge,2019,1.604964,0.415802,-0.090286,0.512967,6,9,107473,rcf,False,True,-0.024283


In [51]:
oos_prediction_r2_groups = [
    "data_fold",
    "split",
    "random_state",
    "variables",
]

oos_prediction_R2 = oos_prediction.groupby(
    oos_prediction_r2_groups, as_index=False
).apply(get_group_R2)
oos_prediction_R2.rename(columns={None: "demean_R2"}, inplace=True)

oos_prediction_r2 = oos_prediction.groupby(
    oos_prediction_r2_groups, as_index=False
).apply(get_group_r2)
oos_prediction_r2.rename(columns={None: "demean_r"}, inplace=True)

oos_prediction_r2["demean_r2"] = oos_prediction_r2["demean_r"] ** 2

oos_prediction_summary = (
    oos_prediction_R2.set_index(oos_prediction_r2_groups)
    .join(oos_prediction_r2.set_index(oos_prediction_r2_groups))
    .reset_index()
)

oos_prediction_summary.sort_values(["variables", "data_fold"])

Unnamed: 0,data_fold,split,random_state,variables,demean_R2,demean_r,demean_r2
0,test,0,670487,rcf,0.493625,0.714968,0.51118
3,test,1,116739,rcf,0.543294,0.743477,0.552759
6,test,2,26225,rcf,0.512101,0.716613,0.513534
9,test,3,777572,rcf,0.356677,0.612213,0.374805
12,test,4,288389,rcf,0.311237,0.634666,0.402801
15,test,5,256787,rcf,0.429404,0.66093,0.436828
18,test,6,234053,rcf,0.311245,0.610583,0.372811
21,test,7,146316,rcf,0.207081,0.522681,0.273196
24,test,8,772246,rcf,0.576371,0.762522,0.581439
27,test,9,107473,rcf,0.48333,0.709921,0.503988


In [52]:
oos_prediction_summary_r2_groups = [
    "data_fold",
    "variables",
]
oos_prediction_summary = (
    oos_prediction_summary.groupby(oos_prediction_summary_r2_groups, as_index=False)
    .agg(
        {
            "demean_R2": "mean",
            "demean_r2": "mean",
            "demean_R2": "sem",
            "demean_r2": "sem",
        }
    )
    .sort_values("variables")
)
oos_prediction_summary

Unnamed: 0,data_fold,variables,demean_R2,demean_r2
0,test,rcf,0.038093,0.030505
3,train,rcf,0.015641,0.011399
1,test,rcf_ndvi,0.041248,0.032474
4,train,rcf_ndvi,0.017272,0.012821
2,test,rcf_ndvi_tmp,0.038115,0.031493
5,train,rcf_ndvi_tmp,0.017062,0.013212


In [53]:
# Pivot the DataFrame
oos_prediction_summary = oos_prediction_summary.pivot_table(
    index="variables", columns="data_fold"
)

# Flatten the MultiIndex columns
oos_prediction_summary.columns = [
    "_".join(col) for col in oos_prediction_summary.columns
]

# Reset the index
oos_prediction_summary = oos_prediction_summary.reset_index()
oos_prediction_summary = oos_prediction_summary.set_index(
    [
        "variables",
        "demean_R2_train",
        "demean_r2_train",
        "demean_R2_test",
        "demean_r2_test",
    ]
)
oos_prediction_summary = oos_prediction_summary.reset_index()
oos_prediction_summary.sort_values("demean_R2_train", ascending=False)

Unnamed: 0,variables,demean_R2_train,demean_r2_train,demean_R2_test,demean_r2_test
1,rcf_ndvi,0.017272,0.012821,0.041248,0.032474
2,rcf_ndvi_tmp,0.017062,0.013212,0.038115,0.031493
0,rcf,0.015641,0.011399,0.038093,0.030505


### Anomaly Model

In [63]:
anomaly_file_pattern = str(
    here(
        "data", "results", "2_sensor_top-mod_oos_predictions_10-splits_*_anom-True.csv"
    )
)
anomaly_files = glob.glob(pathname=anomaly_file_pattern)
sorted(anomaly_files)

['/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_top-mod_oos_predictions_10-splits_2023-07-05_rcf_climate-False_anom-True.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_top-mod_oos_predictions_10-splits_2023-07-06_rcf_climate-True_anom-True.csv']

In [64]:
anomaly_predictions = task_modeling_utils.merge_files(anomaly_files)
anomaly_predictions

Unnamed: 0,data_fold,district,year,yield_mt,log_yield,demean_log_yield,oos_prediction,val_fold,split,random_state,variables,anomaly,hot_encode,demean_oos_prediction
0,train,Kaoma,2019,-0.898481,-0.192782,-0.192782,-0.056889,1,0,670487,rcf,True,True,-0.067769
1,train,Kaoma,2017,0.118076,0.029177,0.029177,0.017457,1,0,670487,rcf,True,True,0.006577
2,train,Sinazongwe,2020,-0.125243,-0.021142,-0.021142,-0.031136,1,0,670487,rcf,True,True,-0.022356
3,train,Kasempa,2020,0.694315,0.077697,0.077697,0.072680,1,0,670487,rcf,True,True,0.070367
4,train,Shangombo,2020,0.239461,0.071072,0.071072,0.046946,1,0,670487,rcf,True,True,0.052757
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12595,test,Chadiza,2018,-0.939644,-0.142270,-0.142270,-0.065308,6,9,107473,rcf_ndvi_tmp,True,True,-0.070035
12596,test,Mkushi,2020,-0.230178,-0.022751,-0.022751,0.000990,6,9,107473,rcf_ndvi_tmp,True,True,0.008657
12597,test,Kitwe,2020,0.216271,0.027490,0.027490,0.067018,6,9,107473,rcf_ndvi_tmp,True,True,0.056572
12598,test,Nchelenge,2019,-0.618532,-0.090286,-0.090286,0.009173,6,9,107473,rcf_ndvi_tmp,True,True,0.018344


In [59]:
r2_groups = [
    "data_fold",
    "split",
    "random_state",
    "variables",
]

grouped_R2 = anomaly_predictions.groupby(r2_groups, as_index=False).apply(
    get_group_R2, "log_yield", "oos_prediction"
)
grouped_R2.rename(columns={None: "demean_R2"}, inplace=True)

grouped_r2 = anomaly_predictions.groupby(r2_groups, as_index=False).apply(
    get_group_r2, "log_yield", "oos_prediction"
)
grouped_r2.rename(columns={None: "demean_r"}, inplace=True)

grouped_r2["demean_r2"] = oos_prediction_r2["demean_r"] ** 2

oos_prediction_summary = (
    grouped_R2.set_index(r2_groups).join(grouped_r2.set_index(r2_groups)).reset_index()
)

# oos_prediction_summary.sort_values(["variables", "data_fold"])

Unnamed: 0,data_fold,split,random_state,variables,demean_R2,demean_r,demean_r2
0,test,0,670487,rcf,0.454223,0.73888,0.51118
3,test,1,116739,rcf,0.646691,0.814142,0.552759
6,test,2,26225,rcf,0.699289,0.8364,0.513534
9,test,3,777572,rcf,0.65321,0.820834,0.374805
12,test,4,288389,rcf,0.529977,0.74721,0.402801
15,test,5,256787,rcf,0.555285,0.748132,0.436828
18,test,6,234053,rcf,0.514385,0.747043,0.372811
21,test,7,146316,rcf,0.381815,0.642125,0.273196
24,test,8,772246,rcf,0.68844,0.830507,0.581439
27,test,9,107473,rcf,0.708049,0.843981,0.503988


In [62]:
demean_summary_r2_groups = [
    "data_fold",
    "variables",
]
oos_prediction_summary.groupby(demean_summary_r2_groups, as_index=False).agg(
    {
        "demean_R2": "mean",
        "demean_r2": "mean",
    }
).sort_values("variables")

Unnamed: 0,data_fold,variables,demean_R2,demean_r2
0,test,rcf,0.583136,0.452334
3,train,rcf,0.537469,0.233819
1,test,rcf_ndvi,0.585674,0.478338
4,train,rcf_ndvi,0.558335,0.265298
2,test,rcf_ndvi_tmp,0.581847,0.489834
5,train,rcf_ndvi_tmp,0.567,0.303363


## Anomaly Model

### Model Selection

In [25]:
model_selection_file_pattern = str(
    here("data", "results", "2_sensor_10-splits_*_anom-True.csv")
)
model_selection_files = glob.glob(pathname=model_selection_file_pattern)
sorted(model_selection_files)

['/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-06-28_1_rcf_climate-False_anom-True.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-06-28_2_rcf_climate-False_anom-True.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-06-28_3_rcf_climate-False_anom-True.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-06-28_4_rcf_climate-False_anom-True.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-06-28_5_rcf_climate-False_anom-True.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-06-28_6_rcf_climate-False_anom-True.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_10-splits_2023-06-29_10_rcf_climate-False_anom-True.csv',
 '/home/cmolitor/crop-mode

In [26]:
model_selection_results = task_modeling_utils.merge_files(model_selection_files)
model_selection_results = model_selection_results.dropna(axis=1, how="all")
model_selection_results

Unnamed: 0,split,random_state,variables,anomaly,country,year_range,satellite_1,bands_1,num_features_1,points_1,...,mean_of_val_R2,val_R2,val_r,val_r2,train_R2,train_r,train_r2,test_R2,test_r,test_r2
0,0,670487,rcf,True,ZMB,2014-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,"[[0.20962120884393354, 0.32375418573679704]]",0.350928,0.615195,0.378465,0.916672,0.964241,0.929761,0.314842,0.582788,0.339641
1,1,116739,rcf,True,ZMB,2014-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,"[[0.14964826236827716, 0.28861209159165735]]",0.341339,0.610577,0.372804,0.915503,0.963208,0.927770,0.281188,0.609699,0.371733
2,2,26225,rcf,True,ZMB,2014-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,"[[0.1642373513973112, 0.19659132452419187]]",0.228736,0.533863,0.285010,0.937866,0.974249,0.949161,0.556001,0.771852,0.595755
3,3,777572,rcf,True,ZMB,2014-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,"[[0.28997967568245275, 0.3625419520053873]]",0.380669,0.621001,0.385642,0.804948,0.914029,0.835449,0.309338,0.560592,0.314263
4,4,288389,rcf,True,ZMB,2014-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,"[[0.3376818293112775, 0.4178160858272063]]",0.454525,0.680078,0.462506,0.943372,0.976125,0.952821,0.177294,0.589068,0.347001
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
895,5,256787,rcf,True,ZMB,2016-2022,sentinel-2-l2a,2-3-4,1000,4,...,"[[0.2663388308746587, 0.4101177035726672]]",0.405602,0.642290,0.412536,0.740603,0.871393,0.759326,0.431489,0.661219,0.437210
896,6,234053,rcf,True,ZMB,2016-2022,sentinel-2-l2a,2-3-4,1000,4,...,"[[0.4313176357145322, 0.4613189019065264]]",0.473084,0.687932,0.473250,0.792304,0.901272,0.812291,0.453518,0.679583,0.461833
897,7,146316,rcf,True,ZMB,2016-2022,sentinel-2-l2a,2-3-4,1000,4,...,"[[0.3888074388033191, 0.4495030964331378]]",0.457212,0.676577,0.457756,0.763701,0.886123,0.785213,0.401973,0.655530,0.429719
898,8,772246,rcf,True,ZMB,2016-2022,sentinel-2-l2a,2-3-4,1000,4,...,"[[0.3775894990531866, 0.4570836198289194]]",0.468138,0.685936,0.470508,0.744719,0.874096,0.764044,0.370545,0.612061,0.374618


In [27]:
model_selection_group_cols = [
    "variables",
    "anomaly",
    "country",
    "year_range",
    "satellite_1",
    "bands_1",
    "num_features_1",
    "points_1",
    "month_range_1",
    "limit_months_1",
    "crop_mask_1",
    "weighted_avg_1",
    "satellite_2",
    "bands_2",
    "num_features_2",
    "points_2",
    "month_range_2",
    "limit_months_2",
    "crop_mask_2",
    "weighted_avg_2",
]

model_selection_results_summary = model_selection_results.groupby(
    model_selection_group_cols, as_index=False
).agg(
    {
        "val_R2": "mean",
        "val_r2": "mean",
        "test_R2": "mean",
        "test_r2": "mean",
    }
)
model_selection_results_summary = model_selection_results_summary.sort_values(
    "val_R2", ascending=False
)
model_selection_results_summary

Unnamed: 0,variables,anomaly,country,year_range,satellite_1,bands_1,num_features_1,points_1,month_range_1,limit_months_1,...,num_features_2,points_2,month_range_2,limit_months_2,crop_mask_2,weighted_avg_2,val_R2,val_r2,test_R2,test_r2
14,rcf,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,1-12,False,...,1000,4,1-12,False,True,False,0.625838,0.634117,0.694785,0.720626
46,rcf,True,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,1-12,False,...,1000,4,1-12,False,True,False,0.587637,0.601613,0.599024,0.632362
30,rcf,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,20,1-12,False,...,1000,4,1-12,False,True,False,0.577214,0.591571,0.573207,0.657712
54,rcf,True,ZMB,2016-2021,landsat-c2-l2,r-g-b-nir-swir16-swir22,1024,20,4-9,True,...,1000,4,1-12,False,True,False,0.575845,0.579485,0.576837,0.593267
38,rcf,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,20,4-9,True,...,1000,4,1-12,False,True,False,0.573925,0.580842,0.536521,0.584403
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69,rcf,True,ZMB,2016-2022,sentinel-2-l2a,2-3-4,1000,15,4-9,True,...,1000,4,4-9,True,True,False,0.322453,0.351643,0.204375,0.333220
10,rcf,True,ZMB,2014-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,4-9,True,...,1000,20,1-12,False,True,False,0.307883,0.354537,0.349738,0.425035
5,rcf,True,ZMB,2014-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,1-12,False,...,1000,20,1-12,False,True,False,0.294726,0.353522,0.140393,0.303955
13,rcf,True,ZMB,2014-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,20,4-9,True,...,1000,20,1-12,False,True,False,0.280281,0.297467,0.080351,0.224114


In [28]:
top_model = model_selection_results_summary.iloc[0:1, :]
top_model_dict = top_model.to_dict(orient="records")[0]
top_model_dict

{'variables': 'rcf',
 'anomaly': True,
 'country': 'ZMB',
 'year_range': '2016-2021',
 'satellite_1': 'landsat-8-c2-l2',
 'bands_1': '1-2-3-4-5-6-7',
 'num_features_1': 1000,
 'points_1': 15,
 'month_range_1': '1-12',
 'limit_months_1': False,
 'crop_mask_1': True,
 'weighted_avg_1': False,
 'satellite_2': 'sentinel-2-l2a',
 'bands_2': '2-3-4',
 'num_features_2': 1000,
 'points_2': 4,
 'month_range_2': '1-12',
 'limit_months_2': False,
 'crop_mask_2': True,
 'weighted_avg_2': False,
 'val_R2': 0.6258380146218993,
 'val_r2': 0.6341167754625727,
 'test_R2': 0.694785309391216,
 'test_r2': 0.7206264487364556}

In [29]:
f1 = "landsat-8-c2-l2_bands-1-2-3-4-5-6-7_ZMB_15k-points_1000-features_yr-2014-2021_mn-1-12_lm-False_cm-True_wa-False_summary.feather"
f2 = "sentinel-2-l2a_bands-2-3-4_ZMB_4k-points_1000-features_yr-2016-2022_mn-1-12_lm-False_cm-True_wa-False_summary.feather"

### Top Over Time Model Splits

In [54]:
top_ot_model_file_pattern = str(
    here("data", "results", "2_sensor_top-ot-mod_10-splits_*_*.csv")
)
top_ot_model_files = glob.glob(pathname=top_ot_model_file_pattern)
sorted(top_ot_model_files)

['/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_top-ot-mod_10-splits_2023-06-30_rcf_climate-False_anom-True.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_top-ot-mod_10-splits_2023-06-30_rcf_climate-True_anom-True.csv']

In [62]:
top_ot_model_results = task_modeling_utils.merge_files(top_ot_model_files)

top_ot_model_results["variables"] = top_ot_model_results["variables"].apply(
    lambda x: "rcf_" + x if "rcf" not in x else x
)
top_ot_model_results.head(10)

Unnamed: 0,split,random_state,variables,anomaly,country,year_range,satellite_1,bands_1,num_features_1,points_1,...,train_r2,test_R2,test_r,test_r2,demean_cv_R2,demean_cv_r,demean_cv_r2,demean_test_R2,demean_test_r,demean_test_r2
0,0,670487,rcf_ndvi,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,0.996305,0.389761,0.737305,0.543618,,,,,,
1,1,116739,rcf_ndvi,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,0.9958,0.696153,0.835465,0.698001,,,,,,
2,2,26225,rcf_ndvi,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,0.991247,0.59051,0.801708,0.642736,,,,,,
3,3,777572,rcf_ndvi,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,0.993892,0.819431,0.905369,0.819694,,,,,,
4,4,288389,rcf_ndvi,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,0.99505,0.755194,0.870068,0.757018,,,,,,
5,5,256787,rcf_ndvi,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,0.995116,0.756721,0.873003,0.762134,,,,,,
6,6,234053,rcf_ndvi,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,0.995074,0.748927,0.885858,0.784745,,,,,,
7,7,146316,rcf_ndvi,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,0.995734,0.688746,0.835809,0.698576,,,,,,
8,8,772246,rcf_ndvi,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,0.999204,0.701865,0.840842,0.707015,,,,,,
9,9,107473,rcf_ndvi,True,ZMB,2016-2021,landsat-8-c2-l2,1-2-3-4-5-6-7,1000,15,...,0.998995,0.841393,0.918779,0.844155,,,,,,


In [63]:
top_model_group_cols = [
    "variables",
    "anomaly",
]

top_ot_model_results_summary = top_ot_model_results.groupby(
    top_model_group_cols, as_index=False
).agg(
    {
        "val_R2": "mean",
        "val_r2": "mean",
        "test_R2": "mean",
        "test_r2": "mean",
        "demean_cv_R2": "mean",
        "demean_cv_r2": "mean",
        "demean_test_R2": "mean",
        "demean_test_r2": "mean",
        # "split": "count",
    }
)
top_ot_model_results_summary = top_ot_model_results_summary.sort_values(
    "val_R2", ascending=False
)
top_ot_model_results_summary

Unnamed: 0,variables,anomaly,val_R2,val_r2,test_R2,test_r2,demean_cv_R2,demean_cv_r2,demean_test_R2,demean_test_r2
2,rcf_ndvi_tmp,True,0.649653,0.657136,0.702917,0.730501,,,,
1,rcf_ndvi,True,0.636982,0.645224,0.69887,0.725769,,,,
0,rcf,True,0.625838,0.634117,0.694785,0.720626,,,,


### Out of Sample Predictions

In [52]:
oos_ot_prediction_file_pattern = str(
    here(
        "data",
        "results",
        "2_sensor_top-ot-mod_oos_predictions_10-splits_*_anom-True.csv",
    )
)
oos_ot_prediction_files = glob.glob(pathname=oos_ot_prediction_file_pattern)
sorted(oos_ot_prediction_files)

['/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_top-ot-mod_oos_predictions_10-splits_2023-06-30_rcf_climate-False_anom-True.csv',
 '/home/cmolitor/crop-modeling/code/4_explore_results/../../data/results/2_sensor_top-ot-mod_oos_predictions_10-splits_2023-06-30_rcf_climate-True_anom-True.csv']

In [53]:
oos_prediction = task_modeling_utils.merge_files(oos_prediction_files)
oos_prediction

Unnamed: 0,data_fold,district,year,yield_mt,log_yield,demean_log_yield,oos_prediction,val_fold,split,random_state,variables,anomaly,hot_encode,demean_oos_prediction
0,train,Kaoma,2019,0.523864,0.182946,-0.192782,0.279597,1,0,670487,rcf_ndvi,False,True,-0.083021
1,train,Kaoma,2017,1.540420,0.404906,0.029177,0.368818,1,0,670487,rcf_ndvi,False,True,0.006200
2,train,Sinazongwe,2020,0.589120,0.201157,-0.021142,0.205443,1,0,670487,rcf_ndvi,False,True,-0.040686
3,train,Kasempa,2020,3.349064,0.638396,0.077697,0.556458,1,0,670487,rcf_ndvi,False,True,0.005091
4,train,Shangombo,2020,0.741213,0.240852,0.071072,0.167303,1,0,670487,rcf_ndvi,False,True,-0.025521
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12595,test,Chadiza,2018,1.299279,0.361592,-0.142270,0.514470,6,9,107473,rcf,False,True,-0.009694
12596,test,Mkushi,2020,2.819023,0.581952,-0.022751,0.605807,6,9,107473,rcf,False,True,-0.004002
12597,test,Kitwe,2020,2.771743,0.576542,0.027490,0.562126,6,9,107473,rcf,False,True,0.020002
12598,test,Nchelenge,2019,1.604964,0.415802,-0.090286,0.512967,6,9,107473,rcf,False,True,-0.024283


## Benchmark - NDVI and Climate

### Maize Yield Levels

In [30]:
climate = pd.read_csv(here("data", "results", "climate_model_10-splits_2023-07-05.csv"))
climate

Unnamed: 0,split,random_state,variables,year_start,hot_encode,anomaly,total_n,train_n,test_n,best_reg_param,...,train_r2,test_R2,test_r,test_r2,demean_cv_R2,demean_cv_r,demean_cv_r2,demean_test_R2,demean_test_r,demean_test_r2
0,2,26225,ndvi,2016,False,False,432,345,87,[1.0],...,0.298571,0.280308,0.542938,0.294781,0.157889,0.416331,0.173331,0.111030,0.348220,0.121257
1,1,116739,pre,2016,False,True,432,345,87,[0.1],...,0.170344,0.102114,0.348568,0.121499,,,,,,
2,3,777572,pre,2016,False,False,432,345,87,[0.1],...,0.398278,0.324432,0.575593,0.331307,-0.281685,-0.151317,0.022897,-0.380065,-0.219479,0.048171
3,0,670487,ndvi,2016,False,True,432,345,87,[0.01],...,0.455168,0.570649,0.756902,0.572900,,,,,,
4,7,146316,pre,2016,False,False,432,345,87,[0.01],...,0.394274,0.368904,0.609873,0.371945,-0.354622,-0.085783,0.007359,-0.536602,-0.147994,0.021902
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
205,9,107473,pre_tmp_ndvi,2016,False,True,432,345,87,"[1000000000.0, 0.1, 0.1]",...,0.553652,0.632169,0.804545,0.647292,,,,,,
206,9,107473,pre_tmp_ndvi,2016,True,False,432,345,87,"[0.01, 0.01, 0.1, 0.01]",...,0.902052,0.848890,0.921905,0.849908,0.059943,0.409038,0.167312,0.565475,0.763606,0.583093
207,8,772246,pre_tmp_ndvi,2016,True,False,432,345,87,"[0.1, 0.01, 0.01, 0.01]",...,0.914142,0.839197,0.919882,0.846182,0.213436,0.510816,0.260933,0.468034,0.686457,0.471223
208,5,256787,pre_tmp_ndvi,2016,True,False,432,345,87,"[0.01, 0.01, 0.001, 0.01]",...,0.914419,0.866718,0.931686,0.868038,0.255763,0.550629,0.303192,0.221413,0.530530,0.281462


In [31]:
climate_model_selection_group_cols = [
    "variables",
    "year_start",
    "hot_encode",
    "anomaly",
]

climate_results_summary = climate.groupby(
    climate_model_selection_group_cols, as_index=False
).agg(
    {
        "val_R2": "mean",
        "val_r2": "mean",
        "test_R2": "mean",
        "test_r2": "mean",
    }
)
climate_results_summary = climate_results_summary.sort_values("val_R2", ascending=False)

climate_results_summary

Unnamed: 0,variables,year_start,hot_encode,anomaly,val_R2,val_r2,test_R2,test_r2
20,tmp_ndvi,2016,True,False,0.814797,0.815993,0.833366,0.838976
14,pre_tmp_ndvi,2016,True,False,0.814602,0.815777,0.831254,0.837263
8,pre_ndvi,2016,True,False,0.791105,0.793264,0.814343,0.823102
2,ndvi,2016,True,False,0.78458,0.786098,0.803736,0.813898
11,pre_tmp,2016,True,False,0.783405,0.783853,0.803051,0.808492
17,tmp,2016,True,False,0.78128,0.782786,0.807925,0.814704
12,pre_tmp_ndvi,2016,False,False,0.708487,0.708831,0.708322,0.714968
9,pre_tmp,2016,False,False,0.686328,0.686473,0.683861,0.689764
18,tmp_ndvi,2016,False,False,0.666576,0.667014,0.673387,0.679776
5,pre,2016,True,False,0.664567,0.671982,0.700099,0.717393


### Demeaned Predictions

In [32]:
oos_climate_preds = pd.read_csv(
    here("data", "results", "climate_model_oos_predictions_10-splits_2023-07-05.csv")
)
oos_climate_preds

Unnamed: 0,data_fold,year,district,yield_mt,log_yield,demean_log_yield,oos_prediction,val_fold,split,random_state,variables,anomaly,hot_encode,year_start,demean_oos_prediction
0,train,2016,Luangwa,0.060554,0.025533,-0.078689,0.447695,1,2,26225,ndvi,False,False,2016,-0.033873
1,train,2016,Chilubi,2.241493,0.510745,-0.046788,0.439148,1,2,26225,ndvi,False,False,2016,0.011935
2,train,2017,Sinazongwe,1.041549,0.309960,0.087661,0.422305,1,2,26225,ndvi,False,False,2016,0.042612
3,train,2021,Kalulushi,3.288288,0.632284,0.045053,0.507982,1,2,26225,ndvi,False,False,2016,-0.004902
4,train,2016,Nakonde,3.007245,0.602846,0.048458,0.462212,1,2,26225,ndvi,False,False,2016,-0.000142
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90715,test,2018,Isoka,3.400556,0.643508,0.002486,0.614625,6,6,234053,pre_tmp_ndvi,False,True,2016,-0.025282
90716,test,2017,Serenje,2.434624,0.535879,-0.040338,0.628334,6,6,234053,pre_tmp_ndvi,False,True,2016,0.035182
90717,test,2016,Mongu,0.923351,0.284059,0.091288,0.194624,6,6,234053,pre_tmp_ndvi,False,True,2016,0.026314
90718,test,2017,Gwembe,1.904668,0.463097,0.140944,0.436605,6,6,234053,pre_tmp_ndvi,False,True,2016,0.104396


In [33]:
oos_prediction_demean_groups = [
    # "data_fold",
    "district",
    # "val_fold",
    "split",
    "random_state",
    "variables",
    "anomaly",
    "hot_encode",
    "year_start",
]
oos_prediction = task_modeling_utils.demean_by_group(
    df=oos_climate_preds,
    observed="log_yield",
    predicted="oos_prediction",
    group=oos_prediction_demean_groups,
)
oos_prediction

Unnamed: 0,data_fold,year,district,yield_mt,log_yield,demean_log_yield,oos_prediction,val_fold,split,random_state,variables,anomaly,hot_encode,year_start,demean_oos_prediction
0,train,2016,Luangwa,0.060554,0.025533,-0.078689,0.447695,1,2,26225,ndvi,False,False,2016,-0.033873
1,train,2016,Chilubi,2.241493,0.510745,-0.046788,0.439148,1,2,26225,ndvi,False,False,2016,0.011935
2,train,2017,Sinazongwe,1.041549,0.309960,0.087661,0.422305,1,2,26225,ndvi,False,False,2016,0.042612
3,train,2021,Kalulushi,3.288288,0.632284,0.045053,0.507982,1,2,26225,ndvi,False,False,2016,-0.004902
4,train,2016,Nakonde,3.007245,0.602846,0.048458,0.462212,1,2,26225,ndvi,False,False,2016,-0.000142
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90715,test,2018,Isoka,3.400556,0.643508,0.002486,0.614625,6,6,234053,pre_tmp_ndvi,False,True,2016,-0.025282
90716,test,2017,Serenje,2.434624,0.535879,-0.040338,0.628334,6,6,234053,pre_tmp_ndvi,False,True,2016,0.035182
90717,test,2016,Mongu,0.923351,0.284059,0.091288,0.194624,6,6,234053,pre_tmp_ndvi,False,True,2016,0.026314
90718,test,2017,Gwembe,1.904668,0.463097,0.140944,0.436605,6,6,234053,pre_tmp_ndvi,False,True,2016,0.104396


In [34]:
oos_prediction_r2_groups = [
    "data_fold",
    "split",
    "random_state",
    "variables",
    "anomaly",
    "hot_encode",
]

oos_prediction_R2 = oos_prediction.groupby(
    oos_prediction_r2_groups, as_index=False
).apply(get_group_R2)
oos_prediction_R2.rename(columns={None: "demean_R2"}, inplace=True)

oos_prediction_r2 = oos_prediction.groupby(
    oos_prediction_r2_groups, as_index=False
).apply(get_group_r2)
oos_prediction_r2.rename(columns={None: "demean_r"}, inplace=True)

oos_prediction_r2["demean_r2"] = oos_prediction_r2["demean_r"] ** 2

oos_prediction_summary = (
    oos_prediction_R2.set_index(oos_prediction_r2_groups)
    .join(oos_prediction_r2.set_index(oos_prediction_r2_groups))
    .reset_index()
)

oos_prediction_summary.sort_values(["variables", "data_fold"])

Unnamed: 0,data_fold,split,random_state,variables,anomaly,hot_encode,demean_R2,demean_r,demean_r2
0,test,0,670487,ndvi,False,False,0.240765,0.581621,0.338284
1,test,0,670487,ndvi,False,True,0.519794,0.727101,0.528676
2,test,0,670487,ndvi,True,False,0.562309,0.751157,0.564237
21,test,1,116739,ndvi,False,False,0.185826,0.461670,0.213139
22,test,1,116739,ndvi,False,True,0.383004,0.619172,0.383374
...,...,...,...,...,...,...,...,...,...
397,train,8,772246,tmp_ndvi,False,True,0.235446,0.524526,0.275127
398,train,8,772246,tmp_ndvi,True,False,0.533611,0.730545,0.533696
417,train,9,107473,tmp_ndvi,False,False,0.262945,0.520624,0.271049
418,train,9,107473,tmp_ndvi,False,True,0.103748,0.421087,0.177314


In [35]:
oos_prediction_summary_r2_groups = [
    "data_fold",
    "anomaly",
    "hot_encode",
    "variables",
]
oos_prediction_summary = (
    oos_prediction_summary.groupby(oos_prediction_summary_r2_groups, as_index=False)
    .agg(
        {
            "demean_R2": "mean",
            "demean_r2": "mean",
            # "demean_R2": "sem",
            # "demean_r2": "sem",
        }
    )
    .sort_values("variables")
)
# oos_prediction_summary  # .sort_values(["variables", "data_fold"])

In [36]:
# Pivot the DataFrame
oos_prediction_summary = oos_prediction_summary.pivot_table(
    index=[
        "anomaly",
        "hot_encode",
        "variables",
    ],
    columns="data_fold",
)

# Flatten the MultiIndex columns
oos_prediction_summary.columns = [
    "_".join(col) for col in oos_prediction_summary.columns
]

# Reset the index
oos_prediction_summary = oos_prediction_summary.reset_index()
oos_prediction_summary = oos_prediction_summary.set_index(
    [
        "variables",
        "anomaly",
        "hot_encode",
        "demean_R2_train",
        "demean_r2_train",
        "demean_R2_test",
        "demean_r2_test",
    ]
)
oos_prediction_summary = oos_prediction_summary.reset_index()
oos_prediction_summary.sort_values("demean_R2_train", ascending=False)

Unnamed: 0,variables,anomaly,hot_encode,demean_R2_train,demean_r2_train,demean_R2_test,demean_r2_test
20,tmp_ndvi,True,False,0.530848,0.531399,0.485076,0.500791
18,pre_tmp_ndvi,True,False,0.529823,0.530509,0.476054,0.492671
16,pre_ndvi,True,False,0.462839,0.464957,0.400379,0.422695
14,ndvi,True,False,0.442356,0.443799,0.38617,0.406944
17,pre_tmp,True,False,0.427708,0.428512,0.35378,0.397824
19,tmp,True,False,0.419594,0.420349,0.365583,0.40901
6,tmp_ndvi,False,False,0.345402,0.358642,0.303817,0.351196
4,pre_tmp_ndvi,False,False,0.312575,0.349396,0.25119,0.328011
5,tmp,False,False,0.292189,0.300871,0.247677,0.305017
3,pre_tmp,False,False,0.278274,0.311435,0.239848,0.312614


In [37]:
oos_prediction_summary.loc[
    (oos_prediction_summary["anomaly"] == False)
    & (oos_prediction_summary["hot_encode"] == True)
].sort_values("demean_R2_train", ascending=False)

Unnamed: 0,variables,anomaly,hot_encode,demean_R2_train,demean_r2_train,demean_R2_test,demean_r2_test
11,pre_tmp_ndvi,False,True,0.192646,0.255906,0.411838,0.434075
13,tmp_ndvi,False,True,0.191583,0.253474,0.413384,0.437167
9,pre_ndvi,False,True,0.097641,0.186494,0.360786,0.383572
7,ndvi,False,True,0.069832,0.155312,0.330809,0.354433
10,pre_tmp,False,True,0.058785,0.155125,0.293347,0.337877
12,tmp,False,True,0.046382,0.150789,0.313956,0.356027
8,pre,False,True,-0.47187,0.051367,0.006051,0.079982
