# OmniField (TQV → TQV) — Notebook Guide

This notebook trains/evaluates a **TQV → TQV** setup using ** OmniField** on the ClimSim dataset. It’s configured for perlmutter environment and a **chronological split**.

---

For now use this as a way to view the model code (need to download the samples in order to train)



Dataloader for Venn Diagram setup

In [2]:
import random, numpy as np, torch
from torch.utils.data import Dataset
import netCDF4 as nc

def load_idx2latlon(grid_meta_path):
    with nc.Dataset(grid_meta_path) as ds:
        lat = ds.variables['lat'][:]
        lon = ds.variables['lon'][:]
    return [(float(lat[i]), float(lon[i])) for i in range(len(lat))]

REGIONS = ['T','Q','V','TQ','TV','QV','TQV']
REG2BITS = {
    'T':   (1,0,0),
    'Q':   (0,1,0),
    'V':   (0,0,1),
    'TQ':  (1,1,0),
    'TV':  (1,0,1),
    'QV':  (0,1,1),
    'TQV': (1,1,1),
}

def _assign_venn_indices(n_points, sparsity, triple_fraction, seed, triple_fixed_count=None):
    rng = np.random.RandomState(seed)
    K_mod = int(round(sparsity * n_points))
    assert K_mod > 0, "sparsity too small"
    a = int(triple_fixed_count if triple_fixed_count is not None else round(triple_fraction * K_mod))
    assert 0 <= a <= K_mod, "invalid triple size"

    # distribute counts per region with t > p
    p = (K_mod - a) // 4
    t = K_mod - a - 2 * p
    if t <= p:
        p = max(0, p - 1)
        t = K_mod - a - 2 * p
    assert t >= 0 and p >= 0
    assert a + 2 * p + t == K_mod  # per-modality total

    cnt = {'TQV': a, 'TQ': p, 'TV': p, 'QV': p, 'T': t, 'Q': t, 'V': t}

    perm = rng.permutation(n_points)
    cursor = 0
    region_indices = {}
    for r in ['TQV','TQ','TV','QV','T','Q','V']:
        k = cnt[r]
        if k > 0:
            region_indices[r] = perm[cursor:cursor+k]
            cursor += k
        else:
            region_indices[r] = np.empty((0,), dtype=int)

    mask_T = np.zeros(n_points, dtype=bool)
    mask_Q = np.zeros(n_points, dtype=bool)
    mask_V = np.zeros(n_points, dtype=bool)
    for r, idxs in region_indices.items():
        bT, bQ, bV = REG2BITS[r]
        if bT: mask_T[idxs] = True
        if bQ: mask_Q[idxs] = True
        if bV: mask_V[idxs] = True

    assert mask_T.sum() == K_mod and mask_Q.sum() == K_mod and mask_V.sum() == K_mod
    assert (mask_T & mask_Q & mask_V).sum() == a

    union_mask = mask_T | mask_Q | mask_V
    fixed_idx = np.sort(np.where(union_mask)[0])

    inv = {}
    for r, idxs in region_indices.items():
        for gi in idxs:
            inv[gi] = r
    region_of_local = [inv[gi] for gi in fixed_idx]

    return fixed_idx, region_of_local, {'T': mask_T, 'Q': mask_Q, 'V': mask_V}


class ClimSimTQVForecastVennFixed(Dataset):
    """
    Venn dataset:
      - Fixed Venn partition (masks for T/Q/V).
      - Inputs come from `input_region`: "union" (default) or "triple".
      - Inputs include only modalities requested by `input_modalities` (e.g., (1,0,0) for T-only).
      - Targets are always full-field [T,Q,V] at t_out.
      - Returns `supervised_idx` = input indices.
    """
    def __init__(self, file_list, grid_meta_path, sparsity=0.02, triple_fraction=0.25,
                 norm_stats=None, input_modalities=(1,0,0), input_region="union", seed=123,train=False):
        self.file_list = file_list
        self.idx2latlon = load_idx2latlon(grid_meta_path)
        self.norm_stats = norm_stats
        self.horizons = [3, 6, 9, 12, 15, 18]
        self.seq_len = 19
        self.N = len(self.idx2latlon)

        self.train = train
        
        self.fixed_idx, self.region_of_local, self.mod_masks = _assign_venn_indices(
            n_points=self.N, sparsity=sparsity, triple_fraction=triple_fraction, seed=seed
        )
        self.union_idx = self.fixed_idx
        self.triple_idx = np.where(self.mod_masks['T'] & self.mod_masks['Q'] & self.mod_masks['V'])[0]
        self.triple_idx = np.sort(self.triple_idx)

        self.input_region = input_region  # "union" or "triple"
        self.input_idx = self.triple_idx if input_region == "triple" else self.union_idx

        self._rng = random.Random(seed)
        assert len(input_modalities) == 3
        self.input_modalities = tuple(bool(int(x)) for x in input_modalities)

    def __len__(self):
        return len(self.file_list) - self.seq_len

    def _norm(self, arr, key):
        if self.norm_stats and key in self.norm_stats:
            mu, sigma = self.norm_stats[key]
            arr = (arr - mu) / sigma
        return arr

    def __getitem__(self, idx):
        # load 19-step window
        seq = [np.load(self.file_list[idx + i]) for i in range(self.seq_len)]
        t_in = 0
        # t_out = self._rng.choice(self.horizons)
        if self.train:
            t_out = 3
        else:
            t_out = self._rng.choice(self.horizons)

        T_t  = self._norm(seq[t_in]["state_t"], "T")
        Q_t  = self._norm(seq[t_in]["state_q"], "Q")
        V_t  = self._norm(seq[t_in]["state_v"], "V")
        T_tp = self._norm(seq[t_out]["state_t"], "T")
        Q_tp = self._norm(seq[t_out]["state_q"], "Q")
        V_tp = self._norm(seq[t_out]["state_v"], "V")

        data_T, mesh_T = [], []
        data_Q, mesh_Q = [], []
        data_V, mesh_V = [], []
        supervision_mask = []
        
        tau = float(t_out) / 18.0  # normalize to [0,1]

        # --- INPUTS ONLY FROM CHOSEN REGION ---
        for gi in self.input_idx:
            lat, lon = self.idx2latlon[gi]
            # modality membership from fixed Venn masks
            bT = bool(self.mod_masks['T'][gi])
            bQ = bool(self.mod_masks['Q'][gi])
            bV = bool(self.mod_masks['V'][gi])
            # for triple region, these will all be True
            supervision_mask.append([bT, bQ, bV])

            # gate by requested input_modalities
            if self.input_modalities[0] and bT:
                data_T.append([lon, lat, float(T_t[gi])]); mesh_T.append([lon, lat])
            if self.input_modalities[1] and bQ:
                data_Q.append([lon, lat, float(Q_t[gi])]); mesh_Q.append([lon, lat])
            if self.input_modalities[2] and bV:
                data_V.append([lon, lat, float(V_t[gi])]); mesh_V.append([lon, lat])

        # --- FULL-FIELD TARGETS ---
        data_y, mesh_y = [], []
        for gi in range(self.N):
            lat, lon = self.idx2latlon[gi]
            data_y.append([float(T_tp[gi]), float(Q_tp[gi]), float(V_tp[gi])])
            mesh_y.append([lon, lat])

        used_modalities = [int(self.input_modalities[0]),
                           int(self.input_modalities[1]),
                           int(self.input_modalities[2])]

        to_tensor = lambda x: torch.tensor(x, dtype=torch.float32)
        return (
            to_tensor(data_T), to_tensor(data_Q), to_tensor(data_V),
            to_tensor(mesh_T), to_tensor(mesh_Q), to_tensor(mesh_V),
            to_tensor(data_y), to_tensor(mesh_y),
            torch.tensor(self.input_idx, dtype=torch.long),         
            torch.tensor(used_modalities, dtype=torch.bool),
            torch.tensor(supervision_mask, dtype=torch.bool),
            torch.tensor(tau, dtype=torch.float32)  
        )

    def venn_counts(self):
        counts = {r:0 for r in REGIONS}
        for r in self.region_of_local:
            counts[r] += 1
        return counts


In [3]:
# -----------------------------------------------
from torch.utils.data import DataLoader
import torch, torch.nn as nn
from torch.optim import AdamW
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import glob
from torch.utils.data import random_split, DataLoader

import glob, numpy as np
from torch.utils.data import DataLoader, Subset


# load the norm stats, samples, and grid info (must download to train)
norm_stats    = dict(np.load("norm_TQV_full.npz", allow_pickle=True))
file_list     = sorted(glob.glob("processed/**/*.npz", recursive=True))
grid_meta_path = "ClimSim_high-res_grid-info.nc"


sparsity = 0.02  # 2% -> 432 


dataset = ClimSimTQVForecastVennFixed(
    file_list=file_list,
    grid_meta_path="ClimSim_high-res_grid-info.nc",
    sparsity=0.02,
    triple_fraction=0.25,
    norm_stats=norm_stats,
    input_modalities=(1,1,1),   # only TQV values
    input_region="union",     
    seed=123,
)


def pretty_counts(counts):
    order = ['T', 'Q', 'V', 'TQ', 'TV', 'QV', 'TQV']
    return {k: counts[k] for k in order}

# Per-modality counts (each modality has exactly K_mod points)
k_T = int(dataset.mod_masks['T'].sum())
k_Q = int(dataset.mod_masks['Q'].sum())
k_V = int(dataset.mod_masks['V'].sum())

# Triple and union sizes
triple_cnt = int((dataset.mod_masks['T'] & dataset.mod_masks['Q'] & dataset.mod_masks['V']).sum())
union_cnt  = int((dataset.mod_masks['T'] | dataset.mod_masks['Q'] | dataset.mod_masks['V']).sum())

# Region-by-region Venn counts from the dataloader
venn = dataset.venn_counts()

print("=== Sensor/Region Summary ===")
print(f"Per-modality (K_mod): T={k_T}, Q={k_Q}, V={k_V}")
print(f"Triple region (T∩Q∩V): {triple_cnt}")
print(f"Union of all modalities (unique locations): {union_cnt}")
print("Venn breakdown (exact memberships):")
for k, v in pretty_counts(venn).items():
    print(f"  {k}: {v}")

print("\n=== DataLoader Input Region Check ===")
print(f"dataset.input_region = {dataset.input_region!r}")
print(f"len(dataset.input_idx) = {len(dataset.input_idx)}")
if dataset.input_region == "union":
    print(f"Matches union? {len(dataset.input_idx) == union_cnt}")
elif dataset.input_region == "triple":
    print(f"Matches triple? {len(dataset.input_idx) == triple_cnt}")
else:
    print("Unknown input_region (expected 'union' or 'triple').")


=== Sensor/Region Summary ===
Per-modality (K_mod): T=432, Q=432, V=432
Triple region (T∩Q∩V): 108
Union of all modalities (unique locations): 837
Venn breakdown (exact memberships):
  T: 162
  Q: 162
  V: 162
  TQ: 81
  TV: 81
  QV: 81
  TQV: 108

=== DataLoader Input Region Check ===
dataset.input_region = 'union'
len(dataset.input_idx) = 837
Matches union? True


In [5]:
# ==== Clean, runnable training cell====

# --- Imports ---
import os
import glob
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Subset
from torch.optim import AdamW

from itertools import cycle
from tqdm import tqdm
from math import log

from einops import rearrange, repeat
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts


# --- Device ---
DEVICE = "cuda:3" if torch.cuda.is_available() else "cpu"


# ===============================================================
# --- 1. The One True Perceiver IO Model Architecture ---
# ===============================================================
# helpers
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

from functools import wraps
def cache_fn(f):
    cache = None
    @wraps(f)
    def cached_fn(*args, _cache = True, **kwargs):
        if not _cache:
            return f(*args, **kwargs)
        nonlocal cache
        if cache is not None:
            return cache
        cache = f(*args, **kwargs)
        return cache
    return cached_fn

# helper classes
class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim = None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)
        if exists(self.norm_context):
            context = kwargs['context']
            normed_context = self.norm_context(context)
            kwargs.update(context = normed_context)
        return self.fn(x, **kwargs)

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, query_dim)
        self.latest_attn = None

    def forward(self, x, context = None, mask = None):
        h = self.heads
        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
        sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h = h)
            sim.masked_fill_(~mask, max_neg_value)
        attn = sim.softmax(dim = -1)
        self.latest_attn = attn.detach()
        out = torch.einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)

# This helper function creates the sinusoidal embeddings
def get_sinusoidal_embeddings(n, d):
    """
    Generates sinusoidal positional embeddings.
    
    Args:
        n (int): The number of positions (num_latents).
        d (int): The embedding dimension (latent_dim).

    Returns:
        torch.Tensor: A tensor of shape (n, d) with sinusoidal embeddings.
    """
    assert d % 2 == 0, "latent_dim must be an even number for sinusoidal embeddings"
    position = torch.arange(n, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d, 2).float() * -(log(10000.0) / d))
    pe = torch.zeros(n, d)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe


class CascadedBlock(nn.Module):
    def __init__(self, dim, n_latents, input_dim, cross_heads, cross_dim_head, self_heads, self_dim_head, residual_dim=None):
        super().__init__()
        self.latents = nn.Parameter(get_sinusoidal_embeddings(n_latents, dim), requires_grad=True)
        self.cross_attn = PreNorm(dim, Attention(dim, input_dim, heads=cross_heads, dim_head=cross_dim_head), context_dim=input_dim)
        self.self_attn = PreNorm(dim, Attention(dim, heads=self_heads, dim_head=self_dim_head))
        self.residual_proj = nn.Linear(residual_dim, dim) if residual_dim and residual_dim != dim else None
        self.ff = PreNorm(dim, FeedForward(dim))

    def forward(self, x, context, mask=None, residual=None):
        b = context.size(0)
        latents = repeat(self.latents, 'n d -> b n d', b=b)
        latents = self.cross_attn(latents, context=context, mask=mask) + latents
        if residual is not None:
            if self.residual_proj:
                residual = self.residual_proj(residual)
            latents = latents + residual
        latents = self.self_attn(latents) + latents
        latents = self.ff(latents) + latents
        return latents


class CascadedPerceiverIO(nn.Module):
    def __init__(
        self,
        *,
        input_dim,
        queries_dim,
        logits_dim = None,
        latent_dims=(512, 512, 512),
        num_latents=(256, 256, 256),
        cross_heads = 4,
        cross_dim_head = 128,
        self_heads = 8,
        self_dim_head = 128,
        decoder_ff = False,
        
    ):
        super().__init__()
        
        assert len(latent_dims) == len(num_latents), "latent_dims and num_latents must have same length"
        
        self.latent_dims  = list(latent_dims)
        self.num_latents  = list(num_latents)

        self.input_proj = nn.Sequential(
                nn.Linear(3, 128),
                nn.GELU(),
                nn.Linear(128, 128)
            ).to(DEVICE)
        
        self.input_proj_T = nn.Sequential(
                nn.Linear(3, 128),
                nn.GELU(),
                nn.Linear(128, 128)
            ).to(DEVICE)
        
        self.input_proj_Q = nn.Sequential(
                nn.Linear(3, 128),
                nn.GELU(),
                nn.Linear(128, 128)
            ).to(DEVICE)
        
        self.input_proj_V = nn.Sequential(
                nn.Linear(3, 128),
                nn.GELU(),
                nn.Linear(128, 128)
            ).to(DEVICE)
        
        self.projection_matrix = nn.Parameter(torch.randn(4, 128) / np.sqrt(4)).to(DEVICE)

        # --- 2. Per-Modality Encoder Blocks ---
        def make_encoder_blocks():
            blocks = nn.ModuleList()
            prev_dim = None
            for dim, n_latents in zip(latent_dims, num_latents):
                blocks.append(CascadedBlock(
                    dim=dim,
                    n_latents=n_latents,
                    input_dim=input_dim,
                    cross_heads=cross_heads,
                    cross_dim_head=cross_dim_head,
                    self_heads=self_heads,
                    self_dim_head=self_dim_head,
                    residual_dim=prev_dim
                ))
                prev_dim = dim
            return blocks

        # Cross-attn: Q/V → T
        self.cross_T_from_Q = PreNorm(
            latent_dims[-1],
            Attention(
                query_dim=latent_dims[-1],
                context_dim=latent_dims[-1],
                heads=cross_heads,
                dim_head=cross_dim_head
            )
        )

        self.cross_T_from_V = PreNorm(
            latent_dims[-1],
            Attention(
                query_dim=latent_dims[-1],
                context_dim=latent_dims[-1],
                heads=cross_heads,
                dim_head=cross_dim_head
            )
        )

        # === Q fusion ===
        self.cross_Q_from_T = PreNorm(
            latent_dims[-1],
            Attention(
                query_dim=latent_dims[-1],
                context_dim=latent_dims[-1],
                heads=cross_heads,
                dim_head=cross_dim_head
            )
        )

        self.cross_Q_from_V = PreNorm(
            latent_dims[-1],
            Attention(
                query_dim=latent_dims[-1],
                context_dim=latent_dims[-1],
                heads=cross_heads,
                dim_head=cross_dim_head
            )
        )

        # === V fusion ===
        self.cross_V_from_T = PreNorm(
            latent_dims[-1],
            Attention(
                query_dim=latent_dims[-1],
                context_dim=latent_dims[-1],
                heads=cross_heads,
                dim_head=cross_dim_head
            )
        )

        self.cross_V_from_Q = PreNorm(
            latent_dims[-1],
            Attention(
                query_dim=latent_dims[-1],
                context_dim=latent_dims[-1],
                heads=cross_heads,
                dim_head=cross_dim_head
            )
        )

        self.encoder_blocks_T = make_encoder_blocks()
        self.encoder_blocks_Q = make_encoder_blocks()
        self.encoder_blocks_V = make_encoder_blocks()
        
        self.sa_queries_T = PreNorm(queries_dim, Attention(queries_dim, heads=4, dim_head=64))
        self.sa_queries_Q = PreNorm(queries_dim, Attention(queries_dim, heads=4, dim_head=64))
        self.sa_queries_V = PreNorm(queries_dim, Attention(queries_dim, heads=4, dim_head=64))
        
        final_latent_dim = latent_dims[-1]
        self.global_proj_T = nn.Linear(final_latent_dim, input_dim)
        self.global_proj_Q = nn.Linear(final_latent_dim, input_dim)
        self.global_proj_V = nn.Linear(final_latent_dim, input_dim)
        
        self.global2latent_proj_T = nn.ModuleList([
            nn.Linear(final_latent_dim, num_latents[i] * latent_dims[i]) for i in range(len(latent_dims))
        ])
        self.global2latent_proj_Q = nn.ModuleList([
            nn.Linear(final_latent_dim, num_latents[i] * latent_dims[i]) for i in range(len(latent_dims))
        ])
        self.global2latent_proj_V = nn.ModuleList([
            nn.Linear(final_latent_dim, num_latents[i] * latent_dims[i]) for i in range(len(latent_dims))
        ])

        # Cascaded encoder blocks (generic list) - kept for checkpoint compatibility
        self.encoder_blocks = nn.ModuleList()
        prev_dim = None
        for dim, n_latents in zip(latent_dims, num_latents):
            block = CascadedBlock(
                dim=dim,
                n_latents=n_latents,
                input_dim=input_dim,
                cross_heads=cross_heads,
                cross_dim_head=cross_dim_head,
                self_heads=self_heads,
                self_dim_head=self_dim_head,
                residual_dim=prev_dim
            )
            self.encoder_blocks.append(block)
            prev_dim = dim

        # Decoder
        final_latent_dim = latent_dims[-1]
        self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, final_latent_dim, heads=cross_heads, dim_head=cross_dim_head), context_dim=final_latent_dim)
        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
        self.to_logits = nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity()
        
        self.decoder_cross_attn_T = PreNorm(queries_dim, Attention(queries_dim, final_latent_dim, heads=cross_heads, dim_head=cross_dim_head), context_dim=final_latent_dim)
        self.decoder_ff_T = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
        self.to_logits_T = nn.Linear(queries_dim, 1)
        
        self.decoder_cross_attn_Q = PreNorm(queries_dim, Attention(queries_dim, final_latent_dim, heads=cross_heads, dim_head=cross_dim_head), context_dim=final_latent_dim)
        self.decoder_ff_Q = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
        self.to_logits_Q = nn.Linear(queries_dim, 1)

        self.decoder_cross_attn_V = PreNorm(queries_dim, Attention(queries_dim, final_latent_dim, heads=cross_heads, dim_head=cross_dim_head), context_dim=final_latent_dim)
        self.decoder_ff_V = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
        self.to_logits_V = nn.Linear(queries_dim, 1)

        self.self_attn_blocks = nn.Sequential(*[
            nn.Sequential(
                PreNorm(latent_dims[-1], Attention(latent_dims[-1], heads=self_heads, dim_head=self_dim_head)),
                PreNorm(latent_dims[-1], FeedForward(latent_dims[-1]))
            )
            for _ in range(3)
        ])

    def forward(self, x_T, x_Q, x_V, queries, used_modalities):
        def residual_from_global(global_latent, proj_layer, n_latents_k, dim_k):
            if global_latent is None:
                return None
            G_pool = global_latent.mean(dim=1)                      # [B, Dg]
            R = proj_layer(G_pool).view(G_pool.size(0), n_latents_k, dim_k)
            return R

        global_latent = None
        num_stages = 3

        for stage_idx in range(num_stages):
            stage_latents = []
            nL_k = self.num_latents[stage_idx]
            d_k  = self.latent_dims[stage_idx]

            # --- T modality ---
            if used_modalities[0] and x_T is not None:
                R_T = residual_from_global(global_latent, self.global2latent_proj_T[stage_idx], nL_k, d_k)
                latent_T = self.encoder_blocks_T[stage_idx](x=None, context=x_T, residual=R_T)
                stage_latents.append(latent_T)
            else:
                latent_T = None

            # --- Q modality ---
            if used_modalities[1] and x_Q is not None:
                R_Q = residual_from_global(global_latent, self.global2latent_proj_Q[stage_idx], nL_k, d_k)
                latent_Q = self.encoder_blocks_Q[stage_idx](x=None, context=x_Q, residual=R_Q)
                stage_latents.append(latent_Q)
            else:
                latent_Q = None

            # --- V modality ---
            if used_modalities[2] and x_V is not None:
                R_V = residual_from_global(global_latent, self.global2latent_proj_V[stage_idx], nL_k, d_k)
                latent_V = self.encoder_blocks_V[stage_idx](x=None, context=x_V, residual=R_V)
                stage_latents.append(latent_V)
            else:
                latent_V = None

            if not stage_latents:
                raise ValueError("No modalities present in this batch.")

            # === Fuse present modality latents into new global ===
            fused_latent = torch.cat(stage_latents, dim=1)
            for sa_block in self.self_attn_blocks:
                fused_latent = sa_block[0](fused_latent) + fused_latent
                fused_latent = sa_block[1](fused_latent) + fused_latent
            global_latent = fused_latent  # pass to next stage

        # === Prepare queries ===
        if queries.ndim == 2:
            queries = repeat(queries, 'n d -> b n d', b=global_latent.size(0))

        # === Decoder: cross-attention ===
        def decode_branch_with_query(cross_attn, ff, head):
            q = queries
            x = cross_attn(q, context=global_latent)
            x = x + q
            if ff:
                x = x + ff(x)
            return head(x)

        T_out = decode_branch_with_query(self.decoder_cross_attn_T, self.decoder_ff_T, self.to_logits_T)
        Q_out = decode_branch_with_query(self.decoder_cross_attn_Q, self.decoder_ff_Q, self.to_logits_Q)
        V_out = decode_branch_with_query(self.decoder_cross_attn_V, self.decoder_ff_V, self.to_logits_V)

        return T_out, Q_out, V_out


class GaussianFourierFeatures(nn.Module):
    def __init__(self, in_features, mapping_size, scale=15.0):
        super().__init__()
        self.in_features = in_features
        self.mapping_size = mapping_size
        self.register_buffer('B', torch.randn((in_features, mapping_size)) * scale)

    def forward(self, coords):
        projections = coords @ self.B
        fourier_feats = torch.cat([torch.sin(projections), torch.cos(projections)], dim=-1)
        return fourier_feats


# -----------------------------------------------

# Same as before
norm_stats = dict(np.load("norm_TQV_full.npz", allow_pickle=True))
file_list  = sorted(glob.glob("processed/**/*.npz", recursive=True))
grid_meta_path = "ClimSim_high-res_grid-info.nc"

sparsity = 0.02  # 2% -> 432

dataset = ClimSimTQVForecastVennFixed(
    file_list=file_list,
    grid_meta_path=grid_meta_path,
    sparsity=sparsity,
    triple_fraction=0.25,
    norm_stats=norm_stats,
    input_modalities=(1,1,1),   # only TQV values
    input_region="union",
    seed=123,
)

train_len = 9000
assert len(dataset) > train_len
train_T = Subset(dataset, range(0, train_len))
val_T   = Subset(dataset, range(train_len, len(dataset)))

train_loader_T = DataLoader(train_T, batch_size=8, shuffle=True)
val_loader_T   = DataLoader(val_T,   batch_size=1, shuffle=False)

pos_enc  = GaussianFourierFeatures(2, 32).to(DEVICE)
time_enc = GaussianFourierFeatures(1, 16, scale=10.0).to(DEVICE)

model = CascadedPerceiverIO(
    input_dim   = 192,
    queries_dim = 96,
    logits_dim  = None,
    latent_dims = (128,128,128),
    num_latents = (128,128,128),
    decoder_ff  = True,
).to(DEVICE)

device = DEVICE
max_lr = 8e-5
min_lr = 8e-6
warmup_steps = int(0.1 * 10000)
weight_decay = 1e-4

# --------- hyper-params ----------
TOTAL_ITERS  = 100000
PRINT_EVERY  = 100
VAL_EVERY    = 100

opt = AdamW(
    model.parameters(),
    lr=max_lr,
    betas=(0.9, 0.999),
    weight_decay=weight_decay
)

scheduler = CosineAnnealingWarmupRestarts(
    opt,
    first_cycle_steps=TOTAL_ITERS,
    max_lr=max_lr,
    min_lr=min_lr,
    warmup_steps=warmup_steps
)

def save_checkpoint(model, pos_enc, time_enc, optimizer, epoch, val_loss, save_path):
    """
    Save model, positional encoder, time encoder, optimizer state, and metadata.
    """
    dirpath = os.path.dirname(save_path)
    if dirpath:
        os.makedirs(dirpath, exist_ok=True)
    state = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "pos_enc_state": pos_enc.state_dict(),
        "time_enc_state": time_enc.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "val_loss": val_loss,
    }
    torch.save(state, save_path)
    print(f"Saved model to {save_path} (val_loss={val_loss:.4f})")

def load_checkpoint(model, pos_enc, time_enc, optimizer, load_path, device=None):
    """
    Load model, positional encoder, time encoder, and optimizer state.
    """
    if device is None:
        device = DEVICE
    checkpoint = torch.load(load_path, map_location=device)
    model.load_state_dict(checkpoint["model_state"])
    pos_enc.load_state_dict(checkpoint["pos_enc_state"])
    time_enc.load_state_dict(checkpoint["time_enc_state"])
    if optimizer is not None and "optimizer_state" in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer_state"])
    print(f"Loaded model from {load_path} (epoch={checkpoint.get('epoch')}, val_loss={checkpoint.get('val_loss'):.4f})")
    return checkpoint

def get_lr(optimizer):
    return optimizer.param_groups[0]["lr"]

@torch.no_grad()
def extract_T(outs):
    """
    Normalize model outputs to a [B, N, 1] tensor for T.
    Supports:
      - tuple/list: (pred_T, pred_Q, pred_V)
      - dict: {'T':..., 'Q':..., 'V':...}
      - tensor: [B, N, 1] or [B, N] (assumed to be T)
    """
    if isinstance(outs, (list, tuple)):
        assert len(outs) >= 1, "Empty outputs!"
        return outs[0]
    if isinstance(outs, dict):
        assert 'T' in outs, "Dict output missing key 'T'."
        return outs['T']
    return outs

@torch.no_grad()
def run_validation_fullfield_T(model, val_loader, pos_enc_fn, time_enc):
    """
    Quick T-only validation on a single batch (assumes batch_size==1 in val_loader).
    Returns (val_loss_scalar, val_T_mse_scalar).
    """
    model.eval()
    batch = next(iter(val_loader))
    (
        data_T, data_Q, data_V,
        mesh_T, mesh_Q, mesh_V,
        data_y, mesh_y,
        supervised_idx, used_modalities, supervision_mask, tau
    ) = batch

    data_T, mesh_T = data_T.to(DEVICE), mesh_T.to(DEVICE)
    data_Q, mesh_Q = data_Q.to(DEVICE), mesh_Q.to(DEVICE)
    data_V, mesh_V = data_V.to(DEVICE), mesh_V.to(DEVICE)

    data_y = data_y.to(DEVICE).squeeze(0)
    mesh_y = mesh_y.to(DEVICE).squeeze(0)
    tau    = tau.to(DEVICE).view(1)
    used_modalities = tuple(bool(x) for x in used_modalities.squeeze(0).tolist())

    x_T = x_Q = x_V = None
    if data_T.numel() > 0:
        x_T = torch.cat([model.input_proj_T(data_T), pos_enc_fn(mesh_T)], dim=-1)
    if data_Q.numel() > 0:
        x_Q = torch.cat([model.input_proj_Q(data_Q), pos_enc_fn(mesh_Q)], dim=-1)
    if data_V.numel() > 0:
        x_V = torch.cat([model.input_proj_V(data_V), pos_enc_fn(mesh_V)], dim=-1)

    q_spatial = pos_enc_fn(mesh_y).unsqueeze(0)
    tfeat = time_enc(tau[:, None])
    tfeat = tfeat[:, None, :].expand(-1, q_spatial.size(1), -1)
    queries_full = torch.cat([q_spatial, tfeat], dim=-1)

    outs = model(x_T=x_T, x_Q=x_Q, x_V=x_V, queries=queries_full, used_modalities=used_modalities)
    pred_T = extract_T(outs)
    if pred_T.ndim == 3 and pred_T.size(-1) == 1:
        pred_T = pred_T.squeeze(-1)
    pred_T = pred_T.squeeze(0)

    tgt_T = data_y[:, 0]
    lT = F.mse_loss(pred_T, tgt_T)
    return float(lT), float(lT)

@torch.no_grad()
def run_validation_fullfield_T_total(model, val_loader, pos_enc_fn, time_enc, max_batches=None):
    """
    T-only validation over the entire val_loader (or first max_batches).
    Returns mean T-MSE across evaluated batches.
    """
    model.eval()
    total, count = 0.0, 0
    for b_idx, batch in enumerate(val_loader):
        if (max_batches is not None) and (b_idx >= max_batches):
            break

        (data_T, data_Q, data_V,
         mesh_T, mesh_Q, mesh_V,
         data_y, mesh_y,
         supervised_idx, used_modalities, supervision_mask, tau) = batch

        data_T, mesh_T = data_T.to(DEVICE), mesh_T.to(DEVICE)
        data_Q, mesh_Q = data_Q.to(DEVICE), mesh_Q.to(DEVICE)
        data_V, mesh_V = data_V.to(DEVICE), mesh_V.to(DEVICE)
        data_y = data_y.to(DEVICE).squeeze(0)
        mesh_y = mesh_y.to(DEVICE).squeeze(0)
        tau    = tau.to(DEVICE).view(1)
        used_modalities = tuple(bool(x) for x in used_modalities.squeeze(0).tolist())

        x_T = torch.cat([model.input_proj_T(data_T), pos_enc_fn(mesh_T)], dim=-1) if data_T.numel() > 0 else None
        x_Q = torch.cat([model.input_proj_Q(data_Q), pos_enc_fn(mesh_Q)], dim=-1) if data_Q.numel() > 0 else None
        x_V = torch.cat([model.input_proj_V(data_V), pos_enc_fn(mesh_V)], dim=-1) if data_V.numel() > 0 else None

        q_spatial = pos_enc_fn(mesh_y).unsqueeze(0)
        tfeat = time_enc(tau[:, None])
        tfeat = tfeat[:, None, :].expand(-1, q_spatial.size(1), -1)
        queries_full = torch.cat([q_spatial, tfeat], dim=-1)

        outs = model(x_T=x_T, x_Q=x_Q, x_V=x_V, queries=queries_full, used_modalities=used_modalities)
        pred_T = extract_T(outs)
        if pred_T.ndim == 3 and pred_T.size(-1) == 1:
            pred_T = pred_T.squeeze(-1)
        pred_T = pred_T.squeeze(0)

        tgt_T = data_y[:, 0]
        lT = F.mse_loss(pred_T, tgt_T).item()
        total += lT
        count += 1

    return total / max(1, count)

def pos_enc_batched(pos_enc_fn, coords):
    # coords: [..., 2]
    leading = coords.shape[:-1]
    flat = coords.reshape(-1, coords.shape[-1])        # [*, 2]
    enc  = pos_enc_fn(flat)                            # [*, Dpos]
    return enc.view(*leading, enc.shape[-1])           # [..., Dpos]

train_iter = cycle(train_loader_T)

running_loss = 0.0
running_mse_T = 0.0   # scalar running avg for T only

best_val_loss = float("inf")
save_path = "ClimSim_checkpoints/best_model.pt"

for it in tqdm(range(1, TOTAL_ITERS + 1)):
    model.train()

    (data_T, data_Q, data_V,
     mesh_T, mesh_Q, mesh_V,
     data_y, mesh_y,
     supervised_idx, used_modalities, supervision_mask, tau) = next(train_iter)

    # ---- move (NO squeeze; keep batch dim) ----
    data_T = data_T.to(DEVICE)       # [B, S_T, 3]
    mesh_T = mesh_T.to(DEVICE)       # [B, S_T, 2]
    data_Q = data_Q.to(DEVICE)       # [B, S_Q, 3]
    mesh_Q = mesh_Q.to(DEVICE)       # [B, S_Q, 2]
    data_V = data_V.to(DEVICE)       # [B, S_V, 3]
    mesh_V = mesh_V.to(DEVICE)       # [B, S_V, 2]
    data_y = data_y.to(DEVICE)       # [B, N, 3]
    mesh_y = mesh_y.to(DEVICE)       # [B, N, 2]
    tau    = tau.to(DEVICE).view(-1)

    # Multimodal inputs → T target
    used_modalities_bits = (True, True, True)

    # ---- inputs: project + concat positional enc (batched) ----
    x_T = torch.cat([model.input_proj_T(data_T), pos_enc_batched(pos_enc, mesh_T)], dim=-1) if data_T.numel() > 0 else None
    x_Q = torch.cat([model.input_proj_Q(data_Q), pos_enc_batched(pos_enc, mesh_Q)], dim=-1) if data_Q.numel() > 0 else None
    x_V = torch.cat([model.input_proj_V(data_V), pos_enc_batched(pos_enc, mesh_V)], dim=-1) if data_V.numel() > 0 else None

    # ---- FULL-GRID queries (batched) ----
    tfeat = time_enc(tau[:, None])                                         # [B, Dtime]
    tfeat_expanded = tfeat[:, None, :].expand(-1, mesh_y.shape[1], -1)     # [B, N, Dtime]
    queries_spatial = pos_enc_batched(pos_enc, mesh_y)                      # [B, N, Dpos]
    queries_full = torch.cat([queries_spatial, tfeat_expanded], dim=-1)     # [B, N, Dpos+Dtime]

    # ---- forward ----
    outs = model(
        x_T=x_T, x_Q=x_Q, x_V=x_V,
        queries=queries_full,
        used_modalities=used_modalities_bits
    )
    pred_T = extract_T(outs)           # [B,N,1] or [B,N]
    if pred_T.ndim == 3 and pred_T.size(-1) == 1:
        pred_T = pred_T.squeeze(-1)    # [B,N]

    # ---- targets ----
    tgt_T = data_y[..., 0]             # [B, N]

    # ---- full-field MSE over the batch (T only) ----
    lT = F.mse_loss(pred_T, tgt_T)     # scalar across B*N
    loss = lT

    opt.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    opt.step()
    scheduler.step()

    # ---- logging ----
    with torch.no_grad():
        running_mse_T += lT.item()
        running_loss  += float(loss)

    if it % PRINT_EVERY == 0:
        mT = running_mse_T / PRINT_EVERY
        print(f"  [TQV->T|Full|Batched]  lr={get_lr(opt):.2e}  B={data_T.shape[0]}")
        print(f"[Iter {it}] Total Loss: {running_loss / PRINT_EVERY:.6f}")
        print(f"  Full-field MSE  T: {mT:.6f}")
        running_loss = 0.0
        running_mse_T = 0.0

    if it % VAL_EVERY == 0:
        val_loss, val_mT = run_validation_fullfield_T(model, val_loader_T, pos_enc, time_enc)
        print(f"[VAL] T-only Loss (full-field): {val_loss:.6f}   (T MSE: {val_mT:.6f})")

    if it % (VAL_EVERY * 5) == 0:
        val_loss_full = run_validation_fullfield_T_total(model, val_loader_T, pos_enc, time_enc)
        print(f"[VAL-TOTAL] Mean T MSE over val: {val_loss_full:.6f}")
        if val_loss_full < best_val_loss:
            best_val_loss = val_loss_full
            save_checkpoint(model, pos_enc, time_enc, opt, it, val_loss_full, save_path)


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f3b16e6dc40>>
Traceback (most recent call last):
  File "/global/homes/k/kevinval/miniconda3/envs/ddim/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
  0%|          | 1/100000 [00:01<38:28:35,  1.39s/it]

KeyboardInterrupt



In [6]:
# --- imports & norm stats (same as you have) ---
# ==== Build T-only dataset & loaders ====
import glob, numpy as np
from torch.utils.data import DataLoader, Subset

norm_stats    = dict(np.load("norm_TQV_full.npz", allow_pickle=True))
file_list     = sorted(glob.glob("processed/**/*.npz", recursive=True))
grid_meta_path = "ClimSim_high-res_grid-info.nc"


sparsity = 0.02  # 2% -> 432 T points


dataset = ClimSimTQVForecastVennFixed(
    file_list=file_list,
    grid_meta_path="ClimSim_high-res_grid-info.nc",
    sparsity=0.02,
    triple_fraction=0.25,
    norm_stats=norm_stats,
    input_modalities=(1,1,1),   
    input_region="union",      
    seed=123,
)


train_len = 9000
assert len(dataset) > train_len
train_T = Subset(dataset, range(0, train_len))
val_T   = Subset(dataset, range(train_len, len(dataset)))


train_loader_T = DataLoader(
    train_T, batch_size=8, shuffle=True,
)
val_loader_T = DataLoader(
    val_T, batch_size=1, shuffle=False,
)

print(len(train_T))  # number of samples in the train subset
print(len(val_T)) 





9000
999


In [8]:
# Robust loader for ClimSim norm stats (T/Q/V) from NPZ files
import numpy as np
import torch

TQV_KEYS = ("T", "Q", "V")

def _to_float1(x, fallback=0.0):
    if x is None: return float(fallback)
    arr = np.asarray(x)
    if arr.size == 0: return float(fallback)
    return float(arr.reshape(-1)[0])

def _extract_mu_sd_from_entry(v):
    """
    Accept:
      - dict-like with 'mean'/'std' or 'mu'/'sigma'
      - array-like [mu, std]
      - scalar-like (invalid → None)
    Return (mu, sd) or None
    """
    # dict-like (possibly 0-d object array)
    if isinstance(v, dict):
        mu = _to_float1(v.get("mean", v.get("mu", None)), 0.0)
        sd = max(_to_float1(v.get("std",  v.get("sigma", None)), 1.0), 1e-6)
        return mu, sd

    if isinstance(v, np.ndarray) and v.dtype == object:
        # 0-d or 1-element object array that holds a dict or pair
        if v.shape == () or (v.size == 1 and v.ndim <= 1):
            inner = v.reshape(()).item()
            if isinstance(inner, dict):
                mu = _to_float1(inner.get("mean", inner.get("mu", None)), 0.0)
                sd = max(_to_float1(inner.get("std",  inner.get("sigma", None)), 1.0), 1e-6)
                return mu, sd
            # fallthrough: maybe it's [mu, std] inside object — try array path below
            v = np.asarray(inner)

    # array-like [mu, std]
    arr = np.asarray(v)
    if arr.size >= 2:
        mu = _to_float1(arr[0], 0.0)
        sd = max(_to_float1(arr[1], 1.0), 1e-6)
        return mu, sd

    return None

def load_tqv_norm_stats(npz_path: str):
    """
    Returns:
      - stats_tuples: { 'T': (mu, sd), 'Q': (mu, sd), 'V': (mu, sd) }
      - Normalizer with .mu/.sd dicts and .norm/.denorm methods (NumPy or torch)
    Tries these layouts:
      A) Keys 'T','Q','V' with dicts or [mu, sd] arrays
      B) Split keys: 'T_mu'/'T_std', 'Q_mu'/'Q_std', 'V_mu'/'V_std'
      C) Packed under 'stats' → {'T':..., 'Q':..., 'V':...}
    """
    z = dict(np.load(npz_path, allow_pickle=True))
    stats = {}

    # Case A: direct per-key entries
    if all(k in z for k in TQV_KEYS):
        for k in TQV_KEYS:
            got = _extract_mu_sd_from_entry(z[k])
            if got is None:
                raise ValueError(f"Key '{k}' present but not parseable as (mu, sd). "
                                 f"Type={type(z[k])}, shape={np.shape(z[k])}")
            stats[k] = got
    else:
        # Case B: split keys
        ok = True
        tmp = {}
        for k in TQV_KEYS:
            mu = z.get(f"{k}_mu", z.get(f"mu_{k}", None))
            sd = z.get(f"{k}_std", z.get(f"std_{k}", None))
            if (mu is None) or (sd is None):
                ok = False
                break
            tmp[k] = ( _to_float1(mu, 0.0), max(_to_float1(sd, 1.0), 1e-6) )
        if ok:
            stats = tmp
        else:
            # Case C: packed under a 'stats' key
            packed = z.get("stats", None)
            if packed is None:
                raise ValueError(
                    "Could not find T/Q/V norm stats in NPZ. "
                    f"Available keys: {sorted(z.keys())}"
                )
            if isinstance(packed, np.ndarray) and packed.dtype == object:
                packed = packed.reshape(()).item()
            if not isinstance(packed, dict):
                raise ValueError("Key 'stats' exists but is not a dict-like container.")
            for k in TQV_KEYS:
                if k not in packed:
                    raise ValueError(f"'stats' missing key '{k}'.")
                got = _extract_mu_sd_from_entry(packed[k])
                if got is None:
                    raise ValueError(f"'stats[{k}]' not parseable as (mu, sd).")
                stats[k] = got

    class TQVNormalizer:
        def __init__(self, stats_tuples):
            self.mu = {k: float(stats_tuples[k][0]) for k in TQV_KEYS}
            self.sd = {k: float(stats_tuples[k][1]) for k in TQV_KEYS}
        def norm(self, key, x):
            mu, sd = self.mu[key], self.sd[key]
            if isinstance(x, torch.Tensor):
                return (x - mu) / sd
            return (np.asarray(x) - mu) / sd
        def denorm(self, key, x):
            mu, sd = self.mu[key], self.sd[key]
            if isinstance(x, torch.Tensor):
                return x * sd + mu
            return np.asarray(x) * sd + mu

    return stats, TQVNormalizer(stats)

# ==== Load & print ====
norm_stats_tuples, TQV_NORM = load_tqv_norm_stats("norm_TQV_full.npz")
print("ClimSim normalization (μ, σ):")
for k in TQV_KEYS:
    mu, sd = norm_stats_tuples[k]
    print(f"  {k}: mu={mu:.6f}, std={sd:.6f}")

# Example helper: denorm [B, N, 3] tensor with channels [T,Q,V]
def denorm_tqv_batch(y_norm: torch.Tensor):
    T = TQV_NORM.denorm("T", y_norm[..., 0])
    Q = TQV_NORM.denorm("Q", y_norm[..., 1])
    V = TQV_NORM.denorm("V", y_norm[..., 2])
    return {"T": T, "Q": Q, "V": V}


ClimSim normalization (μ, σ):
  T: mu=288.291795, std=14.433793
  Q: mu=0.009471, std=0.005907
  V: mu=-0.060691, std=5.796191
