## Import libraries


In [1]:
import os
import sys
import time

source_folder = "/beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src"
sys.path.append(source_folder)

import config.winter_wheat as cfg
import numpy as np
import torch
from config.winter_wheat import model_config, train_config
from dataset.dataset import CropFusionNetDataset
from loss.loss import QuantileLoss
from models.CropFusionNet.model import CropFusionNet
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils.utils import set_seed

device = model_config["device"]
set_seed(42)

## Create datasets and dataloaders


In [2]:
train_dataset = CropFusionNetDataset(cfg, mode="train", scale=True)
val_dataset = CropFusionNetDataset(cfg, mode="val", scale=True)

train_loader = DataLoader(
    train_dataset,
    batch_size=train_config["batch_size"],
    shuffle=True,
    num_workers=32,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=train_config["batch_size"],
    shuffle=False,
    num_workers=32,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)

‚ö†Ô∏è Filtered dataset: Dropped 13 samples missing from Yield Table.


## Model, optimizer and loss


In [3]:
model = CropFusionNet(model_config).to(device)
criterion = QuantileLoss(quantiles=model_config["quantiles"]).to(device)
optimizer = Adam(
    model.parameters(), lr=train_config["lr"], weight_decay=train_config["weight_decay"]
)
num_epochs = train_config.get("num_epochs", 50)
patience = train_config.get("early_stopping_patience", 10)
batch_size = train_config.get("batch_size", 32)

# Learning rate scheduler
scheduler = ReduceLROnPlateau(
    optimizer,
    mode="min",  # minimize validation loss
    factor=0.5,  # reduce LR by 50%
    patience=3,  # wait for 3 epochs before reducing
    threshold=1e-4,  # minimal improvement threshold
    min_lr=1e-6,  # lower bound for learning rate
)

## Training


In [4]:
def train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    num_epochs,
    patience,
    scheduler=None,
    checkpoint_dir="checkpoints",
    exp_name="CropFusionNet_experiment",
):
    # 1. Setup Logging
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    log_id = f"run_{exp_name}_{timestamp}"
    log_dir = os.path.join("runs", log_id)
    writer = SummaryWriter(log_dir=log_dir)

    save_folder = os.path.join(checkpoint_dir, log_id)
    os.makedirs(save_folder, exist_ok=True)

    print(f"üìò TensorBoard logs: {log_dir}")
    print(f"üíæ Checkpoints: {save_folder}")

    best_val_loss = np.inf
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(1, num_epochs + 1):
        start_time = time.time()

        # --- TRAINING PHASE ---
        model.train()
        train_loss_accum = 0.0

        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [Train]")
        for batch in train_pbar:
            optimizer.zero_grad()

            # Move inputs to device
            inputs = {
                "inputs": batch["inputs"].to(device),
                "identifier": batch["identifier"].to(device),
                "mask": batch["mask"].to(device),
                "variable_mask": (
                    batch.get("variable_mask").to(device)
                    if batch.get("variable_mask") is not None
                    else None
                ),
            }
            targets = batch["target"].to(device)

            # Forward Pass
            output_dict = model(inputs)
            preds = output_dict["prediction"]

            # Loss Calculation
            loss = criterion(preds, targets)

            # Backward Pass
            loss.backward()

            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

            # Optimization Step
            optimizer.step()

            train_loss_accum += loss.item()
            train_pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        avg_train_loss = train_loss_accum / len(train_loader)

        # --- VALIDATION PHASE ---
        model.eval()
        val_loss_accum = 0.0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                inputs = {
                    "inputs": batch["inputs"].to(device),
                    "identifier": batch["identifier"].to(device),
                    "mask": batch["mask"].to(device),
                    "variable_mask": (
                        batch.get("variable_mask").to(device)
                        if batch.get("variable_mask") is not None
                        else None
                    ),
                }
                targets = batch["target"].to(device)

                output_dict = model(inputs)
                preds = output_dict["prediction"]

                loss = criterion(preds, targets)
                val_loss_accum += loss.item()

        avg_val_loss = val_loss_accum / len(val_loader)

        # --- LOGGING & SCHEDULING ---
        elapsed = time.time() - start_time
        current_lr = optimizer.param_groups[0]["lr"]

        print(
            f"Epoch {epoch:03d} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f} | LR: {current_lr:.2e} | T: {elapsed:.1f}s"
        )

        writer.add_scalars(
            "Loss", {"Train": avg_train_loss, "Val": avg_val_loss}, epoch
        )
        writer.add_scalar("LR", current_lr, epoch)

        if scheduler:
            scheduler.step(avg_val_loss)

        # Early Stopping
        if avg_val_loss < best_val_loss - 1e-4:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict()
            torch.save(best_model_state, os.path.join(save_folder, "best_model.pt"))
            print(f"‚ú® New best model saved.")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"‚èπÔ∏è Early stopping at epoch {epoch}")
                break

    writer.close()
    return best_model_state

In [7]:
# train_model(
#     model,
#     train_loader,
#     val_loader,
#     criterion,
#     optimizer,
#     device,
#     num_epochs,
#     patience,
#     scheduler,
#     exp_name=train_config["exp_name"],
# )