## Import libraries


In [1]:
import importlib
import os
import sys
import time

source_folder = "/beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src"
sys.path.append(source_folder)

import numpy as np
import torch
from dataset.dataset import CropFusionNetDataset
from loss.loss import QuantileLoss
from models.CropFusionNet.model import CropFusionNet
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils.utils import set_seed, evaluate_and_save_outputs, load_config, save_config

# Crop
crop = "winter_rapeseed"
cfg, model_config, train_config = load_config(crop)

device = model_config["device"]
set_seed(42)

## Create datasets and dataloaders


In [2]:
train_dataset = CropFusionNetDataset(cfg, mode="train", scale=True)
val_dataset = CropFusionNetDataset(cfg, mode="val", scale=True)
test_dataset = CropFusionNetDataset(cfg, mode="test", scale=True)

train_loader = DataLoader(
    train_dataset,
    batch_size=train_config["batch_size"],
    shuffle=True,
    num_workers=32,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=train_config["batch_size"],
    shuffle=False,
    num_workers=32,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=train_config["batch_size"],
    shuffle=False,
    num_workers=16,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
)

## Model, optimizer and loss


In [3]:
model = CropFusionNet(model_config).to(device)
criterion = QuantileLoss(quantiles=model_config["quantiles"]).to(device)
optimizer = Adam(
    model.parameters(), lr=train_config["lr"], weight_decay=train_config["weight_decay"]
)
num_epochs = train_config.get("num_epochs", 50)
patience = train_config.get("early_stopping_patience", 10)
batch_size = train_config.get("batch_size", 32)

# Learning rate scheduler
scheduler = ReduceLROnPlateau(
    optimizer,
    mode="min",  # minimize validation loss
    factor=0.5,  # reduce LR by 50%
    patience=3,  # wait for 3 epochs before reducing
    threshold=1e-4,  # minimal improvement threshold
    min_lr=1e-6,  # lower bound for learning rate
)

## Training


In [4]:
def train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    num_epochs,
    patience,
    scheduler=None,
    checkpoint_dir="checkpoints",
    exp_name="CropFusionNet_experiment",
):
    # 1. Setup Logging
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    log_id = f"run_{exp_name}_{timestamp}"
    log_dir = os.path.join("runs", log_id)
    writer = SummaryWriter(log_dir=log_dir)

    save_folder = os.path.join(checkpoint_dir, log_id)
    os.makedirs(save_folder, exist_ok=True)

    print(f"üìò TensorBoard logs: {log_dir}")
    print(f"üíæ Checkpoints: {save_folder}")

    best_val_loss = np.inf
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(1, num_epochs + 1):
        start_time = time.time()

        # --- TRAINING PHASE ---
        model.train()
        train_loss_accum = 0.0

        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [Train]")
        for batch in train_pbar:
            optimizer.zero_grad()

            # Move inputs to device
            inputs = {
                "inputs": batch["inputs"].to(device),
                "identifier": batch["identifier"].to(device),
                "mask": batch["mask"].to(device),
                "variable_mask": (
                    batch.get("variable_mask").to(device)
                    if batch.get("variable_mask") is not None
                    else None
                ),
            }
            targets = batch["target"].to(device)

            # Forward Pass
            output_dict = model(inputs)
            preds = output_dict["prediction"]

            # Loss Calculation
            loss = criterion(preds, targets)

            # Backward Pass
            loss.backward()

            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

            # Optimization Step
            optimizer.step()

            train_loss_accum += loss.item()
            train_pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        avg_train_loss = train_loss_accum / len(train_loader)

        # --- VALIDATION PHASE ---
        model.eval()
        val_loss_accum = 0.0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                inputs = {
                    "inputs": batch["inputs"].to(device),
                    "identifier": batch["identifier"].to(device),
                    "mask": batch["mask"].to(device),
                    "variable_mask": (
                        batch.get("variable_mask").to(device)
                        if batch.get("variable_mask") is not None
                        else None
                    ),
                }
                targets = batch["target"].to(device)

                output_dict = model(inputs)
                preds = output_dict["prediction"]

                loss = criterion(preds, targets)
                val_loss_accum += loss.item()

        avg_val_loss = val_loss_accum / len(val_loader)

        # --- LOGGING & SCHEDULING ---
        elapsed = time.time() - start_time
        current_lr = optimizer.param_groups[0]["lr"]

        print(
            f"Epoch {epoch:03d} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f} | LR: {current_lr:.2e} | T: {elapsed:.1f}s"
        )

        writer.add_scalars(
            "Loss", {"Train": avg_train_loss, "Val": avg_val_loss}, epoch
        )
        writer.add_scalar("LR", current_lr, epoch)

        if scheduler:
            scheduler.step(avg_val_loss)

        # Early Stopping
        if avg_val_loss < best_val_loss - 1e-4:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict()
            torch.save(best_model_state, os.path.join(save_folder, "best_model.pt"))
            print(f"‚ú® New best model saved.")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"‚èπÔ∏è Early stopping at epoch {epoch}")
                break

    writer.close()
    return best_val_loss

In [5]:
train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    num_epochs,
    patience,
    scheduler,
    exp_name=train_config["exp_name"],
)

üìò TensorBoard logs: runs/run_exp_winter_rapeseed_Jul_20260301-115050
üíæ Checkpoints: checkpoints/run_exp_winter_rapeseed_Jul_20260301-115050


Epoch 1/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:14<00:00, 12.05it/s, loss=0.9908]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:02<00:00,  9.20it/s]


Epoch 001 | Train: 0.7504 | Val: 0.7190 | LR: 1.00e-04 | T: 16.9s
‚ú® New best model saved.


Epoch 2/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:12<00:00, 13.94it/s, loss=0.6566]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.20it/s]


Epoch 002 | Train: 0.6871 | Val: 0.8174 | LR: 1.00e-04 | T: 13.6s


Epoch 3/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:13<00:00, 12.91it/s, loss=0.4182]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.87it/s]


Epoch 003 | Train: 0.6222 | Val: 0.8332 | LR: 1.00e-04 | T: 14.6s


Epoch 4/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:12<00:00, 14.03it/s, loss=0.3757]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.55it/s]


Epoch 004 | Train: 0.5497 | Val: 0.7173 | LR: 1.00e-04 | T: 13.5s
‚ú® New best model saved.


Epoch 5/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:13<00:00, 13.46it/s, loss=0.5717]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.81it/s]


Epoch 005 | Train: 0.5000 | Val: 0.7958 | LR: 1.00e-04 | T: 14.0s


Epoch 6/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:13<00:00, 13.46it/s, loss=0.3174]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.52it/s]


Epoch 006 | Train: 0.4748 | Val: 0.4989 | LR: 1.00e-04 | T: 14.1s
‚ú® New best model saved.


Epoch 7/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:12<00:00, 14.16it/s, loss=0.2904]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 22.30it/s]


Epoch 007 | Train: 0.4563 | Val: 0.4888 | LR: 1.00e-04 | T: 13.4s
‚ú® New best model saved.


Epoch 8/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:12<00:00, 14.10it/s, loss=0.4949]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.53it/s]


Epoch 008 | Train: 0.4418 | Val: 0.5151 | LR: 1.00e-04 | T: 13.5s


Epoch 9/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:12<00:00, 14.00it/s, loss=0.4542]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.86it/s]


Epoch 009 | Train: 0.4373 | Val: 0.4991 | LR: 1.00e-04 | T: 13.5s


Epoch 10/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:12<00:00, 14.15it/s, loss=0.3722]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.68it/s]


Epoch 010 | Train: 0.4221 | Val: 0.5018 | LR: 1.00e-04 | T: 13.4s


Epoch 11/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:12<00:00, 14.11it/s, loss=0.3090]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 22.25it/s]


Epoch 011 | Train: 0.4176 | Val: 0.5066 | LR: 1.00e-04 | T: 13.4s


Epoch 12/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:13<00:00, 13.49it/s, loss=0.4328]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.96it/s]


Epoch 012 | Train: 0.4023 | Val: 0.5214 | LR: 5.00e-05 | T: 14.0s


Epoch 13/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:12<00:00, 14.17it/s, loss=0.3116]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.93it/s]


Epoch 013 | Train: 0.3941 | Val: 0.5392 | LR: 5.00e-05 | T: 13.4s


Epoch 14/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:12<00:00, 13.57it/s, loss=0.2083]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.76it/s]


Epoch 014 | Train: 0.3885 | Val: 0.5370 | LR: 5.00e-05 | T: 13.9s


Epoch 15/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:12<00:00, 13.92it/s, loss=0.6822]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.98it/s]


Epoch 015 | Train: 0.3870 | Val: 0.5359 | LR: 5.00e-05 | T: 13.6s


Epoch 16/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:12<00:00, 13.82it/s, loss=0.2316]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:00<00:00, 21.30it/s]


Epoch 016 | Train: 0.3752 | Val: 0.5077 | LR: 2.50e-05 | T: 13.7s


Epoch 17/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:13<00:00, 13.17it/s, loss=0.2156]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21/21 [00:01<00:00, 20.73it/s]

Epoch 017 | Train: 0.3716 | Val: 0.5090 | LR: 2.50e-05 | T: 14.4s
‚èπÔ∏è Early stopping at epoch 17





0.48875375588734943

## Save the trained model, config, and the outputs


In [6]:
# Save the trained model
output_dir = os.path.join(source_folder, "train", "forecast", crop, cfg.forecast_month)
os.makedirs(output_dir, exist_ok=True)

model_save_path = os.path.join(output_dir, f"best_model.pt")
torch.save(model.state_dict(), model_save_path)
print(f"üíæ Trained model saved to {model_save_path}")

# Save outputs
print("üîç Evaluating and saving outputs...")

# Evaluate and save outputs for train, validation, and test datasets
evaluate_and_save_outputs(model, train_loader, criterion, device, output_dir, "train")
evaluate_and_save_outputs(
    model, val_loader, criterion, device, output_dir, "validation"
)
evaluate_and_save_outputs(model, test_loader, criterion, device, output_dir, "test")

# Save the model config
save_config(train_config, model_config, output_dir)

üíæ Trained model saved to /beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src/train/forecast/winter_barley/Jul/best_model.pt
üîç Evaluating and saving outputs...


Evaluating train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:09<00:00,  4.86it/s]


Train Loss: 0.3498
Outputs saved to: /beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src/train/forecast/winter_barley/Jul/train_outputs.pkl


Evaluating validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:03<00:00,  2.23it/s]


Validation Loss: 0.5149
Outputs saved to: /beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src/train/forecast/winter_barley/Jul/validation_outputs.pkl


Evaluating test: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:02<00:00,  1.49it/s]


Test Loss: 0.6741
Outputs saved to: /beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src/train/forecast/winter_barley/Jul/test_outputs.pkl
Config saved to: /beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src/train/forecast/winter_barley/Jul/config.json
