In [5]:
# Imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import sys
import pandas as pd

In [6]:
# Dataset preparations -- GitHub
class YieldDataset(Dataset):
    def __init__(self, monthly, yearly, static, target):
        self.monthly = monthly      # shape: (N, seq_len, monthly_dim)
        self.yearly = yearly        # shape: (N, yearly_dim)
        self.static = static        # shape: (N, static_dim)
        self.target = target        # shape: (N, 1)

    def __len__(self):
        return len(self.target)

    def __getitem__(self, idx):
        return (
            torch.tensor(self.monthly[idx], dtype=torch.float32),
            torch.tensor(self.yearly[idx], dtype=torch.float32),
            torch.tensor(self.static[idx], dtype=torch.float32),
            torch.tensor(self.target[idx], dtype=torch.float32),
        )

In [3]:
# Model
# Model assumes that there is at least 1 dimension in monthly, yearly, and static
class YieldLSTMMLPConnected(nn.Module):
    def __init__(self,
                    monthly_dim=7,     # Avg's by month (seq features, should be 7 features from prism data)
                    monthly_hidden=64,
                    monthly_layers=1,
                    yearly_dim=5,    # number of yearly features (should be 5 features, 4 from nasa + the annual yield by county)
                    static_dim=8,    # number of static features (should be 8 features from soil data)
                    yearly_hidden=32,
                    static_hidden=32,
                    head_hidden=64,
                    output_dim=1,
                    dropout=0.1
                    ):
        super().__init__()

        # Monthly branch LSTM
        self.lstm = nn.LSTM(
            input_size=monthly_dim,
            hidden_size=monthly_hidden,
            num_layers=monthly_layers,
            batch_first=True,
            bidirectional=False
        )
        self.monthly_proj = nn.Sequential(
            nn.Linear(monthly_hidden, monthly_hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # Yearly branch MLP
        self.yearly_proj = nn.Sequential(
            nn.Linear(yearly_dim, yearly_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(yearly_hidden, yearly_hidden),
            nn.ReLU(),
        )

        # Static branch MLP
        self.static_proj = nn.Sequential(
            nn.Linear(static_dim, static_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(static_hidden, static_hidden),
            nn.ReLU(),
        )

        # Combined head (output of combined branches)
        combined_dim = monthly_hidden + yearly_hidden + static_hidden
        self.head = nn.Sequential(
            nn.Linear(combined_dim, head_hidden),
            nn.ReLU(),
            nn.Droput(dropout),
            nn.Linear(head_hidden, head_hidden//2),
            nn.ReLU(),
            nn.Linear(head_hidden//2, output_dim)
        )

    def forward(self, monthly, yearly, static):
        feats = []

        # Monthly shape (batch, seq_len, monthly_dim)
        # LSTM: take last hidden state
        lstm_out, (h_n, c_n) = self.lstm(monthly)
        # h_n shape: (num_layers, batch, hidden)
        last_h = h_n[-1] # (batch, monthly_hidden)
        monthly_emb = self.monthly_proj(last_h)
        feats.append(monthly_emb)

        feats.append(self.yearly_proj(yearly))

        feats.append(self.static_proj(static))

        combined = torch.cat(feats, dim=1)
        out = self.head(combined)
        return out

In [7]:
# Training
def train_model(
    model,
    train_loader,
    val_loader,
    num_epochs=50,
    lr=1e-3,
    weight_decay=1e-5,
    device="cuda" if torch.cuda.is_available() else "cpu",
    early_stop_patience=8
):

    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    best_val_loss = float("inf")
    patience_counter = 0

    for epoch in range(1, num_epochs+1):
        # -------- TRAIN MODE --------
        model.train()
        train_losses = []

        for monthly, yearly, static, target in train_loader:
            monthly = monthly.to(device)
            yearly = yearly.to(device)
            static = static.to(device)
            target = target.to(device)

            optimizer.zero_grad()
            preds = model(monthly, yearly, static)

            loss = criterion(preds, target)
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

        # -------- VAL MODE --------
        model.eval()
        val_losses = []

        with torch.no_grad():
            for monthly, yearly, static, target in val_loader:
                monthly = monthly.to(device)
                yearly = yearly.to(device)
                static = static.to(device)
                target = target.to(device)

                preds = model(monthly, yearly, static)
                loss = criterion(preds, target)
                val_losses.append(loss.item())

        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)

        print(f"Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        # ---- EARLY STOP ----
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_yield_model.pt")
        else:
            patience_counter += 1
            if patience_counter >= early_stop_patience:
                print("Early stopping triggered!")
                break

    print("Training completed. Best model saved as best_yield_model.pt")