In [48]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from models.vanillann import YieldDataset, SimpleModel

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Specify use case through crop and country

In [33]:
# USER INPUTS
country = "BR" # one of ["US", "BR"]
crop = "wheat" # one of ["maize", "wheat"]

match country:
    case "US":
        match crop:
            case "wheat":
                ecmwf_path = "data/preprocessed/US/ecmwf_era_wheat_US.csv"
                yield_path = "data/CY-Bench/US/wheat/yield_wheat_US.csv"
                test_years = [2020, 2021, 2022]
                print(country, crop, ecmwf_path, yield_path, test_years)
            case "maize":
                ecmwf_path = "data/preprocessed/US/ecmwf_era_maize_US.csv"
                yield_path = "data/CY-Bench/US/maize/yield_maize_US.csv"
                test_years = [2020, 2021, 2022]
                print(country, crop, ecmwf_path, yield_path, test_years)
            case _:
                print("invalid crop, has to be wheat or maize")
    case "BR":
        match crop:
            case "wheat":
                ecmwf_path = "data/preprocessed/BR/ecmwf_era_wheat_BR.csv"
                yield_path = "data/CY-Bench/BR/wheat/yield_wheat_BR.csv"
                test_years = [2006, 2015, 2017]
                print(country, crop, ecmwf_path, yield_path, test_years)
            case "maize":
                ecmwf_path = "data/preprocessed/BR/ecmwf_era_maize_BR.csv"
                yield_path = "data/CY-Bench/BR/maize/yield_maize_BR.csv"
                test_years = [2020, 2021, 2022]
                print(country, crop, ecmwf_path, yield_path, test_years)
            case _:
                print("invalid crop, has to be wheat or maize")
    case _:
        print("invalid country, has to be US or BR")


BR wheat data/preprocessed/BR/ecmwf_era_wheat_BR.csv data/CY-Bench/BR/wheat/yield_wheat_BR.csv [2006, 2015, 2017]


### Read yield and SCM_ERA data

In [39]:
# Historical yield data
y = pd.read_csv(yield_path)
y = y.loc[y["harvest_year"].between(2003, 2023), ["adm_id", "harvest_year", "yield", "harvested_area"]].reset_index(drop=True)

# 8-day aggregated ECMWF and ERA data depending on month of initialization
x = pd.read_csv(ecmwf_path)
x_y = x.merge(y, on=["adm_id", "harvest_year"], how="inner")

# remove test years
train_df = x_y[~x_y['harvest_year'].isin(test_years)].reset_index(drop=True)
train_df.head()

Unnamed: 0,adm_id,harvest_year,init_month,init_time_step,tavg_16,tavg_17,tavg_18,tavg_19,tavg_20,tavg_21,...,tmin_40,tmin_41,tmax_39,tmax_40,tmax_41,prec_39,prec_40,prec_41,yield,harvested_area
0,BR2311504,2020,5,16,25.653409,25.939799,26.185234,26.433421,26.679016,26.861418,...,24.462146,24.512708,35.406395,35.381923,35.32483,0.217376,0.245209,0.310888,5.4,5.0
1,BR2311504,2020,6,19,26.782667,26.057738,26.166207,,25.261646,26.209929,...,24.313246,24.508361,35.201305,35.248513,35.284314,0.299057,0.232168,0.355174,5.4,5.0
2,BR2311504,2020,7,23,26.782667,26.057738,26.166207,26.364367,25.930489,26.029086,...,24.263344,24.303319,35.123788,35.311503,35.331876,0.182081,0.180712,0.199467,5.4,5.0
3,BR2311504,2020,8,27,26.782667,26.057738,26.166207,26.364367,25.930489,26.029086,...,24.127901,24.171361,35.140251,35.033899,34.994636,0.166692,0.349271,0.295164,5.4,5.0
4,BR2311504,2020,9,31,26.782667,26.057738,26.166207,26.364367,25.930489,26.029086,...,23.894624,24.096629,35.106095,35.052865,35.157488,0.136966,0.140238,0.163872,5.4,5.0


In [40]:
#for m in [5, 6, 7, 8, 9, 10, 11, 12]:
#for m in [9]:
#train_df_scm = train_df[train_df["init_month"] == 10].reset_index(drop=True)
#first_scm_time_step = train_df_scm["init_time_step"].max()
#train_df_truncated = train_df[train_df["init_month"] == 12].reset_index(drop=True)
#train_df_truncated[[c for c in [l for l in train_df_truncated.columns if ("tavg" in l) or ("tmax" in l) or ("tmin" in l) or ("prec" in l)] if int(c.split("_")[-1]) >= first_scm_time_step]] = (train_df_truncated
# .groupby(["adm_id"])
# [[c for c in [l for l in train_df_truncated.columns if ("tavg" in l) or ("tmax" in l) or ("tmin" in l) or ("prec" in l)] if int(c.split("_")[-1]) >= first_scm_time_step]]
# .transform("mean"))

#train_df_scm = train_df_scm.drop(columns=["init_month", "init_time_step", "production", "harvested_area", "planted_area"]).set_index("adm_id")
#train_df_truncated = train_df_truncated.drop(columns=["init_month", "init_time_step", "production", "harvested_area", "planted_area"]).set_index("adm_id")

end_of_season_df = train_df[train_df["init_month"] == 12].drop(columns=["init_month", "init_time_step", "harvested_area"]).set_index("adm_id")

In [50]:
def RMSELoss(yhat,y):
    return 100 * torch.sqrt(torch.mean((yhat-y)**2)) / torch.mean(y)

In [51]:
unique_years = end_of_season_df['harvest_year'].unique()
unique_years.sort()
results = dict.fromkeys(unique_years)

batch_size = 32

for year in unique_years:
    print(f'Validating on year {year}')
    
    # Create training and validation sets for this fold
    train_fold_df = end_of_season_df[end_of_season_df['harvest_year'] != year]
    val_fold_df = end_of_season_df[end_of_season_df['harvest_year'] == year]
    
    train_fold_features = train_fold_df[[c for c in train_fold_df.columns if ("tavg" in c) or ("tmax" in c) or ("tmin" in c) or ("prec" in c)]]
    train_fold_target = train_fold_df['yield']
    val_fold_features = val_fold_df[[c for c in val_fold_df.columns if ("tavg" in c) or ("tmax" in c) or ("tmin" in c) or ("prec" in c)]]
    val_fold_target = val_fold_df['yield']
    
    means = train_fold_features.mean()
    stds = train_fold_features.std()
    train_fold_features = (train_fold_features - means) / stds
    val_fold_features = (val_fold_features - means) / stds
    
    train_fold_dataset = YieldDataset(train_fold_features, train_fold_target)
    val_fold_dataset = YieldDataset(val_fold_features, val_fold_target)
    
    train_fold_loader = DataLoader(train_fold_dataset, batch_size=batch_size, shuffle=True)
    val_fold_loader = DataLoader(val_fold_dataset, batch_size=batch_size, shuffle=False)
    
    # Reset the model and optimizer
    model = SimpleModel()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = RMSELoss
    
    # Early stopping parameters
    num_epochs = 10  # Set the maximum number of epochs you want to train for
    patience = 3  # Number of epochs to wait for improvement before stopping
    best_val_loss = float('inf')  # Initialize the best validation loss
    epochs_no_improve = 0  # Counter for epochs without improvement
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        for features, target in train_fold_loader:
            optimizer.zero_grad()
            output = model(features)
            loss = criterion(output, target.unsqueeze(1))
            loss.backward()
            optimizer.step()
        
        # Validation loop
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for features, target in val_fold_loader:
                output = model(features)
                loss = criterion(output, target.unsqueeze(1))
                val_loss += loss.item()
        
        val_loss /= len(val_fold_loader)  # Compute the average validation loss
        print(f'Epoch {epoch + 1}, Validation Loss for year {year}: {val_loss}')
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0  # Reset the counter if validation loss improves
        else:
            epochs_no_improve += 1  # Increment the counter if validation loss does not improve
        
        if epochs_no_improve >= patience:
            print(f'Early stopping at epoch {epoch + 1}')
            break  # Stop training if no improvement for specified number of epochs
    
    results[year] = best_val_loss
        
# Once cross-validation is done, you can test on the test dataset using test_loader


Validating on year 2003
Epoch 1, Validation Loss for year 2003: 26.678354835510255
Epoch 2, Validation Loss for year 2003: 29.258882649739583
Epoch 3, Validation Loss for year 2003: 24.806484858194988
Epoch 4, Validation Loss for year 2003: 27.530549621582033
Epoch 5, Validation Loss for year 2003: 27.721895249684653
Epoch 6, Validation Loss for year 2003: 26.311521816253663
Early stopping at epoch 6
Validating on year 2004
Epoch 1, Validation Loss for year 2004: 27.144367771763957
Epoch 2, Validation Loss for year 2004: 27.3316252616144
Epoch 3, Validation Loss for year 2004: 25.19689993704519
Epoch 4, Validation Loss for year 2004: 27.0602960278911
Epoch 5, Validation Loss for year 2004: 29.098012555030085
Epoch 6, Validation Loss for year 2004: 29.48936293202062
Early stopping at epoch 6
Validating on year 2005
Epoch 1, Validation Loss for year 2005: 42.0397554397583
Epoch 2, Validation Loss for year 2005: 49.843784459431966
Epoch 3, Validation Loss for year 2005: 38.166168721516925

KeyboardInterrupt: 