# Benchmarks Sandbox

**Author**: Ivan Zvonkov

**Last Modified**: Feb 6, 2024

**Description**: Code for benchmarking against different variations in models.

In [1]:
import sys
import pandas as pd

from tqdm.notebook import tqdm

sys.path.append("..")

from datasets import datasets
from src.bboxes import bboxes

## 1. Load Datasets

In [2]:
# Takes a minute and a half
dfs = []
for d in tqdm(datasets):
    df = d.load_df(to_np=True, disable_tqdm=True)
    df["name"] = d.name
    dfs.append(df)
df = pd.concat(dfs)
df["is_crop"] = df["class_probability"] > 0.5

  0%|          | 0/45 [00:00<?, ?it/s]

  df = d.load_df(to_np=True, disable_tqdm=True)


## 2. Compute stats for candidate test sets

In [3]:
# Compute stats for each candidate test set
test_set_names = df[df["subset"] == "testing"]["name"].unique()
stats_list = []
for name in test_set_names:        
    test_df = df[(df["subset"] == "testing") & (df["name"] == name)]
    stats = {
        "Name": name , 
        "Total": len(test_df), 
        "Crop Amount": test_df["is_crop"].sum()
    }
    stats["Crop Rate"] = round(stats["Crop Amount"] / stats["Total"], 3)
    stats_list.append(stats)
stats = pd.DataFrame(stats_list)
stats

Unnamed: 0,Name,Total,Crop Amount,Crop Rate
0,Kenya,822,566,0.689
1,Mali_lower_CEO_2019,271,94,0.347
2,Mali_upper_CEO_2019,323,17,0.053
3,Togo,310,107,0.345
4,Rwanda,555,191,0.344
5,Uganda,456,52,0.114
6,Ethiopia_Tigray_2020,507,139,0.274
7,Ethiopia_Tigray_2021,367,120,0.327
8,Ethiopia_Bure_Jimma_2019,498,161,0.323
9,Ethiopia_Bure_Jimma_2020,455,129,0.284


## 3. Select test sets for benchmark

In [4]:
# Test set candidates to be filtered out
not_representative = ["Kenya"]
too_few_crops = ["Mali_upper_CEO_2019", "Zambia_CEO_2019", "Namibia_CEO_2020", "Hawaii_CEO_2020",
                  "KenyaCEO2019", "MaliStratifiedCEO2019", "NamibiaNorthStratified2020"]

In [5]:
# Finalized benchmark test sets
benchmark = stats[~stats["Name"].isin(not_representative + too_few_crops)].copy()
benchmark

Unnamed: 0,Name,Total,Crop Amount,Crop Rate
1,Mali_lower_CEO_2019,271,94,0.347
3,Togo,310,107,0.345
4,Rwanda,555,191,0.344
5,Uganda,456,52,0.114
6,Ethiopia_Tigray_2020,507,139,0.274
7,Ethiopia_Tigray_2021,367,120,0.327
8,Ethiopia_Bure_Jimma_2019,498,161,0.323
9,Ethiopia_Bure_Jimma_2020,455,129,0.284
10,Malawi_CEO_2020,457,67,0.147
12,Tanzania_CEO_2019,2037,626,0.307


In [6]:
# Associated bbox name
dataset_name_to_bbox_name = {
    "Mali_lower_CEO_2019": "Mali_lower",
    "Togo": "Togo",
    "Rwanda": "Rwanda",
    "Uganda": "Uganda",
    "Ethiopia_Tigray_2020": "Ethiopia_Tigray",
    "Ethiopia_Tigray_2021": "Ethiopia_Tigray",
    "Ethiopia_Bure_Jimma_2019": "Ethiopia_Bure_Jimma",
    "Ethiopia_Bure_Jimma_2020": "Ethiopia_Bure_Jimma",
    "Malawi_CEO_2020": "Malawi",
    "Tanzania_CEO_2019": "Tanzania",
    "Sudan_Blue_Nile_CEO_2019": "Sudan_Blue_Nile",
    "SudanBlueNileCEO2020": "Sudan_Blue_Nile",
    "Senegal_CEO_2022": "Senegal",
    "SudanAlGadarefCEO2019": "Sudan_Al_Gadaref",
    "SudanAlGadarefCEO2020": "Sudan_Al_Gadaref",
    "SudanGedarefDarfurAlJazirah2022": "Sudan_South",
    "Uganda_NorthCEO2022": "Uganda"
}

In [7]:
amounts = []

for dataset_name, bbox_name in tqdm(dataset_name_to_bbox_name.items()):
    is_local_lat = (df.lat >= bboxes[bbox_name].min_lat) & (df.lat <= bboxes[bbox_name].max_lat)
    is_local_lon = (df.lon >= bboxes[bbox_name].min_lon) & (df.lon <= bboxes[bbox_name].max_lon)
    is_test = (df["name"] == dataset_name) & (df["subset"] == "testing")
    test_df = df[is_test]
    is_val = (df["name"] == dataset_name) & (df["subset"] == "validation")    
    train_df = df[~is_test & ~is_val & is_local_lat & is_local_lon]
    
    amounts.append({
        "Name": dataset_name,
        "Validation": is_val.sum(),
        "Training": len(train_df)
    })
    
pd.DataFrame(amounts)

  0%|          | 0/17 [00:00<?, ?it/s]

Unnamed: 0,Name,Validation,Training
0,Mali_lower_CEO_2019,275,1096
1,Togo,268,1062
2,Rwanda,520,676
3,Uganda,454,19943
4,Ethiopia_Tigray_2020,520,4916
5,Ethiopia_Tigray_2021,351,5225
6,Ethiopia_Bure_Jimma_2019,488,1597
7,Ethiopia_Bure_Jimma_2020,483,1645
8,Malawi_CEO_2020,490,6934
9,Tanzania_CEO_2019,2044,24035


## Baselines Playground
Code to run experiments including: 1) withholding bands, 2) different bbox sizes, 3) catboost

In [8]:
!pip install catboost -q

In [8]:
from openmapflow.bands import BANDS
from openmapflow.engineer import calculate_ndvi

from catboost import CatBoostClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score
from src.bboxes import bboxes

import numpy as np

In [9]:
# Recompute NDVI
df["eo_data"] = df["eo_data"].apply(lambda x: calculate_ndvi(x[:, : len(BANDS) - 1]))

In [10]:
ADD_BY = np.zeros(18)
ADD_BY[0:2] = [25.0, 25.0]     # Sentinel-1 VV, VH (range from -50 to 1)
ADD_BY[13] = -272.15           # ERA5 Celcius

DIVIDE_BY = np.ones(18)
DIVIDE_BY[0:2] = [25.0, 25.0]   # Sentinel-1 VV, VH (range from -50 to 1)
DIVIDE_BY[2:13] = [10000.0] * 11 # Sentinel-2 high band values
DIVIDE_BY[13] = 35.0            # ERA5 high celcius value
DIVIDE_BY[14] = 0.03            # ERA5 high precipitation value
DIVIDE_BY[15] = 2000.0          # SRTM elevation high value
DIVIDE_BY[16] = 50.0            # Slope high value

def normalize(x):
    keep_indices = [idx for idx, val in enumerate(BANDS) if val != "B9"] # remove the b9 band
    normalized = ((x + ADD_BY) / DIVIDE_BY).astype(np.float32)
    return normalized[:, keep_indices]

In [11]:
start_month = 2
end_month = start_month + 12

def generate_X_y(df):    
    X = df["eo_data"].apply(lambda x: x[start_month:end_month].flatten()).to_list()
    y = df["is_crop"].astype(int).to_list()
    return X, y

def generate_X_y_NDVI(df):    
    X = df["eo_data"].apply(lambda x: x[start_month:end_month, -1]).to_list()
    y = df["is_crop"].to_list()
    return X, y

def generate_X_y_S2(df):    
    X = df["eo_data"].apply(lambda x: x[start_month:end_month, 2:13].flatten()).to_list()
    y = df["is_crop"].to_list()
    return X, y

def generate_X_y_S1_S2(df):    
    X = df["eo_data"].apply(lambda x: x[start_month:end_month, :13].flatten()).to_list()
    y = df["is_crop"].to_list()
    return X, y

def generate_X_y_S1_S2_SRTM(df):    
    X = df["eo_data"].apply(lambda x: x[start_month:end_month, np.r_[:13, 15:17]].flatten()).to_list()
    y = df["is_crop"].to_list()
    return X, y

def generate_X_y_S1_S2_SRTM_ERA5(df):    
    X = df["eo_data"].apply(lambda x: x[start_month:end_month, :17].flatten()).to_list()
    y = df["is_crop"].to_list()
    return X, y

def generate_X_y_normalized(df):    
    X = df["eo_data"].apply(lambda x: normalize(x[start_month:end_month]).flatten()).to_list()
    y = df["is_crop"].astype(int).to_list()
    return X, y

In [23]:
f1_scores = {}
for dataset_name, bbox_name in tqdm(dataset_name_to_bbox_name.items()):
    buf = 0
    is_local_lat = (df.lat >= (bboxes[bbox_name].min_lat - buf)) & (df.lat <= (bboxes[bbox_name].max_lat + buf))
    is_local_lon = (df.lon >= (bboxes[bbox_name].min_lon - buf)) & (df.lon <= (bboxes[bbox_name].max_lon + buf))
    
    is_test = (df["name"] == dataset_name) & (df["subset"] == "testing")
    test_df = df[is_test]
    
    # is_val = (df["name"] == dataset_name) & (df["subset"] == "validation")
    
    train_df = df[~is_test & is_local_lat & is_local_lon]
    
    X_train, y_train = generate_X_y(train_df)
    X_test, y_test = generate_X_y(test_df)
    
    model = RandomForestClassifier(random_state=0) # CatBoostClassifier(random_state=0)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    f1_scores[dataset_name] = f1_score(y_test, y_pred)

  0%|          | 0/17 [00:00<?, ?it/s]

In [24]:
for dataset, f1 in f1_scores.items():
    benchmark.loc[benchmark["Name"] == dataset, "RF Feb-Feb F1 Score"] = f1

In [25]:
benchmark

Unnamed: 0,Name,Total,Crop Amount,Crop Rate,RF Feb-Feb F1 Score
1,Mali_lower_CEO_2019,271,94,0.347,0.613757
3,Togo,310,107,0.345,0.756098
4,Rwanda,555,191,0.344,0.549575
5,Uganda,456,52,0.114,0.44186
6,Ethiopia_Tigray_2020,507,139,0.274,0.638132
7,Ethiopia_Tigray_2021,367,120,0.327,0.641148
8,Ethiopia_Bure_Jimma_2019,498,161,0.323,0.796296
9,Ethiopia_Bure_Jimma_2020,455,129,0.284,0.857143
10,Malawi_CEO_2020,457,67,0.147,0.20202
12,Tanzania_CEO_2019,2037,626,0.307,0.844167


## Presto Benchmark

In [39]:
%load_ext autoreload
%autoreload 2

In [12]:
presto_benchmark = stats[~stats["Name"].isin(not_representative + too_few_crops)].copy()

In [96]:
from src.single_file_presto_v2 import Presto, DEVICE, Aggregate

import numpy as np
import torch

from copy import deepcopy
from torch.utils.data import Dataset, DataLoader

### Presto Encodings

In [97]:
dw_mask = (torch.ones(12) * 9).long()

class PrestoDataset(Dataset):
    def __init__(self, arg_df, start_month=1):
        xs_list = [normalize(x[start_month:start_month+12]) for x in arg_df["eo_data"].to_list()]
        self.xs_tensors = [torch.from_numpy(x).to(DEVICE).float() for x in xs_list]

        self.latlons = [np.stack([lat, lon], axis=-1) for lat, lon in zip(arg_df["eo_lat"].to_list(), arg_df["eo_lon"].to_list())]
        self.latlons_tensors = [torch.from_numpy(latlon).to(DEVICE).float() for latlon in self.latlons]
        
        self.is_crop_tensors = [torch.tensor(is_crop, dtype=torch.float32) for is_crop in arg_df["is_crop"].astype(int).to_list()]
        self.start_month = start_month
        
    def __len__(self):
        return len(self.xs_tensors)
    
    def __getitem__(self, idx):
        x = self.xs_tensors[idx]
        latlons = self.latlons_tensors[idx]
        is_crop = self.is_crop_tensors[idx]
        return x, latlons, dw_mask, self.start_month, is_crop

In [98]:
DEFAULT_SEED = 42

In [99]:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

In [100]:
encoder_decoder = Presto.load_pretrained("../data/presto/default_model_v2.pt")
pretrained_model = encoder_decoder.encoder.eval()

In [101]:
def generate_encodings(dataset, aggregate):
    dataloader = DataLoader(dataset=dataset, batch_size=64, shuffle=False)
    feature_list = []
    for (x, latlons, dw, start_month, _) in tqdm(dataloader, desc="Encodings", leave=False):
        with torch.no_grad():
            encodings = (pretrained_model(
                x, dynamic_world=dw, latlons=latlons, month=start_month, aggregate=aggregate
            ).cpu().numpy())
            feature_list.append(encodings)
    return np.concatenate(feature_list)

In [102]:
# Use Sklearn scaling of encodings
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

## Presto Encodings Benchmark

In [104]:
f1_scores = {}
for dataset_name, bbox_name in tqdm(dataset_name_to_bbox_name.items()):
    buf = 0
    is_local_lat = (df.lat >= (bboxes[bbox_name].min_lat - buf)) & (df.lat <= (bboxes[bbox_name].max_lat + buf))
    is_local_lon = (df.lon >= (bboxes[bbox_name].min_lon - buf)) & (df.lon <= (bboxes[bbox_name].max_lon + buf))
    
    is_test = (df["name"] == dataset_name) & (df["subset"] == "testing")
    test_df = df[is_test]    
    train_df = df[~is_test & is_local_lat & is_local_lon]
 
    train_dataset = PrestoDataset(train_df, start_month=1)
    test_dataset = PrestoDataset(test_df, start_month=1) 
    X_train = generate_encodings(train_dataset, Aggregate.BAND_GROUPS_MEAN)
    X_test = generate_encodings(test_dataset, Aggregate.BAND_GROUPS_MEAN)
    
    y_train = train_df["is_crop"].to_list() 
    y_test = test_df["is_crop"].to_list()
    
    model = LogisticRegression(class_weight="balanced", max_iter=1000, random_state=DEFAULT_SEED)
    #pipe = make_pipeline(StandardScaler(), model)
    #pipe.fit(X_train, y_train)
    #model = RandomForestClassifier(class_weight="balanced", random_state=DEFAULT_SEED)
    model.fit(X_train, y_train)
    
    #y_pred = model.predict(X_test)
    y_pred_proba = model.predict_proba(X_test)[:, 1]
    y_pred = (y_pred_proba > 0.5).astype(int)
    
    f1_scores[dataset_name] = f1_score(y_test, y_pred)
    print(f"{dataset_name}: {f1_scores[dataset_name]}")

  0%|          | 0/17 [00:00<?, ?it/s]

Encodings:   0%|          | 0/22 [00:00<?, ?it/s]

Encodings:   0%|          | 0/5 [00:00<?, ?it/s]

Mali_lower_CEO_2019: 0.6198830409356726


Encodings:   0%|          | 0/21 [00:00<?, ?it/s]

Encodings:   0%|          | 0/5 [00:00<?, ?it/s]

Togo: 0.7317073170731708


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Encodings:   0%|          | 0/19 [00:00<?, ?it/s]

Encodings:   0%|          | 0/9 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Rwanda: 0.6847290640394088


Encodings:   0%|          | 0/319 [00:00<?, ?it/s]

Encodings:   0%|          | 0/8 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Uganda: 0.5098039215686275


Encodings:   0%|          | 0/85 [00:00<?, ?it/s]

Encodings:   0%|          | 0/8 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Ethiopia_Tigray_2020: 0.671480144404332


Encodings:   0%|          | 0/88 [00:00<?, ?it/s]

Encodings:   0%|          | 0/6 [00:00<?, ?it/s]

Ethiopia_Tigray_2021: 0.7222222222222223


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Encodings:   0%|          | 0/33 [00:00<?, ?it/s]

Encodings:   0%|          | 0/8 [00:00<?, ?it/s]

Ethiopia_Bure_Jimma_2019: 0.8571428571428571


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Encodings:   0%|          | 0/34 [00:00<?, ?it/s]

Encodings:   0%|          | 0/8 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Ethiopia_Bure_Jimma_2020: 0.8673835125448028


Encodings:   0%|          | 0/116 [00:00<?, ?it/s]

Encodings:   0%|          | 0/8 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Malawi_CEO_2020: 0.4079601990049751


Encodings:   0%|          | 0/408 [00:00<?, ?it/s]

Encodings:   0%|          | 0/32 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Tanzania_CEO_2019: 0.8313155770782888


Encodings:   0%|          | 0/39 [00:00<?, ?it/s]

Encodings:   0%|          | 0/9 [00:00<?, ?it/s]

Sudan_Blue_Nile_CEO_2019: 0.9201101928374655


Encodings:   0%|          | 0/39 [00:00<?, ?it/s]

Encodings:   0%|          | 0/9 [00:00<?, ?it/s]

SudanBlueNileCEO2020: 0.7789473684210527


Encodings:   0%|          | 0/19 [00:00<?, ?it/s]

Encodings:   0%|          | 0/10 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Senegal_CEO_2022: 0.6244343891402715


Encodings:   0%|          | 0/54 [00:00<?, ?it/s]

Encodings:   0%|          | 0/9 [00:00<?, ?it/s]

SudanAlGadarefCEO2019: 0.5892857142857143


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Encodings:   0%|          | 0/54 [00:00<?, ?it/s]

Encodings:   0%|          | 0/9 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


SudanAlGadarefCEO2020: 0.7209775967413442


Encodings:   0%|          | 0/140 [00:00<?, ?it/s]

Encodings:   0%|          | 0/6 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


SudanGedarefDarfurAlJazirah2022: 0.7615658362989324


Encodings:   0%|          | 0/321 [00:00<?, ?it/s]

Encodings:   0%|          | 0/5 [00:00<?, ?it/s]

Uganda_NorthCEO2022: 0.4333333333333333


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [105]:
benchmark_name = "Presto LR Feb-Feb F1 Score (no DW, band group encodings, per group LayerNorm)"
for dataset, f1 in f1_scores.items():
    presto_benchmark.loc[presto_benchmark["Name"] == dataset, benchmark_name] = f1
    

In [106]:
presto_benchmark

Unnamed: 0,Name,Total,Crop Amount,Crop Rate,"Presto LR Mar-Mar F1 Score (no DW, band group encodings, no norm)","Presto LR Mar-Mar F1 Score (no DW, band group encodings, per group norm)","Presto LR Mar-Mar F1 Score (no DW, band group encodings, sklearn StandardScaler)","Presto LR Feb-Feb F1 Score (no DW, band group encodings, per group LayerNorm)"
1,Mali_lower_CEO_2019,271,94,0.347,0.602273,0.609195,0.576087,0.619883
3,Togo,310,107,0.345,0.742857,0.741935,0.742857,0.731707
4,Rwanda,555,191,0.344,0.666667,0.682927,0.663342,0.684729
5,Uganda,456,52,0.114,0.47205,0.465409,0.465409,0.509804
6,Ethiopia_Tigray_2020,507,139,0.274,0.666667,0.654412,0.654275,0.67148
7,Ethiopia_Tigray_2021,367,120,0.327,0.712963,0.694836,0.712329,0.722222
8,Ethiopia_Bure_Jimma_2019,498,161,0.323,0.868263,0.867257,0.861446,0.857143
9,Ethiopia_Bure_Jimma_2020,455,129,0.284,0.827338,0.845878,0.826568,0.867384
10,Malawi_CEO_2020,457,67,0.147,0.340659,0.381443,0.352273,0.40796
12,Tanzania_CEO_2019,2037,626,0.307,0.8272,0.843875,0.844051,0.831316


## Presto Finetuning Benchmark

In [43]:
from torch import nn
from torch.optim import Adam

In [75]:
lr = 3e-4
batch_size = 64
max_epochs = 10

In [84]:
dataset_name = "Togo"
bbox_name = "Togo"

is_local_lat = (df.lat >= (bboxes[bbox_name].min_lat - buf)) & (df.lat <= (bboxes[bbox_name].max_lat + buf))
is_local_lon = (df.lon >= (bboxes[bbox_name].min_lon - buf)) & (df.lon <= (bboxes[bbox_name].max_lon + buf))

is_test = (df["name"] == dataset_name) & (df["subset"] == "testing")
is_val = (df["name"] == dataset_name) & (df["subset"] == "validation")
test_df = df[is_test]    
val_df = df[is_val]
train_df = df[~is_test & ~is_val & is_local_lat & is_local_lon]

train_dataset = PrestoDataset(train_df, start_month=1)
val_dataset = PrestoDataset(val_df, start_month=1)
test_dataset = PrestoDataset(test_df, start_month=1) 

train_dl = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
test_dl = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [125]:
def finetune_w_early_stopping(train_dl, val_dl, test_dl):
    model = encoder_decoder.construct_finetuning_model(num_outputs=1, regression=False)

    optimizer = Adam(model.parameters(), lr=lr) # TODO: Consider AdamW
    loss_fn = nn.BCELoss()

    train_loss = []
    val_loss = []
    best_loss = None
    best_model_dict = None
    epochs_since_improvement = 0
    patience = 3

    for _ in (pbar := tqdm(range(max_epochs), desc="Finetuning")):
        model.train()
        epoch_train_loss = 0.0
        for x, latlons, dw, month, y in tqdm(train_dl, desc="Training", leave=False):
            x, dw, latlons, y, month = [t.to(DEVICE) for t in (x, dw, latlons, y, month)]
            optimizer.zero_grad()
            preds = model(x, dynamic_world=dw, mask=None, latlons=latlons, month=month)
            loss = loss_fn(preds, y.unsqueeze(dim=1))
            epoch_train_loss += loss.item()
            loss.backward()
            optimizer.step()
        train_loss.append(epoch_train_loss / len(train_dl))

        model.eval()
        all_preds, all_y = [], []
        for x, latlons, dw, month, y in val_dl:
            x, dw, latlons, y, month = [t.to(DEVICE) for t in (x, dw, latlons, y, month)]
            with torch.no_grad():
                preds = model(x, dynamic_world=dw, mask=None, latlons=latlons, month=month)
                all_preds.append(preds)
                all_y.append(y.unsqueeze(dim=1))

        val_loss.append(loss_fn(torch.cat(all_preds), torch.cat(all_y)))
        pbar.set_description(f"Train metric: {train_loss[-1]}, Val metric: {val_loss[-1]}")
        if best_loss is None:
            best_loss = val_loss[-1]
            best_model_dict = deepcopy(model.state_dict())
        else:
            if val_loss[-1] < best_loss:
                best_loss = val_loss[-1]
                best_model_dict = deepcopy(model.state_dict())
                epochs_since_improvement = 0
            else:
                epochs_since_improvement += 1
                if epochs_since_improvement >= patience:
                    print("Early stopping!")
                    break
    assert best_model_dict is not None
    model.load_state_dict(best_model_dict)
    model.eval()

    test_preds, test_y = [], []
    for x, latlons, dw, month, y in test_dl:
        x, dw, latlons, y, month = [t.to(DEVICE) for t in (x, dw, latlons, y, month)]
        with torch.no_grad():
            preds = model(x, dynamic_world=dw, mask=None, latlons=latlons, month=month)
            test_preds += (preds.flatten().numpy() > 0.50).astype(int).tolist()
            test_y += y.numpy().astype(int).tolist()

    return f1_score(test_preds, test_y)

In [127]:
f1_scores = {}
for dataset_name, bbox_name in tqdm(dataset_name_to_bbox_name.items()):

    is_local_lat = (df.lat >= bboxes[bbox_name].min_lat) & (df.lat <= bboxes[bbox_name].max_lat)
    is_local_lon = (df.lon >= bboxes[bbox_name].min_lon) & (df.lon <= bboxes[bbox_name].max_lon)
    
    is_test = (df["name"] == dataset_name) & (df["subset"] == "testing")
    is_val = (df["name"] == dataset_name) & (df["subset"] == "validation")
    test_df = df[is_test]    
    val_df = df[is_val]
    train_df = df[~is_test & ~is_val & is_local_lat & is_local_lon]

    train_dataset = PrestoDataset(train_df, start_month=1)
    val_dataset = PrestoDataset(val_df, start_month=1)
    test_dataset = PrestoDataset(test_df, start_month=1) 

    train_dl = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    val_dl = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
    test_dl = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    
    f1_scores[dataset_name] = finetune_w_early_stopping(train_dl, val_dl, test_dl)

  0%|          | 0/17 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/18 [00:00<?, ?it/s]

Training:   0%|          | 0/18 [00:00<?, ?it/s]

Training:   0%|          | 0/18 [00:00<?, ?it/s]

Training:   0%|          | 0/18 [00:00<?, ?it/s]

Training:   0%|          | 0/18 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/17 [00:00<?, ?it/s]

Training:   0%|          | 0/17 [00:00<?, ?it/s]

Training:   0%|          | 0/17 [00:00<?, ?it/s]

Training:   0%|          | 0/17 [00:00<?, ?it/s]

Training:   0%|          | 0/17 [00:00<?, ?it/s]

Training:   0%|          | 0/17 [00:00<?, ?it/s]

Training:   0%|          | 0/17 [00:00<?, ?it/s]

Training:   0%|          | 0/17 [00:00<?, ?it/s]

Training:   0%|          | 0/17 [00:00<?, ?it/s]

Training:   0%|          | 0/17 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/11 [00:00<?, ?it/s]

Training:   0%|          | 0/11 [00:00<?, ?it/s]

Training:   0%|          | 0/11 [00:00<?, ?it/s]

Training:   0%|          | 0/11 [00:00<?, ?it/s]

Training:   0%|          | 0/11 [00:00<?, ?it/s]

Training:   0%|          | 0/11 [00:00<?, ?it/s]

Training:   0%|          | 0/11 [00:00<?, ?it/s]

Training:   0%|          | 0/11 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/312 [00:00<?, ?it/s]

Training:   0%|          | 0/312 [00:00<?, ?it/s]

Training:   0%|          | 0/312 [00:00<?, ?it/s]

Training:   0%|          | 0/312 [00:00<?, ?it/s]

Training:   0%|          | 0/312 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/77 [00:00<?, ?it/s]

Training:   0%|          | 0/77 [00:00<?, ?it/s]

Training:   0%|          | 0/77 [00:00<?, ?it/s]

Training:   0%|          | 0/77 [00:00<?, ?it/s]

Training:   0%|          | 0/77 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/82 [00:00<?, ?it/s]

Training:   0%|          | 0/82 [00:00<?, ?it/s]

Training:   0%|          | 0/82 [00:00<?, ?it/s]

Training:   0%|          | 0/82 [00:00<?, ?it/s]

Training:   0%|          | 0/82 [00:00<?, ?it/s]

Training:   0%|          | 0/82 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/25 [00:00<?, ?it/s]

Training:   0%|          | 0/25 [00:00<?, ?it/s]

Training:   0%|          | 0/25 [00:00<?, ?it/s]

Training:   0%|          | 0/25 [00:00<?, ?it/s]

Training:   0%|          | 0/25 [00:00<?, ?it/s]

Training:   0%|          | 0/25 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/26 [00:00<?, ?it/s]

Training:   0%|          | 0/26 [00:00<?, ?it/s]

Training:   0%|          | 0/26 [00:00<?, ?it/s]

Training:   0%|          | 0/26 [00:00<?, ?it/s]

Training:   0%|          | 0/26 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/109 [00:00<?, ?it/s]

Training:   0%|          | 0/109 [00:00<?, ?it/s]

Training:   0%|          | 0/109 [00:00<?, ?it/s]

Training:   0%|          | 0/109 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/376 [00:00<?, ?it/s]

Training:   0%|          | 0/376 [00:00<?, ?it/s]

Training:   0%|          | 0/376 [00:00<?, ?it/s]

Training:   0%|          | 0/376 [00:00<?, ?it/s]

Training:   0%|          | 0/376 [00:00<?, ?it/s]

Training:   0%|          | 0/376 [00:00<?, ?it/s]

Training:   0%|          | 0/376 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Training:   0%|          | 0/31 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/10 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Training:   0%|          | 0/45 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/135 [00:00<?, ?it/s]

Training:   0%|          | 0/135 [00:00<?, ?it/s]

Training:   0%|          | 0/135 [00:00<?, ?it/s]

Training:   0%|          | 0/135 [00:00<?, ?it/s]

Training:   0%|          | 0/135 [00:00<?, ?it/s]

Training:   0%|          | 0/135 [00:00<?, ?it/s]

Training:   0%|          | 0/135 [00:00<?, ?it/s]

Training:   0%|          | 0/135 [00:00<?, ?it/s]

Early stopping!


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/317 [00:00<?, ?it/s]

Training:   0%|          | 0/317 [00:00<?, ?it/s]

Training:   0%|          | 0/317 [00:00<?, ?it/s]

Training:   0%|          | 0/317 [00:00<?, ?it/s]

Training:   0%|          | 0/317 [00:00<?, ?it/s]

Training:   0%|          | 0/317 [00:00<?, ?it/s]

Training:   0%|          | 0/317 [00:00<?, ?it/s]

Training:   0%|          | 0/317 [00:00<?, ?it/s]

Early stopping!


In [130]:
for dataset, f1 in f1_scores.items():
    presto_benchmark.loc[presto_benchmark["Name"] == dataset, "Presto Finetuning w Early Stopping Feb-Feb F1 Score (no DW)"] = f1

In [131]:
presto_benchmark

Unnamed: 0,Name,Total,Crop Amount,Crop Rate,Presto RF Mar-Mar F1 Score (no DW),Presto RF Feb-Feb F1 Score (no DW),Presto Finetuning w Early Stopping Feb-Feb F1 Score (no DW)
1,Mali_lower_CEO_2019,271,94,0.347,0.588235,0.641711,0.592593
3,Togo,310,107,0.345,0.756303,0.742857,0.731092
4,Rwanda,555,191,0.344,0.627346,0.576819,0.604534
5,Uganda,456,52,0.114,0.464646,0.4375,0.5
6,Ethiopia_Tigray_2020,507,139,0.274,0.641221,0.651341,0.666667
7,Ethiopia_Tigray_2021,367,120,0.327,0.669903,0.68932,0.682692
8,Ethiopia_Bure_Jimma_2019,498,161,0.323,0.819355,0.825806,0.808642
9,Ethiopia_Bure_Jimma_2020,455,129,0.284,0.880309,0.849421,0.838951
10,Malawi_CEO_2020,457,67,0.147,0.142857,0.117647,0.113475
12,Tanzania_CEO_2019,2037,626,0.307,0.80322,0.809991,0.770306


In [132]:
def finetune(train_dl, test_dl):
    model = encoder_decoder.construct_finetuning_model(num_outputs=1, regression=False)

    optimizer = Adam(model.parameters(), lr=lr) # TODO: Consider AdamW
    loss_fn = nn.BCELoss()

    train_loss = []

    for _ in (pbar := tqdm(range(max_epochs), desc="Finetuning")):
        model.train()
        epoch_train_loss = 0.0
        for x, latlons, dw, month, y in tqdm(train_dl, desc="Training", leave=False):
            x, dw, latlons, y, month = [t.to(DEVICE) for t in (x, dw, latlons, y, month)]
            optimizer.zero_grad()
            preds = model(x, dynamic_world=dw, mask=None, latlons=latlons, month=month)
            loss = loss_fn(preds, y.unsqueeze(dim=1))
            epoch_train_loss += loss.item()
            loss.backward()
            optimizer.step()
        train_loss.append(epoch_train_loss / len(train_dl))

    model.eval()

    test_preds, test_y = [], []
    for x, latlons, dw, month, y in test_dl:
        x, dw, latlons, y, month = [t.to(DEVICE) for t in (x, dw, latlons, y, month)]
        with torch.no_grad():
            preds = model(x, dynamic_world=dw, mask=None, latlons=latlons, month=month)
            test_preds += (preds.flatten().numpy() > 0.50).astype(int).tolist()
            test_y += y.numpy().astype(int).tolist()

    return f1_score(test_preds, test_y)

In [134]:
f1_scores = {}
for dataset_name, bbox_name in tqdm(dataset_name_to_bbox_name.items()):

    is_local_lat = (df.lat >= bboxes[bbox_name].min_lat) & (df.lat <= bboxes[bbox_name].max_lat)
    is_local_lon = (df.lon >= bboxes[bbox_name].min_lon) & (df.lon <= bboxes[bbox_name].max_lon)
    
    is_test = (df["name"] == dataset_name) & (df["subset"] == "testing")
    test_df = df[is_test]    
    train_df = df[~is_test & is_local_lat & is_local_lon]

    train_dataset = PrestoDataset(train_df, start_month=1)
    test_dataset = PrestoDataset(test_df, start_month=1) 

    train_dl = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_dl = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    
    f1_scores[dataset_name] = finetune(train_dl, test_dl)

  0%|          | 0/17 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/22 [00:00<?, ?it/s]

Training:   0%|          | 0/22 [00:00<?, ?it/s]

Training:   0%|          | 0/22 [00:00<?, ?it/s]

Training:   0%|          | 0/22 [00:00<?, ?it/s]

Training:   0%|          | 0/22 [00:00<?, ?it/s]

Training:   0%|          | 0/22 [00:00<?, ?it/s]

Training:   0%|          | 0/22 [00:00<?, ?it/s]

Training:   0%|          | 0/22 [00:00<?, ?it/s]

Training:   0%|          | 0/22 [00:00<?, ?it/s]

Training:   0%|          | 0/22 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/21 [00:00<?, ?it/s]

Training:   0%|          | 0/21 [00:00<?, ?it/s]

Training:   0%|          | 0/21 [00:00<?, ?it/s]

Training:   0%|          | 0/21 [00:00<?, ?it/s]

Training:   0%|          | 0/21 [00:00<?, ?it/s]

Training:   0%|          | 0/21 [00:00<?, ?it/s]

Training:   0%|          | 0/21 [00:00<?, ?it/s]

Training:   0%|          | 0/21 [00:00<?, ?it/s]

Training:   0%|          | 0/21 [00:00<?, ?it/s]

Training:   0%|          | 0/21 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/319 [00:00<?, ?it/s]

Training:   0%|          | 0/319 [00:00<?, ?it/s]

Training:   0%|          | 0/319 [00:00<?, ?it/s]

Training:   0%|          | 0/319 [00:00<?, ?it/s]

Training:   0%|          | 0/319 [00:00<?, ?it/s]

Training:   0%|          | 0/319 [00:00<?, ?it/s]

Training:   0%|          | 0/319 [00:00<?, ?it/s]

Training:   0%|          | 0/319 [00:00<?, ?it/s]

Training:   0%|          | 0/319 [00:00<?, ?it/s]

Training:   0%|          | 0/319 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/85 [00:00<?, ?it/s]

Training:   0%|          | 0/85 [00:00<?, ?it/s]

Training:   0%|          | 0/85 [00:00<?, ?it/s]

Training:   0%|          | 0/85 [00:00<?, ?it/s]

Training:   0%|          | 0/85 [00:00<?, ?it/s]

Training:   0%|          | 0/85 [00:00<?, ?it/s]

Training:   0%|          | 0/85 [00:00<?, ?it/s]

Training:   0%|          | 0/85 [00:00<?, ?it/s]

Training:   0%|          | 0/85 [00:00<?, ?it/s]

Training:   0%|          | 0/85 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/88 [00:00<?, ?it/s]

Training:   0%|          | 0/88 [00:00<?, ?it/s]

Training:   0%|          | 0/88 [00:00<?, ?it/s]

Training:   0%|          | 0/88 [00:00<?, ?it/s]

Training:   0%|          | 0/88 [00:00<?, ?it/s]

Training:   0%|          | 0/88 [00:00<?, ?it/s]

Training:   0%|          | 0/88 [00:00<?, ?it/s]

Training:   0%|          | 0/88 [00:00<?, ?it/s]

Training:   0%|          | 0/88 [00:00<?, ?it/s]

Training:   0%|          | 0/88 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/33 [00:00<?, ?it/s]

Training:   0%|          | 0/33 [00:00<?, ?it/s]

Training:   0%|          | 0/33 [00:00<?, ?it/s]

Training:   0%|          | 0/33 [00:00<?, ?it/s]

Training:   0%|          | 0/33 [00:00<?, ?it/s]

Training:   0%|          | 0/33 [00:00<?, ?it/s]

Training:   0%|          | 0/33 [00:00<?, ?it/s]

Training:   0%|          | 0/33 [00:00<?, ?it/s]

Training:   0%|          | 0/33 [00:00<?, ?it/s]

Training:   0%|          | 0/33 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/34 [00:00<?, ?it/s]

Training:   0%|          | 0/34 [00:00<?, ?it/s]

Training:   0%|          | 0/34 [00:00<?, ?it/s]

Training:   0%|          | 0/34 [00:00<?, ?it/s]

Training:   0%|          | 0/34 [00:00<?, ?it/s]

Training:   0%|          | 0/34 [00:00<?, ?it/s]

Training:   0%|          | 0/34 [00:00<?, ?it/s]

Training:   0%|          | 0/34 [00:00<?, ?it/s]

Training:   0%|          | 0/34 [00:00<?, ?it/s]

Training:   0%|          | 0/34 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/116 [00:00<?, ?it/s]

Training:   0%|          | 0/116 [00:00<?, ?it/s]

Training:   0%|          | 0/116 [00:00<?, ?it/s]

Training:   0%|          | 0/116 [00:00<?, ?it/s]

Training:   0%|          | 0/116 [00:00<?, ?it/s]

Training:   0%|          | 0/116 [00:00<?, ?it/s]

Training:   0%|          | 0/116 [00:00<?, ?it/s]

Training:   0%|          | 0/116 [00:00<?, ?it/s]

Training:   0%|          | 0/116 [00:00<?, ?it/s]

Training:   0%|          | 0/116 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/408 [00:00<?, ?it/s]

Training:   0%|          | 0/408 [00:00<?, ?it/s]

Training:   0%|          | 0/408 [00:00<?, ?it/s]

Training:   0%|          | 0/408 [00:00<?, ?it/s]

Training:   0%|          | 0/408 [00:00<?, ?it/s]

Training:   0%|          | 0/408 [00:00<?, ?it/s]

Training:   0%|          | 0/408 [00:00<?, ?it/s]

Training:   0%|          | 0/408 [00:00<?, ?it/s]

Training:   0%|          | 0/408 [00:00<?, ?it/s]

Training:   0%|          | 0/408 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Training:   0%|          | 0/39 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Training:   0%|          | 0/19 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Training:   0%|          | 0/54 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/140 [00:00<?, ?it/s]

Training:   0%|          | 0/140 [00:00<?, ?it/s]

Training:   0%|          | 0/140 [00:00<?, ?it/s]

Training:   0%|          | 0/140 [00:00<?, ?it/s]

Training:   0%|          | 0/140 [00:00<?, ?it/s]

Training:   0%|          | 0/140 [00:00<?, ?it/s]

Training:   0%|          | 0/140 [00:00<?, ?it/s]

Training:   0%|          | 0/140 [00:00<?, ?it/s]

Training:   0%|          | 0/140 [00:00<?, ?it/s]

Training:   0%|          | 0/140 [00:00<?, ?it/s]

Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/321 [00:00<?, ?it/s]

Training:   0%|          | 0/321 [00:00<?, ?it/s]

Training:   0%|          | 0/321 [00:00<?, ?it/s]

Training:   0%|          | 0/321 [00:00<?, ?it/s]

Training:   0%|          | 0/321 [00:00<?, ?it/s]

Training:   0%|          | 0/321 [00:00<?, ?it/s]

Training:   0%|          | 0/321 [00:00<?, ?it/s]

Training:   0%|          | 0/321 [00:00<?, ?it/s]

Training:   0%|          | 0/321 [00:00<?, ?it/s]

Training:   0%|          | 0/321 [00:00<?, ?it/s]

In [135]:
for dataset, f1 in f1_scores.items():
    presto_benchmark.loc[presto_benchmark["Name"] == dataset, "Presto Finetuning Feb-Feb F1 Score (no DW)"] = f1

In [136]:
presto_benchmark

Unnamed: 0,Name,Total,Crop Amount,Crop Rate,Presto RF Mar-Mar F1 Score (no DW),Presto RF Feb-Feb F1 Score (no DW),Presto Finetuning w Early Stopping Feb-Feb F1 Score (no DW),Presto Finetuning Feb-Feb F1 Score (no DW)
1,Mali_lower_CEO_2019,271,94,0.347,0.588235,0.641711,0.592593,0.642487
3,Togo,310,107,0.345,0.756303,0.742857,0.731092,0.736842
4,Rwanda,555,191,0.344,0.627346,0.576819,0.604534,0.571429
5,Uganda,456,52,0.114,0.464646,0.4375,0.5,0.533333
6,Ethiopia_Tigray_2020,507,139,0.274,0.641221,0.651341,0.666667,0.706714
7,Ethiopia_Tigray_2021,367,120,0.327,0.669903,0.68932,0.682692,0.756757
8,Ethiopia_Bure_Jimma_2019,498,161,0.323,0.819355,0.825806,0.808642,0.839117
9,Ethiopia_Bure_Jimma_2020,455,129,0.284,0.880309,0.849421,0.838951,0.844106
10,Malawi_CEO_2020,457,67,0.147,0.142857,0.117647,0.113475,0.217687
12,Tanzania_CEO_2019,2037,626,0.307,0.80322,0.809991,0.770306,0.808394
