In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import math
import time

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm
from sklearn.preprocessing import StandardScaler

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/users/marron31/repos/futurecrop'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/the-future-crop-challenge/pr_wheat_train.parquet
/kaggle/input/the-future-crop-challenge/tasmax_maize_train.parquet
/kaggle/input/the-future-crop-challenge/sample_submission.csv
/kaggle/input/the-future-crop-challenge/soil_co2_wheat_train.parquet
/kaggle/input/the-future-crop-challenge/tas_wheat_train.parquet
/kaggle/input/the-future-crop-challenge/rsds_maize_train.parquet
/kaggle/input/the-future-crop-challenge/tasmin_wheat_train.parquet
/kaggle/input/the-future-crop-challenge/tasmax_wheat_train.parquet
/kaggle/input/the-future-crop-challenge/rsds_maize_test.parquet
/kaggle/input/the-future-crop-challenge/soil_co2_maize_test.parquet
/kaggle/input/the-future-crop-challenge/train_solutions_maize.parquet
/kaggle/input/the-future-crop-challenge/pr_maize_test.parquet
/kaggle/input/the-future-crop-challenge/tas_wheat_test.parquet
/kaggle/input/the-future-crop-challenge/tasmax_maize_test.parquet
/kaggle/input/the-future-crop-challenge/pr_maize_train.parquet
/kaggle/input/the-fu

In [None]:
DATA_DIR = r'/users/marron31/repos/futurecrop'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Running on {device}')

Running on cuda:0


In [38]:
# EALSTM
import torch
import torch.nn as nn
from torch import Tensor

class EALSTMCell(nn.Module):
    """
    Entity-Aware LSTM cell:
      i_t = σ( W_xi x_t + W_hi h_{t-1} + W_si s + b_i )
      f_t = σ( W_xf x_t + W_hf h_{t-1} + b_f )
      o_t = σ( W_xo x_t + W_ho h_{t-1} + b_o )
      g_t = tanh( W_xg x_t + W_hg h_{t-1} + b_g )
      c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t
      h_t = o_t ⊙ tanh(c_t)
    Static features s only affect the input gate.
    """
    def __init__(self, input_size: int, hidden_size: int, static_size: int):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.static_size = static_size

        # Dynamic projections (x_t and h_{t-1}) for all 4 gates at once
        self.lin_x = nn.Linear(input_size, 4 * hidden_size, bias=True)
        self.lin_h = nn.Linear(hidden_size, 4 * hidden_size, bias=True)

        # Static projection ONLY into the input gate (no bias by convention)
        self.lin_s = nn.Linear(static_size, hidden_size, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        # Xavier for weights; small biases (forget gate bias to ~1.0 is common)
        for m in [self.lin_x, self.lin_h]:
            nn.init.xavier_uniform_(m.weight)
            nn.init.zeros_(m.bias)

        nn.init.xavier_uniform_(self.lin_s.weight)

        # Set forget gate bias to positive (helps with long-term memory)
        with torch.no_grad():
            # gates layout: [i, f, o, g] in chunks of hidden_size
            self.lin_x.bias[self.hidden_size:2*self.hidden_size].fill_(1.0)
            self.lin_h.bias[self.hidden_size:2*self.hidden_size].fill_(0.0)

    def forward(self, x_t: Tensor, s: Tensor, h_prev: Tensor, c_prev: Tensor):
        """
        x_t: (B, input_size)
        s:   (B, static_size)  -- same across all t for a given sample
        h_prev, c_prev: (B, hidden_size)
        returns: h_t, c_t
        """
        gates = self.lin_x(x_t) + self.lin_h(h_prev)  # (B, 4H)

        # Split gates
        i, f, o, g = torch.chunk(gates, 4, dim=-1)

        # Inject static features ONLY into input gate
        i = i + self.lin_s(s)

        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        o = torch.sigmoid(o)
        g = torch.tanh(g)

        c_t = f * c_prev + i * g
        h_t = o * torch.tanh(c_t)
        return h_t, c_t


class EALSTM(nn.Module):
    """
    Single-layer EA-LSTM unrolled over time.
    """
    def __init__(self, input_size: int, hidden_size: int, static_size: int, dropout: float = 0.0):
        super().__init__()
        self.hidden_size = hidden_size
        self.cell = EALSTMCell(input_size, hidden_size, static_size)
        self.dropout = nn.Dropout(dropout) if dropout and dropout > 0 else nn.Identity()

    def forward(self, x_seq: Tensor, x_static: Tensor, h0: Tensor = None, c0: Tensor = None, return_sequences: bool = False):
        """
        x_seq:    (B, T, F) dynamic inputs
        x_static: (B, S)    static features
        h0, c0:   optional initial states, shape (B, H)
        """
        B, T, _ = x_seq.shape
        H = self.hidden_size
        device = x_seq.device

        if h0 is None:
            h_t = torch.zeros(B, H, device=device, dtype=x_seq.dtype)
        else:
            h_t = h0

        if c0 is None:
            c_t = torch.zeros(B, H, device=device, dtype=x_seq.dtype)
        else:
            c_t = c0

        outputs = []
        for t in range(T):
            x_t = x_seq[:, t, :]
            h_t, c_t = self.cell(x_t, x_static, h_t, c_t)
            h_t = self.dropout(h_t)
            if return_sequences:
                outputs.append(h_t)

        if return_sequences:
            # (B, T, H)
            h_seq = torch.stack(outputs, dim=1)
            return h_seq, (h_t, c_t)
        else:
            # Last hidden state (B, H)
            return h_t, (h_t, c_t)


class YieldNetEALSTM(nn.Module):
    """
    Yield predictor using an EA-LSTM over daily climate inputs, conditioned on static soil/site features.
    By default, predicts from the final hidden state. Optionally concatenate static features in the head.
    """
    def __init__(
        self,
        seq_input_size: int,     # F
        static_size: int,        # S
        hidden_size: int,        # H
        head_layers: list[int] = None,  # e.g., [H, 64, 1] or [H+S, 64, 1] if concat_static_in_head=True
        dropout: float = 0.0,
        concat_static_in_head: bool = False
    ):
        super().__init__()
        self.concat_static_in_head = concat_static_in_head
        self.ealstm = EALSTM(seq_input_size, hidden_size, static_size, dropout=dropout)

        if head_layers is None:
            in_size = hidden_size + (static_size if concat_static_in_head else 0)
            head_layers = [in_size, 1]

        linears = []
        for i in range(len(head_layers) - 1):
            linears.append(nn.Linear(head_layers[i], head_layers[i+1]))
            if i < len(head_layers) - 2:
                linears.append(nn.ReLU())
                if dropout and dropout > 0:
                    linears.append(nn.Dropout(dropout))
        self.head = nn.Sequential(*linears)

        # Init for head
        for m in self.head:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x_seq: Tensor, x_static: Tensor) -> Tensor:
        """
        x_seq:    (B, T, F)
        x_static: (B, S)
        returns:  (B,) yield prediction
        """
        h_last, _ = self.ealstm(x_seq, x_static, return_sequences=False)
        if self.concat_static_in_head:
            z = torch.cat([h_last, x_static], dim=-1)
        else:
            z = h_last
        y = self.head(z).squeeze(-1)
        return y

In [39]:
# YieldNet
class YieldNet(nn.Module):
    """
    Defines yield neural network. A LSTM processes the daily climate feature timeseries, while a parallel branch 
    handles fixed soil properties per location and year in a fully connected network (FCN). 
    The resuls of both branches are merged to a single yield prediction (again, per location and year).
    """
    def __init__(self, seq_input_size, seq_hidden_size, fcn_layers):
        super().__init__()

        self.seq_input_size = seq_input_size
        self.seq_hidden_size = seq_hidden_size

        self.lstm = nn.LSTM(input_size=seq_input_size, hidden_size=seq_hidden_size, num_layers=1, batch_first=True)
        self.seq_fcn = nn.Linear(self.seq_hidden_size, 1)

        self.fcn_layers = fcn_layers
        self.activation = nn.Tanh()
        self.linears = nn.ModuleList([nn.Linear(fcn_layers[i], fcn_layers[i+1]) for i in range(len(fcn_layers)-1)]) 

        for i in range(len(fcn_layers)-1):         
            nn.init.xavier_normal_(self.linears[i].weight.data, gain=1.0)            
            nn.init.zeros_(self.linears[i].bias.data)   
    
    def forward(self, x_seq, x_fcn):
        h0 = torch.autograd.Variable(torch.zeros(1, x_seq.shape[0], self.seq_hidden_size)).to(x_seq.device)
        c0 = torch.autograd.Variable(torch.zeros(1, x_seq.shape[0], self.seq_hidden_size)).to(x_seq.device)

        out, _ = self.lstm(x_seq, (h0, c0))
        out = self.seq_fcn(out[:, -1, :])

        a = x_fcn
        for i in range(len(self.fcn_layers) - 2):  
            z = self.linears[i](a)
            a = self.activation(z)
        a = self.linears[-1](a)

        return out.reshape(out.shape[0], out.shape[1]) * a

In [56]:
class ClimateDataset(Dataset):
    """
    The ClimateDataset class provides a convenient way for acessing, merging and scaling parquet data. 
    This is used in the main training loop to access features and target variables.
    """
    def __init__(self, crop: str, mode: str, data_dir: str, scalers: list = None):
        self.tasmax = pd.read_parquet(os.path.join(data_dir, f"tasmax_{crop}_{mode}.parquet"))
        self.tasmin = pd.read_parquet(os.path.join(data_dir, f"tasmin_{crop}_{mode}.parquet"))
        # self.tas = pd.read_parquet(os.path.join(data_dir, f"tas_{crop}_{mode}.parquet"))
        self.pr = pd.read_parquet(os.path.join(data_dir, f"pr_{crop}_{mode}.parquet"))
        self.rsds = pd.read_parquet(os.path.join(data_dir, f"rsds_{crop}_{mode}.parquet"))
        self.soil_co2 = pd.read_parquet(os.path.join(data_dir, f"soil_co2_{crop}_{mode}.parquet"))

        self.soil_co2['lon_sin'] = np.sin(self.soil_co2['lon'])*2*math.pi/365
        self.soil_co2['lon_cos'] = np.cos(self.soil_co2['lon'])*2*math.pi/365
        
        if mode == 'train':
            self.yield_ = pd.read_parquet(os.path.join(data_dir, f"{mode}_solutions_{crop}.parquet"))
        else:
            self.yield_ = None

        self.locs = self.tasmax[['lat','lon']].copy()
        # self.locs['ID'] = self.locs.index.copy()

        self.yield_loc = pd.merge(self.yield_, self.locs, on='ID')
        self.yield_var = self.yield_loc.groupby(['lon','lat'])['yield'].var()
        self.yield_loc_with_var = pd.merge(self.yield_loc, self.yield_var, on=['lon','lat'], suffixes=['','_var'])
        # self.yield_loc_with_var['yield_var'] = self.yield_loc_with_var['yield_var'].transform(lambda x: max(x, 0.1))
        # self.yield_loc_with_var['yield_var'] = self.yield_loc_with_var['yield_var'].transform(lambda x: (x - x.min()) / (x.max() - x.min()) + 1) 
        # self.yield_loc_with_var['log_yield'] = self.yield_loc_with_var['yield'].transform(lambda x: np.log(x+1))
        
        if scalers is None:
            self._init_scalers()
        else:
            self.scaler_climate, self.scaler_soil, self.scaler_yield = scalers
        self._check_data([self.tasmax, self.tasmin, self.pr, self.rsds, self.soil_co2], self.yield_)

    def __getitem__(self, index):
        # 240x4 climate matrix per location/year (features in last dimension by convention)
        climate = np.vstack([
            self.tasmax.iloc[index, 5:].astype(np.float32), 
            self.tasmin.iloc[index, 5:].astype(np.float32),
            self.pr.iloc[index, 5:].astype(np.float32),
            self.rsds.iloc[index, 5:].astype(np.float32),
        ]).T

        # Fixed soil properties per location/year
        soil = self.soil_co2.iloc[index][['co2', 'nitrogen','lat','lon_cos','lon_sin']].astype(np.float32)
        
        id = soil.name
        soil = soil.values

        # soil = np.concatenate([soil, loc_enc], axis=0)

        # Yield estimated by process model
        if self.yield_ is not None:

            # retrieving log yield
            yield_ = self.yield_loc_with_var.iloc[index,0].astype(np.float32)
            yield_ = self.scaler_yield.transform(yield_.reshape(1, -1)).reshape(-1)

            # yieldvar_ = self.yield_loc_with_var.iloc[index,-1].astype(np.float32)
            
        else:
            yield_ = None
            yieldvar_ = None

        climate = self.scaler_climate.transform(climate)
        soil = self.scaler_soil.transform(soil.reshape(1, -1)).reshape(-1)

        return torch.tensor(climate), torch.tensor(soil), torch.tensor(yield_ or []), id
        # return torch.tensor(climate), torch.tensor(soil), torch.tensor(yield_ or []), id, torch.tensor(yieldvar_ or [])

    def __len__(self):
        return self.tasmax.shape[0]
        # return 1000

    def _init_scalers(self):
        # Draw random sample from climate data to estimate distribution moments for scaler.
        climate_sample = np.vstack([
            self.tasmax.sample(1000).iloc[:, 5:].values.flatten(), 
            self.tasmin.sample(1000).iloc[:, 5:].values.flatten(),
            self.pr.sample(1000).iloc[:, 5:].values.flatten(),
            self.rsds.sample(1000).iloc[:, 5:].values.flatten(),
        ]).T
        self.scaler_climate = StandardScaler()
        self.scaler_climate.fit(climate_sample)

        # Scaler for fixed soil properties
        self.scaler_soil = StandardScaler()
        self.scaler_soil.fit(self.soil_co2[['co2', 'nitrogen','lat','lon_cos','lon_sin']].values)

        # Scaler for yield
        self.scaler_yield = StandardScaler()
        if self.yield_ is not None:
            self.scaler_yield.fit(self.yield_.values)

    def _check_data(self, climate: list, yield_: pd.DataFrame) -> bool:
        # Check for matching year, lon, lat columns
        for i in range(1, len(climate)):
            assert np.all(climate[0][['year', 'lon', 'lat']] == climate[i][['year', 'lon', 'lat']])
        # Check label for matching length
        assert yield_ is None or climate[0].shape[0] == yield_.shape[0]

In [48]:
# Intialize model, data loader, loss and optimizer
ds_train = ClimateDataset('maize', 'train', data_dir=DATA_DIR)
train_loader = DataLoader(ds_train, batch_size=64, shuffle=True, num_workers=8)



In [57]:
seq_feat = ds_train[1][0].shape[-1]   # e.g., 4
H = 128
S = ds_train[1][1].shape[-1]   

S, H, seq_feat

(5, 128, 4)

In [58]:
model = YieldNetEALSTM(
    seq_input_size=seq_feat,
    static_size=S,
    hidden_size=H,
    head_layers=[H + S, 64, 1],  # <— 130 here
    dropout=0.1,
    concat_static_in_head=True
).to(device)

# model = YieldNet(seq_input_size=4, seq_hidden_size=128, fcn_layers=[S, 32, 32, 1]).to(device)

cost_mse = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def model_device(m):
    return next(m.parameters()).device

In [59]:
load_checkpoint = False

if load_checkpoint:

    checkpoint = torch.load("/kaggle/working/checkpoints/lstm_location/epoch_3.pt", map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1

In [60]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [61]:
count_parameters(model)

77889

In [62]:
# task[0].shape, task[1], task[2], task[3]
ds_train[3000][1:]

(tensor([-1.5227, -0.8337,  0.4713, -0.5730, -1.2823]),
 tensor([-1.0632]),
 3000)

In [None]:
# Directory to save checkpoints
ckpt_dir = "./checkpoints/ealstm_location"
os.makedirs(ckpt_dir, exist_ok=True)

# Main training loop
# In this example the model is only trained on maize data!
epoch_loss = []
num_epochs = 3
# pbar = tqdm(range(num_epochs))
pbar = range(num_epochs)

for epoch in range(num_epochs):

    start = time.time()

    loss = []
    model.train()
    for i, (climate, soil, yield_true, _) in enumerate(tqdm(train_loader, leave=False)):
        dev = model_device(model)
        optimizer.zero_grad()
        yield_pred = model(climate.to(device), soil.to(device))
        data_loss = cost_mse(yield_pred, yield_true.to(device))
        # data_loss = ((yield_pred - yield_true.to(device))**2 / yield_var.to(device)).mean()
        data_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # optional but helpful
        optimizer.step()

        loss.append(data_loss.item())

        if i % 100 == 0:
            print(f"{i} -- loss: {np.mean(loss):5.10f}")

    epoch_loss.append(np.mean(loss))
    print(f"Epoch {epoch} - loss: {np.mean(loss):5.10f}")
    print(f"Elapsed: {(time.time() - start):5.2f}")

    # --- Save checkpoint ---
    ckpt_path = os.path.join(ckpt_dir, f"epoch_{epoch+1}.pt")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': np.mean(loss),
    }, ckpt_path)
    print(f"Checkpoint saved: {ckpt_path}")

  0%|          | 0/5465 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


0 -- loss: 1.0978884697
100 -- loss: 1.0220752506
200 -- loss: 1.0161526310
300 -- loss: 1.0103548977
400 -- loss: 1.0029603708
500 -- loss: 1.0031274621
600 -- loss: 1.0019174413
700 -- loss: 0.9965851357
800 -- loss: 0.9985041705
900 -- loss: 0.9987341731
1000 -- loss: 0.9993326206
1100 -- loss: 0.9986532334
1200 -- loss: 0.9987741121
1300 -- loss: 0.9990034068
1400 -- loss: 0.9974697124
1500 -- loss: 0.9996482654
1600 -- loss: 0.9986375407
1700 -- loss: 0.9987637310
1800 -- loss: 0.9994348737
1900 -- loss: 1.0004388048
2000 -- loss: 0.9999675373
2100 -- loss: 0.9985731284
2200 -- loss: 0.9995767428
2300 -- loss: 0.9994997172
2400 -- loss: 0.9993758010
2500 -- loss: 0.9989218766
2600 -- loss: 0.9979687307
2700 -- loss: 0.9972301224
2800 -- loss: 0.9975232727
2900 -- loss: 0.9965779131
3000 -- loss: 0.9976155941
3100 -- loss: 0.9978519903
3200 -- loss: 0.9986435981
3300 -- loss: 0.9980051999
3400 -- loss: 0.9980620271
3500 -- loss: 0.9993258882
3600 -- loss: 0.9991237778
3700 -- loss:

In [None]:
# Predict test set for submission
ds_test_maize = ClimateDataset('maize', 'test', data_dir=DATA_DIR, scalers=(ds_train.scaler_climate, ds_train.scaler_soil, ds_train.scaler_yield))
# ds_test_wheat = ClimateDataset('wheat', 'test', data_dir=DATA_DIR, scalers=(ds_train.scaler_climate, ds_train.scaler_soil, ds_train.scaler_yield))

test_loader_maize = DataLoader(ds_test_maize, batch_size=400, shuffle=False, num_workers=6)
# test_loader_wheat = DataLoader(ds_test_wheat, batch_size=100, shuffle=False, num_workers=4)

In [None]:
ids = []
yields_pred = []

model.eval()

for i, (climate, soil, _, id) in enumerate(tqdm(test_loader_maize)):
    start = time.time()
    ids.append(id.detach().numpy())
    yields_pred.append(model(climate.to(device), soil.to(device)).detach().cpu().numpy())
    if i % 100 == 0:
        print(i)
        print(f"Time elapsed: {time.time() - start}")

# for i, (climate, soil, _, id) in enumerate(tqdm(test_loader_wheat)):
#     ids.append(id.detach().numpy())
#     yields_pred.append(model(climate.to(device), soil.to(device)).detach().cpu().numpy())
#     if i % 100 == 0:
#         print(i)

In [None]:
# # Concatenate all batches
# ids = np.concatenate(ids_list, axis=0)
# yields_pred = np.concatenate(yields_list, axis=0)

In [None]:
yields_pred = np.concatenate(yields_pred)
yields_pred = ds_train.scaler_yield.inverse_transform(yields_pred)
yields_pred = yields_pred.reshape(-1)

In [None]:
len(yields_pred)

In [None]:
predictions = pd.Series(yields_pred, index=np.concatenate(ids))
predictions.index.name = 'ID'
predictions.name = 'yield'
predictions.to_csv('/kaggle/working/last_loc_6ep_maize_submission.csv')

In [None]:
!ls