In [None]:
# Minimal ConvLSTM training on sharded NPZ produced by preprocess.ipynb

import os
import json
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# -------- Dataset (streams shards) --------
class ShardedNPZDataset(Dataset):
    def __init__(self, split_dir: str):
        self.split_dir = Path(split_dir)
        man = self.split_dir / "manifest.parquet"
        if man.exists():
            df = pd.read_parquet(man)
            self.shards = [(self.split_dir / r.shard, int(r.num_samples)) for _, r in df.iterrows()]
        else:
            self.shards = []
            for p in sorted(self.split_dir.glob("shard_*.npz")):
                with np.load(p, mmap_mode="r") as f:
                    n = f["X"].shape[0]
                self.shards.append((p, n))
        if not self.shards:
            raise FileNotFoundError(f"No shards found in {self.split_dir}")
        self.cum = np.cumsum([n for _, n in self.shards])
        self._cache = {}  # per-worker cache of open npz

    def __len__(self):
        return int(self.cum[-1])

    def _open_npz(self, path: Path):
        wid = torch.utils.data.get_worker_info()
        key = (wid.id if wid else -1, str(path))
        if key not in self._cache:
            self._cache[key] = np.load(path, mmap_mode="r", allow_pickle=False)
        return self._cache[key]

    def __getitem__(self, idx: int):
        shard_idx = int(np.searchsorted(self.cum, idx, side="right"))
        base = 0 if shard_idx == 0 else int(self.cum[shard_idx - 1])
        local_idx = int(idx - base)
        path, _ = self.shards[shard_idx]
        f = self._open_npz(path)
        # X: [T,C,H,W], y: [H,W]
        X = torch.from_numpy(f["X"][local_idx]).float()
        y = torch.from_numpy(f["y"][local_idx]).float()  # 0/1
        # Add channel to y for BCEWithLogits over [B,1,H,W]
        return X, y.unsqueeze(0)

# -------- ConvLSTM building blocks --------
class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, hidden_channels, kernel_size=3, bias=True):
        super().__init__()
        padding = kernel_size // 2
        self.hidden_channels = hidden_channels
        self.conv = nn.Conv2d(in_channels + hidden_channels,
                              4 * hidden_channels,
                              kernel_size,
                              padding=padding,
                              bias=bias)

    def forward(self, x, state):
        h_prev, c_prev = state
        combined = torch.cat([x, h_prev], dim=1)
        gates = self.conv(combined)
        i, f, o, g = torch.chunk(gates, 4, dim=1)
        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        o = torch.sigmoid(o)
        g = torch.tanh(g)
        c = f * c_prev + i * g
        h = o * torch.tanh(c)
        return h, c

    def init_state(self, B, H, W, device):
        h = torch.zeros(B, self.hidden_channels, H, W, device=device)
        c = torch.zeros(B, self.hidden_channels, H, W, device=device)
        return h, c

class StackedConvLSTM(nn.Module):
    def __init__(self, in_channels, hidden_channels_list=(32, 64), kernel_size=3):
        super().__init__()
        cells = []
        ch_in = in_channels
        for ch_hidden in hidden_channels_list:
            cells.append(ConvLSTMCell(ch_in, ch_hidden, kernel_size))
            ch_in = ch_hidden
        self.cells = nn.ModuleList(cells)

    def forward(self, x):  # x: [B,T,C,H,W]
        B, T, C, H, W = x.shape
        states = [cell.init_state(B, H, W, x.device) for cell in self.cells]
        out = None
        for t in range(T):
            xt = x[:, t]  # [B,C,H,W]
            for li, cell in enumerate(self.cells):
                h, c = states[li]
                h, c = cell(xt, (h, c))
                states[li] = (h, c)
                xt = h
            out = xt  # last layer hidden at time t
        return out  # [B, hidden_last, H, W]

class ConvLSTMFireSeg(nn.Module):
    def __init__(self, in_channels, hidden_channels=(32, 64)):
        super().__init__()
        self.backbone = StackedConvLSTM(in_channels, hidden_channels)
        self.head = nn.Sequential(
            nn.Conv2d(hidden_channels[-1], 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, 1)  # logits
        )

    def forward(self, x):  # x: [B,T,C,H,W]
        feats = self.backbone(x)
        logits = self.head(feats)
        return logits  # [B,1,H,W]

# -------- Training utilities --------
def pixel_metrics(logits, y, threshold=0.5):
    with torch.no_grad():
        probs = torch.sigmoid(logits)
        preds = (probs >= threshold).float()
        y = y.float()
        tp = (preds * y).sum().item()
        fp = (preds * (1 - y)).sum().item()
        fn = ((1 - preds) * y).sum().item()
        tn = (((1 - preds) * (1 - y))).sum().item()
        eps = 1e-8
        prec = tp / (tp + fp + eps)
        rec = tp / (tp + fn + eps)
        f1 = 2 * prec * rec / (prec + rec + eps)
        acc = (tp + tn) / (tp + tn + fp + fn + eps)
    return acc, prec, rec, f1

def device_select():
    if torch.cuda.is_available():
        return torch.device("cuda")
    if torch.backends.mps.is_available():  # macOS Metal
        return torch.device("mps")
    return torch.device("cpu")

def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-root", type=str, default="data/convlstm")
    parser.add_argument("--train-split", type=str, default="train")
    parser.add_argument("--val-split", type=str, default="val")
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=4)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--pos-weight", type=float, default=None, help="Positive class weight for BCEWithLogitsLoss")
    parser.add_argument("--amp", action="store_true", help="Enable mixed precision")
    args = parser.parse_args()

    root = Path(args.data_root)
    cfg_path = root / "config.json"
    if not cfg_path.exists():
        raise FileNotFoundError(f"{cfg_path} not found. Run preprocessing first.")
    with open(cfg_path) as f:
        cfg = json.load(f)

    # Datasets/DataLoaders
    train_dir = root / args.train_split
    val_dir = root / args.val_split
    train_ds = ShardedNPZDataset(str(train_dir))
    val_ds = ShardedNPZDataset(str(val_dir))

    # Peek one sample to infer shapes
    X0, y0 = train_ds[0]  # X0: [T,C,H,W], y0: [1,H,W]
    T, C, H, W = X0.shape

    train_loader = DataLoader(
        train_ds, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=True, persistent_workers=args.num_workers > 0
    )
    val_loader = DataLoader(
        val_ds, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, persistent_workers=args.num_workers > 0
    )

    # Model/Loss/Opt
    device = device_select()
    model = ConvLSTMFireSeg(in_channels=C, hidden_channels=(32, 64)).to(device)
    pos_weight = torch.tensor([args.pos_weight], device=device) if args.pos_weight else None
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp and device.type == "cuda")

    # MLflow (optional)
    try:
        import mlflow, mlflow.pytorch
        mlflow.pytorch.autolog(log_models=True)
        use_mlflow = True
    except Exception:
        use_mlflow = False

    best_val = float("inf")
    for epoch in range(1, args.epochs + 1):
        model.train()
        total_loss = 0.0
        total_batches = 0
        for X, y in train_loader:
            # X: [B,T,C,H,W], y: [B,1,H,W]
            X = X.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            if scaler.is_enabled():
                with torch.cuda.amp.autocast():
                    logits = model(X)
                    loss = criterion(logits, y)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                logits = model(X)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()
            total_loss += loss.item()
            total_batches += 1

        train_loss = total_loss / max(1, total_batches)

        # Validation
        model.eval()
        val_loss = 0.0
        n_val_batches = 0
        acc_sum = prec_sum = rec_sum = f1_sum = 0.0
        with torch.no_grad():
            for X, y in val_loader:
                X = X.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)
                logits = model(X)
                loss = criterion(logits, y)
                val_loss += loss.item()
                n_val_batches += 1
                a, p, r, f1 = pixel_metrics(logits, y)
                acc_sum += a; prec_sum += p; rec_sum += r; f1_sum += f1

        val_loss = val_loss / max(1, n_val_batches)
        acc = acc_sum / max(1, n_val_batches)
        prec = prec_sum / max(1, n_val_batches)
        rec = rec_sum / max(1, n_val_batches)
        f1 = f1_sum / max(1, n_val_batches)

        print(f"Epoch {epoch:03d} | train_loss={train_loss:.4f} val_loss={val_loss:.4f} acc={acc:.4f} prec={prec:.4f} rec={rec:.4f} f1={f1:.4f}")

        # Early stopping checkpoint
        if val_loss < best_val:
            best_val = val_loss
            out_dir = Path("model")
            out_dir.mkdir(parents=True, exist_ok=True)
            torch.save({"model_state": model.state_dict(),
                        "cfg": cfg,
                        "T": T, "C": C, "H": H, "W": W},
                       out_dir / "convlstm_best.pt")

    print("Training complete.")

if __name__ == "__main__":
    main()