In [None]:
import os, re, glob, math, json, random, numpy as np
from typing import List, Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import GroupShuffleSplit
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from seg_models_train import *
from architecture import UNetBasic, UNetDropout

In [None]:
# ---------------- config ----------------
DATA_DIR   =  "/home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/packed_aug8_weightemitvel"  # folder with your .npy packs
PATTERN    = "TNG50_snap099_subid*_views10_aug8_C5_256x256.npy"
H, W = 256, 256
R_MASK = 20                     # pixels
BATCH_SIZE = 16
EPOCHS = 100
LR = 2e-4
NUM_WORKERS = 1
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VEL_BINS = [(-300, -100), (-100, 100), (100, 300)]  # 3 input channels

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

In [None]:
model = UNetDropout(in_channels=4,out_channels=4,p=0.2).to(DEVICE)
opt   = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scaler= torch.amp.GradScaler(enabled=("cuda" in DEVICE))

In [None]:
packs = find_packs(DATA_DIR)
all_subids = sorted(packs.keys(), key=lambda s: int(s))
print(f"Loaded {len(all_subids)} galaxies into RAM.")

In [None]:
# 2) split by subid (no leakage)
groups = np.array([int(s) for s in all_subids])
splitter = GroupShuffleSplit(n_splits=1, test_size=0.20, random_state=SEED)
# Split operates on indices; use subids as both samples and groups
idx = np.arange(len(all_subids))
train_idx, test_idx = next(splitter.split(idx, groups=groups))
train_subids = [all_subids[i] for i in train_idx]
test_subids  = [all_subids[i] for i in test_idx]
print(f"Train galaxies: {len(train_subids)} | Test galaxies: {len(test_subids)}")


compression = 'log10'
# 3) compute input normalization on train only
mean, std = compute_input_norm(packs, train_subids,compression=compression)
#mean, std = np.array([0,0,0]), np.array([1,1,1])*1e-20
print("Input mean:", mean, "std:", std)

In [None]:
i = 4
los = 9
flip = 0
rot = 0
n = los*8 + flip*4 + rot
Hmid = packs[train_subids[i]][n][1]
if compression == 'log10':
    Hmid = np.log10(Hmid)
elif compression == 'sqrt':
    Hmid = np.sqrt(Hmid)

#Hmid = (Hmid - mean[1])# / std[1]
plt.imshow(Hmid,origin='lower')
plt.colorbar()
plt.show()

vel_u = packs[train_subids[i]][n][4]
vel_v = packs[train_subids[i]][n][5]
plt.imshow(vel_v,origin='lower')
plt.colorbar()

In [None]:
# 4) datasets / loaders
train_ds = GalaxyPackDataset(packs, train_subids, mean=mean, std=std, r_mask=R_MASK,compression=compression)
test_ds  = GalaxyPackDataset(packs, test_subids,  mean=mean, std=std, r_mask=R_MASK,compression=compression)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=("cuda" in DEVICE),
                          persistent_workers=(NUM_WORKERS>0))
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=("cuda" in DEVICE),
                          persistent_workers=(NUM_WORKERS>0))

In [None]:
std

In [None]:
subid = '421555'

train_ds.packs[subid].shape

los = 9
flip = 0
rot = 1
i = los*8 + flip*4 + rot

u = train_ds.packs[subid][i,4]
v = train_ds.packs[subid][i,5]

ny, nx = v.shape
x = np.arange(nx)
y = np.arange(ny)
X, Y = np.meshgrid(x, y)
R = np.sqrt((X-W/2)**2 + (Y-H/2)**2)
mask = np.array(R > 20)

step=4
# subsample
sl = (slice(None, None, step), slice(None, None, step))
Xs, Ys = X[sl], Y[sl]
Vx, Vy = (mask*u)[sl], (mask*v)[sl]

fig, ax = plt.subplots(1,1,figsize=(6,6))
ax.imshow(np.log10(train_ds.packs[subid][i,3]),origin='lower')
q = ax.quiver(Xs,Ys,Vx,Vy,color='white')
plt.show()

In [None]:
plt.imshow(train_ds.packs[subid][i,1],origin='lower')
plt.colorbar()

In [None]:
plt.imshow(train_ds[1150]['x'][1])

In [None]:
model = smp.FPN(
    encoder_name="resnet50",       # good starting point; try "resnet50", "convnext_tiny", "efficientnet-b3", etc.
    encoder_weights="imagenet",    # <-- THIS loads pretrained encoder weights
    in_channels=3,                 # your three velocity-bin brightness maps
    classes=2,                     # 2 output channels (u, v)
    activation=None                # regression: keep raw logits
).to(DEVICE)
opt   = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scaler= torch.amp.GradScaler(enabled=("cuda" in DEVICE))

In [None]:
# 6) train
EPOCHS = 10
FREEZE_ENCODER_EPOCHS = 0

best_val = float("inf")
#os.makedirs("checkpoints", exist_ok=True)
ckpt_path = os.path.join("checkpoints_test", "unet_cgm_best.pt")

for p in model.encoder.parameters(): p.requires_grad = False  # warmup 3–5 epochs

for epoch in range(1, EPOCHS+1):
    if epoch > FREEZE_ENCODER_EPOCHS:
        for p in model.encoder.parameters(): p.requires_grad = True
            
    tr_loss, tr_mae = train_one_epoch(model, train_loader, opt, scaler)
    va_loss, va_mae = evaluate(model, test_loader)
    print(f"[{epoch:03d}/{EPOCHS}] train: loss {tr_loss:.5f}, mae {tr_mae:.5f} | "
          f"val: loss {va_loss:.5f}, mae {va_mae:.5f}")

    if va_loss < best_val:
        best_val = va_loss
        torch.save({
            "model": model.state_dict(),
            "mean": mean, "std": std,
            "epoch": epoch, "val_loss": va_loss,
            "config": {
                "R_MASK": R_MASK, "H": H, "W": W,
                "BATCH_SIZE": BATCH_SIZE, "LR": LR
            }
        }, ckpt_path)
        print(f"  ✓ saved best → {ckpt_path}")
    # ----- save PERIODIC checkpoint every 10 epochs -----
    if epoch % 10 == 0:
        periodic_path = f"checkpoints_test/unet_cgm_epoch{epoch:03d}.pt"
        torch.save({
            "model": model.state_dict(),
            "mean": mean, "std": std,
            "epoch": epoch, "val_loss": va_loss,
            "config": {
                "R_MASK": R_MASK, "H": H, "W": W,
                "BATCH_SIZE": BATCH_SIZE, "LR": LR
            }
        }, periodic_path)
        print(f"  • saved periodic → {periodic_path}")

print("Done. Best val loss:", best_val)