# LSST Asteroid Trail — Evaluation Notebook
This notebook loads a trained model, runs inference on an HDF5 **test set**, aligns detections with a `test.csv` catalog of injected trails, and produces the histograms you requested:

- **Histogram of LSST stack detections vs. magnitude** and **trail length**  
- **Histogram of Neural Network detections vs. magnitude** and **trail length**  
- Summary counts for NN, stack, and combined detections

> Tip: Adjust the **paths** and **threshold** in the first code cell below.


In [None]:

# --- Configuration (edit these) ---
MODEL_CKPT = "./best_unet_tf_parity.pt"   # path to your trained checkpoint
TEST_H5    = "../DATA/test.h5"            # path to HDF5 test set with datasets: images, masks
CATALOG_CSV= "../DATA/test.csv"           # path to catalog (injected trails with positions & metadata)

# Detection parameters
THRESHOLD  = 0.80      # binarization threshold for NN probability map
RADIUS_PX  = 3         # pixel radius around (x,y) to accept NN hit

# Tiling parameters (should match training/eval input tiling)
TILE       = 128

# Histogram saving
OUTDIR     = "./eval_outputs"
import os
os.makedirs(OUTDIR, exist_ok=True)
print("Configured. Edit paths above if needed.")


In [None]:

import os, math, time, numpy as np, torch, h5py, pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

def robust_stats_mad(arr):
    """Return (median, sigma) using MAD (robust to outliers)."""
    med = np.median(arr)
    mad = np.median(np.abs(arr - med))
    sigma = 1.4826 * (mad + 1e-12)  # 1.4826*MAD ~ std for normal
    return np.float32(med), np.float32(sigma)

class H5TiledDataset(Dataset):
    """Streams tiles from HDF5 and applies robust per-image standardization + k-sigma clipping.
    Expected datasets in H5: images (N,H,W), masks (N,H,W)"""
    def __init__(self, h5_path, tile=128, k_sigma=5.0, image_crop_for_stats=(512,512)):
        self.h5_path = h5_path
        self.tile = int(tile)
        self.k_sigma = float(k_sigma)
        self.image_crop_for_stats = image_crop_for_stats
        self._h5 = None
        self._stats_cache = {}
        with h5py.File(self.h5_path, "r") as f:
            self.N, self.H, self.W = f["images"].shape
            assert f["masks"].shape == (self.N, self.H, self.W)
        Hb = math.ceil(self.H / self.tile)
        Wb = math.ceil(self.W / self.tile)
        self.indices = [(i, r, c) for i in range(self.N) for r in range(Hb) for c in range(Wb)]

    def _ensure_open(self):
        if self._h5 is None:
            self._h5 = h5py.File(self.h5_path, "r")
            self.x = self._h5["images"]
            self.y = self._h5["masks"]

    def _get_image_stats(self, i):
        if i in self._stats_cache: return self._stats_cache[i]
        H, W = self.H, self.W
        if self.image_crop_for_stats is None:
            s = min(512, H, W)
            h0 = (H - s)//2; w0 = (W - s)//2; h1, w1 = h0+s, w0+s
        else:
            sH, sW = self.image_crop_for_stats
            sH = min(sH, H); sW = min(sW, W)
            h0 = (H - sH)//2; w0 = (W - sW)//2; h1, w1 = h0+sH, w0+sW
        crop = self.x[i, h0:h1, w0:w1].astype("float32")
        med, sigma = robust_stats_mad(crop)
        if not np.isfinite(sigma) or sigma <= 0: sigma = np.float32(1.0)
        self._stats_cache[i] = (med, sigma)
        return self._stats_cache[i]

    def _normalize_and_clip(self, tile_arr, med, sigma):
        x = (tile_arr - med) / sigma
        if self.k_sigma > 0: x = np.clip(x, -self.k_sigma, +self.k_sigma)
        return x

    def __len__(self): return len(self.indices)

    def __getitem__(self, idx):
        self._ensure_open()
        i, r, c = self.indices[idx]
        t = self.tile
        r0, c0 = r*t, c*t
        r1, c1 = min(r0+t, self.H), min(c0+t, self.W)
        x = self.x[i, r0:r1, c0:c1].astype("float32")
        y = self.y[i, r0:r1, c0:c1].astype("float32")
        if x.shape[0] != t or x.shape[1] != t:
            xp = np.zeros((t, t), np.float32)
            yp = np.zeros((t, t), np.float32)
            xp[:x.shape[0], :x.shape[1]] = x
            yp[:y.shape[0], :y.shape[1]] = y
            x, y = xp, yp
        med, sigma = self._get_image_stats(i)
        x = self._normalize_and_clip(x, med, sigma)
        return torch.from_numpy(x[None, ...]), torch.from_numpy(y[None, ...])


In [None]:

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(nn.Conv2d(F_g, F_int, 1, bias=False), nn.BatchNorm2d(F_int))
        self.W_x = nn.Sequential(nn.Conv2d(F_l, F_int, 1, bias=False), nn.BatchNorm2d(F_int))
        self.psi = nn.Sequential(nn.Conv2d(F_int, F_l, 1, bias=False), nn.BatchNorm2d(F_l), nn.Sigmoid())
        self.relu = nn.ReLU(inplace=True)
    def forward(self, g, x):
        g1 = self.W_g(g); x1 = self.W_x(x)
        if g1.shape[-2:] != x1.shape[-2:]:
            g1 = F.interpolate(g1, size=x1.shape[-2:], mode='bilinear', align_corners=False)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

def conv_bn_act(in_ch, out_ch, k=3, act='relu'):
    layers = [nn.Conv2d(in_ch, out_ch, k, padding=k//2, bias=False), nn.BatchNorm2d(out_ch)]
    if act.lower() == 'relu': layers += [nn.ReLU(inplace=True)]
    elif act.lower() == 'selu': layers += [nn.SELU(inplace=True)]
    elif act.lower() == 'elu': layers += [nn.ELU(inplace=True)]
    else: layers += [nn.ReLU(inplace=True)]
    return nn.Sequential(*layers)

class EncoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch, act='relu', dropout=0.0, max_pool=True):
        super().__init__()
        self.conv1 = conv_bn_act(in_ch, out_ch, 3, act)
        self.conv2 = conv_bn_act(out_ch, out_ch, 3, act)
        self.drop = nn.Dropout2d(dropout) if dropout>0 else nn.Identity()
        self.pool = nn.MaxPool2d(2) if max_pool else nn.Identity()
    def forward(self, x):
        x = self.conv1(x); x = self.conv2(x); x = self.drop(x)
        skip = x
        x = self.pool(x)
        return x, skip

class DecoderBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch, act='relu', dropout=0.0, do_up=True, use_attn=True):
        super().__init__()
        self.skip_ch = skip_ch
        up_out_ch = in_ch if (skip_ch == 0) else skip_ch
        self.up = (nn.Sequential(
            nn.ConvTranspose2d(in_ch, up_out_ch, 3, stride=2, padding=1, output_padding=1, bias=False),
            nn.BatchNorm2d(up_out_ch),
            nn.ReLU(inplace=True) if act=='relu' else nn.SELU(inplace=True)
        ) if do_up else nn.Identity())
        self.attn = (AttentionGate(up_out_ch, skip_ch, max(1, skip_ch//2)) if (use_attn and skip_ch > 0) else nn.Identity())
        conv1_in = up_out_ch if skip_ch == 0 else (up_out_ch + skip_ch)
        self.conv1 = conv_bn_act(conv1_in, out_ch, 3, act)
        self.conv2 = conv_bn_act(out_ch, out_ch, 3, act)
        self.drop = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
    def forward(self, x, skip=None):
        x = self.up(x)
        if self.skip_ch > 0 and skip is not None:
            if x.shape[-2:] != skip.shape[-2:]:
                x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            skip = self.attn(x, skip)
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x); x = self.conv2(x); x = self.drop(x)
        return x

class UNetTFParity(nn.Module):
    def __init__(self, in_ch=1, arch=None, kernel_size=5):
        super().__init__()
        assert arch is not None, "Provide architecture dict with keys like 'downFilters', 'downActivation', 'downDropout', 'downMaxPool', 'upFilters', 'upActivation', 'upDropout'."
        self.input_bn = nn.BatchNorm2d(in_ch)
        self.enc = nn.ModuleList()
        prev = in_ch
        for nf, act, drop, mp in zip(arch["downFilters"], arch["downActivation"], arch["downDropout"], arch["downMaxPool"]):
            self.enc.append(EncoderBlock(prev, nf, act, drop, max_pool=mp))
            prev = nf
        self.dec = nn.ModuleList()
        skip_chs = arch["downFilters"][:]
        for i, (nf, act, drop) in enumerate(zip(arch["upFilters"], arch["upActivation"], arch["upDropout"])):
            in_ch_dec = skip_chs[-1] if i == 0 else arch["upFilters"][i-1]
            skip_ch = 0 if i == 0 else skip_chs[-1-i]
            self.dec.append(DecoderBlock(in_ch_dec, skip_ch, nf, act, drop, do_up=True, use_attn=True))
        self.out_conv = nn.Conv2d(arch["upFilters"][-1], 1, kernel_size, padding=kernel_size//2)
        self.out_act  = nn.Sigmoid()
    def forward(self, x):
        x = self.input_bn(x)
        skips = []
        for i, blk in enumerate(self.enc):
            x, skip = blk(x)
            skips.append(skip if i < len(self.enc)-1 else None)
        x = self.dec[0](x, skips[-1])
        for i in range(1, len(self.dec)):
            x = self.dec[i](x, skips[-1 - i])
        x = self.out_conv(x)
        return self.out_act(x)

def load_model(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location=device)
    arch = ckpt.get("arch", None)
    if arch is None:
        raise ValueError("Checkpoint missing 'arch'. Re-save your model with {'state_dict':..., 'arch':...}.")
    model = UNetTFParity(in_ch=1, arch=arch, kernel_size=5).to(device)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()
    print("Loaded model with architecture:", arch)
    return model


In [None]:

@torch.no_grad()
def predict_tiles_to_full(h5_path, loader, model, tile=128):
    """Assemble full-size per-panel predictions from tile predictions."""
    model.eval()
    with h5py.File(h5_path, "r") as f:
        N, H, W = f["images"].shape
        Hb, Wb = math.ceil(H / tile), math.ceil(W / tile)
        tiles_per_panel = Hb * Wb
    xb0, _ = next(iter(loader))
    xb0 = xb0.to(next(model.parameters()).device)
    out0 = model(xb0[:1])
    oh, ow = out0.shape[-2], out0.shape[-1]
    full_preds = np.zeros((N, H, W), dtype=np.float32)
    tile_buf = []
    ptr = 0
    for xb, _ in loader:
        xb = xb.to(next(model.parameters()).device)
        out = model(xb)
        probs = out.detach()[:, 0]
        if (oh, ow) != (tile, tile):
            probs = F.interpolate(probs.unsqueeze(1), size=(tile, tile), mode='bilinear', align_corners=False).squeeze(1)
        probs = probs.cpu().numpy()
        tile_buf.extend(list(probs))
        while len(tile_buf) >= tiles_per_panel:
            p = ptr // tiles_per_panel
            if p >= full_preds.shape[0]: break
            panel = np.zeros((Hb * tile, Wb * tile), dtype=np.float32)
            for r in range(Hb):
                for c in range(Wb):
                    t_idx = r * Wb + c
                    tile_img = tile_buf[t_idx]
                    r0, c0 = r * tile, c * tile
                    panel[r0:r0 + tile, c0:c0 + tile] = tile_img
            full_preds[p] = panel[:H, :W]
            tile_buf = tile_buf[tiles_per_panel:]
            ptr += tiles_per_panel
    return full_preds


In [None]:

def mark_nn_and_stack(csv_path, p_full, radius=3, thr=0.5):
    """Add detection flags to catalog: stack_detected (from CSV) and nn_detected (from p_full)."""
    cat = pd.read_csv(csv_path).copy()
    need = {"image_id","x","y"}
    missing = need - set(cat.columns)
    if missing:
        raise ValueError(f"CSV missing columns: {missing}")
    if "stack_detection" in cat.columns:
        cat["stack_detected"] = cat["stack_detection"].astype(bool)
    elif "stack_mag" in cat.columns:
        cat["stack_detected"] = ~cat["stack_mag"].isna()
    else:
        cat["stack_detected"] = False
    H, W = p_full.shape[1:]
    pred_bin = (p_full >= thr).astype(np.uint8)
    nn = np.zeros(len(cat), dtype=bool)
    for pid, grp in cat.groupby("image_id"):
        pid = int(pid)
        if pid < 0 or pid >= pred_bin.shape[0]: continue
        mask = pred_bin[pid]
        xs = grp["x"].to_numpy().astype(int)
        ys = grp["y"].to_numpy().astype(int)
        xs = np.clip(xs, 0, W-1)
        ys = np.clip(ys, 0, H-1)
        for idx_row, (x, y) in zip(grp.index.to_numpy(), zip(xs, ys)):
            y0, y1 = max(0, y-radius), min(H, y+radius+1)
            x0, x1 = max(0, x-radius), min(W, x+radius+1)
            nn[idx_row] = (mask[y0:y1, x0:x1].max() > 0)
    cat["nn_detected"] = nn
    return cat

def _choose_mag_field(df):
    for c in ["PSF_mag", "integrated_mag", "mag"]:
        if c in df.columns: return c
    return None

def plot_detect_hist(cat, field, bins=12, title=None, savepath=None):
    nn_det = cat[cat["nn_detected"]]
    stk_det = cat[cat["stack_detected"]]
    cum_det = cat[cat["nn_detected"] | cat["stack_detected"]]
    vals = cat[field].to_numpy()
    vals = vals[np.isfinite(vals)]
    edges = np.histogram_bin_edges(vals, bins=bins)
    fig, ax = plt.subplots(figsize=(6.5,4.5))
    ax.hist(cat[field],      bins=edges, histtype="step", label="All injected", alpha=0.7)
    ax.hist(cum_det[field],  bins=edges, histtype="step", label="Cumulative (NN ∪ LSST)")
    ax.hist(nn_det[field],   bins=edges, histtype="step", label="NN detected")
    ax.hist(stk_det[field],  bins=edges, histtype="step", label="LSST stack detected")
    ax.set_xlabel(field.replace("_"," ")); ax.set_ylabel("Count")
    if title: ax.set_title(title)
    ax.legend(); ax.grid(True, alpha=0.3)
    if savepath:
        fig.savefig(savepath, dpi=150, bbox_inches="tight")
        print("Saved:", savepath)
    plt.show()


In [None]:

# Build dataset/loader, load model, and run predictions
test_ds = H5TiledDataset(TEST_H5, tile=TILE, k_sigma=5.0, image_crop_for_stats=(512,512))
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=0, pin_memory=False)
print("Tiles:", len(test_ds), "| images:", test_ds.N, f"| panel size: {test_ds.H}x{test_ds.W}")
model = load_model(MODEL_CKPT)
p_full = predict_tiles_to_full(TEST_H5, test_loader, model, tile=TILE)
np.save(os.path.join(OUTDIR, "p_full.npy"), p_full)
print("Predictions shape:", p_full.shape)


In [None]:

# Catalog matching + histograms
cat = mark_nn_and_stack(CATALOG_CSV, p_full, radius=RADIUS_PX, thr=THRESHOLD)

mag_field = _choose_mag_field(cat)
if mag_field is None:
    print("No magnitude column found in CSV. Skipping magnitude histogram.")
else:
    plot_detect_hist(cat, mag_field, bins=12, title=f"Detections vs {mag_field}",
                     savepath=os.path.join(OUTDIR, f"hist_vs_{mag_field}.png"))

if "trail_length" in cat.columns:
    plot_detect_hist(cat, "trail_length", bins=12, title="Detections vs trail length",
                     savepath=os.path.join(OUTDIR, "hist_vs_trail_length.png"))
else:
    print("CSV missing 'trail_length' column. Skipping trail-length histogram.")

out_csv = os.path.join(OUTDIR, "test_with_detections.csv")
cat.to_csv(out_csv, index=False)
print("Saved catalog with flags:", out_csv)

tot = len(cat)
nn  = int(cat["nn_detected"].sum())
stk = int(cat["stack_detected"].sum())
cum = int((cat["nn_detected"] | cat["stack_detected"]).sum())
print(f"NN: {nn}/{tot} | LSST stack: {stk}/{tot} | Cumulative: {cum}/{tot}")


In [None]:

# Optional: quick pixel-level PRF1 at chosen threshold
with h5py.File(TEST_H5, "r") as f:
    gt_full = f["masks"][:].astype(np.uint8)
bin_full = (p_full >= THRESHOLD).astype(np.uint8)
tp = int((bin_full & gt_full).sum())
fp = int((bin_full & (1-gt_full)).sum())
fn = int(((1-bin_full) & gt_full).sum())
precision = tp / max(tp+fp, 1)
recall    = tp / max(tp+fn, 1)
f1        = 2*precision*recall / max(precision+recall, 1e-8)
print(f"Pixel-level P={precision:.4f}, R={recall:.4f}, F1={f1:.4f} (thr={THRESHOLD})")
