## Import libraries


In [19]:
import importlib
import os
import sys
import time

source_folder = "/beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src"
sys.path.append(source_folder)

import config.winter_wheat as cfg
import numpy as np
import torch
from config.winter_wheat import model_config, train_config
from dataset.dataset import CropFusionNetDataset
from loss.loss import QuantileLoss
from models.AttnLSTM.model import AttnLSTM
from models.ResCNN.model import ResCNN
from models.SimpleTransformer.model import SimpleTransformer
from models.VanillaLSTM.model import VanillaLSTM
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils.utils import set_seed, evaluate_and_save_outputs, load_config, save_config

# Crop
crop = "winter_barley"
cfg, model_config, train_config = load_config(crop)

device = model_config["device"]
set_seed(42)
baseline_model_name = "ResCNN"

## Create datasets and dataloaders


In [20]:
train_dataset = CropFusionNetDataset(cfg, mode="train", scale=True)
val_dataset = CropFusionNetDataset(cfg, mode="val", scale=True)
test_dataset = CropFusionNetDataset(cfg, mode="test", scale=True)

train_loader = DataLoader(
    train_dataset,
    batch_size=train_config["batch_size"],
    shuffle=True,
    num_workers=32,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=train_config["batch_size"],
    shuffle=False,
    num_workers=32,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=train_config["batch_size"],
    shuffle=False,
    num_workers=16,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
)

## Model, optimizer and loss


In [22]:
model = ResCNN(model_config).to(device)
criterion = QuantileLoss(quantiles=model_config["quantiles"]).to(device)
optimizer = Adam(
    model.parameters(), lr=train_config["lr"], weight_decay=train_config["weight_decay"]
)
num_epochs = train_config.get("num_epochs", 50)
patience = train_config.get("early_stopping_patience", 10)
batch_size = train_config.get("batch_size", 32)

# Learning rate scheduler
scheduler = ReduceLROnPlateau(
    optimizer,
    mode="min",  # minimize validation loss
    factor=0.5,  # reduce LR by 50%
    patience=3,  # wait for 3 epochs before reducing
    threshold=1e-4,  # minimal improvement threshold
    min_lr=1e-6,  # lower bound for learning rate
)

## Training


In [23]:
def train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    num_epochs,
    patience,
    scheduler=None,
    checkpoint_dir="checkpoints",
    exp_name="CropFusionNet_experiment",
):
    # 1. Setup Logging
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    log_id = f"run_{exp_name}_{timestamp}"
    log_dir = os.path.join("runs", log_id)
    writer = SummaryWriter(log_dir=log_dir)

    save_folder = os.path.join(checkpoint_dir, log_id)
    os.makedirs(save_folder, exist_ok=True)

    print(f"üìò TensorBoard logs: {log_dir}")
    print(f"üíæ Checkpoints: {save_folder}")

    best_val_loss = np.inf
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(1, num_epochs + 1):
        start_time = time.time()

        # --- TRAINING PHASE ---
        model.train()
        train_loss_accum = 0.0

        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [Train]")
        for batch in train_pbar:
            optimizer.zero_grad()

            # Move inputs to device
            inputs = {
                "inputs": batch["inputs"].to(device),
                "identifier": batch["identifier"].to(device),
                "mask": batch["mask"].to(device),
                "variable_mask": (
                    batch.get("variable_mask").to(device)
                    if batch.get("variable_mask") is not None
                    else None
                ),
            }
            targets = batch["target"].to(device)

            # Forward Pass
            output_dict = model(inputs)
            preds = output_dict["prediction"]

            # Loss Calculation
            loss = criterion(preds, targets)

            # Backward Pass
            loss.backward()

            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

            # Optimization Step
            optimizer.step()

            train_loss_accum += loss.item()
            train_pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        avg_train_loss = train_loss_accum / len(train_loader)

        # --- VALIDATION PHASE ---
        model.eval()
        val_loss_accum = 0.0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                inputs = {
                    "inputs": batch["inputs"].to(device),
                    "identifier": batch["identifier"].to(device),
                    "mask": batch["mask"].to(device),
                    "variable_mask": (
                        batch.get("variable_mask").to(device)
                        if batch.get("variable_mask") is not None
                        else None
                    ),
                }
                targets = batch["target"].to(device)

                output_dict = model(inputs)
                preds = output_dict["prediction"]

                loss = criterion(preds, targets)
                val_loss_accum += loss.item()

        avg_val_loss = val_loss_accum / len(val_loader)

        # --- LOGGING & SCHEDULING ---
        elapsed = time.time() - start_time
        current_lr = optimizer.param_groups[0]["lr"]

        print(
            f"Epoch {epoch:03d} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f} | LR: {current_lr:.2e} | T: {elapsed:.1f}s"
        )

        writer.add_scalars(
            "Loss", {"Train": avg_train_loss, "Val": avg_val_loss}, epoch
        )
        writer.add_scalar("LR", current_lr, epoch)

        if scheduler:
            scheduler.step(avg_val_loss)

        # Early Stopping
        if avg_val_loss < best_val_loss - 1e-4:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict()
            torch.save(best_model_state, os.path.join(save_folder, "best_model.pt"))
            print(f"‚ú® New best model saved.")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"‚èπÔ∏è Early stopping at epoch {epoch}")
                break

    writer.close()
    return best_val_loss

In [24]:
train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    num_epochs,
    patience,
    scheduler,
    exp_name=f"{train_config["exp_name"]}_{baseline_model_name}_Baseline",
)

üìò TensorBoard logs: runs/run_exp_winter_barley_Jul_ResCNN_Baseline_20260301-113405
üíæ Checkpoints: checkpoints/run_exp_winter_barley_Jul_ResCNN_Baseline_20260301-113405


Epoch 1/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:04<00:00,  9.44it/s, loss=0.5911]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:02<00:00,  2.72it/s]


Epoch 001 | Train: 0.6724 | Val: 0.7464 | LR: 1.00e-04 | T: 7.5s
‚ú® New best model saved.


Epoch 2/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:04<00:00, 11.29it/s, loss=0.4895]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.94it/s]


Epoch 002 | Train: 0.5572 | Val: 0.6966 | LR: 1.00e-04 | T: 5.5s
‚ú® New best model saved.


Epoch 3/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 12.03it/s, loss=0.5635]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.86it/s]


Epoch 003 | Train: 0.5338 | Val: 0.6462 | LR: 1.00e-04 | T: 5.3s
‚ú® New best model saved.


Epoch 4/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 11.78it/s, loss=0.4871]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.96it/s]


Epoch 004 | Train: 0.5166 | Val: 0.6208 | LR: 1.00e-04 | T: 5.3s
‚ú® New best model saved.


Epoch 5/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 11.95it/s, loss=0.4662]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.69it/s]


Epoch 005 | Train: 0.5033 | Val: 0.6254 | LR: 1.00e-04 | T: 5.3s


Epoch 6/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 12.20it/s, loss=0.4712]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.69it/s]


Epoch 006 | Train: 0.4904 | Val: 0.6479 | LR: 1.00e-04 | T: 5.3s


Epoch 7/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 11.93it/s, loss=0.5461]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.91it/s]


Epoch 007 | Train: 0.4811 | Val: 0.6660 | LR: 1.00e-04 | T: 5.3s


Epoch 8/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 12.09it/s, loss=0.4487]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  5.22it/s]


Epoch 008 | Train: 0.4699 | Val: 0.6265 | LR: 1.00e-04 | T: 5.2s


Epoch 9/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 12.31it/s, loss=0.4351]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  5.09it/s]


Epoch 009 | Train: 0.4551 | Val: 0.6165 | LR: 5.00e-05 | T: 5.1s
‚ú® New best model saved.


Epoch 10/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 12.36it/s, loss=0.4598]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.43it/s]


Epoch 010 | Train: 0.4534 | Val: 0.6533 | LR: 5.00e-05 | T: 5.3s


Epoch 11/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 12.24it/s, loss=0.5295]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  5.04it/s]


Epoch 011 | Train: 0.4484 | Val: 0.6306 | LR: 5.00e-05 | T: 5.2s


Epoch 12/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 12.37it/s, loss=0.4802]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  5.07it/s]


Epoch 012 | Train: 0.4466 | Val: 0.6399 | LR: 5.00e-05 | T: 5.1s


Epoch 13/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 12.32it/s, loss=0.3549]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.96it/s]


Epoch 013 | Train: 0.4380 | Val: 0.6510 | LR: 5.00e-05 | T: 5.1s


Epoch 14/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 12.47it/s, loss=0.3996]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.89it/s]


Epoch 014 | Train: 0.4301 | Val: 0.6516 | LR: 2.50e-05 | T: 5.1s


Epoch 15/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 11.99it/s, loss=0.4471]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  5.02it/s]


Epoch 015 | Train: 0.4278 | Val: 0.6386 | LR: 2.50e-05 | T: 5.2s


Epoch 16/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 12.01it/s, loss=0.4437]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.97it/s]


Epoch 016 | Train: 0.4241 | Val: 0.6531 | LR: 2.50e-05 | T: 5.2s


Epoch 17/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 11.99it/s, loss=0.3945]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.85it/s]


Epoch 017 | Train: 0.4220 | Val: 0.6445 | LR: 2.50e-05 | T: 5.3s


Epoch 18/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 11.79it/s, loss=0.3807]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.71it/s]


Epoch 018 | Train: 0.4218 | Val: 0.6417 | LR: 1.25e-05 | T: 5.4s


Epoch 19/500 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 11.83it/s, loss=0.4932]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  4.64it/s]

Epoch 019 | Train: 0.4214 | Val: 0.6705 | LR: 1.25e-05 | T: 5.4s
‚èπÔ∏è Early stopping at epoch 19





0.6165075855595725

## Save the trained model, config, and the outputs


In [25]:
# Save the trained model
output_dir = os.path.join(
    source_folder, "train", "baseline", crop, cfg.forecast_month, baseline_model_name
)
os.makedirs(output_dir, exist_ok=True)

model_save_path = os.path.join(output_dir, f"best_model.pt")
torch.save(model.state_dict(), model_save_path)
print(f"üíæ Trained model saved to {model_save_path}")

# Save outputs
print("üîç Evaluating and saving outputs...")

# Evaluate and save outputs for train, validation, and test datasets
evaluate_and_save_outputs(model, train_loader, criterion, device, output_dir, "train")
evaluate_and_save_outputs(
    model, val_loader, criterion, device, output_dir, "validation"
)
evaluate_and_save_outputs(model, test_loader, criterion, device, output_dir, "test")

# Save the model config
save_config(train_config, model_config, output_dir)

üíæ Trained model saved to /beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src/train/baseline/winter_barley/Jul/ResCNN/best_model.pt
üîç Evaluating and saving outputs...


Evaluating train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46/46 [00:03<00:00, 12.51it/s]


Train Loss: 0.4163
Outputs saved to: /beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src/train/baseline/winter_barley/Jul/ResCNN/train_outputs.pkl


Evaluating validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:01<00:00,  5.05it/s]


Validation Loss: 0.6705
Outputs saved to: /beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src/train/baseline/winter_barley/Jul/ResCNN/validation_outputs.pkl


Evaluating test: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:01<00:00,  2.28it/s]

Test Loss: 0.6995
Outputs saved to: /beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src/train/baseline/winter_barley/Jul/ResCNN/test_outputs.pkl
Config saved to: /beegfs/halder/GITHUB/RESEARCH/crop-yield-forecasting-germany/src/train/baseline/winter_barley/Jul/ResCNN/config.json



