In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from models.vanillann import YieldDataset, SimpleModel
from sklearn.feature_selection  import mutual_info_regression

ecmwf_path = "data/preprocessed/US/ecmwf_era_wheat_US.csv"
predictor_path = "data/preprocessed/US/ndvi_soil_soil_moisture_meteo_fpar_wheat_US.csv"
yield_path = "data/CY-Bench/US/wheat/yield_wheat_US.csv"
test_years = [2015, 2018, 2022]

%load_ext autoreload
%autoreload 2

In [2]:
def get_yield_and_predictors(yield_path, predictor_path, ecmwf_path, test_years):
    y = pd.read_csv(yield_path)
    y = y.loc[y["harvest_year"].between(2003, 2023), ["adm_id", "harvest_year", "yield"]].reset_index(drop=True)

    # Merge predictor data
    x_1 = pd.read_csv(ecmwf_path)
    x_2 = pd.read_csv(predictor_path)
    x = x_1.merge(x_2, on=["adm_id", "harvest_year"], how="left").dropna().reset_index(drop=True)

    # Merge predictor and yield data
    x_y = x.merge(y, on=["adm_id", "harvest_year"], how="inner")

    train_df = x_y[~x_y['harvest_year'].isin(test_years)].reset_index(drop=True)
    test_df = x_y[x_y['harvest_year'].isin(test_years)].reset_index(drop=True)
    
    return train_df, test_df

In [3]:
train_df, test_df = get_yield_and_predictors(yield_path, predictor_path, ecmwf_path, test_years)

In [4]:
pd.to_datetime(train_df["init_date"]).dt.month.max()

9

In [11]:
end_of_season_df = train_df[pd.to_datetime(train_df["init_date"]).dt.month == 9].drop(columns=["init_date"]).set_index("adm_id")

In [12]:
def RMSELoss(yhat,y):
    return 100 * torch.sqrt(torch.mean((yhat-y)**2)) / torch.mean(y)

In [13]:
end_of_season_df = end_of_season_df.set_index(["yield", "harvest_year"], append=True)

In [14]:
end_of_season_df = (end_of_season_df[[c for c in end_of_season_df.columns if "tmax" in c]].rolling(window=4, step=4, axis=1).mean().dropna(axis=1)
 .join(
     end_of_season_df[[c for c in end_of_season_df.columns if "tmin" in c]].rolling(window=4, step=4, axis=1).mean().dropna(axis=1)
 ).join(
    end_of_season_df[[c for c in end_of_season_df.columns if "prec" in c]].rolling(window=4, step=4, axis=1).mean().dropna(axis=1)
 ).join(
     end_of_season_df[["awc", "bulk_density", "drainage_class_1", "drainage_class_2", "drainage_class_3", "drainage_class_4", "drainage_class_5", "drainage_class_6"]]
 ).reset_index().set_index("adm_id")
)

In [15]:
end_of_season_df

Unnamed: 0_level_0,yield,harvest_year,tmax_8,tmax_12,tmax_16,tmax_20,tmax_24,tmax_28,tmin_8,tmin_12,...,prec_24,prec_28,awc,bulk_density,drainage_class_1,drainage_class_2,drainage_class_3,drainage_class_4,drainage_class_5,drainage_class_6
adm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
US-01-001,3.9707,2003,17.270446,20.937684,26.825161,27.870006,29.791414,30.027308,6.340036,10.441741,...,5.243151,5.574161,11.436359,1.342325,0.0,0.0,0.0,1.0,0.0,0.0
US-01-003,3.1631,2003,18.446471,21.386636,26.738730,28.084772,29.090873,29.618936,9.253883,12.043147,...,9.176004,6.633327,14.818747,1.424239,0.0,0.0,0.0,0.0,1.0,0.0
US-01-003,3.3650,2004,17.260808,22.149763,25.288360,29.233394,30.445200,30.656131,8.757897,11.588240,...,5.335587,4.157435,14.818747,1.424239,0.0,0.0,0.0,0.0,1.0,0.0
US-01-003,3.4323,2006,19.049926,22.990201,26.594140,31.007885,32.452603,32.025328,9.436357,12.178329,...,1.656067,3.771095,14.818747,1.424239,0.0,0.0,0.0,0.0,1.0,0.0
US-01-003,3.6342,2007,18.125727,22.861901,27.243667,29.355054,31.427483,32.693986,7.138254,13.386155,...,3.823193,2.309769,14.818747,1.424239,0.0,0.0,0.0,0.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
US-56-033,2.0863,2003,0.233752,10.902196,13.228545,20.573137,27.212852,32.217816,-12.818909,-1.986878,...,1.811108,0.582011,13.133602,1.559058,0.0,0.0,0.0,0.0,1.0,0.0
US-56-033,1.6152,2004,5.824528,13.768752,16.756334,19.084380,24.497442,25.945786,-6.548714,0.661033,...,1.836529,2.263215,13.133602,1.559058,0.0,0.0,0.0,0.0,1.0,0.0
US-56-033,3.1631,2005,7.939766,11.219689,12.008681,19.030050,28.913044,27.521570,-4.799025,-2.013270,...,1.241215,2.053632,13.133602,1.559058,0.0,0.0,0.0,0.0,1.0,0.0
US-56-033,2.0190,2007,6.420172,8.795595,18.362682,19.661393,30.077782,31.287195,-6.789862,-2.890519,...,0.430370,0.953817,13.133602,1.559058,0.0,0.0,0.0,0.0,1.0,0.0


In [16]:
unique_years = end_of_season_df['harvest_year'].unique()
unique_years.sort()
results = dict.fromkeys(unique_years)

batch_size = 64

for year in unique_years:
    print(f'Validating on year {year}')
    
    # Create training and validation sets for this fold
    train_fold_df = end_of_season_df[end_of_season_df['harvest_year'] != year]
    val_fold_df = end_of_season_df[end_of_season_df['harvest_year'] == year]
    
    train_fold_features = train_fold_df.drop(columns=['yield', 'harvest_year'])
    train_fold_target = train_fold_df['yield']
    val_fold_features = val_fold_df.drop(columns=['yield', 'harvest_year'])
    val_fold_target = val_fold_df['yield']
    
    means = train_fold_features.mean()
    stds = train_fold_features.std()
    train_fold_features = (train_fold_features - means) / stds
    val_fold_features = (val_fold_features - means) / stds
    
    
    train_fold_dataset = YieldDataset(train_fold_features, train_fold_target)
    val_fold_dataset = YieldDataset(val_fold_features, val_fold_target)
    
    train_fold_loader = DataLoader(train_fold_dataset, batch_size=batch_size, shuffle=True)
    val_fold_loader = DataLoader(val_fold_dataset, batch_size=batch_size, shuffle=False)
    
    # Reset the model and optimizer
    model = SimpleModel()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = RMSELoss
    
    # Early stopping parameters
    num_epochs = 10  # Set the maximum number of epochs you want to train for
    patience = 4  # Number of epochs to wait for improvement before stopping
    best_val_loss = float('inf')  # Initialize the best validation loss
    epochs_no_improve = 0  # Counter for epochs without improvement
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        for features, target in train_fold_loader:
            optimizer.zero_grad()
            output = model(features)
            loss = criterion(output, target.unsqueeze(1))
            loss.backward()
            optimizer.step()
        
        # Validation loop
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for features, target in val_fold_loader:
                output = model(features)
                loss = criterion(output, target.unsqueeze(1))
                val_loss += loss.item()
        
        val_loss /= len(val_fold_loader)  # Compute the average validation loss
        print(f'Epoch {epoch + 1}, Validation Loss for year {year}: {val_loss}')
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0  # Reset the counter if validation loss improves
        else:
            epochs_no_improve += 1  # Increment the counter if validation loss does not improve
        
        if epochs_no_improve >= patience:
            print(f'Early stopping at epoch {epoch + 1}')
            break  # Stop training if no improvement for specified number of epochs
    
    results[year] = best_val_loss
        
# Once cross-validation is done, you can test on the test dataset using test_loader


Validating on year 2003
Epoch 1, Validation Loss for year 2003: 43.57643826802572
Epoch 2, Validation Loss for year 2003: 36.37793646918403
Epoch 3, Validation Loss for year 2003: 34.92403881638138
Epoch 4, Validation Loss for year 2003: 33.85168983318187
Epoch 5, Validation Loss for year 2003: 33.79119788275825
Epoch 6, Validation Loss for year 2003: 33.34014613540084
Epoch 7, Validation Loss for year 2003: 32.63258859846327
Epoch 8, Validation Loss for year 2003: 32.77996197453252
Epoch 9, Validation Loss for year 2003: 32.512628625940394
Epoch 10, Validation Loss for year 2003: 32.75680167586715
Validating on year 2004
Epoch 1, Validation Loss for year 2004: 37.70468572469858
Epoch 2, Validation Loss for year 2004: 33.60248455634484
Epoch 3, Validation Loss for year 2004: 34.17192598489615
Epoch 4, Validation Loss for year 2004: 34.08541811429537
Epoch 5, Validation Loss for year 2004: 34.36715419475849
Epoch 6, Validation Loss for year 2004: 32.787856542147125
Epoch 7, Validation L

KeyboardInterrupt: 