In [None]:
import os, json, numpy as np
from collections import defaultdict
import re, glob, math, random
from typing import List, Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import GroupShuffleSplit
import matplotlib.pyplot as plt

In [None]:
# ---------------- config ----------------
ROOT = "/home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/"
H, W = 256, 256
R_MASK = 20                     # pixels
BATCH_SIZE = 8
EPOCHS = 100
LR = 1e-3
NUM_WORKERS = 4
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VEL_BINS = [(-300, -100), (-100, 100), (100, 300)]  # 3 input channels

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)


In [None]:
# ---------------- file indexing ----------------
# Expected names, e.g.:
# TNG50_snap099_subid613192_view08_Halpha_brightness_-100_100.npy
# TNG50_snap099_subid613192_view08_velocity_u.npy
# TNG50_snap099_subid613192_view08_velocity_v.npy
re_bright = re.compile(r".*?_subid(?P<subid>\d+)_view(?P<view>\d+)_(?P<line>[A-Za-z0-9]+)_brightness_(?P<vmin>-?\d+)_(?P<vmax>-?\d+)\.npy$")
re_bright_bolo = re.compile(r".*?_subid(?P<subid>\d+)_view(?P<view>\d+)_(?P<line>[A-Za-z0-9]+)_brightness.npy$")
#re_vel    = re.compile(r".*?_subid(?P<subid>\d+)_view(?P<view>\d+)_velocity_(?P<cmp>[uv])\.npy$")
re_vel    = re.compile(r".*?_subid(?P<subid>\d+)_view(?P<view>\d+)_cold_gas_velocity_(?P<cmp>[uvw])\.npy$")
#re_vel    = re.compile(r".*?_subid(?P<subid>\d+)_view(?P<view>\d+)_momentum_density_(?P<cmp>[uv])\.npy$")
#re_vel    = re.compile(r".*?_subid(?P<subid>\d+)_view(?P<view>\d+)_cold_velocity_(?P<cmp>[uv])\.npy$")

LINES = ["Halpha", "N2", "O3"]

def discover_samples(root: str):
    """
    Returns dict keyed by (subid, view) -> {
        "Halpha_brightness": {(vmin,vmax): path},
        "N2_brightness":     {(vmin,vmax): path},
        "O3_brightness":     {(vmin,vmax): path},
        "Halpha_bolobrightness": path | None,
        "N2_bolobrightness":     path | None,
        "O3_bolobrightness":     path | None,
        "vel_u": path | None,
        "vel_v": path | None,
        "vel_w": path | None,
    }
    Only keeps entries that have all required files.
    """
    entries: Dict[Tuple[str, str], Dict] = {}

    # ---------- brightness ----------
    for line in LINES:
        # binned brightness
        for p in glob.glob(os.path.join(root, f"{line}", "*_brightness_*.npy")):
            m = re_bright.match(p)
            if not m:
                continue

            key = (m["subid"], m["view"])
            vmin, vmax = int(m["vmin"]), int(m["vmax"])

            # ensure a base record exists for this (subid, view)
            d = entries.setdefault(
                key,
                {
                    "vel_u": None,
                    "vel_v": None,
                    "vel_w": None,
                },
            )

            # ensure line-specific brightness dict exists
            bright_key = f"{line}_brightness"
            d.setdefault(bright_key, {})[(vmin, vmax)] = p

        # bolometric brightness
        for p in glob.glob(os.path.join(root, f"{line}", "*_brightness.npy")):
            m = re_bright_bolo.match(p)
            if not m:
                continue

            key = (m["subid"], m["view"])
            d = entries.setdefault(
                key,
                {
                    "vel_u": None,
                    "vel_v": None,
                    "vel_w": None,
                },
            )
            d[f"{line}_bolobrightness"] = p

    # ---------- velocities ----------
    for p in glob.glob(os.path.join(root, "cold_gas_velocity", "*cold_gas_velocity_*.npy")):
        m = re_vel.match(p)
        if not m:
            continue

        key = (m["subid"], m["view"])
        cmp_ = m["cmp"]

        d = entries.setdefault(
            key,
            {
                "vel_u": None,
                "vel_v": None,
                "vel_w": None,
            },
        )

        if cmp_ == "u":
            d["vel_u"] = p
        elif cmp_ == "v":
            d["vel_v"] = p
        elif cmp_ == "w":
            d["vel_w"] = p

    # ---------- filter to full samples ----------
    full: Dict[Tuple[str, str], Dict] = {}
    needed_bins = set(VEL_BINS)

    for key, rec in entries.items():
        # if there's no Halpha brightness at all, skip – avoids KeyError
        H = rec.get("Halpha_brightness")
        if H is None:
            continue

        has_all_bins = set(H.keys()) >= needed_bins
        has_vel = rec.get("vel_u") and rec.get("vel_v")

        if has_all_bins and has_vel:
            full[key] = rec

    return full

db = discover_samples(ROOT)

In [None]:
x = [1,2,3,4,5,6]
x[-2:]

In [None]:
db

In [None]:
AUG_NAMES = [
    "rot0", "rot90", "rot180", "rot270",
    "rot0_vflip", "rot90_vflip", "rot180_vflip", "rot270_vflip"
]

def _rotate_uv(u, v, k):
    """Rotate vector field (u,v) by k*90° CCW, returning (ur, vr).
    Rotations:
      k=0: ( u,  v)
      k=1: (v,  -u)
      k=2: (-u, -v)
      k=3: ( -v, u)
    """
    if k % 4 == 0:
        ur, vr = u, v
    elif k % 4 == 1:
        ur, vr = v, -u
    elif k % 4 == 2:
        ur, vr = -u, -v
    else:  # k == 3
        ur, vr = -v, u
    # Rotate the arrays spatially as well
    ur = np.rot90(ur, k, axes=(0, 1))
    vr = np.rot90(vr, k, axes=(0, 1))
    return ur, vr

def _apply_augments(x3, u, v):
    """Given one view:
       x3: (3,H,W) brightness channels
       u,v: (H,W) velocity components
       Returns list of 8 arrays, each (5,H,W), with proper vector transforms.
    """
    outs = []
    inC, H, W = x3.shape[0], x3.shape[1], x3.shape[2]

    # 4 rotations
    for k in range(4):
        # scalars: just rotate
        '''b0 = np.rot90(x3[0], k, axes=(0, 1))
        b1 = np.rot90(x3[1], k, axes=(0, 1))
        b2 = np.rot90(x3[2], k, axes=(0, 1))'''
        bi = [np.rot90(x3[i], k, axes=(0, 1)) for i in range(x3.shape[0])]
        ur, vr = _rotate_uv(u, v, k)
        pack = np.stack(bi+[ur, vr], axis=0)  # (5,H,W)
        outs.append(pack.astype(np.float32, copy=False))

    # vertical flips of each rotated result: flip along axis=0 (rows, y)
    aug_vflips = []
    for k in range(4):
        base = outs[k]
        #b0, b1, b2, ur, vr = base
        bi = base[:inC]
        ur, vr = base[inC:]
        bif = [np.flip(bi_, axis=0) for bi_ in bi]
        #b0f = np.flip(b0, axis=0); b1f = np.flip(b1, axis=0); b2f = np.flip(b2, axis=0)
        urf = np.flip(ur, axis=0)                      # u unchanged sign for vertical flip
        vrf = -np.flip(vr, axis=0)                     # v changes sign under vertical flip
        #packf = np.stack([b0f, b1f, b2f, urf, vrf], axis=0)
        packf = np.stack(bif+[urf, vrf], axis=0)
        aug_vflips.append(packf.astype(np.float32, copy=False))

    return outs + aug_vflips  # 8 tensors of shape (5,H,W)

def _apply_augments_w(x3, u, v, w):
    """Given one view:
       x3: (3,H,W) brightness channels
       u,v,w: (H,W) velocity components
       Returns list of 8 arrays, each (5,H,W), with proper vector transforms.
    """
    outs = []
    inC, H, W = x3.shape[0], x3.shape[1], x3.shape[2]

    # 4 rotations
    for k in range(4):
        # scalars: just rotate
        bi = [np.rot90(x3[i], k, axes=(0, 1)) for i in range(x3.shape[0])]
        wr = np.rot90(w, k, axes=(0, 1))
        ur, vr = _rotate_uv(u, v, k)
        pack = np.stack(bi+[ur, vr, wr], axis=0)  # (5,H,W)
        outs.append(pack.astype(np.float32, copy=False))

    # vertical flips of each rotated result: flip along axis=0 (rows, y)
    aug_vflips = []
    for k in range(4):
        base = outs[k]
        #b0, b1, b2, ur, vr = base
        bi = base[:inC]
        ur, vr, wr = base[inC:]
        bif = [np.flip(bi_, axis=0) for bi_ in bi]
        #b0f = np.flip(b0, axis=0); b1f = np.flip(b1, axis=0); b2f = np.flip(b2, axis=0)
        wrf = np.flip(ur, axis=0)
        urf = np.flip(ur, axis=0)                      # u unchanged sign for vertical flip
        vrf = -np.flip(vr, axis=0)                     # v changes sign under vertical flip
        #packf = np.stack([b0f, b1f, b2f, urf, vrf], axis=0)
        packf = np.stack(bif+[urf, vrf, wrf], axis=0)
        aug_vflips.append(packf.astype(np.float32, copy=False))

    return outs + aug_vflips  # 8 tensors of shape (5,H,W)

def pack_galaxies_to_npy_with_aug(ROOT: str, OUT: str, db: dict, vel_bins=None, dtype=np.float32,bolomask=5e-22,line='Halpha',
                                 include_w=False,C=7):
    os.makedirs(OUT, exist_ok=True)
    vel_bins = vel_bins or [(-300, -100), (-100, 100), (100, 300)]

    # Group by subid
    by_subid = defaultdict(list)
    for (subid, view) in db.keys():
        by_subid[subid].append(view)

    # Discover H,W from any brightness map
    any_key = next(iter(db))
    H, W = np.load(db[any_key][f"{line}_brightness"][vel_bins[0]], mmap_mode='r').shape

    for subid, views in sorted(by_subid.items(), key=lambda kv: int(kv[0])):
        vlist = sorted(views, key=lambda v: int(v))
        V = len(vlist)
        A = 8  # augmentations per view

        # Allocate (V*A, C, H, W)
        arr = np.zeros((V * A, C, H, W), dtype=dtype)

        write_idx = 0
        for view in vlist:
            rec = db[(subid, view)]

            # Load 3+1 brightness channels
            bch = []
            if bolomask !=None:
                x = np.load(rec[f"{line}_bolobrightness"], mmap_mode='r').astype(dtype, copy=False)
                mask = (x > bolomask).astype(dtype)
                bch.append(mask)
                bch.append(x)
            for (vmin, vmax) in vel_bins:
                x = np.load(rec[f"{line}_brightness"][(vmin, vmax)], mmap_mode='r').astype(dtype, copy=False)
                bch.append(x)
            x3 = np.stack(bch, axis=0)  # (3+1,H,W)

            # Load velocities
            u = np.load(rec["vel_u"], mmap_mode='r').astype(dtype, copy=False)
            v = np.load(rec["vel_v"], mmap_mode='r').astype(dtype, copy=False)
            if include_w: w = np.load(rec["vel_w"], mmap_mode='r').astype(dtype, copy=False)

            if bolomask !=None:
                u = u*mask
                v = v*mask
                if include_w: w = w*mask

            # Build 8 augmentations and write
            if include_w:
                aug_list = _apply_augments_w(x3, u, v, w)
            else:
                aug_list = _apply_augments(x3, u, v)  # list of 8×(5+1,H,W)
            for aug in aug_list:
                arr[write_idx] = aug  # (5,H,W)
                write_idx += 1

        # Save .npy and metadata json
        out_path = os.path.join(OUT, f"TNG50_snap099_subid{subid}_views{V}_aug8_C{C}_{H}x{W}.npy")
        np.save(out_path, arr)
        if bolomask != None:
            channels = [
                    f"brightness_{vel_bins[0][0]}_{vel_bins[0][1]}",
                    f"brightness_{vel_bins[1][0]}_{vel_bins[1][1]}",
                    f"brightness_{vel_bins[2][0]}_{vel_bins[2][1]}",
                    "velocity_u", "velocity_v"
                ]
        else:
            channels = [
                    "mask_bolo",
                    "brightness_bolo",
                    f"brightness_{vel_bins[0][0]}_{vel_bins[0][1]}",
                    f"brightness_{vel_bins[1][0]}_{vel_bins[1][1]}",
                    f"brightness_{vel_bins[2][0]}_{vel_bins[2][1]}",
                    "velocity_u", "velocity_v","velocity_w"
                ]

        meta = {
            "subid": subid,
            "views": [int(v) for v in vlist],
            "shape": list(arr.shape),              # [V*8, 5, H, W]
            "dtype": str(arr.dtype),
            "height": H, "width": W,
            "channels": channels,
            "velocity_bins": vel_bins,
            "augmentations": AUG_NAMES,
            "aug_order_note": "Per original view, data is appended in listed order."
        }
        with open(out_path.replace(".npy", ".json"), "w") as f:
            json.dump(meta, f, indent=2)

        print(f"✓ wrote {out_path}  (views={V}, aug per view=8, total samples={V*A})")

In [None]:
# ---- usage example ----
OUTROOT = "/home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/packed_arrays/"
db = discover_samples(ROOT)   # from earlier
OUT = os.path.join(OUTROOT, "packed_aug8_coldvel_1e-22mask")
pack_galaxies_to_npy_with_aug(ROOT, OUT, db, vel_bins=VEL_BINS,bolomask=1e-22,include_w=True,C=8)

In [None]:
shell_midpoints = np.arange(20,205,5)
shell_midpoints.shape[0]

In [None]:
vel_bins = [(-300, -100), (-100, 100), (100, 300)]
by_subid = defaultdict(list)
for (subid, view) in db.keys():
    by_subid[subid].append(view)
any_key = next(iter(db))
H, W = np.load(db[any_key]["Halpha_brightness"][vel_bins[0]], mmap_mode='r').shape
#vlist = sorted(views, key=lambda v: int(v))
V = 10
C = 7#5
A = 8  # augmentations per view

# Allocate (V*A, C, H, W)
arr = np.zeros((V * A, C, H, W), dtype=np.float32)

bad_subids = [
    '372754',
    '372755',
    '414918',
    '429471',
    '438148',
    '447914',
    '471248',
    '475619',
    '478216',
    '498522',
    '516101',
    '517899',
    '523889',
    '525533',
    '526478',
    '535050',
    '543114',
    '546114',
    '559386',
    '574037',
    '576516',
    '580250',
    '613192',  
]
bolomask=5e-22
for i in range(10):
    subid = sorted(by_subid.items(), key=lambda kv: int(kv[0]))[i][0]
    print(i)
    print(subid)
    
    view = '00'
    rec = db[(subid, view)]
    # Load 3 brightness channels
    bch = []
    for (vmin, vmax) in vel_bins:
        x = np.load(rec["Halpha_brightness"][(vmin, vmax)], mmap_mode='r').astype(np.float32, copy=False)
        bch.append(x)
    if bolomask !=None:
        x = np.load(rec["Halpha_bolobrightness"], mmap_mode='r').astype(np.float32, copy=False)
        bch.append(x)
    x3 = np.stack(bch, axis=0)  # (3+1,H,W)

    # Load velocities
    u = np.load(rec["vel_u"], mmap_mode='r').astype(np.float32, copy=False)
    v = np.load(rec["vel_v"], mmap_mode='r').astype(np.float32, copy=False)

    if bolomask !=None:
        mask = x3[-1] > bolomask
        u = u*mask
        v = v*mask
    
    ny, nx = v.shape
    x = np.arange(nx)
    y = np.arange(ny)
    X, Y = np.meshgrid(x, y)
    R = np.sqrt((X-W/2)**2 + (Y-H/2)**2)
    pc_per_pix = 500/W
    mask = np.array(R > 30/pc_per_pix,dtype=int)
    
    
    step=4
    # subsample
    sl = (slice(None, None, step), slice(None, None, step))
    Xs, Ys = X[sl], Y[sl]
    Vx, Vy = (mask*u)[sl], (mask*v)[sl]
    
    fig, ax = plt.subplots(1,1,figsize=(6,6))
    ax.imshow(np.log10(x3[3]),vmin=-23,vmax=-18,origin='lower')
    q = ax.quiver(Xs,Ys,Vx,Vy,color='white')
    plt.show()
    plt.imshow(u,vmin=-300,vmax=300,cmap='RdBu',origin='lower')
    plt.show()

In [None]:
import subprocess

bad_subids = [
    '372754',
    '372755',
    '414918',
    '429471',
    '438148',
    '447914',
    '471248',
    '475619',
    '478216',
    '498522',
    '516101',
    '517899',
    '523889',
    '525533',
    '526478',
    '535050',
    '543114',
    '546114',
    '559386',
    '574037',
    '576516',
    '580250',
    '613192',  
]

commands = [f'rm /home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/packed_arrays/packed_aug8*/TNG50_snap099_subid{subid}_views10_aug8_C*_256x256.*'\
            for subid in bad_subids]
for cmd in commands:
    subprocess.run(cmd,shell=True)


In [None]:
import os, glob, re, json
import numpy as np
import pandas as pd

ROOT       = "/home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/"
PACKED_DIR = os.path.join(ROOT, "packed_arrays/packed_aug8_coldvel_1e-22mask")
KIN_DIR    = os.path.join(ROOT, "1D_kinematics")

# Packed 2D maps:
#   TNG50_snap099_subid571454_views10_aug8_C5_256x256.npy
re_packed = re.compile(
    r".*?_subid(?P<subid>\d+)_views(?P<views>\d+)_aug8_C8_(?P<H>\d+)x(?P<W>\d+)\.npy$"
)

# 1D kinematics:
#   TNG50_snap099_subid571454_emitting_1D_kinematics.npy
re_kin = re.compile(
    r".*?_subid(?P<subid>\d+)_emitting_1D_kinematics\.npy$"
)

rows = []

for maps_path in glob.glob(os.path.join(PACKED_DIR, "TNG50_snap099_subid*_views*_aug8_C8_*.npy")):
    m = re_packed.match(maps_path)
    if not m:
        continue

    subid  = m["subid"]
    views  = int(m["views"])
    H      = int(m["H"])
    W      = int(m["W"])
    meta_path = maps_path.replace(".npy", ".json")

    # corresponding 1D kinematics file (may or may not exist)
    kin_path = os.path.join(
        KIN_DIR,
        f"TNG50_snap099_subid{subid}_emitting_1D_kinematics.npy"
    )

    if not os.path.exists(kin_path):
        # No 1D kinematics for this subid → skip if you only want galaxies with both
        # If you prefer to keep them and store None, comment out the `continue`.
        # print(f"Skipping subid {subid}: no 1D kinematics file found.")
        continue

    kin = np.load(kin_path)  # (N_shells, 2) [mass, mass_flow]
    if kin.ndim != 2 or kin.shape[1] != 2:
        raise ValueError(f"Unexpected kinematics shape {kin.shape} in {kin_path}")

    mass_profile = kin[:, 0]  # (N_shells,)
    flow_profile = kin[:, 1]  # (N_shells,)

    rows.append({
        "subid": subid,
        "views": views,
        "H": H,
        "W": W,
        "maps_path": maps_path,
        "maps_meta_path": meta_path,
        "kinematics_path": kin_path,
        "mass_profile": mass_profile,   # numpy array
        "flow_profile": flow_profile    # numpy array
    })

df = pd.DataFrame(rows).set_index("subid").sort_index()

print(f"Built catalog for {len(df)} galaxies with BOTH 2D maps and 1D kinematics.")
#print(df.head())
PKLPATH = '/home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/catalog_pkls/coldvel_1e-22mask_C8_20_200_profile.pkl'
df.to_pickle(PKLPATH)

In [None]:
sid = '342447'
df.loc[sid, "maps_path"]