In [None]:
import polars as pl
import torch
import numpy as np

In [None]:
dataset = torch.load("data/torch_dataset.pt")
dataset

In [None]:
dataset["y"].isnan().sum()

In [None]:
# simple_shared_sum_model.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split


# ----------------------------
# Dataset
# ----------------------------
class WindAreaDataset(Dataset):
    def __init__(
        self,
        X,
        y,
        x_mean=None,
        x_std=None,
        y_mean=None,
        y_std=None,
        normalize=True,
        normalize_y=True,
    ):
        """
        X: (N, L, V) float tensor
        y: (N,) float tensor
        """
        self.X = X.float()
        self.y = y.float()
        self.normalize = normalize
        self.normalize_y = normalize_y

        # Stats (computed on the *training* partition and passed in for val)
        self.x_mean = x_mean
        self.x_std = x_std
        self.y_mean = y_mean
        self.y_std = y_std

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        x = self.X[idx]  # (L, V)
        y = self.y[idx]  # scalar

        if self.normalize and self.x_mean is not None and self.x_std is not None:
            # normalize per variable across time+locations using train stats
            x = (x - self.x_mean) / (self.x_std + 1e-6)

        if self.normalize_y and self.y_mean is not None and self.y_std is not None:
            y = (y - self.y_mean) / (self.y_std + 1e-6)

        return x, y


# ----------------------------
# Model: shared per-location MLP + sum
# ----------------------------
class SharedPerLocationSum(nn.Module):
    def __init__(self, in_dim=7, hidden=(64, 32), dropout=0.0, return_locals=False):
        super().__init__()
        h1, h2 = hidden
        self.return_locals = return_locals
        self.phi = nn.Sequential(
            nn.Linear(in_dim, h1),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(h1, h2),
            nn.ReLU(),
            nn.Linear(h2, 1),  # scalar contribution per location
            # NOTE: if you want to enforce non-negativity of each contribution:
            nn.Softplus(),
        )

    def forward(self, x):
        """
        x: (B, L, V)
        returns:
          y_hat: (B,) predicted total
          (optionally) loc_contribs: (B, L)
        """
        B, L, V = x.shape
        z = x.view(B * L, V)  # flatten locations
        contribs = self.phi(z).view(B, L)  # (B, L)
        y_hat = contribs.sum(dim=1)  # (B,)
        if self.return_locals:
            return y_hat, contribs
        return y_hat


# ----------------------------
# Training utility
# ----------------------------
def train_model(
    data_path="data.pt",
    batch_size=512,
    lr=1e-3,
    weight_decay=1e-4,
    epochs=10,
    val_frac=0.1,
    normalize_x=True,
    normalize_y=True,
    seed=42,
    device=None,
):
    torch.manual_seed(seed)

    # Load data
    blob = torch.load(data_path)
    X = blob["X"].float()  # (N, L, V)
    y = blob["y"].float()  # (N,)

    assert X.dim() == 3, f"Expected X to be (N,L,V), got {tuple(X.shape)}"
    assert y.dim() == 1 and y.shape[0] == X.shape[0], "y should be (N,) aligned with X"

    N, L, V = X.shape
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    # Train/val split (shuffle first)
    n_val = int(N * val_frac)
    n_train = N - n_val
    train_ds_full = torch.utils.data.TensorDataset(X, y)
    train_subset, val_subset = random_split(
        train_ds_full, [n_train, n_val], generator=torch.Generator().manual_seed(seed)
    )

    # Compute normalization stats on train split ONLY
    X_train = (
        train_subset.dataset.tensors[0][train_subset.indices]
        if hasattr(train_subset, "indices")
        else X[:n_train]
    )
    y_train = (
        train_subset.dataset.tensors[1][train_subset.indices]
        if hasattr(train_subset, "indices")
        else y[:n_train]
    )

    if normalize_x:
        # mean/std over (time, location) for each variable (V,)
        x_mean = X_train.mean(dim=(0, 1))  # (V,)
        x_std = X_train.std(dim=(0, 1)) + 1e-6  # (V,)
        # reshape to (1, V) for broadcasting on (L, V)
        x_mean = x_mean.view(1, V)
        x_std = x_std.view(1, V)
    else:
        x_mean = x_std = None

    if normalize_y:
        y_mean = y_train.mean()
        y_std = y_train.std() + 1e-6
    else:
        y_mean = y_std = None

    # Wrap datasets with normalization logic
    def wrap_subset(subset):
        X_sub = (
            subset.dataset.tensors[0][subset.indices]
            if hasattr(subset, "indices")
            else subset.tensors[0]
        )
        y_sub = (
            subset.dataset.tensors[1][subset.indices]
            if hasattr(subset, "indices")
            else subset.tensors[1]
        )
        return WindAreaDataset(
            X_sub,
            y_sub,
            x_mean,
            x_std,
            y_mean,
            y_std,
            normalize=normalize_x,
            normalize_y=normalize_y,
        )

    train_ds = wrap_subset(train_subset)
    val_ds = wrap_subset(val_subset)

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True
    )

    # Model, loss, optimizer
    model = SharedPerLocationSum(in_dim=V, hidden=(64, 32), dropout=0.1).to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    def denorm_y(t):
        if not normalize_y:
            return t
        return t * y_std.to(t.device) + y_mean.to(t.device)

    # Training loop
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        for xb, yb in train_loader:
            xb = xb.to(device)  # (B, L, V)
            yb = yb.to(device)  # (B,)

            optimizer.zero_grad()
            preds = model(xb)  # (B,)
            loss = criterion(preds, yb)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()

            total_loss += loss.item() * xb.size(0)

        avg_train_loss = total_loss / len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                preds = model(xb)
                val_loss += criterion(preds, yb).item() * xb.size(0)

        avg_val_loss = val_loss / len(val_loader.dataset)

        # Also report RMSE in original units for intuition
        with torch.no_grad():
            rmse = 0.0
            n = 0
            for xb, yb in val_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                preds = model(xb)
                preds_den = denorm_y(preds)
                yb_den = denorm_y(yb)
                rmse += torch.sqrt(((preds_den - yb_den) ** 2).mean()).item() * xb.size(
                    0
                )
                n += xb.size(0)
            rmse /= max(n, 1)

        print(
            f"Epoch {epoch:02d} | train MSE: {avg_train_loss:.4f} | val MSE: {avg_val_loss:.4f} | val RMSE (orig units): {rmse:.3f}"
        )

    return model, (x_mean, x_std, y_mean, y_std)


if __name__ == "__main__":
    # Adjust the path to your .pt file
    model, stats = train_model(
        data_path="your_data.pt",
        epochs=20,
        batch_size=1024,
        lr=2e-3,
        weight_decay=1e-4,
        normalize_x=True,
        normalize_y=True,
    )