In [None]:
import os, json, numpy as np
from collections import defaultdict
import re, glob, math, random
from typing import List, Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import GroupShuffleSplit
import matplotlib.pyplot as plt

In [None]:
# ---------------- config ----------------
ROOT = "/home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/"
H, W = 256, 256
R_MASK = 20                     # pixels
BATCH_SIZE = 8
EPOCHS = 100
LR = 1e-3
NUM_WORKERS = 4
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VEL_BINS = [(-300, -100), (-100, 100), (100, 300)]  # 3 input channels

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)


In [None]:
# ---------------- file indexing ----------------
# Expected names, e.g.:
# TNG50_snap099_subid613192_view08_Halpha_brightness_-100_100.npy
# TNG50_snap099_subid613192_view08_velocity_u.npy
# TNG50_snap099_subid613192_view08_velocity_v.npy
re_bright = re.compile(r".*?_subid(?P<subid>\d+)_view(?P<view>\d+)_Halpha_brightness_(?P<vmin>-?\d+)_(?P<vmax>-?\d+)\.npy$")
re_vel    = re.compile(r".*?_subid(?P<subid>\d+)_view(?P<view>\d+)_velocity_(?P<cmp>[uv])\.npy$")
#re_vel    = re.compile(r".*?_subid(?P<subid>\d+)_view(?P<view>\d+)_momentum_density_(?P<cmp>[uv])\.npy$")
#re_vel    = re.compile(r".*?_subid(?P<subid>\d+)_view(?P<view>\d+)_cold_velocity_(?P<cmp>[uv])\.npy$")

def discover_samples(root: str) -> Dict[Tuple[str,str], Dict[str, List[str]]]:
    """
    Returns dict keyed by (subid, view) -> {
        "brightness": {(vmin,vmax): path},
        "vel_u": path, "vel_v": path
    }
    Only keeps entries that have all required files.
    """
    entries: Dict[Tuple[str,str], Dict] = {}
    # scan brightness
    for p in glob.glob(os.path.join(root+'Halpha/', "*_Halpha_brightness_*.npy")):
        m = re_bright.match(p)
        if not m: continue
        key = (m["subid"], m["view"])
        vmin, vmax = int(m["vmin"]), int(m["vmax"])
        d = entries.setdefault(key, {"brightness": {}, "vel_u": None, "vel_v": None})
        d["brightness"][(vmin, vmax)] = p
    # scan velocities/momentum
    for p in glob.glob(os.path.join(root+'velocity/', f"*_velocity_*.npy")):
        m = re_vel.match(p)
        if not m: continue
        key = (m["subid"], m["view"])
        cmp_ = m["cmp"]
        d = entries.setdefault(key, {"brightness": {}, "vel_u": None, "vel_v": None})
        if cmp_ == "u": d["vel_u"] = p
        elif cmp_ == "v": d["vel_v"] = p

    # filter to full samples
    full: Dict[Tuple[str,str], Dict] = {}
    needed_bins = set(VEL_BINS)
    for key, rec in entries.items():
        if set(rec["brightness"].keys()) >= needed_bins and rec["vel_u"] and rec["vel_v"]:
            full[key] = rec
    return full

db = discover_samples(ROOT)

In [None]:
db

In [None]:
AUG_NAMES = [
    "rot0", "rot90", "rot180", "rot270",
    "rot0_vflip", "rot90_vflip", "rot180_vflip", "rot270_vflip"
]

def _rotate_uv(u, v, k):
    """Rotate vector field (u,v) by k*90° CCW, returning (ur, vr).
    Rotations:
      k=0: ( u,  v)
      k=1: (v,  -u)
      k=2: (-u, -v)
      k=3: ( -v, u)
    """
    if k % 4 == 0:
        ur, vr = u, v
    elif k % 4 == 1:
        ur, vr = v, -u
    elif k % 4 == 2:
        ur, vr = -u, -v
    else:  # k == 3
        ur, vr = -v, u
    # Rotate the arrays spatially as well
    ur = np.rot90(ur, k, axes=(0, 1))
    vr = np.rot90(vr, k, axes=(0, 1))
    return ur, vr

def _apply_augments(x3, u, v):
    """Given one view:
       x3: (3,H,W) brightness channels
       u,v: (H,W) velocity components
       Returns list of 8 arrays, each (5,H,W), with proper vector transforms.
    """
    outs = []
    H, W = x3.shape[1], x3.shape[2]

    # 4 rotations
    for k in range(4):
        # scalars: just rotate
        b0 = np.rot90(x3[0], k, axes=(0, 1))
        b1 = np.rot90(x3[1], k, axes=(0, 1))
        b2 = np.rot90(x3[2], k, axes=(0, 1))
        ur, vr = _rotate_uv(u, v, k)
        pack = np.stack([b0, b1, b2, ur, vr], axis=0)  # (5,H,W)
        outs.append(pack.astype(np.float32, copy=False))

    # vertical flips of each rotated result: flip along axis=0 (rows, y)
    aug_vflips = []
    for k in range(4):
        base = outs[k]
        b0, b1, b2, ur, vr = base
        b0f = np.flip(b0, axis=0); b1f = np.flip(b1, axis=0); b2f = np.flip(b2, axis=0)
        urf = np.flip(ur, axis=0)                      # u unchanged sign for vertical flip
        vrf = -np.flip(vr, axis=0)                     # v changes sign under vertical flip
        packf = np.stack([b0f, b1f, b2f, urf, vrf], axis=0)
        aug_vflips.append(packf.astype(np.float32, copy=False))

    return outs + aug_vflips  # 8 tensors of shape (5,H,W)

def pack_galaxies_to_npy_with_aug(ROOT: str, OUT: str, db: dict, vel_bins=None, dtype=np.float32):
    os.makedirs(OUT, exist_ok=True)
    vel_bins = vel_bins or [(-300, -100), (-100, 100), (100, 300)]

    # Group by subid
    by_subid = defaultdict(list)
    for (subid, view) in db.keys():
        by_subid[subid].append(view)

    # Discover H,W from any brightness map
    any_key = next(iter(db))
    H, W = np.load(db[any_key]["brightness"][vel_bins[0]], mmap_mode='r').shape

    for subid, views in sorted(by_subid.items(), key=lambda kv: int(kv[0])):
        vlist = sorted(views, key=lambda v: int(v))
        V = len(vlist)
        C = 5
        A = 8  # augmentations per view

        # Allocate (V*A, C, H, W)
        arr = np.zeros((V * A, C, H, W), dtype=dtype)

        write_idx = 0
        for view in vlist:
            rec = db[(subid, view)]

            # Load 3 brightness channels
            bch = []
            for (vmin, vmax) in vel_bins:
                x = np.load(rec["brightness"][(vmin, vmax)], mmap_mode='r').astype(dtype, copy=False)
                bch.append(x)
            x3 = np.stack(bch, axis=0)  # (3,H,W)

            # Load velocities
            u = np.load(rec["vel_u"], mmap_mode='r').astype(dtype, copy=False)
            v = np.load(rec["vel_v"], mmap_mode='r').astype(dtype, copy=False)

            # Build 8 augmentations and write
            aug_list = _apply_augments(x3, u, v)  # list of 8×(5,H,W)
            for aug in aug_list:
                arr[write_idx] = aug  # (5,H,W)
                write_idx += 1

        # Save .npy and metadata json
        out_path = os.path.join(OUT, f"TNG50_snap099_subid{subid}_views{V}_aug8_C5_{H}x{W}.npy")
        np.save(out_path, arr)

        meta = {
            "subid": subid,
            "views": [int(v) for v in vlist],
            "shape": list(arr.shape),              # [V*8, 5, H, W]
            "dtype": str(arr.dtype),
            "height": H, "width": W,
            "channels": [
                f"brightness_{vel_bins[0][0]}_{vel_bins[0][1]}",
                f"brightness_{vel_bins[1][0]}_{vel_bins[1][1]}",
                f"brightness_{vel_bins[2][0]}_{vel_bins[2][1]}",
                "velocity_u", "velocity_v"
            ],
            "velocity_bins": vel_bins,
            "augmentations": AUG_NAMES,
            "aug_order_note": "Per original view, data is appended in listed order."
        }
        with open(out_path.replace(".npy", ".json"), "w") as f:
            json.dump(meta, f, indent=2)

        print(f"✓ wrote {out_path}  (views={V}, aug per view=8, total samples={V*A})")

In [None]:
# ---- usage example ----
ROOT = "/home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/"
db = discover_samples(ROOT)   # from earlier
OUT = os.path.join(ROOT, "packed_aug8")
pack_galaxies_to_npy_with_aug(ROOT, OUT, db, vel_bins=VEL_BINS)

In [None]:
vel_bins = [(-300, -100), (-100, 100), (100, 300)]
by_subid = defaultdict(list)
for (subid, view) in db.keys():
    by_subid[subid].append(view)
any_key = next(iter(db))
H, W = np.load(db[any_key]["brightness"][vel_bins[0]], mmap_mode='r').shape
#vlist = sorted(views, key=lambda v: int(v))
V = 10
C = 5
A = 8  # augmentations per view

# Allocate (V*A, C, H, W)
arr = np.zeros((V * A, C, H, W), dtype=np.float32)

bad_subids = [
    '372754',
    '372755',
    '414918',
    '429471',
    '438148',
    '447914',
    '471248',
    '475619',
    '478216',
    '498522',
    '516101',
    '517899',
    '523889',
    '525533',
    '526478',
    '535050',
    '543114',
    '546114',
    '559386',
    '574037',
    '576516',
    '580250',
    '613192',  
]

for i in range(198):
    subid = sorted(by_subid.items(), key=lambda kv: int(kv[0]))[i][0]
    print(i)
    print(subid)
    
    view = '00'
    rec = db[(subid, view)]
    # Load 3 brightness channels
    bch = []
    for (vmin, vmax) in vel_bins:
        x = np.load(rec["brightness"][(vmin, vmax)], mmap_mode='r').astype(np.float32, copy=False)
        bch.append(x)
    x3 = np.stack(bch, axis=0)  # (3,H,W)
    
    # Load velocities
    u = np.load(rec["vel_u"], mmap_mode='r').astype(np.float32, copy=False)
    v = np.load(rec["vel_v"], mmap_mode='r').astype(np.float32, copy=False)
    
    ny, nx = v.shape
    x = np.arange(nx)
    y = np.arange(ny)
    X, Y = np.meshgrid(x, y)
    R = np.sqrt((X-W/2)**2 + (Y-H/2)**2)
    pc_per_pix = 500/W
    mask = np.array(R > 30/pc_per_pix,dtype=int)
    
    
    step=4
    # subsample
    sl = (slice(None, None, step), slice(None, None, step))
    Xs, Ys = X[sl], Y[sl]
    Vx, Vy = (mask*u)[sl], (mask*v)[sl]
    
    fig, ax = plt.subplots(1,1,figsize=(6,6))
    ax.imshow(np.log10(x3[1]),vmin=-23,vmax=-18,origin='lower')
    q = ax.quiver(Xs,Ys,Vx,Vy,color='white')
    plt.show()

In [None]:
import subprocess

bad_subids = [
    '372754',
    '372755',
    '414918',
    '429471',
    '438148',
    '447914',
    '471248',
    '475619',
    '478216',
    '498522',
    '516101',
    '517899',
    '523889',
    '525533',
    '526478',
    '535050',
    '543114',
    '546114',
    '559386',
    '574037',
    '576516',
    '580250',
    '613192',  
]

commands = [f'rm /home/cj535/palmer_scratch/TNG50_cutouts/MW_sample_maps/packed_aug8*/TNG50_snap099_subid{subid}_views10_aug8_C5_256x256.*'\
            for subid in bad_subids]
for cmd in commands:
    subprocess.run(cmd,shell=True)

In [None]:
# Example file paths (adjust to your scratch directory)

subid = "613192"

npy_path  = os.path.join(OUT, f"TNG50_snap099_subid{subid}_views10_aug8_C5_256x256.npy")
json_path = npy_path.replace(".npy", ".json")

# ---- read metadata ----
with open(json_path, "r") as f:
    meta = json.load(f)

print("Metadata keys:", list(meta.keys()))
print(json.dumps(meta, indent=2))

# ---- load data and inspect shape ----
arr = np.load(npy_path, mmap_mode='r')  # memory-maps file (doesn't load entire array)
print("Array shape:", arr.shape)
print("Channels order:", meta["channels"])
print("Augmentation order:", meta["augmentations"])

# ---- example: pick a single augmented view ----
i = 0
sample = arr[i]   # (5, H, W)
brightness = sample[:3]  # (3,H,W)
velocity_u, velocity_v = sample[3], sample[4]

print("Brightness map stats:", brightness.mean(), brightness.std())
print("Velocity_u/v stats:", velocity_u.mean(), velocity_v.mean())