In [None]:
import os, re, glob, math, json, random, numpy as np
from typing import List, Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import GroupShuffleSplit
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from math import floor
from architecture import UNetBasic, UNetDropout, UNetMultiHeadProfiles
from seg_models_multihead_train import *
import diptest
import tqdm
import pandas as pd

In [None]:
#DATA_DIR = "/home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/packed_aug8_weightemitvel"
#PATTERN  = "TNG50_snap099_subid*_views10_aug8_C5_256x256.npy"
CATALOG    = "/home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/catalog_pkls/coldvel_1e-22mask_C8_20_200_profile.pkl"
CHECKPOINTS_DIR = "/home/cj535/palmer_scratch/CNN_checkpoints/coldgas_multihead_C8"
CHECKPOINTS_NAME = "UNetMultihead"
H, W = 256, 256
IN_CHANNELS = 5            # mask + bolometric + 3 vel bins
TARGET_C = 3                # u,v,w
ALEOTORIC_ERRORS = 0
OUT_CHANNELS = (ALEOTORIC_ERRORS+1) * TARGET_C
R_MASK = 15                     # pixels
BATCH_SIZE = 32
EPOCHS = 100
FREEZE_ENCODER_EPOCHS = 3
LR = 3e-4
NUM_WORKERS = 1
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
#VEL_BINS = [(-300, -100), (-100, 100), (100, 300)]  # 3 input channels
COMPRESSION = 'log10'
# weights
LAMBDA_MAPS  = 3.0
LAMBDA_MASS  = 1.0
LAMBDA_FLOW  = 1.0

shell_midpoints = np.arange(20,205,5)
K=shell_midpoints.shape[0]
L=shell_midpoints.shape[0]

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)


catalog_path = CATALOG
df = pd.read_pickle(catalog_path)
# ensure subids are strings (since your packs/dataset use string keys)
df.index = df.index.astype(str)

# 1) load all galaxy packs into RAM from catalog
packs = {}
subid_to_Mprof = {}
subid_to_Fprof = {}

for sid in df.index:
    row = df.loc[sid]
    arr = np.load(row["maps_path"])  # (N, C, H, W)
    packs[sid] = arr.astype(np.float32, copy=False)

    subid_to_Mprof[sid] = row["mass_profile"]  # (K,)
    subid_to_Fprof[sid] = row["flow_profile"]  # (L,)

all_subids = sorted(packs.keys(), key=lambda s: int(s))
print(f"Loaded {len(all_subids)} galaxies into RAM (with 1D profiles).")

In [None]:
# 2) split by subid (no leakage)
groups = np.array([int(s) for s in all_subids])
splitter = GroupShuffleSplit(n_splits=1, test_size=0.20, random_state=SEED)
# Split operates on indices; use subids as both samples and groups
idx = np.arange(len(all_subids))
train_idx, test_idx = next(splitter.split(idx, groups=groups))
train_subids = [all_subids[i] for i in train_idx]
test_subids  = [all_subids[i] for i in test_idx]
print(f"Train galaxies: {len(train_subids)} | Test galaxies: {len(test_subids)}")
#

K = len(next(iter(subid_to_Mprof.values())))
L = len(next(iter(subid_to_Fprof.values())))

# 3) compute input normalization on train only

M_mean, M_std = compute_profile_norm(subid_to_Mprof, train_subids, mode="log10")

def compute_z_norm(mapping, train_subids):
    arrs = [np.asarray(mapping[s], np.float32)[None, :] for s in train_subids]
    A = np.concatenate(arrs, axis=0)
    return A.mean(0).astype(np.float32), (A.std(0) + 1e-8).astype(np.float32)

F_mean, F_std = compute_z_norm(subid_to_Fprof, train_subids)

mean, std = compute_input_norm(packs, train_subids, compression=COMPRESSION, inC=IN_CHANNELS, channel0mask=True)
print("Input mean:", mean, "std:", std)

# 4) datasets / loaders
train_ds = GalaxyPackDataset(
    packs, train_subids,
    mean=mean, std=std,
    r_mask=R_MASK, compression=COMPRESSION,
    in_channels=IN_CHANNELS,
    subid_to_Mprof=subid_to_Mprof,
    subid_to_Fprof=subid_to_Fprof,
    Mprof_mode="log10", Fprof_mode=None,
    M_mean=M_mean, M_std=M_std,
    F_mean=F_mean, F_std=F_std,
)
test_ds = GalaxyPackDataset(
    packs, test_subids,
    mean=mean, std=std,
    r_mask=R_MASK, compression=COMPRESSION,
    in_channels=IN_CHANNELS,
    subid_to_Mprof=subid_to_Mprof,
    subid_to_Fprof=subid_to_Fprof,
    Mprof_mode="log10", Fprof_mode=None,
    M_mean=M_mean, M_std=M_std,
    F_mean=F_mean, F_std=F_std,
)

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=("cuda" in DEVICE),
    persistent_workers=(NUM_WORKERS > 0),
)
test_loader = DataLoader(
    test_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=("cuda" in DEVICE),
    persistent_workers=(NUM_WORKERS > 0),
)

In [None]:
# same architecture as during training
MODEL = UNetMultiHeadProfiles(in_channels=IN_CHANNELS,out_channels=OUT_CHANNELS,p=0.2,K=K,L=L)
model = MODEL.to(DEVICE)
#ckpt = torch.load(CHECKPOINTS_DIR+"/fpn_cgm_best.pt", map_location=DEVICE,weights_only=False)'
ckpt = torch.load(CHECKPOINTS_DIR+"/UNetMultihead_cgm_best.pt", map_location=DEVICE,weights_only=False)
model.load_state_dict(ckpt["model"])
#model.to(DEVICE).eval()


In [None]:
def add_quiver(ax,u,v,scale,step=2,color='white'):
    ny, nx = v.shape
    x = np.arange(nx)
    y = np.arange(ny)
    X, Y = np.meshgrid(x, y)
    # subsample
    sl = (slice(None, None, step), slice(None, None, step))
    Xs, Ys = X[sl], Y[sl]
    Vx, Vy = u[sl], v[sl]
    qu = ax.quiver(Xs,Ys,Vx,Vy,scale=scale,color=color)
    return qu
@torch.no_grad()
def get_batch(loader, batch_idx=0,make_arrays=True,profiles=False):
    model.eval()
    for i, batch in enumerate(loader):
        if i == batch_idx:
            x = batch["x"].to(DEVICE, non_blocking=True)
            y = batch["y"].to(DEVICE, non_blocking=True)
            m = batch["mask"].to(DEVICE, non_blocking=True)
            Mprof = batch["Mprof"].to(DEVICE)
            Fprof= batch["Fprof"].to(DEVICE)
            
            out = model(x)
            y_pred = out["maps"]
            Mprof_pred = out["mass_prof"]
            Fprof_pred = out["flow_prof"]
            if make_arrays:
                mask = m.cpu().numpy()
                if profiles:
                    return x.cpu().numpy(), y.cpu().numpy()*mask, y_pred.cpu().numpy()*mask, m.cpu().numpy(), batch["subid"],\
                    Mprof.cpu().numpy(),Fprof.cpu().numpy(),Mprof_pred.cpu().numpy(),Fprof_pred.cpu().numpy()
                else:
                    return x.cpu().numpy(), y.cpu().numpy()*mask, y_pred.cpu().numpy()*mask, m.cpu().numpy(), batch["subid"]
            else:
                mask = m
                if profiles:
                    return x, y*mask, y_pred*mask, mask, batch["subid"],\
                    Mprof,Fprof,Mprof_pred,Fprof_pred
                else:
                    return x, y*mask, y_pred*mask, mask, batch["subid"]
    raise IndexError(f"Batch {batch_idx} not found")

In [None]:
def enable_mc_dropout_only(model: nn.Module):
    model.eval()  # keep BN frozen
    for m in model.modules():
        if isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
            m.train()

@torch.no_grad()
def predict_mc_multihead(model, x, mask=None, T=50, device="cuda"):
    x = x.to(device, non_blocking=True)
    if mask is not None:
        mask = mask.to(device, non_blocking=True)

    enable_mc_dropout_only(model)

    preds_maps = []
    preds_M = []
    preds_F = []

    for _ in range(T):
        out = model(x)
        preds_maps.append(out["maps"])
        preds_M.append(out["mass_prof"])
        preds_F.append(out["flow_prof"])

    preds_maps = torch.stack(preds_maps, 0)     # (T,B,C_or_2C,H,W)
    preds_M    = torch.stack(preds_M, 0)        # (T,B,K)
    preds_F    = torch.stack(preds_F, 0)        # (T,B,L)

    # reuse the above logic for maps...
    # and then for profiles:
    M_mu   = preds_M.mean(0)                    # (B,K)
    M_std  = preds_M.std(0, unbiased=True)      # (B,K)
    F_mu   = preds_F.mean(0)                    # (B,L)
    F_std  = preds_F.std(0, unbiased=True)      # (B,L)

    out = {
        # map fields ('mu', 'std_epistemic', etc.),
        # plus:
        "mass_mu": M_mu,
        "mass_std_epistemic": M_std,
        "flow_mu": F_mu,
        "flow_std_epistemic": F_std,
    }

    if mask is not None:
        # only spatial fields need masking
        for k in list(out.keys()):
            if out[k].ndim == 4:  # (B,C,H,W) shapes
                out[k] = out[k] * mask

    return out

In [None]:
x_np, y_np, p_np, m_np, subids, Mprof,Fprof,Mprof_pred,Fprof_pred = get_batch(test_loader,10,profiles=True)
i=0
plt.plot(Mprof[i],color='blue')
plt.plot(Mprof_pred[i],color='blue',linestyle=':')
plt.plot(Fprof[i],color='red')
plt.plot(Fprof_pred[i],color='red',linestyle=':')
plt.show()

In [None]:
x_np, y_np, p_np, m_np, subids, Mprof,Fprof,Mprof_pred,Fprof_pred = get_batch(test_loader,0,profiles=True)
i = 0  # which example in the batch to show
step = 4
fig, axs = plt.subplots(1, 2, figsize=(12,12),dpi=200)
scale = 4e3
axs[0].imshow(x_np[i,1], cmap="inferno",origin='lower')
add_quiver(axs[0],p_np[i,0],p_np[i,1],scale,step=step)
axs[0].set_title('pred quiver')

axs[1].imshow(x_np[i,1], cmap="inferno",origin='lower')
add_quiver(axs[1],y_np[i,0],y_np[i,1],scale,step=step)
axs[1].set_title('true quiver')
plt.show()

In [None]:
from matplotlib.animation import FuncAnimation, PillowWriter
x_np, y_np, p_np, m_np, subids = get_batch(test_loader,25,make_arrays=False)
i=0
p_np = p_np.cpu().detach().numpy()
y_np = y_np.cpu().detach().numpy()
x = x_np[i]
x = x[None,:,:,:]
T = 100
enable_mc_dropout_only(model)
MC = []
for t in tqdm.tqdm(range(T)):
    p = model(x)*m_np
    MC.append(p[0].cpu().detach().numpy())
MC = np.array(MC)
uncertainty = np.std(MC,axis=0)
x_np = x_np.cpu().detach().numpy()

In [None]:
scale = 8e3
fig, ax = plt.subplots(figsize=(6,6),dpi=200)
step=4

ax.imshow(x_np[i,3], cmap="viridis",origin='lower')
qu1 = add_quiver(ax,y_np[i,0],y_np[i,1],scale,color='pink')
#qu2 = add_quiver(ax,p_np[i,0],p_np[i,1],scale)
u,v = p_np[i,0],p_np[i,1]
ny, nx = v.shape
x = np.arange(nx)
y = np.arange(ny)
X, Y = np.meshgrid(x, y)
# subsample
sl = (slice(None, None, step), slice(None, None, step))
Xs, Ys = X[sl], Y[sl]
Vx, Vy = u[sl], v[sl]
qu2 = ax.quiver(Xs,Ys,Vx,Vy,scale=scale,color='white')
ax.axis("off")

def update(frame):
    u,v = MC[frame,0], MC[frame,1]
    Vx = u[::step, ::step]
    Vy = v[::step, ::step]
    qu2.set_UVC(Vx, Vy)   # update only U,V

fps = 120
anim = FuncAnimation(fig, update, frames=T, interval=1000 / fps, blit=False)
writer = PillowWriter(fps=fps)
anim.save('plots/errors5.gif', writer=writer)

In [None]:
plt.imshow(MC[0,0],origin='lower',cmap='RdBu',vmin=-300,vmax=300)
plt.scatter(135,170,color='k')
plt.show()
x_th = np.arange(MC.shape[0])
multimodal_map = np.ones_like(MC[0])
for i in tqdm.tqdm(range(multimodal_map.shape[1])):
    for j in range(multimodal_map.shape[2]):
        dip,pval = diptest.diptest(MC[:,0,i,j])
        multimodal_map[0,i,j]=pval
        dip,pval = diptest.diptest(MC[:,1,i,j])
        multimodal_map[1,i,j]=pval

In [None]:
#plt.imshow(multimodal_map[0],vmin=0,vmax=0.10,origin='lower',alpha=0.5,zorder=10)
#plt.colorbar()
i=6
plt.imshow(uncertainty[0],origin='lower',cmap='viridis')
plt.colorbar()

In [None]:
batch_num = 2
i = 1  # which example in the batch to show


x_np, y_np, p_np, m_np, subids = get_batch(test_loader,batch_num,make_arrays=False)
p_np = p_np.cpu().detach().numpy()
y_np = y_np.cpu().detach().numpy()
x = x_np[i]
x = x[None,:,:,:]
T = 500
enable_mc_dropout_only(model)
MC = []
for t in tqdm.tqdm(range(T)):
    p = model(x)*m_np
    MC.append(p[0].cpu().detach().numpy())
MC = np.array(MC)
uncertainty = np.std(MC,axis=0)
x_np = x_np.cpu().detach().numpy()
m_np = m_np.cpu().detach().numpy()

fig, axs = plt.subplots(4, 4, figsize=(12,12),dpi=200)
for j in range(4):
    axs[0,j].imshow(x_np[i,j], cmap="inferno",origin='lower'); axs[0,j].set_title(f"Input ch{j}")
#mask = m_np[i,0]
#axs[0,3].imshow(mask, cmap="gray"); axs[0,3].set_title("Mask")

uv = 0
#vmin, vmax = np.percentile(np.concatenate([(y_np[:,uv]).ravel(), p_np[:,uv].ravel()]), [2,98])
vlim = np.max(np.abs(np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])))
axs[1,0].imshow(p_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[1,0].set_title("Pred u")
axs[1,1].imshow(y_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[1,1].set_title("True u")
vlim = np.max(np.abs(np.percentile((p_np[i,uv]-y_np[i,uv]).ravel(), [2,98])))
#axs[1,2].imshow(p_np[i,uv]-y_np[i,uv], cmap="coolwarm",vmin=-vlim, vmax=vlim,origin='lower'); axs[1,2].set_title("Residual u")
axs[1,2].imshow((p_np[i,uv]-y_np[i,uv])**2, cmap="viridis",vmin=0, vmax=300**2,origin='lower'); axs[1,2].set_title("Residual^2 u")
axs[1,3].imshow((uncertainty[uv])**2, cmap="viridis",vmin=0, vmax=300**2,origin='lower'); axs[1,3].set_title("var u")

uv = 1
#vmin, vmax = np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])
vlim = np.max(np.abs(np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])))
axs[2,0].imshow(p_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[2,0].set_title("Pred v")
axs[2,1].imshow(y_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[2,1].set_title("True v")
vlim = np.max(np.abs(np.percentile((p_np[i,uv]-y_np[i,uv]).ravel(), [2,98])))
#axs[2,2].imshow(p_np[i,uv]-y_np[i,uv], cmap="coolwarm",vmin=-vlim, vmax=vlim,origin='lower'); axs[2,2].set_title("Residual v")
axs[2,2].imshow((p_np[i,uv]-y_np[i,uv])**2, cmap="viridis",vmin=0, vmax=200**2,origin='lower'); axs[2,2].set_title("Residual^2 v")
axs[2,3].imshow((uncertainty[uv])**2, cmap="viridis",vmin=0, vmax=50**2,origin='lower'); axs[2,3].set_title("var v")

scale = 4e3
axs[3,0].imshow(x_np[i,1], cmap="inferno",origin='lower')
add_quiver(axs[3,0],p_np[i,0],p_np[i,1],scale)
axs[3,0].set_title('pred quiver')

axs[3,1].imshow(x_np[i,1], cmap="inferno",origin='lower')
add_quiver(axs[3,1],y_np[i,0],y_np[i,1],scale)
axs[3,1].set_title('true quiver')


for ax in axs.ravel(): ax.axis("off")
plt.tight_layout(); plt.show()

In [None]:
x_np, y_np, p_np, m_np, subids = get_batch(test_loader,2)

i = 10  # which example in the batch to show


fig, axs = plt.subplots(4, 4, figsize=(12,12),dpi=200)
for j in range(4):
    axs[0,j].imshow(x_np[i,j], cmap="inferno",origin='lower'); axs[0,j].set_title(f"Input ch{j}")
mask = m_np[i,0]
#axs[0,3].imshow(mask, cmap="gray"); axs[0,3].set_title("Mask")

uv = 0
#vmin, vmax = np.percentile(np.concatenate([(y_np[:,uv]).ravel(), p_np[:,uv].ravel()]), [2,98])
vlim = np.max(np.abs(np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])))
axs[1,0].imshow(p_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[1,0].set_title("Pred u")
axs[1,1].imshow(y_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[1,1].set_title("True u")
vlim = np.max(np.abs(np.percentile((p_np[i,uv]-y_np[i,uv]).ravel(), [2,98])))
#axs[1,2].imshow(p_np[i,uv]-y_np[i,uv], cmap="coolwarm",vmin=-vlim, vmax=vlim,origin='lower'); axs[1,2].set_title("Residual u")
axs[1,2].imshow((p_np[i,uv]-y_np[i,uv])**2, cmap="viridis",vmin=0, vmax=vlim**2,origin='lower'); axs[1,2].set_title("Residual u")
#axs[1,3].imshow(np.sqrt(np.abs(p_np[i,uv+2])),origin='lower',vmin=0)

uv = 1
#vmin, vmax = np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])
vlim = np.max(np.abs(np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])))
axs[2,0].imshow(p_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[2,0].set_title("Pred v")
axs[2,1].imshow(y_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[2,1].set_title("True v")
vlim = np.max(np.abs(np.percentile((p_np[i,uv]-y_np[i,uv]).ravel(), [2,98])))
axs[2,2].imshow(p_np[i,uv]-y_np[i,uv], cmap="coolwarm",vmin=-vlim, vmax=vlim,origin='lower'); axs[2,2].set_title("Residual v")
#axs[3,2].axis("off"); axs[3,2].text(0,0.1,f"subid {subids[i]}", fontsize=12)

uv = 2
#vmin, vmax = np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])
vlim = np.max(np.abs(np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])))
axs[2,0].imshow(p_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[2,0].set_title("Pred w")
axs[2,1].imshow(y_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[2,1].set_title("True w")
vlim = np.max(np.abs(np.percentile((p_np[i,uv]-y_np[i,uv]).ravel(), [2,98])))
axs[2,2].imshow(p_np[i,uv]-y_np[i,uv], cmap="coolwarm",vmin=-vlim, vmax=vlim,origin='lower'); axs[2,2].set_title("Residual v")
#axs[3,2].axis("off"); axs[3,2].text(0,0.1,f"subid {subids[i]}", fontsize=12)


scale = 4e3
axs[3,0].imshow(x_np[i,1], cmap="inferno",origin='lower')
add_quiver(axs[3,0],p_np[i,0],p_np[i,1],scale)
axs[3,0].set_title('pred quiver')

axs[3,1].imshow(x_np[i,1], cmap="inferno",origin='lower')
add_quiver(axs[3,1],y_np[i,0],y_np[i,1],scale)
axs[3,1].set_title('true quiver')


for ax in axs.ravel(): ax.axis("off")
plt.tight_layout(); plt.show()