In [None]:
import os, re, glob, math, json, random, numpy as np
from typing import List, Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import GroupShuffleSplit
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from math import floor
from architecture import UNetBasic, UNetDropout
from seg_models_train import *
import diptest
import tqdm

In [None]:
DATA_DIR = "/home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/packed_aug8_weightemitvel"
PATTERN  = "TNG50_snap099_subid*_views10_aug8_C5_256x256.npy"
CHECKPOINTS_DIR = "/home/cj535/palmer_scratch/CNN_checkpoints/weightemit_vel"
H, W = 256, 256
R_MASK = 20                     # pixels
BATCH_SIZE = 16
EPOCHS = 200
FREEZE_ENCODER_EPOCHS = 10
LR = 5e-4
NUM_WORKERS = 1
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VEL_BINS = [(-300, -100), (-100, 100), (100, 300)]  # 3 input channels
COMPRESSION = 'log10'


packs = find_packs(DATA_DIR,in_channels=4,target_c=2)
all_subids = sorted(packs.keys(), key=lambda s: int(s))
len(all_subids), list(all_subids)[:5]

In [None]:
# same architecture as during training
'''model = smp.Unet(
        encoder_name="resnet34",       # good starting point; try "resnet50", "convnext_tiny", "efficientnet-b3", etc.
        encoder_weights="imagenet",    # <-- THIS loads pretrained encoder weights
        in_channels=3,                 # your three velocity-bin brightness maps
        classes=2,                     # 2 output channels (u, v)
        activation=None                # regression: keep raw logits
    ).to(DEVICE)'''
model = UNetDropout(in_channels=4,out_channels=2,p=0.2).to(DEVICE)
#ckpt = torch.load(CHECKPOINTS_DIR+"/fpn_cgm_best.pt", map_location=DEVICE,weights_only=False)'
ckpt = torch.load(CHECKPOINTS_DIR+"/UNetDropout_ae0_cgm_best.pt", map_location=DEVICE,weights_only=False)
model.load_state_dict(ckpt["model"])
#model.to(DEVICE).eval()


In [None]:
# 2) split by subid (no leakage)
groups = np.array([int(s) for s in all_subids])
splitter = GroupShuffleSplit(n_splits=1, test_size=0.20, random_state=SEED)
# Split operates on indices; use subids as both samples and groups
idx = np.arange(len(all_subids))
train_idx, test_idx = next(splitter.split(idx, groups=groups))
train_subids = [all_subids[i] for i in train_idx]
test_subids  = [all_subids[i] for i in test_idx]
print(f"Train galaxies: {len(train_subids)} | Test galaxies: {len(test_subids)}")

# 3) compute input normalization on train only
if "mean" in ckpt and "std" in ckpt:
    mean, std = np.array(ckpt["mean"], dtype=np.float32), np.array(ckpt["std"], dtype=np.float32)
else:
    mean, std = compute_input_norm(packs, train_subids,compression=COMPRESSION)
print("Input mean:", mean, "std:", std)

# 4) datasets / loaders
train_ds = GalaxyPackDataset(packs, train_subids, mean=mean, std=std, r_mask=R_MASK,compression=COMPRESSION,in_channels=4)
test_ds  = GalaxyPackDataset(packs, test_subids,  mean=mean, std=std, r_mask=R_MASK,compression=COMPRESSION,in_channels=4)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=("cuda" in DEVICE),
                          persistent_workers=(NUM_WORKERS>0))
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=("cuda" in DEVICE),
                          persistent_workers=(NUM_WORKERS>0))


In [None]:
def add_quiver(ax,u,v,scale,step=4,color='white'):
    ny, nx = v.shape
    x = np.arange(nx)
    y = np.arange(ny)
    X, Y = np.meshgrid(x, y)
    # subsample
    sl = (slice(None, None, step), slice(None, None, step))
    Xs, Ys = X[sl], Y[sl]
    Vx, Vy = u[sl], v[sl]
    qu = ax.quiver(Xs,Ys,Vx,Vy,scale=scale,color=color)
    return qu
@torch.no_grad()
def get_batch(loader, batch_idx=0,make_arrays=True):
    model.eval()
    for i, batch in enumerate(loader):
        if i == batch_idx:
            x = batch["x"].to(DEVICE, non_blocking=True)
            y = batch["y"].to(DEVICE, non_blocking=True)
            m = batch["mask"].to(DEVICE, non_blocking=True)
            pred = model(x)
            if make_arrays:
                mask = m.cpu().numpy()[0,0]
                mask = mask[None,None,:,:]
                return x.cpu().numpy()*mask, y.cpu().numpy()*mask, pred.cpu().numpy()*mask, m.cpu().numpy(), batch["subid"]
            else:
                mask = m[0,0]
                mask = mask[None,None,:,:]
                return x*mask, y*mask, pred*mask, mask, batch["subid"]
    raise IndexError(f"Batch {batch_idx} not found")

In [None]:
def enable_mc_dropout_only(model: nn.Module):
    model.eval()  # keep BN frozen
    for m in model.modules():
        if isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
            m.train()

@torch.no_grad()
def predict_mc(model, x, mask=None, T=100, device="cuda"):
    """
    x:     (B, inC, H, W) tensor
    mask:  (B, 1, H, W) tensor with 1=valid region (optional)
    returns dict with:
      'mu' (B,C,H,W), 'std_epistemic' (B,C,H,W),
      and if aleatoric present: 'sigma_aleatoric' (B,C,H,W), 'std_total'
    """
    x = x.to(device, non_blocking=True)
    if mask is not None:
        mask = mask.to(device, non_blocking=True)

    enable_mc_dropout_only(model)
    preds = []
    for _ in range(T):
        y = model(x)                     # (B,C or 2C,H,W)
        preds.append(y)
    preds = torch.stack(preds, 0)        # (T,B,C_or_2C,H,W)

    B, _, H, W = preds.shape[1:]
    # detect aleatoric head
    outC = preds.shape[2]
    if outC % 2 == 0 and outC >= 4:
        C = outC // 2
        mu      = preds[:, :, :C]                    # (T,B,C,H,W)
        logvar  = preds[:, :, C:]                    # (T,B,C,H,W)
        mu_mean = mu.mean(0)                         # (B,C,H,W)
        std_epi = mu.std(0, unbiased=True)           # epistemic
        sigma_ale = torch.exp(0.5 * logvar).mean(0)  # average aleatoric std
        # Combine (variances add)
        std_total = torch.sqrt(std_epi**2 + sigma_ale**2)
        out = {"mu": mu_mean, "std_epistemic": std_epi,
               "sigma_aleatoric": sigma_ale, "std_total": std_total}
    else:
        # plain regression: no aleatoric
        mu_mean = preds.mean(0)                      # (B,outC,H,W)
        std_epi = preds.std(0, unbiased=True)
        out = {"mu": mu_mean, "std_epistemic": std_epi}

    # apply mask if provided
    if mask is not None:
        for k in out:
            out[k] = out[k] * mask

    return out

In [None]:
x_np, y_np, p_np, m_np, subids = get_batch(test_loader,5,make_arrays=True)
i = 10
fig, ax = plt.subplots(figsize=(6,6),dpi=200)
ax.imshow(x_np[i,3],origin='lower',vmin=0.5,vmax=5)
#ax.imshow(y_np[i,1],origin='lower',vmin=-400,vmax=400,cmap='RdBu')
ax.axis("off")
scale = 8e3
#add_quiver(ax,y_np[i,0],y_np[i,1],scale=scale,color='pink')
#add_quiver(ax,p_np[i,0],p_np[i,1],scale=scale)
#plt.colorbar()
plt.show()

In [None]:
from matplotlib.animation import FuncAnimation, PillowWriter
x_np, y_np, p_np, m_np, subids = get_batch(test_loader,25,make_arrays=False)
i=0
p_np = p_np.cpu().detach().numpy()
y_np = y_np.cpu().detach().numpy()
x = x_np[i]
x = x[None,:,:,:]
T = 100
enable_mc_dropout_only(model)
MC = []
for t in tqdm.tqdm(range(T)):
    p = model(x)*m_np
    MC.append(p[0].cpu().detach().numpy())
MC = np.array(MC)
uncertainty = np.std(MC,axis=0)
x_np = x_np.cpu().detach().numpy()

In [None]:
scale = 8e3
fig, ax = plt.subplots(figsize=(6,6),dpi=200)
step=4

ax.imshow(x_np[i,3], cmap="viridis",origin='lower')
qu1 = add_quiver(ax,y_np[i,0],y_np[i,1],scale,color='pink')
#qu2 = add_quiver(ax,p_np[i,0],p_np[i,1],scale)
u,v = p_np[i,0],p_np[i,1]
ny, nx = v.shape
x = np.arange(nx)
y = np.arange(ny)
X, Y = np.meshgrid(x, y)
# subsample
sl = (slice(None, None, step), slice(None, None, step))
Xs, Ys = X[sl], Y[sl]
Vx, Vy = u[sl], v[sl]
qu2 = ax.quiver(Xs,Ys,Vx,Vy,scale=scale,color='white')
ax.axis("off")

def update(frame):
    u,v = MC[frame,0], MC[frame,1]
    Vx = u[::step, ::step]
    Vy = v[::step, ::step]
    qu2.set_UVC(Vx, Vy)   # update only U,V

fps = 120
anim = FuncAnimation(fig, update, frames=T, interval=1000 / fps, blit=False)
writer = PillowWriter(fps=fps)
anim.save('plots/errors5.gif', writer=writer)

In [None]:
plt.imshow(MC[0,0],origin='lower',cmap='RdBu',vmin=-300,vmax=300)
plt.scatter(135,170,color='k')
plt.show()
x_th = np.arange(MC.shape[0])
multimodal_map = np.ones_like(MC[0])
for i in tqdm.tqdm(range(multimodal_map.shape[1])):
    for j in range(multimodal_map.shape[2]):
        dip,pval = diptest.diptest(MC[:,0,i,j])
        multimodal_map[0,i,j]=pval
        dip,pval = diptest.diptest(MC[:,1,i,j])
        multimodal_map[1,i,j]=pval

In [None]:
#plt.imshow(multimodal_map[0],vmin=0,vmax=0.10,origin='lower',alpha=0.5,zorder=10)
#plt.colorbar()
i=6
plt.imshow(uncertainty[0],origin='lower',cmap='viridis')
plt.colorbar()

In [None]:
batch_num = 2
i = 1  # which example in the batch to show


x_np, y_np, p_np, m_np, subids = get_batch(test_loader,batch_num,make_arrays=False)
p_np = p_np.cpu().detach().numpy()
y_np = y_np.cpu().detach().numpy()
x = x_np[i]
x = x[None,:,:,:]
T = 500
enable_mc_dropout_only(model)
MC = []
for t in tqdm.tqdm(range(T)):
    p = model(x)*m_np
    MC.append(p[0].cpu().detach().numpy())
MC = np.array(MC)
uncertainty = np.std(MC,axis=0)
x_np = x_np.cpu().detach().numpy()
m_np = m_np.cpu().detach().numpy()

fig, axs = plt.subplots(4, 4, figsize=(12,12),dpi=200)
for j in range(4):
    axs[0,j].imshow(x_np[i,j], cmap="inferno",origin='lower'); axs[0,j].set_title(f"Input ch{j}")
#mask = m_np[i,0]
#axs[0,3].imshow(mask, cmap="gray"); axs[0,3].set_title("Mask")

uv = 0
#vmin, vmax = np.percentile(np.concatenate([(y_np[:,uv]).ravel(), p_np[:,uv].ravel()]), [2,98])
vlim = np.max(np.abs(np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])))
axs[1,0].imshow(p_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[1,0].set_title("Pred u")
axs[1,1].imshow(y_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[1,1].set_title("True u")
vlim = np.max(np.abs(np.percentile((p_np[i,uv]-y_np[i,uv]).ravel(), [2,98])))
#axs[1,2].imshow(p_np[i,uv]-y_np[i,uv], cmap="coolwarm",vmin=-vlim, vmax=vlim,origin='lower'); axs[1,2].set_title("Residual u")
axs[1,2].imshow((p_np[i,uv]-y_np[i,uv])**2, cmap="viridis",vmin=0, vmax=300**2,origin='lower'); axs[1,2].set_title("Residual^2 u")
axs[1,3].imshow((uncertainty[uv])**2, cmap="viridis",vmin=0, vmax=300**2,origin='lower'); axs[1,3].set_title("var u")

uv = 1
#vmin, vmax = np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])
vlim = np.max(np.abs(np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])))
axs[2,0].imshow(p_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[2,0].set_title("Pred v")
axs[2,1].imshow(y_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[2,1].set_title("True v")
vlim = np.max(np.abs(np.percentile((p_np[i,uv]-y_np[i,uv]).ravel(), [2,98])))
#axs[2,2].imshow(p_np[i,uv]-y_np[i,uv], cmap="coolwarm",vmin=-vlim, vmax=vlim,origin='lower'); axs[2,2].set_title("Residual v")
axs[2,2].imshow((p_np[i,uv]-y_np[i,uv])**2, cmap="viridis",vmin=0, vmax=200**2,origin='lower'); axs[2,2].set_title("Residual^2 v")
axs[2,3].imshow((uncertainty[uv])**2, cmap="viridis",vmin=0, vmax=50**2,origin='lower'); axs[2,3].set_title("var v")

scale = 4e3
axs[3,0].imshow(x_np[i,1], cmap="inferno",origin='lower')
add_quiver(axs[3,0],p_np[i,0],p_np[i,1],scale)
axs[3,0].set_title('pred quiver')

axs[3,1].imshow(x_np[i,1], cmap="inferno",origin='lower')
add_quiver(axs[3,1],y_np[i,0],y_np[i,1],scale)
axs[3,1].set_title('true quiver')


for ax in axs.ravel(): ax.axis("off")
plt.tight_layout(); plt.show()

In [None]:
x_np, y_np, p_np, m_np, subids = get_batch(test_loader,2)

i = 1  # which example in the batch to show


fig, axs = plt.subplots(4, 4, figsize=(12,12),dpi=200)
for j in range(4):
    axs[0,j].imshow(x_np[i,j], cmap="inferno",origin='lower'); axs[0,j].set_title(f"Input ch{j}")
mask = m_np[i,0]
#axs[0,3].imshow(mask, cmap="gray"); axs[0,3].set_title("Mask")

uv = 0
#vmin, vmax = np.percentile(np.concatenate([(y_np[:,uv]).ravel(), p_np[:,uv].ravel()]), [2,98])
vlim = np.max(np.abs(np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])))
axs[1,0].imshow(p_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[1,0].set_title("Pred u")
axs[1,1].imshow(y_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[1,1].set_title("True u")
vlim = np.max(np.abs(np.percentile((p_np[i,uv]-y_np[i,uv]).ravel(), [2,98])))
#axs[1,2].imshow(p_np[i,uv]-y_np[i,uv], cmap="coolwarm",vmin=-vlim, vmax=vlim,origin='lower'); axs[1,2].set_title("Residual u")
axs[1,2].imshow((p_np[i,uv]-y_np[i,uv])**2, cmap="viridis",vmin=0, vmax=vlim**2,origin='lower'); axs[1,2].set_title("Residual u")
#axs[1,3].imshow(np.sqrt(np.abs(p_np[i,uv+2])),origin='lower',vmin=0)

uv = 1
#vmin, vmax = np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])
vlim = np.max(np.abs(np.percentile(np.concatenate([y_np[:,uv].ravel(), p_np[:,uv].ravel()]), [2,98])))
axs[2,0].imshow(p_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[2,0].set_title("Pred v")
axs[2,1].imshow(y_np[i,uv], cmap="RdBu", vmin=-vlim, vmax=vlim,origin='lower'); axs[2,1].set_title("True v")
vlim = np.max(np.abs(np.percentile((p_np[i,uv]-y_np[i,uv]).ravel(), [2,98])))
axs[2,2].imshow(p_np[i,uv]-y_np[i,uv], cmap="coolwarm",vmin=-vlim, vmax=vlim,origin='lower'); axs[2,2].set_title("Residual v")
#axs[3,2].axis("off"); axs[3,2].text(0,0.1,f"subid {subids[i]}", fontsize=12)

scale = 4e3
axs[3,0].imshow(x_np[i,1], cmap="inferno",origin='lower')
add_quiver(axs[3,0],p_np[i,0],p_np[i,1],scale)
axs[3,0].set_title('pred quiver')

axs[3,1].imshow(x_np[i,1], cmap="inferno",origin='lower')
add_quiver(axs[3,1],y_np[i,0],y_np[i,1],scale)
axs[3,1].set_title('true quiver')


for ax in axs.ravel(): ax.axis("off")
plt.tight_layout(); plt.show()