In [None]:
import os, re, glob, math, json, random, numpy as np
from typing import List, Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import GroupShuffleSplit
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from math import floor
from architecture import UNetBasic, UNetDropout

In [None]:
DATA_DIR = "/home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/packed_aug8_coldvel"
PATTERN  = "TNG50_snap099_subid*_views10_aug8_C5_256x256.npy"
CHECKPOINTS_DIR = "/home/cj535/palmer_scratch/CNN_checkpoints/coldgas_vel"
H, W = 256, 256
R_MASK = 20                     # pixels
BATCH_SIZE = 16
EPOCHS = 200
FREEZE_ENCODER_EPOCHS = 10
LR = 5e-4
NUM_WORKERS = 1
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VEL_BINS = [(-300, -100), (-100, 100), (100, 300)]  # 3 input channels
COMPRESSION = 'log10'

def circular_outer_mask(H, W, R, device="cpu"):
    yy, xx = torch.meshgrid(torch.arange(H, device=device),
                            torch.arange(W, device=device), indexing="ij")
    yc, xc = H/2.0, W/2.0
    rr2 = (yy-yc)**2 + (xx-xc)**2
    return (rr2 >= R**2).float().unsqueeze(0)  # (1,H,W)

subid_re = re.compile(r".*?_subid(?P<subid>\d+)_views10_aug8_C5_256x256\.npy$")
def find_packs(data_dir: str) -> Dict[str, np.ndarray]:
    packs = {}
    for path in glob.glob(os.path.join(data_dir, PATTERN)):
        m = subid_re.match(path)
        if not m: 
            continue
        subid = m.group("subid")
        # Load fully into RAM as float32 ndarray
        arr = np.load(path)  # already float32 per your save; if not: .astype(np.float32, copy=False)
        if arr.shape[1] != 5 or arr.shape[2:] != (H, W):
            raise RuntimeError(f"Unexpected shape {arr.shape} in {path}")
        packs[subid] = arr
    if not packs:
        raise RuntimeError("No packs found. Check DATA_DIR/PATTERN.")
    return packs

def compute_input_norm(packs: Dict[str, np.ndarray], subids: List[str],compression='log10'):
    """
    Compute per-channel mean/std over the 3 brightness channels using ONLY the train subids.
    """
    s = np.zeros(3, dtype=np.float64)
    q = np.zeros(3, dtype=np.float64)
    n = 0
    for sid in subids:
        arr = packs[sid]          # (N,5,H,W)
        x = arr[:, :3, :, :]      # (N,3,H,W)
        if compression == 'sqrt':
            x = np.sqrt(x)
        elif compression == 'log10':
            x = np.log10(x+1e-25)
        n += x.shape[0]*H*W
        s += x.reshape(-1,3,H,W).transpose(1,0,2,3).reshape(3,-1).sum(axis=1)
        q += (x**2).reshape(-1,3,H,W).transpose(1,0,2,3).reshape(3,-1).sum(axis=1)
    mean = s / n
    var  = (q / n) - mean**2
    std  = np.sqrt(var)#np.sqrt(np.clip(var, 1e-12, None))
    return mean.astype(np.float32), std.astype(np.float32)

class GalaxyPackDataset(Dataset):
    """
    Yields individual (viewÃ—aug) samples from a list of subids.
    Expects memory-resident dict: subid -> ndarray (N,5,H,W).
    Normalization (mean/std) is applied to input channels only.
    """
    def __init__(self, packs: Dict[str, np.ndarray], subids: List[str], mean=None, std=None, compression='log10',r_mask=R_MASK):
        self.subids = list(subids)
        self.packs = packs
        # Build an index: for each subid, iterate over samples
        self.items = []  # list of (subid, local_idx)
        for sid in self.subids:
            N = self.packs[sid].shape[0]
            self.items.extend((sid, i) for i in range(N))
        self.mask = circular_outer_mask(H, W, r_mask, device="cpu")  # (1,H,W)
        self.mean = mean  # (3,)
        self.std  = std   # (3,)
        self.compression = compression

    def __len__(self): return len(self.items)

    def __getitem__(self, idx):
        sid, i = self.items[idx]
        sample = self.packs[sid][i]       # (5,H,W) float32
        x = sample[:3]                    # (3,H,W)
        y = sample[3:]                    # (2,H,W)
        if self.compression == 'sqrt':
            x = np.sqrt(x)
        elif self.compression == 'log10':
            x = np.log10(x+1e-25)
        if self.mean is not None and self.std is not None:
            x = (x - self.mean[:, None, None]) / (self.std[:, None, None] + 1e-21)
        # zero inputs in center
        x = x * self.mask.numpy()
        return {
            "x": torch.from_numpy(x),     # float32
            "y": torch.from_numpy(y),
            "mask": self.mask.clone(),    # torch float32 (1,H,W)
            "subid": sid
        }

packs = find_packs(DATA_DIR)
all_subids = sorted(packs.keys(), key=lambda s: int(s))
len(all_subids), list(all_subids)[:5]

In [None]:
# same architecture as during training
model = smp.Unet(
        encoder_name="resnet34",       # good starting point; try "resnet50", "convnext_tiny", "efficientnet-b3", etc.
        encoder_weights="imagenet",    # <-- THIS loads pretrained encoder weights
        in_channels=3,                 # your three velocity-bin brightness maps
        classes=2,                     # 2 output channels (u, v)
        activation=None                # regression: keep raw logits
    ).to(DEVICE)
#model = UNetDropout(in_channels=3,out_channels=4,p=0.2).to(DEVICE)
#ckpt = torch.load(CHECKPOINTS_DIR+"/fpn_cgm_best.pt", map_location=DEVICE,weights_only=False)'
ckpt = torch.load(CHECKPOINTS_DIR+"/unet_cgm_best.pt", map_location=DEVICE,weights_only=False)
model.load_state_dict(ckpt["model"])
#model.to(DEVICE).eval()


In [None]:
# 2) split by subid (no leakage)
groups = np.array([int(s) for s in all_subids])
splitter = GroupShuffleSplit(n_splits=1, test_size=0.20, random_state=SEED)
# Split operates on indices; use subids as both samples and groups
idx = np.arange(len(all_subids))
train_idx, test_idx = next(splitter.split(idx, groups=groups))
train_subids = [all_subids[i] for i in train_idx]
test_subids  = [all_subids[i] for i in test_idx]
print(f"Train galaxies: {len(train_subids)} | Test galaxies: {len(test_subids)}")

# 3) compute input normalization on train only
if "mean" in ckpt and "std" in ckpt:
    mean, std = np.array(ckpt["mean"], dtype=np.float32), np.array(ckpt["std"], dtype=np.float32)
else:
    mean, std = compute_input_norm(packs, train_subids,compression=COMPRESSION)
print("Input mean:", mean, "std:", std)

# 4) datasets / loaders
train_ds = GalaxyPackDataset(packs, train_subids, mean=mean, std=std, r_mask=R_MASK,compression=COMPRESSION)
test_ds  = GalaxyPackDataset(packs, test_subids,  mean=mean, std=std, r_mask=R_MASK,compression=COMPRESSION)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=("cuda" in DEVICE),
                          persistent_workers=(NUM_WORKERS>0))
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=("cuda" in DEVICE),
                          persistent_workers=(NUM_WORKERS>0))


In [None]:
def add_quiver(ax,u,v,scale,step=4):
    ny, nx = v.shape
    x = np.arange(nx)
    y = np.arange(ny)
    X, Y = np.meshgrid(x, y)
    # subsample
    sl = (slice(None, None, step), slice(None, None, step))
    Xs, Ys = X[sl], Y[sl]
    Vx, Vy = u[sl], v[sl]
    ax.quiver(Xs,Ys,Vx,Vy,scale=scale,color='white')

In [None]:
@torch.no_grad()
def get_batch(loader, batch_idx=0):
    for i, batch in enumerate(loader):
        if i == batch_idx:
            x = batch["x"].to(DEVICE, non_blocking=True)
            y = batch["y"].to(DEVICE, non_blocking=True)
            m = batch["mask"].to(DEVICE, non_blocking=True)
            mask = m.cpu().numpy()[0,0]
            mask = mask[None,None,:,:]
            pred = model(x)
            return x.cpu().numpy()*mask, y.cpu().numpy()*mask, pred.cpu().numpy()*mask, m.cpu().numpy(), batch["subid"]
    raise IndexError(f"Batch {batch_idx} not found")

x_np, y_np, p_np, m_np, subids = get_batch(test_loader,2)

i = 12  # which example in the batch to show


fig, axs = plt.subplots(4, 3, figsize=(8,12),dpi=200)
for j in range(3):
    axs[0,j].imshow(x_np[i,j], cmap="inferno",origin='lower'); axs[0,j].set_title(f"Input ch{j}")
mask = m_np[i,0]
#axs[0,3].imshow(mask, cmap="gray"); axs[0,3].set_title("Mask")

uv = 0
#vmin, vmax = np.percentile(np.concatenate([(y_np[:,uv]).ravel(), p_np[:,uv].ravel()]), [2,98])
vlim = np.max(np.abs(np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])))
axs[1,0].imshow(p_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[1,0].set_title("Pred u")
axs[1,1].imshow(y_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[1,1].set_title("True u")
vlim = np.max(np.abs(np.percentile((p_np[i,uv]-y_np[i,uv]).ravel(), [2,98])))
axs[1,2].imshow(p_np[i,uv]-y_np[i,uv], cmap="coolwarm",vmin=-vlim, vmax=vlim,origin='lower'); axs[1,2].set_title("Residual u")

uv = 1
#vmin, vmax = np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])
vlim = np.max(np.abs(np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])))
axs[2,0].imshow(p_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[2,0].set_title("Pred v")
axs[2,1].imshow(y_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[2,1].set_title("True v")
vlim = np.max(np.abs(np.percentile((p_np[i,uv]-y_np[i,uv]).ravel(), [2,98])))
axs[2,2].imshow(p_np[i,uv]-y_np[i,uv], cmap="coolwarm",vmin=-vlim, vmax=vlim,origin='lower'); axs[2,2].set_title("Residual v")
axs[3,2].axis("off"); axs[3,2].text(0,0.1,f"subid {subids[i]}", fontsize=12)

scale = 4e3
axs[3,0].imshow(x_np[i,1], cmap="inferno",origin='lower')
add_quiver(axs[3,0],p_np[i,0],p_np[i,1],scale)
axs[3,0].set_title('pred quiver')

axs[3,1].imshow(x_np[i,1], cmap="inferno",origin='lower')
add_quiver(axs[3,1],y_np[i,0],y_np[i,1],scale)
axs[3,1].set_title('true quiver')


for ax in axs.ravel(): ax.axis("off")
plt.tight_layout(); plt.show()