In [14]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import sys
import os
from torch.utils.data import DataLoader

In [15]:
# compactly add project src and analysis/zero-shot to sys.path if not already present
for rel in ('../../src', 'analysis/forecasting'):
    p = os.path.abspath(os.path.join(os.getcwd(), rel))
    if p not in sys.path:
        sys.path.append(p)

# now imports that rely on those paths
from utils import SequentialDeepONetDataset
from helper import  convert2dim, train_val_test_split, fit, compute_metrics_region

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


## Load Datasets

In [16]:
# input data shape (8400, 12)
# trunk shape (361 * 181, 2)
# target shape (8400, 361 * 181)

input_data = np.random.randn(8400, 12)        
trunk = np.random.randn(361 * 181, 2)        
target = np.random.randn(8400, 361 * 181)    

print("Input data shape:", input_data.shape)
print("Trunk shape:", trunk.shape)
print("Target shape:", target.shape)

Input data shape: (8400, 12)
Trunk shape: (65341, 2)
Target shape: (8400, 65341)


In [17]:
from forecasting_analysis import create_windows_forecasting_with_index

In [18]:
dates = pd.date_range("2001-01-01", "2023-12-31", freq="D")

W, H = 30, 1
X_all, y_all, tgt_idx = create_windows_forecasting_with_index(input_data, target, W, H)
tgt_dates = dates[tgt_idx]

train_mask = (tgt_dates <= pd.Timestamp("2021-12-31"))
val_mask   = (tgt_dates >= pd.Timestamp("2022-01-01")) & (tgt_dates <= pd.Timestamp("2022-12-31"))
test_mask  = (tgt_dates >= pd.Timestamp("2023-01-01")) & (tgt_dates <= pd.Timestamp("2023-12-31"))

X_train, y_train = X_all[train_mask], y_all[train_mask]
X_val,   y_val   = X_all[val_mask],   y_all[val_mask]
X_test,  y_test  = X_all[test_mask],  y_all[test_mask]

# check shapes
print("Train set:", X_train.shape, y_train.shape)
print("Validation set:", X_val.shape, y_val.shape)
print("Test set:", X_test.shape, y_test.shape)

Train set: torch.Size([7640, 30, 12]) torch.Size([7640, 65341])
Validation set: torch.Size([365, 30, 12]) torch.Size([365, 65341])
Test set: torch.Size([365, 30, 12]) torch.Size([365, 65341])


In [19]:
scaler_target = MinMaxScaler()
y_train_scaled = scaler_target.fit_transform(y_train)[..., np.newaxis]
y_val_scaled   = scaler_target.transform(y_val)[..., np.newaxis]
y_test_scaled  = scaler_target.transform(y_test)[..., np.newaxis]

In [20]:
# create datasets
train_dataset = SequentialDeepONetDataset(X_train, trunk, y_train_scaled)
val_dataset   = SequentialDeepONetDataset(X_val, trunk, y_val_scaled)
test_dataset  = SequentialDeepONetDataset(X_test, trunk, y_test_scaled)

# create dataloaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## TRON

In [21]:
from helper import init_model

In [22]:
model = init_model()

In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np

def train_model(
    model,
    train_loader,
    val_loader,
    device,
    scaler_target,
    num_epochs=100,
    lr=1e-3,
    weight_decay=0.0,
    scheduler_step=20,
    scheduler_gamma=0.5,
    early_stop_patience=15,
    save_path="dev_test.pt",
):
    """
    Generic training loop for Sequential DeepONet–style models with branch + trunk inputs.

    Parameters
    ----------
    model : torch.nn.Module
        The DeepONet model; must accept (branch_batch, trunk_batch) → output tensor.
    train_loader : DataLoader
        Training data loader providing (branch, trunk, target).
    val_loader : DataLoader
        Validation data loader providing (branch, trunk, target).
    device : torch.device
        CUDA or CPU device.
    scaler_target : sklearn-like scaler
        Used to inverse-transform predictions for evaluation.
    num_epochs : int
        Maximum number of epochs.
    lr : float
        Learning rate.
    weight_decay : float
        Weight-decay (L2 regularization).
    scheduler_step : int
        StepLR scheduler step size.
    scheduler_gamma : float
        StepLR scheduler decay factor.
    early_stop_patience : int
        Stop training if validation loss does not improve for these many epochs.
    save_path : str
        File path to store the best model checkpoint.

    Returns
    -------
    history : dict
        Training history with epoch-wise losses.
    """

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)
    criterion = nn.MSELoss()

    best_val_loss = np.inf
    epochs_no_improve = 0
    history = {"train_loss": [], "val_loss": []}

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for branch_batch, trunk_batch, target_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False):
            branch_batch = branch_batch.to(device, non_blocking=True)
            trunk_batch  = trunk_batch.to(device, non_blocking=True)
            target_batch = target_batch.to(device, non_blocking=True)

            optimizer.zero_grad()

            output = model(branch_batch, trunk_batch)        # Forward pass
            
            loss = criterion(output, target_batch)           # Compute MSE loss
            loss.backward()                                  # Backprop
            optimizer.step()

            running_loss += loss.item() * branch_batch.size(0)

        # Average training loss
        train_loss = running_loss / len(train_loader.dataset)
        history["train_loss"].append(train_loss)

        # ---------------- Validation ----------------
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for branch_batch, trunk_batch, target_batch in val_loader:
                branch_batch = branch_batch.to(device, non_blocking=True)
                trunk_batch  = trunk_batch.to(device, non_blocking=True)
                target_batch = target_batch.to(device, non_blocking=True)

                output = model(branch_batch, trunk_batch)
                loss = criterion(output, target_batch)
                val_loss += loss.item() * branch_batch.size(0)

        val_loss /= len(val_loader.dataset)
        history["val_loss"].append(val_loss)

        scheduler.step()

        print(f"[Epoch {epoch+1:03d}] Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")

        # ---------------- Early stopping ----------------
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), save_path)
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= early_stop_patience:
            print(f"Early stopping at epoch {epoch+1} (no improvement for {early_stop_patience} epochs).")
            break

    print(f"Training complete. Best validation loss: {best_val_loss:.6f}")
    model.load_state_dict(torch.load(save_path))

    return history


In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    scaler_target=scaler_target,
    num_epochs=200,
    lr=1e-4,
    weight_decay=1e-6,
    scheduler_step=50,
    scheduler_gamma=0.7,
    early_stop_patience=20,
    save_path="best_tron_forecast.pt"
)


                                                              

[Epoch 001] Train Loss: 0.019018 | Val Loss: 0.018441


                                                              

KeyboardInterrupt: 