# OCT Speckle Denoising (Supervised)

Supervised OCT denoising pipeline that corrects frame jitter, builds pseudo-clean targets by aligning and median-averaging neighboring B-scans, trains a 2D UNet, and evaluates results using image-quality metrics and manual retinal-layer annotations.

In [3]:
# Cell 1: Imports, config, basic dataset checks
# Edit DATA_ROOT to point to your unzipped dataset folder before running.
import warnings
warnings.filterwarnings("ignore")

from pathlib import Path
import sys, os, json, hashlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# image I/O & processing
try:
    import tifffile
except Exception as e:
    raise ImportError("tifffile is required. Install with `pip install tifffile`") from e

from scipy.ndimage import shift as ndi_shift

# try to import scikit-image utilities (registration, filters). If missing, show helpful message.
_skimage_ok = True
try:
    from skimage.registration import phase_cross_correlation
    from skimage import filters, restoration, metrics
except Exception:
    _skimage_ok = False
    print("Warning: scikit-image not fully available. Install with `pip install scikit-image` to enable registration/filters.")

# Torch (used later for training)
try:
    import torch
    _torch_ok = True
except Exception:
    _torch_ok = False
    print("Warning: PyTorch not available. Install `torch`/`torchvision` if you plan to train a model in this notebook.")

# --------- USER CONFIG (edit these) ----------
DATA_ROOT = Path("./Data")        # <<< change this to your dataset folder (where Intensity.tif and Excel files live)
BASE_TIF_NAME = "Intensity.tif"   # base used in many annotation filenames
OUT_DIR = Path("./oct_denoise_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# --------- quick helpers ----------
def md5(path, chunk=8192):
    h = hashlib.md5()
    with open(path, "rb") as f:
        for c in iter(lambda: f.read(chunk), b""):
            h.update(c)
    return h.hexdigest()

def find_files(root: Path):
    tif_paths = sorted(list(root.rglob("*.tif")) + list(root.rglob("*.tiff")))
    mat_paths = sorted(list(root.rglob("*.mat")))
    excel_paths = sorted(list(root.rglob("*.xls")) + list(root.rglob("*.xlsx")))
    return tif_paths, mat_paths, excel_paths

# --------- quick status print ----------
print("Python:", sys.version.splitlines()[0])
print("Working dir:", Path.cwd())
print("DATA_ROOT (edit if incorrect):", DATA_ROOT.resolve())
print("OUT_DIR:", OUT_DIR.resolve())
print("scikit-image available:", _skimage_ok)
print("PyTorch available:", _torch_ok, "| torch.cuda available:", (_torch_ok and torch.cuda.is_available()))

# If DATA_ROOT doesn't exist, stop early (avoid long glob)
if not DATA_ROOT.exists():
    print("\nERROR: DATA_ROOT does not exist ->", DATA_ROOT.resolve())
    print("Please set DATA_ROOT to the folder containing your unzipped dataset and re-run this cell.")
else:
    tifs, mats, excels = find_files(DATA_ROOT)
    print(f"\nFound {len(tifs)} TIFF(s), {len(mats)} .mat file(s), {len(excels)} Excel file(s).")
    if tifs:
        for t in tifs[:6]:
            try:
                stat = t.stat()
                print(" -", t.relative_to(DATA_ROOT), f"(size={stat.st_size} bytes, md5={md5(t)[:8]}...)")
            except Exception:
                print(" -", t.relative_to(DATA_ROOT))
    if excels:
        print("\nExample annotation files (first 12):")
        for e in excels[:12]:
            print(" -", e.relative_to(DATA_ROOT))
    # Save a tiny JSON summary for downstream cells
    summary = {
        "tif_count": len(tifs),
        "mat_count": len(mats),
        "excel_count": len(excels),
        "tifs": [str(p.relative_to(DATA_ROOT)) for p in tifs[:8]],
        "examples_excels": [str(p.relative_to(DATA_ROOT)) for p in excels[:12]]
    }
    with open(OUT_DIR / "dataset_discovery.json", "w") as f:
        json.dump(summary, f, indent=2)
    print("\nWrote dataset_discovery.json to", OUT_DIR / "dataset_discovery.json")


Python: 3.12.5 | packaged by Anaconda, Inc. | (main, Sep 12 2024, 13:22:57) [Clang 14.0.6 ]
Working dir: /Users/kartikgoyal/Desktop/Speckle_Noise_Reduction
DATA_ROOT (edit if incorrect): /Users/kartikgoyal/Desktop/Speckle_Noise_Reduction/Data
OUT_DIR: /Users/kartikgoyal/Desktop/Speckle_Noise_Reduction/oct_denoise_outputs
scikit-image available: True
PyTorch available: True | torch.cuda available: False

Found 1 TIFF(s), 1 .mat file(s), 520 Excel file(s).
 - Intensity.tif (size=136682861 bytes, md5=632aa6e9...)

Example annotation files (first 12):
 - Intensity.tif100_x.xlsx
 - Intensity.tif100_y.xlsx
 - Intensity.tif101_x.xlsx
 - Intensity.tif101_y.xlsx
 - Intensity.tif102_x.xlsx
 - Intensity.tif102_y.xlsx
 - Intensity.tif103_x.xlsx
 - Intensity.tif103_y.xlsx
 - Intensity.tif104_x.xlsx
 - Intensity.tif104_y.xlsx
 - Intensity.tif105_x.xlsx
 - Intensity.tif105_y.xlsx

Wrote dataset_discovery.json to oct_denoise_outputs/dataset_discovery.json


In [4]:
# Cell 2: Parse annotation Excel pairs, save annotations.json + summary, make histogram + overlay images
# Requires DATA_ROOT and OUT_DIR defined in previous cell.

import re
import json
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import tifffile

# safety checks
if not Path(DATA_ROOT).exists():
    raise FileNotFoundError(f"DATA_ROOT does not exist: {DATA_ROOT}")
OUT_DIR.mkdir(parents=True, exist_ok=True)

def find_xy_pairs(root: Path, base_hint="Intensity.tif"):
    """
    Find pairs of Excel files that are annotations:
    - prefer files matching <base_hint><idx>_x.xlsx / _y.xlsx
    - fallback to any files ending with _x.(xls|xlsx) with a corresponding _y file (matching digit group)
    Returns dict mapping slice_index -> list of (x_list, y_list) curves
    """
    root = Path(root)
    x_files = sorted([p for p in root.rglob("*.xlsx")] + [p for p in root.rglob("*.xls")])
    # filter x-files
    x_files = [p for p in x_files if re.search(r"_x(\.xls|\.xlsx)$", str(p.name), flags=re.IGNORECASE)]
    pairs = {}
    for fx in x_files:
        name = fx.name
        # try extract index in form base_hint<digits> or <digits>_x
        m = re.search(rf"{re.escape(base_hint)}(\d+)", name) or re.search(r"(\d+)(?=_x\.)", name) or re.search(r"(\d+)", name)
        if not m:
            # fallback: try to find any file with same prefix but _y replacement
            possible = list(root.rglob(name.replace("_x", "_y")))
            fy = possible[0] if possible else None
            idx = None
        else:
            idx = int(m.group(1))
            # build likely y filename candidates
            cand1 = fx.with_name(re.sub(r"_x(\.xls|\.xlsx)$", "_y\\1", name, flags=re.IGNORECASE))
            cand2 = fx.with_name(re.sub(r"_x", "_y", name, flags=re.IGNORECASE))
            fy = cand1 if cand1.exists() else (cand2 if cand2.exists() else None)
            if fy is None:
                # try search for any file containing same digit group and '_y'
                matches = list(root.rglob(f"*{idx}*_y.*"))
                fy = matches[0] if matches else None
        if fy is None:
            continue
        # read x,y sheets (no header expected)
        try:
            xdf = pd.read_excel(fx, header=None)
            ydf = pd.read_excel(fy, header=None)
        except Exception as e:
            print(f"Warning: could not read pair {fx.relative_to(root)} / {fy.relative_to(root)}: {e}")
            continue
        curves = []
        for r in range(min(xdf.shape[0], ydf.shape[0])):
            xrow = xdf.iloc[r].dropna().values
            yrow = ydf.iloc[r].dropna().values
            if xrow.size == 0 or yrow.size == 0: 
                continue
            if xrow.size != yrow.size:
                n = min(xrow.size, yrow.size)
                xrow = xrow[:n]; yrow = yrow[:n]
            try:
                xlist = [float(v) for v in xrow]
                ylist = [float(v) for v in yrow]
            except:
                continue
            if len(xlist) >= 2:
                curves.append([xlist, ylist])
        if curves:
            pairs.setdefault(int(idx), []).extend(curves)
    return pairs

print("Parsing annotation _x/_y Excel pairs (this may take a moment)...")
annotations = find_xy_pairs(Path(DATA_ROOT), base_hint=BASE_TIF_NAME)
n_annotated = len(annotations)
print(f"Found annotations for {n_annotated} slices.")

# Save annotations.json (keys as strings)
ann_path = OUT_DIR / "annotations.json"
with open(ann_path, "w") as f:
    json.dump({str(k): annotations[k] for k in sorted(annotations.keys())}, f, indent=2)
print("Saved annotations.json ->", ann_path)

# Save annotation summary (counts per slice)
summary = {str(k): len(v) for k,v in annotations.items()}
summary_path = OUT_DIR / "annotation_summary.json"
with open(summary_path, "w") as f:
    json.dump(summary, f, indent=2)
print("Saved annotation_summary.json ->", summary_path)

# Histogram of annotation coverage
counts = np.array(list(summary.values())) if summary else np.array([0])
plt.figure(figsize=(6,3))
plt.hist(counts, bins=range(0, int(counts.max() if counts.size else 1)+2), align='left', rwidth=0.8)
plt.xlabel("Number of traced curves (per slice)")
plt.ylabel("Number of slices")
plt.title("Annotation coverage")
plt.grid(axis='y', alpha=0.4)
hist_path = OUT_DIR / "annotation_histogram.png"
plt.savefig(hist_path, bbox_inches='tight', dpi=150)
plt.close()
print("Saved annotation histogram ->", hist_path)

# Create representative slice and overlays (if a TIFF exists)
tif_list = sorted(list(Path(DATA_ROOT).rglob("*.tif")) + list(Path(DATA_ROOT).rglob("*.tiff")))
if tif_list:
    tif0 = tif_list[0]
    print("Using TIFF for overlays:", tif0.relative_to(DATA_ROOT))
    stack = tifffile.imread(str(tif0)).astype(np.float32)
    # representative middle slice
    mid = stack.shape[0] // 2
    rep = stack[mid]
    rep_norm = (rep - rep.min()) / (rep.max() - rep.min() + 1e-9)
    rep_path = OUT_DIR / "representative_slice.png"
    plt.imsave(str(rep_path), rep_norm, cmap="gray")
    print("Saved representative slice ->", rep_path)
    # overlays: pick first, median, and most-annotated slices if available
    if annotations:
        sorted_idxs = sorted(annotations.keys())
        first_idx = sorted_idxs[0]
        median_idx = sorted_idxs[len(sorted_idxs)//2]
        max_idx = max(sorted_idxs, key=lambda k: len(annotations[k]))
        def save_overlay(idx, ann_list, name):
            if idx < 0 or idx >= stack.shape[0]:
                print("Skipping overlay for out-of-range index", idx)
                return
            img = stack[idx]
            norm = (img - img.min())/(img.max()-img.min()+1e-9)
            plt.figure(figsize=(8,6))
            plt.imshow(norm, cmap='gray')
            for curve in ann_list:
                x,y = curve
                plt.plot(x, y, '-r', linewidth=1)
            plt.axis('off')
            p = OUT_DIR / name
            plt.savefig(p, bbox_inches='tight', dpi=150)
            plt.close()
            print("Saved overlay ->", p)
        save_overlay(first_idx, annotations[first_idx], "overlay_first.png")
        save_overlay(median_idx, annotations[median_idx], "overlay_median.png")
        save_overlay(max_idx, annotations[max_idx], "overlay_max.png")
    else:
        print("No annotations available to create overlays.")
else:
    print("No TIFF files found under DATA_ROOT; overlays and representative slice not created.")

print("\nCell complete. Check files in OUT_DIR:", OUT_DIR.resolve())


Parsing annotation _x/_y Excel pairs (this may take a moment)...
Found annotations for 260 slices.
Saved annotations.json -> oct_denoise_outputs/annotations.json
Saved annotation_summary.json -> oct_denoise_outputs/annotation_summary.json
Saved annotation histogram -> oct_denoise_outputs/annotation_histogram.png
Using TIFF for overlays: Intensity.tif
Saved representative slice -> oct_denoise_outputs/representative_slice.png
Saved overlay -> oct_denoise_outputs/overlay_first.png
Saved overlay -> oct_denoise_outputs/overlay_median.png
Saved overlay -> oct_denoise_outputs/overlay_max.png

Cell complete. Check files in OUT_DIR: /Users/kartikgoyal/Desktop/Speckle_Noise_Reduction/oct_denoise_outputs


In [5]:
# Cell 3 (fixed): Estimate frame-to-frame jitter (shifts) with robust preprocessing and fallback error
# Writes: OUT_DIR/estimated_shifts_all_fixed.csv, OUT_DIR/shifts_plot_fixed.png, OUT_DIR/shifts_summary_fixed.json

import numpy as np, csv, json, time
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.ndimage import shift as ndi_shift
import tifffile

# registration routine
try:
    from skimage.registration import phase_cross_correlation
except Exception:
    raise ImportError("skimage.registration.phase_cross_correlation required (pip install scikit-image)")

# PARAMETERS
UPSAMPLE_FACTOR = 10     # subpixel precision (reduce to 5 if slow)
REFERENCE = "center"     # "center" or integer index
MAX_FRAMES = None        # optionally limit processing for speed
WINDOW_APPLY = True      # apply Hann windowing to reduce edge effects
NORMALIZE_FOR_REG = True # subtract mean and divide by std before registration

# Paths (reuse variables from cell1)
DATA_ROOT = Path(DATA_ROOT) if isinstance(DATA_ROOT, (str, Path)) else Path("./data")
OUT_DIR  = Path(OUT_DIR)  if isinstance(OUT_DIR, (str, Path)) else Path("./oct_denoise_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Find TIFF
tif_list = sorted(list(DATA_ROOT.rglob("*.tif")) + list(DATA_ROOT.rglob("*.tiff")))
if not tif_list:
    raise FileNotFoundError(f"No TIFF files found under DATA_ROOT: {DATA_ROOT}")
tif_path = tif_list[0]
print("Using TIFF:", tif_path.relative_to(DATA_ROOT))

# Load stack
stack = tifffile.imread(str(tif_path)).astype(np.float32)
n_slices = stack.shape[0]
print("Stack shape (slices, H, W):", stack.shape)

# choose frames
if MAX_FRAMES is None:
    idxs = np.arange(n_slices)
else:
    idxs = np.arange(min(MAX_FRAMES, n_slices))

# choose reference index
if REFERENCE == "center":
    ref_idx = n_slices // 2
elif isinstance(REFERENCE, int):
    ref_idx = int(REFERENCE)
else:
    ref_idx = 0
print("Reference slice index:", ref_idx)

# helper: hann window 2D
def hann2d(h, w):
    wy = np.hanning(h)
    wx = np.hanning(w)
    return np.outer(wy, wx)

eps = 1e-9

# preprocess function for registration
def prep_for_reg(img):
    # convert to float64
    im = img.astype(np.float64)
    # subtract mean and scale by std to avoid large DC component
    if NORMALIZE_FOR_REG:
        std = im.std()
        if std < eps:
            im = im - im.mean()
        else:
            im = (im - im.mean()) / (std + eps)
    # apply windowing to reduce edge artifacts (helpful for FFT-based methods)
    if WINDOW_APPLY:
        h,w = im.shape
        win = hann2d(h,w)
        im = im * win
    return im

# arrays to fill
shifts = np.zeros((n_slices,2), dtype=float)
errors = np.zeros((n_slices,), dtype=float)

ref_raw = stack[ref_idx]
ref_p = prep_for_reg(ref_raw)

start = time.time()
for i in idxs:
    if i == ref_idx:
        shifts[i] = (0.0, 0.0)
        errors[i] = 0.0
        continue
    mov_raw = stack[i]
    mov_p = prep_for_reg(mov_raw)

    # try subpixel registration; if it raises, fallback to upsample=1
    try:
        shift, err, diffphase = phase_cross_correlation(ref_p, mov_p, upsample_factor=UPSAMPLE_FACTOR)
    except Exception:
        shift, err, diffphase = phase_cross_correlation(ref_p, mov_p, upsample_factor=1)

    # If shift contains NaN or large values, clamp it
    if not np.isfinite(shift).all():
        shift = np.array((0.0, 0.0))
    shifts[i] = shift

    # If skimage gave a valid error (0..1, finite, and not exactly 1.0), use it; otherwise compute robust fallback
    use_err = None
    if np.isfinite(err) and (err >= 0.0) and (err < 0.9999):
        use_err = float(err)
    else:
        # fallback: apply the shift to the raw moving image and compute normalized RMSE vs raw reference
        try:
            moved_full = ndi_shift(mov_raw, shift=shift, order=1, mode='nearest')
            # compute normalized RMSE: sqrt(mean((ref - moved)^2)) / (ref.max()-ref.min())
            denom = (ref_raw.max() - ref_raw.min()) if (ref_raw.max() - ref_raw.min())>eps else ref_raw.std()+eps
            nrmse = np.sqrt(np.mean((ref_raw - moved_full)**2)) / (denom + eps)
            # clamp to [0,1]
            nrmse = float(min(max(nrmse, 0.0), 1.0))
            use_err = nrmse
        except Exception:
            use_err = 1.0
    errors[i] = use_err

    # progress
    if (i % 50) == 0:
        print(f"  processed slice {i}/{n_slices}  shift={shifts[i]}  err={errors[i]:.4f}")

end = time.time()
print(f"Done in {(end-start):.1f}s")

# Save CSV with fixed results
csv_path = OUT_DIR / "estimated_shifts_all_fixed.csv"
with open(csv_path, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["slice_index","shift_y","shift_x","error"])
    for i in range(n_slices):
        writer.writerow([int(i), float(shifts[i,0]), float(shifts[i,1]), float(errors[i])])
print("Saved CSV ->", csv_path)

# Save summary JSON and a recommendation
dy_med = float(np.median(np.abs(shifts[:,0])))
dx_med = float(np.median(np.abs(shifts[:,1])))
dy_max = float(np.max(np.abs(shifts[:,0])))
dx_max = float(np.max(np.abs(shifts[:,1])))

summary = {
    "tif_used": str(tif_path.relative_to(DATA_ROOT)),
    "n_slices": int(n_slices),
    "reference_index": int(ref_idx),
    "upsample_factor": int(UPSAMPLE_FACTOR),
    "median_abs_shift_y": dy_med,
    "median_abs_shift_x": dx_med,
    "max_abs_shift_y": dy_max,
    "max_abs_shift_x": dx_max
}
summary["recommendation"] = ("Register slices before averaging (median shifts > ~1.5 px)."
                            if (dy_med > 1.5 or dx_med > 1.5) else
                            "Registration optional (median shifts small).")

json_path = OUT_DIR / "shifts_summary_fixed.json"
with open(json_path, "w") as jf:
    json.dump(summary, jf, indent=2)
print("Saved summary ->", json_path)
print("Recommendation:", summary["recommendation"])

# Plot shifts and errors
plt.figure(figsize=(11,3))
plt.subplot(1,2,1)
plt.plot(shifts[:,0], label="shift_y (vertical)"); plt.plot(shifts[:,1], label="shift_x (horizontal)")
plt.axhline(0, color='k', linewidth=0.5); plt.legend(); plt.title("Estimated shifts (px)"); plt.xlabel("slice"); plt.ylabel("pixels"); plt.grid(True)

plt.subplot(1,2,2)
plt.plot(errors, label="error (NRMSE-like)"); plt.ylim(0,1.0); plt.title("Per-slice error"); plt.xlabel("slice"); plt.ylabel("error"); plt.grid(True)

plot_path = OUT_DIR / "shifts_plot_fixed.png"
plt.tight_layout(); plt.savefig(plot_path, dpi=150, bbox_inches="tight"); plt.close()
print("Saved plot ->", plot_path)

# Print a few sample rows for quick sanity check
print("\nSample (first 8) slice shifts+errors:")
for i in range(min(8, n_slices)):
    print(f" slice {i:3d}: shift=({shifts[i,0]:6.2f},{shifts[i,1]:6.2f})  error={errors[i]:.4f}")


Using TIFF: Intensity.tif
Stack shape (slices, H, W): (339, 806, 500)
Reference slice index: 169
  processed slice 0/339  shift=[-14.1 -20.7]  err=0.1667
  processed slice 50/339  shift=[-48.    3.9]  err=0.1517
  processed slice 100/339  shift=[-25.  -32.9]  err=0.1403
  processed slice 150/339  shift=[-51.9  54.9]  err=0.1732
  processed slice 200/339  shift=[ 1.7 -7. ]  err=0.1053
  processed slice 250/339  shift=[-53.   32.9]  err=0.2133
  processed slice 300/339  shift=[42.2 45.9]  err=0.1674
Done in 36.7s
Saved CSV -> oct_denoise_outputs/estimated_shifts_all_fixed.csv
Saved summary -> oct_denoise_outputs/shifts_summary_fixed.json
Recommendation: Register slices before averaging (median shifts > ~1.5 px).
Saved plot -> oct_denoise_outputs/shifts_plot_fixed.png

Sample (first 8) slice shifts+errors:
 slice   0: shift=(-14.10,-20.70)  error=0.1667
 slice   1: shift=( 31.70, 63.10)  error=0.1735
 slice   2: shift=(-14.10, 54.80)  error=0.1567
 slice   3: shift=(-76.20, 36.00)  error=

In [7]:
# Cell: Generate improved pseudo-targets with robust registration, confidence maps, and optional NLM post-filter
# (Drop-in replacement â€” minimal safe improvements: robust manual error, exclude ref from med/MAD, optional cc normalization)
import time, json
from pathlib import Path
import numpy as np
import tifffile
from scipy.ndimage import shift as ndi_shift
from scipy.ndimage import median_filter
from skimage.registration import phase_cross_correlation
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

# Optional fast NLM
try:
    from skimage.restoration import denoise_nl_means
    _HAS_NLM = True
except Exception:
    _HAS_NLM = False

# ---------- PARAMETERS (tune these) ----------
OUT_DIR = Path(OUT_DIR) if isinstance(OUT_DIR, (str, Path)) else Path("./oct_denoise_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)
DATA_ROOT = Path(DATA_ROOT) if isinstance(DATA_ROOT, (str, Path)) else Path("./data")

RADIUS = 4                # neighbors on each side -> up to 2*RADIUS+1 frames
UPSAMPLE_FACTOR = 6       # subpixel precision for phase_cross_correlation
WINDOW_APPLY = True       # apply Hann window for registration stability (you enabled this)
NORMALIZE_FOR_REG = True  # subtract mean & divide by std before reg (you enabled this)
MIN_FRAMES = 3            # need at least this many aligned frames to accept pseudo
MAX_SHIFT_PIX = 80.0      # clamp extreme shifts (px)
ERR_MAD_MULT = 2.5        # threshold = median_err + ERR_MAD_MULT * mad_err
ERR_CLIP = 10.0           # cap manual_err to avoid huge outliers dominating MAD
NLM_POSTPROCESS = True and _HAS_NLM  # apply small NLM to pseudos if skimage present
NLM_H = 0.05              # very conservative
NLM_PATCH_SIZE = 5
NLM_PATCH_DISTANCE = 6
SAVE_VIS = True
VIS_COUNT = 6

# output filenames
OUT_NPZ = OUT_DIR / "pseudo_clean_improved.npz"
SUMMARY_JSON = OUT_DIR / "pseudo_improved_summary.json"
VIS_DIR = OUT_DIR / "pseudo_improved_vis"
VIS_DIR.mkdir(exist_ok=True)

# ---------- helpers ----------
eps = 1e-9
def hann2d(h,w):
    return np.outer(np.hanning(h), np.hanning(w))

def prep_for_reg(img):
    im = img.astype(np.float64)
    if NORMALIZE_FOR_REG:
        s = im.std()
        if s < eps:
            im = im - im.mean()
        else:
            im = (im - im.mean())/(s + eps)
    if WINDOW_APPLY:
        im = im * hann2d(im.shape[0], im.shape[1])
    return im

def mad(a):
    med = np.median(a)
    return np.median(np.abs(a - med))

# ---------- load stack & optional annotations ----------
tif_list = sorted(list(Path(DATA_ROOT).rglob("*.tif")) + list(Path(DATA_ROOT).rglob("*.tiff")))
if not tif_list:
    raise FileNotFoundError("No TIFF found under DATA_ROOT: " + str(DATA_ROOT))
tif_path = tif_list[0]
stack = tifffile.imread(str(tif_path)).astype(np.float32)   # (S,H,W)
S, H, W = stack.shape
print("Loaded stack:", tif_path.name, "shape:", stack.shape)

# if you want to restrict pseudos to annotated slices (earlier workflow), load annotations.json in OUT_DIR
ann_path = OUT_DIR / "annotations.json"
if ann_path.exists():
    with open(ann_path, "r") as f:
        annotations = json.load(f)
    target_slices = sorted([int(k) for k in annotations.keys()])
    print("Found annotations.json; generating pseudos for annotated slices:", len(target_slices))
else:
    target_slices = list(range(S))
    print("No annotations found in OUT_DIR; generating pseudos for all slices:", len(target_slices))

# ---------- main loop ----------
indices_out = []
imgs_out = []
confs_out = []
used_frames_out = []
shifts_out = []
errs_out = []

summary = {"created": [], "params": {
    "RADIUS": RADIUS, "UPSAMPLE_FACTOR": UPSAMPLE_FACTOR,
    "MIN_FRAMES": MIN_FRAMES, "ERR_MAD_MULT": ERR_MAD_MULT, "NLM_POSTPROCESS": bool(NLM_POSTPROCESS)
}}

t0_all = time.time()
vis_saved = 0

# small diagnostic: print first slice neighbor errs to confirm behavior
DEBUG_PRINT_FIRST = True
printed_debug = False

for count, center in enumerate(target_slices):
    lo = max(0, center - RADIUS)
    hi = min(S-1, center + RADIUS)
    neighbor_idxs = list(range(lo, hi+1))

    ref_raw = stack[center].astype(np.float32)
    ref_p = prep_for_reg(ref_raw)

    aligned = []
    shifts = []
    errs = []
    used_idxs = []

    # register neighbors -> ref
    for j in neighbor_idxs:
        mov_raw = stack[j].astype(np.float32)
        mov_p = prep_for_reg(mov_raw)

        # try to call with explicit normalization if supported (some skimage versions accept it)
        try:
            shift, err, _ = phase_cross_correlation(ref_p, mov_p, upsample_factor=UPSAMPLE_FACTOR, normalization=None)
        except TypeError:
            # older skimage may not support 'normalization' kwarg
            try:
                shift, err, _ = phase_cross_correlation(ref_p, mov_p, upsample_factor=UPSAMPLE_FACTOR)
            except Exception:
                shift, err, _ = phase_cross_correlation(ref_p, mov_p, upsample_factor=1)
        except Exception:
            # any other runtime error fallback
            try:
                shift, err, _ = phase_cross_correlation(ref_p, mov_p, upsample_factor=1)
            except Exception:
                # final fallback: zero shift / huge err
                shift = np.array([0.0, 0.0], dtype=float)
                err = 1e6

        # sanitize err from phase_cross_correlation (may be unreliable in some skimage versions)
        if not np.isfinite(err):
            err = 1e6
        shift = np.array(shift, dtype=float)
        # clamp huge shifts -> mark as bad
        if np.any(np.abs(shift) > MAX_SHIFT_PIX):
            err = 1e6

        if err < 1e6:
            # compute manual residual/error on the preprocessed images (robust)
            aligned_test = ndi_shift(mov_p, shift=(-shift[0], -shift[1]), order=1, mode='reflect')
            resid = (ref_p - aligned_test).ravel()
            # prefer RMS normalization (norm(resid) / norm(ref_p))
            num = float(np.linalg.norm(resid))
            den = float(np.linalg.norm(ref_p.ravel()) + eps)
            manual_err = num / den
            if not np.isfinite(manual_err):
                manual_err = 1e6
            # clip extreme manual errors to avoid dominating MAD
            manual_err = float(min(manual_err, ERR_CLIP))

            # apply shift to original raw image (not preprocessed) and record results
            aligned_full = ndi_shift(mov_raw, shift=(-shift[0], -shift[1]), order=1, mode='reflect')
            aligned.append(aligned_full)
            shifts.append(shift.tolist())
            errs.append(manual_err)   # <- use manual_err
            used_idxs.append(int(j))
        else:
            # ignore badly aligned frame
            pass

    # If no aligned frames (unexpected), fallback: use reference only
    if len(aligned) == 0:
        pseudo = ref_raw.copy()
        conf = np.zeros_like(pseudo, dtype=np.float32)
        used = [int(center)]
        shifts_used = [[0.0,0.0]]
        errs_used = [float(0.0)]
        med_err = None
        mad_err = None
    else:
        # robust per-neighbor error thresholding: compute median & MAD of errs and reject outliers
        errs_arr = np.array(errs, dtype=float)
        # exclude exact-zero entries (typically the reference) when computing med/MAD so zeros don't bias stats
        errs_for_stats = errs_arr[errs_arr > 0.0]
        if errs_for_stats.size == 0:
            # if nothing non-zero, fall back to full array
            errs_for_stats = errs_arr.copy()

        med_err = float(np.median(errs_for_stats))
        mad_err = float(mad(errs_for_stats))
        thr = med_err + ERR_MAD_MULT * (mad_err + 1e-9)

        # choose frames with err <= thr (always keep the reference if present in used_idxs)
        keep_mask = errs_arr <= thr
        # if after rejecting we have fewer than MIN_FRAMES, relax threshold to keep top MIN_FRAMES closest
        if keep_mask.sum() < MIN_FRAMES:
            order = np.argsort(errs_arr)
            keep_idx = order[:min(len(order), MIN_FRAMES)]
            keep_mask = np.zeros_like(keep_mask)
            keep_mask[keep_idx] = True

        aligned_keep = [aligned[i] for i in range(len(aligned)) if keep_mask[i]]
        used = [used_idxs[i] for i in range(len(used_idxs)) if keep_mask[i]]
        shifts_used = [shifts[i] for i in range(len(shifts)) if keep_mask[i]]
        errs_used = [errs[i] for i in range(len(errs)) if keep_mask[i]]

        if len(aligned_keep) == 0:
            pseudo = ref_raw.copy()
            conf = np.zeros_like(pseudo, dtype=np.float32)
        else:
            A = np.stack(aligned_keep, axis=0).astype(np.float32)  # (K,H,W)
            # median pseudo is edge-preserving
            pseudo_med = np.median(A, axis=0).astype(np.float32)
            # per-pixel MAD -> confidence: smaller MAD -> higher confidence
            pixel_mad = np.median(np.abs(A - np.expand_dims(pseudo_med,0)), axis=0)
            # normalize mad to [0,1] (invert -> conf)
            # use robust scale: instead of max which can be impacted by outliers, use 99th percentile
            pm_scale = np.percentile(pixel_mad.ravel(), 99.0)
            pm_max = pm_scale if pm_scale > 0 else 1.0
            conf = 1.0 - (pixel_mad / (pm_max + eps))
            conf = np.clip(conf, 0.0, 1.0)

            # optional light NLM on normalized pseudo (low strength)
            if NLM_POSTPROCESS:
                try:
                    pseudo01 = (pseudo_med - pseudo_med.min())/(pseudo_med.max()-pseudo_med.min()+eps)
                    # skimage API differences: keep the familiar args (fast_mode True, multichannel False)
                    nlm01 = denoise_nl_means(pseudo01, patch_size=NLM_PATCH_SIZE,
                                             patch_distance=NLM_PATCH_DISTANCE, h=NLM_H, fast_mode=True, multichannel=False)
                    # rescale to original pseudo range
                    pseudo = (nlm01 - nlm01.min())/(nlm01.max()-nlm01.min()+eps) * (pseudo_med.max()-pseudo_med.min()) + pseudo_med.min()
                except Exception:
                    pseudo = pseudo_med
            else:
                pseudo = pseudo_med

    # record outputs
    indices_out.append(int(center))
    imgs_out.append(pseudo.astype(np.float32))
    confs_out.append(conf.astype(np.float32))
    used_frames_out.append(used)
    shifts_out.append(shifts_used)
    errs_out.append(errs_used)

    # summary stats for this slice
    summary["created"].append({
        "idx": int(center),
        "n_neighbors_found": int(len(used_idxs)),
        "n_neighbors_used": int(len(used)),
        "median_err_before_filter": float(med_err) if med_err is not None else None,
        "mad_err_before_filter": float(mad_err) if med_err is not None else None,
        "shifts_used": shifts_used,
        "errs_used": errs_used
    })

    # save a few PNGs for QA
    if SAVE_VIS and vis_saved < VIS_COUNT:
        norm_ref = (ref_raw - ref_raw.min())/(ref_raw.max()-ref_raw.min()+eps)
        norm_pseudo = (pseudo - pseudo.min())/(pseudo.max()-pseudo.min()+eps)
        fig, axs = plt.subplots(1,3, figsize=(12,4))
        axs[0].imshow(norm_ref, cmap='gray'); axs[0].set_title(f"ref {center}"); axs[0].axis('off')
        axs[1].imshow(norm_pseudo, cmap='gray'); axs[1].set_title("pseudo (median+opt)"); axs[1].axis('off')
        axs[2].imshow(conf, cmap='viridis'); axs[2].set_title("confidence"); axs[2].axis('off')
        plt.tight_layout()
        plt.savefig(VIS_DIR / f"pseudo_vis_{center:04d}.png", dpi=150, bbox_inches='tight')
        plt.close(fig)
        vis_saved += 1

    # small periodic diagnostics
    if DEBUG_PRINT_FIRST and not printed_debug:
        printed_debug = True
        print("DEBUG sample (first processed slice): center=", center)
        print("  neighbor_idxs:", neighbor_idxs)
        print("  used_idxs sample:", used[:8])
        print("  shifts_used sample:", shifts_used[:8])
        print("  errs_used sample:", errs_used[:8])
    if (count+1) % 50 == 0:
        print(f"Processed {count+1}/{len(target_slices)} target slices (last center={center})")

# ---------- save results ----------
if len(imgs_out) > 0:
    imgs_arr = np.stack(imgs_out, axis=0)         # (N,H,W)
    conf_arr = np.stack(confs_out, axis=0)         # (N,H,W)
    np.savez_compressed(OUT_NPZ,
                        indices=np.array(indices_out, dtype=np.int32),
                        imgs=imgs_arr.astype(np.float32),
                        confs=conf_arr.astype(np.float32),
                        used_frames=np.array(used_frames_out, dtype=object),
                        shifts=np.array(shifts_out, dtype=object),
                        errs=np.array(errs_out, dtype=object))
    with open(SUMMARY_JSON, "w") as f:
        json.dump(summary, f, indent=2)
    print("Saved improved pseudos ->", OUT_NPZ)
    print("Saved summary ->", SUMMARY_JSON)
    print("Saved QA visuals (up to {}) ->".format(VIS_COUNT), VIS_DIR)
else:
    print("No pseudo images created - check parameters / input stack")

print("Done. Total time: {:.1f}s".format(time.time() - t0_all))


Loaded stack: Intensity.tif shape: (339, 806, 500)
Found annotations.json; generating pseudos for annotated slices: 260
DEBUG sample (first processed slice): center= 29
  neighbor_idxs: [25, 26, 27, 28, 29, 30, 31, 32, 33]
  used_idxs sample: [25, 26, 27, 28, 29, 30, 31, 32]
  shifts_used sample: [[-1.3333333333333333, 0.3333333333333333], [-0.5, 1.6666666666666667], [0.0, 0.16666666666666666], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [-0.16666666666666666, 0.3333333333333333], [0.8333333333333334, -1.5]]
  errs_used sample: [0.506396777099378, 0.475615585235201, 0.46715252049187744, 0.37047922306446196, 0.0, 0.36705796436596533, 0.44725449460606537, 0.48843346827124645]
Processed 50/260 target slices (last center=107)
Processed 100/260 target slices (last center=159)
Processed 150/260 target slices (last center=209)
Processed 200/260 target slices (last center=259)
Processed 250/260 target slices (last center=310)
Saved improved pseudos -> oct_denoise_outputs/pseudo_clean_improved.npz
Save

In [8]:
# Cell 5: Baseline denoisers + evaluation (image metrics + anatomical boundary error)
# - Loads pseudo_clean_subset.npz (indices, imgs)
# - For each slice: apply median filter and Non-Local Means (NLM)
# - Compute PSNR and SSIM vs pseudo-clean; compute boundary localization error (using annotations.json)
# - Save results CSV and a few before/after visual examples into OUT_DIR/baseline_visuals

import numpy as np, json, os, csv
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.ndimage import median_filter
try:
    from skimage.restoration import denoise_nl_means
    from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
    from skimage import filters
except Exception as e:
    raise ImportError("Requires scikit-image (denoise_nl_means, metrics). Install `pip install scikit-image`.") from e

# Paths
OUT_DIR = Path(OUT_DIR) if isinstance(OUT_DIR, (str, Path)) else Path("./oct_denoise_outputs")
DATA_ROOT = Path(DATA_ROOT) if isinstance(DATA_ROOT, (str, Path)) else Path("./data")
npz_path = OUT_DIR / "pseudo_clean_improved.npz"
ann_path = OUT_DIR / "annotations.json"
tif_list = sorted(list(Path(DATA_ROOT).rglob("*.tif")) + list(Path(DATA_ROOT).rglob("*.tiff")))
if not npz_path.exists():
    raise FileNotFoundError("pseudo_clean_subset.npz not found in OUT_DIR. Run pseudo-creation cell first.")
if not ann_path.exists():
    raise FileNotFoundError("annotations.json not found in OUT_DIR. Run annotation cell first.")
if not tif_list:
    raise FileNotFoundError("No TIFF found under DATA_ROOT.")

# Parameters (tweak for speed/quality)
MEDIAN_SIZE = 3           # median filter window
NLM_PATCH_SIZE = 5
NLM_PATCH_DISTANCE = 6
NLM_H = 0.08              # denoising strength (relative to image range 0..1)
NLM_FAST = True           # fast mode
MAX_SAMPLES = None        # limit processed slices (None => all pseudo-clean slices)
VISUAL_EXAMPLES = 12      # save up to this many before/after visual panels

# Load data
npz = np.load(npz_path, allow_pickle=True)
indices = npz["indices"].astype(int).tolist()
pseudo_imgs = npz["imgs"]   # shape (N, H, W)
print("Loaded", len(indices), "pseudo-clean targets from", npz_path.name)

with open(ann_path, "r") as f:
    annotations = json.load(f)

# load stack (for raw images)
tif_path = tif_list[0]
import tifffile
stack = tifffile.imread(str(tif_path)).astype(np.float32)

# helper: boundary localization error using vertical sobel (as earlier)
def boundary_localization_error(denoised_img, curves, window=10):
    gy = np.abs(filters.sobel_v(denoised_img))
    errs = []
    for (x_arr, y_arr) in curves:
        xs = np.rint(x_arr).astype(int)
        ys = np.rint(y_arr).astype(int)
        valid = (xs>=0) & (xs<denoised_img.shape[1])
        xs = xs[valid]; ys = ys[valid]
        if xs.size==0:
            continue
        col_positions = []
        for col, ytrue in zip(xs, ys):
            lo = int(max(0, ytrue-window)); hi = int(min(denoised_img.shape[0], ytrue+window+1))
            patch = gy[lo:hi, col]
            if patch.size==0:
                col_positions.append(ytrue)
            else:
                rel = np.argmax(patch)
                y_est = lo + rel
                col_positions.append(y_est)
        col_positions = np.array(col_positions)
        errs.append(np.mean(np.abs(col_positions - ys)))
    if len(errs)==0:
        return None
    return float(np.mean(errs))

# iterate slices and compute metrics
results = []
visuals_saved = 0
visual_dir = OUT_DIR / "baseline_visuals"
visual_dir.mkdir(exist_ok=True)

sample_list = indices if MAX_SAMPLES is None else indices[:MAX_SAMPLES]
for i, idx in enumerate(sample_list):
    raw = stack[idx].astype(np.float32)
    pseudo = pseudo_imgs[i].astype(np.float32)

    # normalize images to [0,1] for NLM and metrics (preserve dynamic range for PSNR)
    raw01 = (raw - raw.min()) / (raw.max() - raw.min() + 1e-9)
    pseudo01 = (pseudo - pseudo.min()) / (pseudo.max() - pseudo.min() + 1e-9)

    # Baseline 1: median filter
    med = median_filter(raw, size=(MEDIAN_SIZE, MEDIAN_SIZE)).astype(np.float32)
    med01 = (med - med.min()) / (med.max() - med.min() + 1e-9)

    # Baseline 2: NLM (on normalized image)
    try:
        nlm01 = denoise_nl_means(raw01, patch_size=NLM_PATCH_SIZE,
                                 patch_distance=NLM_PATCH_DISTANCE, h=NLM_H, fast_mode=NLM_FAST, multichannel=False)
    except TypeError:
        # older versions use different arg names; try without multichannel
        nlm01 = denoise_nl_means(raw01, patch_size=NLM_PATCH_SIZE,
                                 patch_distance=NLM_PATCH_DISTANCE, h=NLM_H, fast_mode=NLM_FAST)
    # rescale nlm back to raw range if needed
    nlm = (nlm01 - nlm01.min()) / (nlm01.max() - nlm01.min() + 1e-9) * (raw.max()-raw.min()) + raw.min()
    nlm01 = (nlm - nlm.min()) / (nlm.max() - nlm.min() + 1e-9)

    # Compute image metrics vs pseudo-clean (use pseudo as reference)
    # PSNR expects data_range; use pseudo.max()-pseudo.min()
    data_range = float(pseudo.max() - pseudo.min())
    med_psnr = psnr(pseudo, med, data_range=data_range)
    nlm_psnr = psnr(pseudo, nlm, data_range=data_range)
    raw_psnr = psnr(pseudo, raw, data_range=data_range)

    med_ssim = ssim(pseudo, med, data_range=data_range)
    nlm_ssim = ssim(pseudo, nlm, data_range=data_range)
    raw_ssim = ssim(pseudo, raw, data_range=data_range)

    # Anatomical metric: boundary localization error for annotated curves on this slice
    ann_curves = annotations.get(str(idx), None)
    be_raw = boundary_localization_error(raw, ann_curves) if ann_curves else None
    be_med = boundary_localization_error(med, ann_curves) if ann_curves else None
    be_nlm = boundary_localization_error(nlm, ann_curves) if ann_curves else None

    results.append({
        "slice_index": int(idx),
        "raw_psnr": float(raw_psnr), "med_psnr": float(med_psnr), "nlm_psnr": float(nlm_psnr),
        "raw_ssim": float(raw_ssim), "med_ssim": float(med_ssim), "nlm_ssim": float(nlm_ssim),
        "be_raw": be_raw, "be_med": be_med, "be_nlm": be_nlm
    })

    # save visual examples (first VISUAL_EXAMPLES)
    if visuals_saved < VISUAL_EXAMPLES:
        fig, axs = plt.subplots(1,4, figsize=(16,6))
        axs[0].imshow(raw, cmap='gray'); axs[0].set_title(f"Raw (slice {idx})"); axs[0].axis('off')
        axs[1].imshow(pseudo, cmap='gray'); axs[1].set_title("Pseudo-clean (target)"); axs[1].axis('off')
        axs[2].imshow(med, cmap='gray'); axs[2].set_title(f"Median\nPSNR={med_psnr:.2f}, BE={be_med}"); axs[2].axis('off')
        axs[3].imshow(nlm, cmap='gray'); axs[3].set_title(f"NLM\nPSNR={nlm_psnr:.2f}, BE={be_nlm}"); axs[3].axis('off')
        # overlay annotation curves on pseudo for reference
        if ann_curves:
            for curve in ann_curves:
                x,y = curve
                axs[1].plot(x, y, '-r', linewidth=1)
                axs[2].plot(x, y, '-r', linewidth=1)
                axs[3].plot(x, y, '-r', linewidth=1)
        outv = visual_dir / f"baseline_vis_{idx:04d}.png"
        plt.tight_layout()
        plt.savefig(outv, dpi=150, bbox_inches='tight')
        plt.close()
        visuals_saved += 1

# Save results CSV
csv_out = OUT_DIR / "results_baselines.csv"
with open(csv_out, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=list(results[0].keys()))
    writer.writeheader()
    writer.writerows(results)
print("Saved baseline results ->", csv_out)

# Print summary statistics (mean across slices, ignoring None)
def mean_ignore_none(arr):
    vals = [v for v in arr if v is not None]
    return float(np.mean(vals)) if vals else None

raw_psnrs = [r["raw_psnr"] for r in results]
med_psnrs = [r["med_psnr"] for r in results]
nlm_psnrs = [r["nlm_psnr"] for r in results]
raw_ssims = [r["raw_ssim"] for r in results]
med_ssims = [r["med_ssim"] for r in results]
nlm_ssims = [r["nlm_ssim"] for r in results]

be_raws = [r["be_raw"] for r in results if r["be_raw"] is not None]
be_meds = [r["be_med"] for r in results if r["be_med"] is not None]
be_nlms = [r["be_nlm"] for r in results if r["be_nlm"] is not None]

print("\n=== Mean image metrics (vs pseudo-clean target) ===")
print(f"PSNR  - raw: {np.mean(raw_psnrs):.3f}, median: {np.mean(med_psnrs):.3f}, nlm: {np.mean(nlm_psnrs):.3f}")
print(f"SSIM  - raw: {np.mean(raw_ssims):.3f}, median: {np.mean(med_ssims):.3f}, nlm: {np.mean(nlm_ssims):.3f}")

if be_raws:
    print("\n=== Mean boundary localization error (pixels) ===")
    print(f"BE - raw: {np.mean(be_raws):.3f}, median: {np.mean(be_meds):.3f}, nlm: {np.mean(be_nlms):.3f}")

print("\nSaved example visuals to:", visual_dir)


Loaded 260 pseudo-clean targets from pseudo_clean_improved.npz
Saved baseline results -> oct_denoise_outputs/results_baselines.csv

=== Mean image metrics (vs pseudo-clean target) ===
PSNR  - raw: 23.080, median: 24.715, nlm: 25.469
SSIM  - raw: 0.515, median: 0.606, nlm: 0.670

=== Mean boundary localization error (pixels) ===
BE - raw: 5.578, median: 5.560, nlm: 5.687

Saved example visuals to: oct_denoise_outputs/baseline_visuals


In [None]:

import os, time, random, json
from pathlib import Path
import numpy as np
import tifffile
from sklearn.model_selection import train_test_split
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from skimage.metrics import structural_similarity as sk_ssim
from skimage import filters
import matplotlib.pyplot as plt
import scipy.ndimage as ndi    # << ADDED

# ---------- CONFIG ----------
OUT_DIR = Path("./oct_denoise_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)
DATA_ROOT = Path("./data")
NPZ_PATH = OUT_DIR / "pseudo_clean_improved.npz"
ANN_PATH = OUT_DIR / "annotations.json"

EPOCHS = 20         # smoke default
BATCH_SIZE = 1        # M1-friendly
LR = 1e-3
WEIGHT_DECAY = 1e-5
EDGE_LOSS_WEIGHT = 0.35
VAL_FRACTION = 0.15
SEED = 1337
CKPT_PATH = OUT_DIR / "unet_smoke_fullimage_groupnorm.pth"
EXAMPLES_DIR = OUT_DIR / "examples"
EXAMPLES_DIR.mkdir(exist_ok=True)
EPS = 1e-9
PIXEL_SIZE_UM = None   # set to physical pixel spacing if available to convert BE to microns

# ---------- BE weighting params (minimal additions) ----------
BE_SIGMA_PX = 3.0       # gaussian sigma (band width)
BE_MAX_WEIGHT = 6.0     # maximum multiplier near boundary
BE_MIN_WEIGHT = 1.0     # baseline multiplier away from boundary
BE_LOSS_WEIGHT = 0.0    # keep separate term small for now (you can increase). use 0.0 to start.

# ---------- device ----------
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")
print("Device:", DEVICE)

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# ---------- load data ----------
if not NPZ_PATH.exists():
    raise FileNotFoundError(f"Missing {NPZ_PATH}. Create pseudo targets first.")
npz = np.load(str(NPZ_PATH), allow_pickle=True)
pseudo_indices = npz["indices"].astype(int).tolist()
pseudo_imgs = npz["imgs"].astype(np.float32)

# load first TIFF stack under DATA_ROOT
tif_list = sorted(list(Path(DATA_ROOT).rglob("*.tif")) + list(Path(DATA_ROOT).rglob("*.tiff")))
if not tif_list:
    raise FileNotFoundError("No TIFFs found under DATA_ROOT: " + str(DATA_ROOT))
stack = tifffile.imread(str(tif_list[0])).astype(np.float32)   # (S, H, W)
S, H, W = stack.shape
print("Raw stack:", tif_list[0].name, "shape:", stack.shape)
print("Loaded", len(pseudo_imgs), "pseudo-clean targets from", NPZ_PATH.name)

with open(ANN_PATH, "r") as f:
    annotations = json.load(f)
annotated_slices = sorted([int(k) for k in annotations.keys()])

# ---------- PRECOMPUTE boundary weight maps (minimal addition) ----------
# This creates boundary_weight_maps[s] where values are high near annotated boundaries.
boundary_weight_maps = np.ones((S, H, W), dtype=np.float32) * BE_MIN_WEIGHT
for s in range(S):
    ann_curves = annotations.get(str(s), None)
    if not ann_curves:
        continue
    bmap = np.zeros((H, W), dtype=np.uint8)
    for (xcoords, ycoords) in ann_curves:
        xs = np.rint(np.array(xcoords)).astype(int)
        ys = np.rint(np.array(ycoords)).astype(int)
        valid = (xs >= 0) & (xs < W) & (ys >= 0) & (ys < H)
        xs = xs[valid]; ys = ys[valid]
        if xs.size > 0:
            bmap[ys, xs] = 1
    if bmap.sum() == 0:
        continue
    dist = ndi.distance_transform_edt(1 - bmap)  # dist=0 at boundaries
    weight = BE_MIN_WEIGHT + (BE_MAX_WEIGHT - BE_MIN_WEIGHT) * np.exp(-(dist**2) / (2 * (BE_SIGMA_PX**2)))
    boundary_weight_maps[s] = weight
# optionally save
np.save(OUT_DIR / "boundary_weight_maps.npy", boundary_weight_maps)

# prepare train/val positions (only positions with pseudo targets)
all_positions = list(range(len(pseudo_indices)))
train_pos, val_pos = train_test_split(all_positions, test_size=VAL_FRACTION, random_state=SEED)
val_full_slices = sorted(list({pseudo_indices[p] for p in val_pos}))
print(f"Train samples: {len(train_pos)}, Val samples: {len(val_pos)}, Val full-slices: {len(val_full_slices)}")

# ---------- dataset (full images) ----------
class FullSliceDataset(Dataset):
    def __init__(self, pos_list, indices_map, pseudo_imgs, raw_stack):
        self.pos_list = pos_list
        self.indices_map = indices_map
        self.pseudo = pseudo_imgs
        self.raw = raw_stack
    def __len__(self): return len(self.pos_list)
    def __getitem__(self, idx):
        pos = self.pos_list[idx]
        slice_idx = int(self.indices_map[pos])
        raw = self.raw[slice_idx].astype(np.float32)
        tgt = self.pseudo[pos].astype(np.float32)
        raw01 = (raw - raw.min()) / (raw.max() - raw.min() + EPS)
        tgt01 = (tgt - tgt.min()) / (tgt.max() - tgt.min() + EPS)
        x = torch.from_numpy(raw01[None,:,:]).float()
        y = torch.from_numpy(tgt01[None,:,:]).float()
        return x, y, slice_idx

train_ds = FullSliceDataset(train_pos, pseudo_indices, pseudo_imgs, stack)
val_ds   = FullSliceDataset(val_pos,   pseudo_indices, pseudo_imgs, stack)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)

# ---------- small UNet with GroupNorm (stable for B=1) ----------
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        # choose groups that divide out_ch
        def choose_groups(c):
            for g in (8,4,2,1):
                if c % g == 0:
                    return g
            return 1
        groups = choose_groups(out_ch)
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.GroupNorm(groups, out_ch),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.GroupNorm(groups, out_ch)
        )
    def forward(self,x): return self.net(x)

class SmallUNetFull(nn.Module):
    def __init__(self, in_ch=1, base=16):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base)
        self.enc2 = ConvBlock(base, base*2)
        self.enc3 = ConvBlock(base*2, base*4)
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec3 = ConvBlock(base*4, base*2)
        self.dec2 = ConvBlock(base*2, base)
        self.final = nn.Conv2d(base, 1, 1)
    def forward(self, x):
        H_in, W_in = x.shape[2], x.shape[3]
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        d = self.up(e3); d = self.dec3(d)
        d = self.up(d); d = self.dec2(d)
        out = self.final(d)
        if out.shape[2] != H_in or out.shape[3] != W_in:
            out = F.interpolate(out, size=(H_in, W_in), mode='bilinear', align_corners=False)
        return out

# ---------- edge operator & losses ----------
def sobel_edges_torch(x):
    Kx = torch.tensor([[1,0,-1],[2,0,-2],[1,0,-1]], dtype=torch.float32, device=x.device).reshape(1,1,3,3)
    Ky = torch.tensor([[1,2,1],[0,0,0],[-1,-2,-1]], dtype=torch.float32, device=x.device).reshape(1,1,3,3)
    gx = F.conv2d(x, Kx, padding=1); gy = F.conv2d(x, Ky, padding=1)
    return torch.sqrt(gx*gx + gy*gy + 1e-12)

l1_loss = nn.L1Loss()

# ---------- metrics helpers (robust) ----------
def manual_psnr_safe(ref, im, data_range=1.0):
    if ref.shape != im.shape: return np.nan
    if not np.isfinite(ref).all() or not np.isfinite(im).all(): return np.nan
    if (ref.max() - ref.min()) < 1e-6: return np.nan
    mse = float(np.mean((ref - im)**2))
    if mse <= 0: return 100.0
    return 10.0 * np.log10((data_range**2) / mse)


def safe_ssim(ref, im, data_range=1.0):
    try:
        if (ref.max() - ref.min()) < 1e-6: return np.nan
        return float(sk_ssim(ref, im, data_range=data_range))
    except Exception:
        return np.nan


def compute_BE(pred_np, ann_curves):
    if not ann_curves: return None
    gy = np.abs(filters.sobel_v(pred_np))
    errs = []
    for (xcoords, ycoords) in ann_curves:
        xs = np.rint(np.array(xcoords)).astype(int)
        ys = np.rint(np.array(ycoords)).astype(int)
        valid = (xs>=0) & (xs < pred_np.shape[1])
        xs = xs[valid]; ys = ys[valid]
        if xs.size == 0: continue
        col_positions = []
        for col, ytrue in zip(xs, ys):
            lo = int(max(0, ytrue-12)); hi = int(min(pred_np.shape[0], ytrue+13))
            patch = gy[lo:hi, col]
            if patch.size == 0:
                col_positions.append(ytrue)
            else:
                rel = np.argmax(patch); col_positions.append(lo + rel)
        col_positions = np.array(col_positions)
        if col_positions.size>0:
            errs.append(np.mean(np.abs(col_positions - ys)))
    return float(np.mean(errs)) if errs else None

# ---------- model & optimizer ----------
model = SmallUNetFull(in_ch=1, base=16).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=1, verbose=True)

# ---------- utilities: save visual examples ----------
def save_example(raw, tgt, pred, fname):
    fig, axs = plt.subplots(1,4, figsize=(12,3))
    axs[0].imshow(raw, cmap='gray'); axs[0].set_title('raw'); axs[0].axis('off')
    axs[1].imshow(tgt, cmap='gray'); axs[1].set_title('target'); axs[1].axis('off')
    axs[2].imshow(pred, cmap='gray'); axs[2].set_title('pred'); axs[2].axis('off')
    axs[3].imshow(np.abs(pred - tgt), cmap='hot'); axs[3].set_title('abs diff'); axs[3].axis('off')
    plt.tight_layout(); fig.savefig(str(fname), bbox_inches='tight'); plt.close(fig)

# ---------- training loop (minimal edits to use bw_map) ----------
best_be = float("inf")
history = {"train_loss":[], "val_patch_loss":[], "val_psnr":[], "val_ssim":[], "val_be":[]}

for epoch in range(1, EPOCHS+1):
    model.train()
    running = 0.0; seen = 0
    t0 = time.time()
    for xb, yb, slice_idx in train_loader:   # << CHANGED: unpack slice_idx
        xb = xb.to(DEVICE); yb = yb.to(DEVICE)
        # prepare boundary weight map for this slice (B=1 expected)
        sidx = int(slice_idx) if isinstance(slice_idx, (int, np.integer)) else int(slice_idx.item())
        bw = boundary_weight_maps[sidx]   # (H,W) numpy
        bw_tensor = torch.from_numpy(bw).to(DEVICE)[None, None, :, :].float()  # (1,1,H,W)
        pred = model(xb)
        # per-pixel L1 and edge diffs
        int_l1 = torch.abs(pred - yb)                          # (B,1,H,W)
        edge_l1 = torch.abs(sobel_edges_torch(pred) - sobel_edges_torch(yb))
        # weighted losses
        int_loss = (int_l1 * bw_tensor).mean()
        edge_loss_local = (edge_l1 * bw_tensor).mean()
        loss = int_loss + EDGE_LOSS_WEIGHT * edge_loss_local + BE_LOSS_WEIGHT * edge_loss_local
        opt.zero_grad(); loss.backward(); opt.step()
        running += loss.item() * xb.size(0); seen += xb.size(0)
    train_loss = running / (seen if seen>0 else 1)
    history["train_loss"].append(train_loss)
    t1 = time.time()

    # validation patch loss (full-image here)
    model.eval()
    vrun = 0.0; vseen = 0
    with torch.no_grad():
        for xb, yb, slice_idx in val_loader:   # << CHANGED: unpack slice_idx
            xb = xb.to(DEVICE); yb = yb.to(DEVICE)
            sidx = int(slice_idx) if isinstance(slice_idx, (int, np.integer)) else int(slice_idx.item())
            bw = boundary_weight_maps[sidx]
            bw_tensor = torch.from_numpy(bw).to(DEVICE)[None, None, :, :].float()
            pred = model(xb)
            int_l1 = torch.abs(pred - yb)
            edge_l1 = torch.abs(sobel_edges_torch(pred) - sobel_edges_torch(yb))
            int_loss = (int_l1 * bw_tensor).mean()
            edge_loss_local = (edge_l1 * bw_tensor).mean()
            vloss = (int_loss + EDGE_LOSS_WEIGHT * edge_loss_local).item()
            vrun += vloss * xb.size(0); vseen += xb.size(0)
    val_patch_loss = vrun / (vseen if vseen>0 else 1)
    history["val_patch_loss"].append(val_patch_loss)

    # full-slice eval on val_full_slices (unchanged)
    ps_list = []; ss_list = []; be_list = []
    per_slice = []
    with torch.no_grad():
        for s in val_full_slices:
            raw = stack[s].astype(np.float32)
            # find first pseudo patch that maps to this slice (if any)
            pos_idx = [i for i,si in enumerate(pseudo_indices) if int(si) == int(s)]
            tgt = None
            if len(pos_idx) > 0:
                tgt = pseudo_imgs[pos_idx[0]].astype(np.float32)
            raw01 = (raw - raw.min())/(raw.max()-raw.min()+EPS)
            x = torch.from_numpy(raw01[None,None,:,:]).float().to(DEVICE)
            pred = model(x).detach().cpu().numpy()[0,0]
            pred = np.clip(pred, 0.0, 1.0)

            ps = ss = np.nan
            if tgt is not None:
                tgt01 = (tgt - tgt.min())/(tgt.max()-tgt.min()+EPS)
                # if tgt is a different shape, center-pad/crop to pred shape
                if tgt01.shape != pred.shape:
                    ph, pw = tgt01.shape; H, W = pred.shape
                    if ph <= H and pw <= W:
                        canvas = np.zeros_like(pred); top = (H-ph)//2; left = (W-pw)//2
                        canvas[top:top+ph, left:left+pw] = tgt01; tgt01 = canvas
                    else:
                        top = (ph-H)//2; left = (pw-W)//2
                        tgt01 = tgt01[top:top+H, left:left+W]
                ps = manual_psnr_safe(tgt01, pred, data_range=1.0)
                ss = safe_ssim(tgt01, pred, data_range=1.0)
                if np.isfinite(ps): ps_list.append(ps)
                if np.isfinite(ss): ss_list.append(ss)

            ann_curves = annotations.get(str(s), None)
            be = compute_BE(pred, ann_curves)
            if be is not None: be_list.append(be)
            per_slice.append({"slice":int(s), "ps":ps, "ss":ss, "be":be})

    mean_ps = float(np.nanmean(ps_list)) if len(ps_list)>0 else None
    mean_ss = float(np.nanmean(ss_list)) if len(ss_list)>0 else None
    mean_be = float(np.mean(be_list)) if len(be_list)>0 else None

    history["val_psnr"].append(mean_ps); history["val_ssim"].append(mean_ss); history["val_be"].append(mean_be)

    # scheduler step (monitor val_patch_loss)
    scheduler.step(val_patch_loss)

    saved = False
    if mean_be is not None and mean_be < best_be:
        best_be = mean_be
        torch.save(model.state_dict(), CKPT_PATH)
        saved = True

    print(f"Epoch {epoch}/{EPOCHS}  train_loss={train_loss:.4f} val_patch_loss={val_patch_loss:.4f} "
          f"valid_ps_count={len(ps_list)} valid_ss_count={len(ss_list)} be_count={len(be_list)} mean_psnr={mean_ps} mean_ssim={mean_ss} mean_be={mean_be} saved={saved} time={(t1-t0):.1f}s")
    for d in per_slice[:6]:
        print("  slice",d["slice"], " ps=", d["ps"], " ss=", d["ss"], " be=", d["be"]) 

    # save a few visual examples for quick QA
    for i, s in enumerate(val_full_slices[:3]):
        raw = stack[s].astype(np.float32)
        raw01 = (raw - raw.min())/(raw.max()-raw.min()+EPS)
        pos_idx = [i for i,si in enumerate(pseudo_indices) if int(si) == int(s)]
        if pos_idx:
            tgt = pseudo_imgs[pos_idx[0]].astype(np.float32)
            tgt01 = (tgt - tgt.min())/(tgt.max()-tgt.min()+EPS)
            # compute pred again (cheap)
            x = torch.from_numpy(raw01[None,None,:,:]).float().to(DEVICE)
            pred = model(x).detach().cpu().numpy()[0,0]
            pred = np.clip(pred, 0.0, 1.0)
            # pad/crop tgt01 to pred shape if necessary (same logic as above)
            if tgt01.shape != pred.shape:
                ph, pw = tgt01.shape; H, W = pred.shape
                if ph <= H and pw <= W:
                    canvas = np.zeros_like(pred); top = (H-ph)//2; left = (W-pw)//2
                    canvas[top:top+ph, left:left+pw] = tgt01; tgt01 = canvas
                else:
                    top = (ph-H)//2; left = (pw-W)//2
                    tgt01 = tgt01[top:top+H, left:left+W]
            save_example(raw01, tgt01, pred, EXAMPLES_DIR / f"epoch{epoch}_slice{s}.png")

# Save history
with open(OUT_DIR / "unet_smoke_fullimage_groupnorm_history.json", "w") as f:
    json.dump(history, f, indent=2)

print("Finished. Best-checkpoint:", CKPT_PATH)


Device: mps
Raw stack: Intensity.tif shape: (339, 806, 500)
Loaded 260 pseudo-clean targets from pseudo_clean_improved.npz
Train samples: 221, Val samples: 39, Val full-slices: 39
Epoch 1/20  train_loss=0.1096 val_patch_loss=0.0937 valid_ps_count=39 valid_ss_count=39 be_count=39 mean_psnr=26.459842123558047 mean_ssim=0.6603095807046554 mean_be=7.143664939557795 saved=True time=54.4s
  slice 40  ps= 24.52741113218312  ss= 0.6538476426477776  be= 7.512499999999999
  slice 54  ps= 26.598463765645345  ss= 0.6850156281443891  be= 7.5275
  slice 93  ps= 26.51845889663395  ss= 0.6624479174816245  be= 7.69625
  slice 99  ps= 24.48760540540194  ss= 0.6556070858222166  be= 7.57375
  slice 104  ps= 27.37059081131683  ss= 0.6670788705609234  be= 7.50375
  slice 112  ps= 26.904229859215015  ss= 0.6997827881238587  be= 7.100088383838384
Epoch 2/20  train_loss=0.0884 val_patch_loss=0.0893 valid_ps_count=39 valid_ss_count=39 be_count=39 mean_psnr=26.611076968513625 mean_ssim=0.6625670400323805 mean_be