In [None]:
# ============================================================
# Epoxy_CyNif-AF + Background-Anchor + GenFill(exp)
# Neu: Settings-Zeile laden (Apply settings)
# (enthÃ¤lt: robuste ZIP-Speicherung + clean-fix fÃ¼r local/ring/hybrid)
# ============================================================

import os, re, json, zipfile, traceback, hashlib, datetime, tempfile, warnings, sys
from typing import Optional
warnings.filterwarnings('ignore')

import numpy as np
from scipy import ndimage as ndi
from skimage.morphology import disk
from skimage import morphology as morph
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from PIL import Image
from pathlib import Path
import csv
from xml.etree import ElementTree as ET
try:
    import tifffile as tiff
except Exception:
    tiff = None

# ---------- OME-Metadaten Helfer ----------
def parse_ome_channels(ome_xml: str):
    if not ome_xml:
        return [], {}
    root = ET.fromstring(ome_xml)
    pixels = root.find('.//ome:Pixels', OME_NS)
    if pixels is None:
        return [], {}
    channels = []
    for idx, ch in enumerate(pixels.findall('ome:Channel', OME_NS), start=1):
        channels.append({'Name': ch.get('Name', f'Channel-{idx}')})
    px_sizes = {
        'PhysicalSizeX': pixels.get('PhysicalSizeX'),
        'PhysicalSizeXUnit': pixels.get('PhysicalSizeXUnit'),
        'PhysicalSizeY': pixels.get('PhysicalSizeY'),
        'PhysicalSizeYUnit': pixels.get('PhysicalSizeYUnit'),
        'PhysicalSizeZ': pixels.get('PhysicalSizeZ'),
        'PhysicalSizeZUnit': pixels.get('PhysicalSizeZUnit'),
    }
    return channels, px_sizes


def sanitize_pixel_sizes(px_sizes):
    meta = {}
    for key, value in (px_sizes or {}).items():
        if value is None:
            continue
        if key.endswith('Unit'):
            meta[key] = value
        else:
            try:
                meta[key] = float(value)
            except (TypeError, ValueError):
                pass
    return meta



ACE_PARAMS = dict(
    radii=(5, 15, 45),
    alpha=6.0,
    iterations=1,
    clip=(0.2, 99.8),
    radius=None,
    preserve_background=True,
    max_gain=3.0,
    engine='internal',
    post_enabled=True
)


def ace_local_equalize(
    image: np.ndarray,
    *,
    radii=(5, 15, 45),
    alpha: float = 6.0,
    iterations: int = 1,
    clip=(0.2, 99.8),
    radius=None,
    preserve_background: bool = True,
    max_gain: Optional[float] = 3.0
) -> np.ndarray:
    """ACE-inspired local equalisation with safeguards for fluorescence data."""
    # NumPy-kompatibel: erst asarray, dann copy
    work = np.asarray(image, dtype=np.float32)
    if work is image or work.base is image:
        work = work.copy()
    if work.size == 0:
        return work

    clip = clip if clip not in (None, False) else None
    base_lo, base_hi = (float(work.min()), float(work.max()))
    if clip:
        lo_ref, hi_ref = np.percentile(work, clip)
    else:
        lo_ref, hi_ref = base_lo, base_hi
    if hi_ref <= lo_ref + 1e-6:
        hi_ref = lo_ref + 1.0

    if radius:
        base_r = max(1, int(radius))
        radii = sorted({max(1, base_r // 4), max(1, base_r // 2), base_r})
    else:
        radii = tuple(int(max(1, r)) for r in (radii or (5, 15, 45)))

    alpha = float(max(1.0, alpha))
    iterations = max(1, int(iterations))

    for _ in range(iterations):
        accum = np.zeros_like(work, dtype=np.float32)
        for radius in radii:
            size = radius * 2 + 1
            local_mean = ndi.uniform_filter(work, size=size, mode='reflect')
            diff = work - local_mean
            scale = alpha / (np.percentile(np.abs(diff), 98.0) + 1e-6)
            response = np.clip(diff * scale, -1.0, 1.0)
            accum += response
        work = work + accum / float(len(radii))

    if clip:
        lo_curr, hi_curr = np.percentile(work, clip)
        rng_curr = float(max(hi_curr - lo_curr, 1e-6))
        rng_ref = float(max(hi_ref - lo_ref, 1e-6))
        work = (work - lo_curr) * (rng_ref / rng_curr) + lo_ref

    if preserve_background:
        zero_mask = image <= 0
        work[zero_mask] = 0.0

    if max_gain and max_gain > 1.0:
        mask = image > 0
        if np.any(mask):
            base = np.maximum(image[mask], 1e-6)
            ratio = work[mask] / base
            np.clip(ratio, 1.0 / max_gain, max_gain, out=ratio)
            work[mask] = ratio * base

    np.clip(work, 0.0, None, out=work)
    return work


# =============================================================================
# 5.2b UniFORM HISTOGRAM-BASED NORMALIZATION (Dynamic background aware)
# =============================================================================

def _histnorm_sample_values(image, step=4, max_samples=5_000_000):
    """Downsample an image to a manageable sample of pixel intensities."""
    arr = image[::step, ::step].astype(np.float64, copy=False)
    flat = arr.ravel()
    if flat.size > max_samples:
        stride = int(max(1, flat.size // max_samples))
        flat = flat[::stride]
    return flat


def _histnorm_trimmed_median(values, drop_upper=0.3):
    arr = np.asarray(values, dtype=np.float64)
    if arr.size == 0:
        return 0.0
    arr.sort()
    keep = int(np.ceil(arr.size * (1.0 - float(drop_upper))))
    keep = int(np.clip(keep, 1, arr.size))
    return float(np.median(arr[:keep]))


def _histnorm_background_peak(values, bins=512, focus_fraction=0.45):
    if values.size == 0:
        return 0.0
    vmax = float(values.max()) if values.size else 0.0
    if vmax <= 0:
        return 0.0
    hist, edges = np.histogram(values, bins=bins, range=(0.0, vmax))
    if hist.sum() == 0:
        return 0.0
    smooth = ndi.gaussian_filter1d(hist.astype(np.float64), sigma=1.25, mode='nearest')
    limit = max(4, int(len(smooth) * float(focus_fraction)))
    idx = int(np.argmax(smooth[:limit]))
    return float((edges[idx] + edges[idx + 1]) * 0.5)


def _histnorm_is_reference_channel(name):
    lname = (name or '').lower()
    if not lname:
        return True
    blocked = (
        'autoflu', 'blank', ' none', 'none(', 'none_', 'background', 'unmix',
        'donor', 'spill', 'saibr', 'artifact'
    )
    return not any(token in lname for token in blocked)


def _histnorm_is_af_channel(name):
    lname = (name or '').lower()
    return 'autoflu' in lname or ' af ' in lname or lname.startswith('af ') or lname.endswith(' af') or '(af' in lname


def _histnorm_cycle_groups(channel_count, cycle_json_path):
    cycle_path = Path(cycle_json_path) if cycle_json_path else None
    groups = {}
    idx_to_cycle = {}
    names_from_json = None
    if cycle_path and cycle_path.exists():
        with open(cycle_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        names_from_json = data.get('channel_names')
        offsets = data.get('offsets') or {}
        counts = data.get('counts_selected') or data.get('counts') or {}
        for key, start in offsets.items():
            try:
                start_idx = int(start)
            except (TypeError, ValueError):
                continue
            count_val = counts.get(str(key)) or counts.get(int(key)) or counts.get(key)
            try:
                count_val = int(count_val)
            except (TypeError, ValueError):
                count_val = 0
            if count_val <= 0:
                continue
            end_idx = min(channel_count, start_idx + count_val)
            idxs = list(range(start_idx, end_idx))
            if not idxs:
                continue
            cycle_id = str(key)
            groups[cycle_id] = idxs
            for idx in idxs:
                idx_to_cycle[idx] = cycle_id
    if not groups:
        groups['1'] = list(range(channel_count))
        idx_to_cycle = {idx: '1' for idx in range(channel_count)}
    return groups, idx_to_cycle, names_from_json


def _histnorm_channel_statistics(stack, channel_names, sample_step, low_focus_percentile, high_clip_percentile):
    stats = []
    C = stack.shape[0]
    for idx in range(C):
        image = stack[idx]
        sample = _histnorm_sample_values(image, step=sample_step)
        if sample.size == 0:
            sample = image.astype(np.float64, copy=False).ravel()
        max_val = float(sample.max()) if sample.size else 0.0
        high_clip = np.percentile(sample, high_clip_percentile) if sample.size else 0.0
        if not np.isfinite(high_clip) or high_clip <= 0:
            high_clip = max_val
        if high_clip <= 0:
            high_clip = max_val if max_val > 0 else 1.0
        clipped = sample[sample <= high_clip]
        if clipped.size == 0:
            clipped = sample
        focus_cut = np.percentile(clipped, low_focus_percentile) if clipped.size else 0.0
        focus_vals = clipped[clipped <= focus_cut] if clipped.size else clipped
        if focus_vals.size < 128:
            focus_vals = clipped
        bg_peak = _histnorm_background_peak(focus_vals, bins=512)
        pct = np.percentile(clipped, [0.1, 1.0, 5.0, 50.0, 95.0, 99.5]) if clipped.size else np.zeros(6)
        stats.append({
            'index': idx,
            'name': channel_names[idx] if idx < len(channel_names) else f'Ch{idx:02d}',
            'bg': float(bg_peak),
            'p001': float(pct[0]),
            'p1': float(pct[1]),
            'p5': float(pct[2]),
            'p50': float(pct[3]),
            'p95': float(pct[4]),
            'p995': float(pct[5]),
        })
    return stats


def run_uniform_histogram_normalization_dynamic(
    input_path=None,
    output_path=None,
    channel_names=None,
    cycle_json_path='split_cycles_out_fixed/channel_map_applied.json',
    sample_step=4,
    low_focus_percentile=35.0,
    high_clip_percentile=99.8,
    drop_upper_fraction=0.35,
    align_spread=True,
    min_scale=0.9,
    max_scale=1.1,
    af_extra_pull=0.2,
    max_shift=4000,
    output_suffix='_Histo_Norm',
    log_preview=5
):
    """Dynamic UniFORM-style histogram normalization with per-cycle references."""
    if input_path is None:
        input_path = IMG_PATH
    input_path = Path(input_path)
    if output_path is not None:
        output_path = Path(output_path)
    if tiff is None:
        raise RuntimeError("tifffile not available â€“ required for OME-TIFF normalization")

    print("=" * 80)
    print("[5.2b] UniFORM HISTOGRAM-BASED NORMALIZATION (dynamic)")
    print("=" * 80)
    print(f"[INFO] Input: {input_path}")

    with tiff.TiffFile(str(input_path)) as tf:
        fused_stack = tf.asarray()
        ome_xml = getattr(tf, 'ome_metadata', None)

    if fused_stack.ndim == 2:
        fused_stack = fused_stack[None, ...]
    elif fused_stack.ndim > 3:
        fused_stack = np.squeeze(fused_stack)
    fused_stack = fused_stack.astype(np.uint16, copy=False)

    ome_channel_names = []
    px_meta = {}
    if 'parse_ome_channels' in globals():
        try:
            channels, px_sizes = parse_ome_channels(ome_xml) if ome_xml else ([], {})
        except Exception as exc:
            print(f"[WARN] OME metadata parsing failed: {exc}")
            channels, px_sizes = [], {}
        if channels:
            ome_channel_names = [ch.get('Name', f'Channel_{idx+1:02d}') for idx, ch in enumerate(channels)]
        if px_sizes:
            px_meta = sanitize_pixel_sizes(px_sizes)

    groups, idx_to_cycle, names_from_json = _histnorm_cycle_groups(fused_stack.shape[0], cycle_json_path)

    if channel_names is not None:
        resolved_names = list(channel_names)
    elif ome_channel_names:
        resolved_names = ome_channel_names
    elif names_from_json:
        resolved_names = names_from_json
    else:
        resolved_names = []

    if len(resolved_names) < fused_stack.shape[0]:
        resolved_names = list(resolved_names)
        for idx in range(len(resolved_names), fused_stack.shape[0]):
            resolved_names.append(f'Channel_{idx:02d}')

    stats = _histnorm_channel_statistics(
        fused_stack,
        resolved_names,
        sample_step=sample_step,
        low_focus_percentile=low_focus_percentile,
        high_clip_percentile=high_clip_percentile
    )

    for stat in stats:
        stat['cycle'] = idx_to_cycle.get(stat['index'], 'NA')

    backgrounds = np.array([stat['bg'] for stat in stats], dtype=np.float64)
    ref_mask = np.array([_histnorm_is_reference_channel(stat['name']) for stat in stats], dtype=bool)

    global_candidates = backgrounds[ref_mask] if ref_mask.any() else backgrounds
    global_ref = _histnorm_trimmed_median(global_candidates, drop_upper=drop_upper_fraction)

    cycle_refs = {}
    cycle_spreads = {}
    for cycle_id, idxs in groups.items():
        idxs = [idx for idx in idxs if idx < len(stats)]
        if not idxs:
            continue
        idxs_arr = np.array(idxs, dtype=int)
        cycle_candidates = backgrounds[idxs_arr][ref_mask[idxs_arr]] if ref_mask[idxs_arr].any() else backgrounds[idxs_arr]
        if cycle_candidates.size:
            cycle_refs[cycle_id] = _histnorm_trimmed_median(cycle_candidates, drop_upper=drop_upper_fraction)
        else:
            cycle_refs[cycle_id] = global_ref
        spread_vals = np.array([max(1.0, stats[i]['p95'] - stats[i]['p5']) for i in idxs], dtype=np.float64)
        if ref_mask[idxs_arr].any():
            spread_vals = np.array([max(1.0, stats[i]['p95'] - stats[i]['p5']) for i in idxs if ref_mask[i]], dtype=np.float64)
        cycle_spreads[cycle_id] = _histnorm_trimmed_median(spread_vals, drop_upper=0.2) if spread_vals.size else 1.0

    normalized = np.empty_like(fused_stack, dtype=np.uint16)
    shifts = []
    scales = []
    af_flags = []

    for stat in stats:
        idx = stat['index']
        cycle_id = stat['cycle']
        name = stat['name']
        ref = cycle_refs.get(cycle_id, global_ref)
        af_flag = _histnorm_is_af_channel(name)
        shift = float(ref - stat['bg'])
        if af_flag and shift < 0:
            shift *= (1.0 + float(af_extra_pull))
        shift = float(np.clip(shift, -float(max_shift), float(max_shift)))
        spread_val = max(1.0, stat['p95'] - stat['p5'])
        scale = 1.0
        if align_spread:
            ref_spread = cycle_spreads.get(cycle_id, spread_val)
            if spread_val > 0 and ref_spread > 0:
                scale = float(np.clip(ref_spread / max(spread_val, 1.0), float(min_scale), float(max_scale)))
        channel = fused_stack[idx].astype(np.float64, copy=False)
        shifted = channel + shift
        if align_spread:
            shifted = (shifted - ref) * scale + ref
        normalized[idx] = np.clip(shifted, 0, 65535).astype(np.uint16)
        stat['applied_shift'] = shift
        stat['applied_scale'] = scale
        stat['is_af'] = af_flag
        shifts.append(shift)
        scales.append(scale)
        af_flags.append(af_flag)

    shifts = np.array(shifts, dtype=np.float64)
    scales = np.array(scales, dtype=np.float64)

    print()
    print(f"[INFO] Global background reference: {global_ref:.1f}")
    def _cycle_sort_key(val):
        try:
            return (0, int(val))
        except (TypeError, ValueError):
            return (1, str(val))
    for cycle_id in sorted(cycle_refs.keys(), key=_cycle_sort_key):
        print(f"  Cycle {cycle_id}: ref={cycle_refs[cycle_id]:.1f}, spread={cycle_spreads.get(cycle_id, 0.0):.1f}")

    print()
    preview = stats[:max(1, int(log_preview))]
    for stat in preview:
        print(
            f"  C{stat['index']:02d} ({stat['name']} | cycle {stat['cycle']}): "
            f"bg {stat['bg']:.1f} -> shift {stat['applied_shift']:+.1f}, "
            f"scale {stat['applied_scale']:.3f}{' [AF]' if stat['is_af'] else ''}"
        )

    if output_path is None:
        # OUTPUT IN AF_removal/-Ordner (dediziert, Pipeline-konform)
        af_removal_dir = BASE_EXPORT / "AF_removal"
        af_removal_dir.mkdir(exist_ok=True, parents=True)
        
        # Fixer Dateiname (kein Suffix-Chaos!)
        output_path = af_removal_dir / "fused_decon_AF_cleaned.ome.tif"
    else:
        output_path = Path(output_path)
        
    output_path.parent.mkdir(parents=True, exist_ok=True)

    metadata = {'axes': 'CYX'}
    if resolved_names:
        names = resolved_names[:normalized.shape[0]]
        metadata.setdefault('Channel', {})['Name'] = names
        fluor_list = []
        for nm in names:
            if '(' in nm and nm.strip().endswith(')'):
                base, flav = nm.rsplit('(', 1)
                fluor_list.append(flav.strip(') '))
            else:
                fluor_list.append('Unknown')
        metadata['Channel']['Fluor'] = fluor_list
    if px_meta:
        pixels_meta = {}
        for key in ('PhysicalSizeX', 'PhysicalSizeY', 'PhysicalSizeZ', 'PhysicalSizeXUnit', 'PhysicalSizeYUnit', 'PhysicalSizeZUnit'):
            if key in px_meta and px_meta[key] is not None:
                pixels_meta[key] = px_meta[key]
        if pixels_meta:
            metadata['Pixels'] = pixels_meta

    print()
    print("[SAVE] Writing OME-TIFF with histogram-normalized intensitiesâ€¦")
    tiff.imwrite(
        str(output_path),
        normalized,
        ome=True,
        bigtiff=True,
        compression='zlib',
        photometric='minisblack',
        metadata=metadata
    )

    size_gb = output_path.stat().st_size / (1024 ** 3)
    print()
    print("=" * 80)
    print("SUCCESS: Histogram normalization complete")
    print("=" * 80)
    print(f"Output file:  {output_path}")
    print(f"File size:    {size_gb:.2f} GB")
    print(f"Channels:     {normalized.shape[0]}")
    print(f"Image size:   {normalized.shape[1]} x {normalized.shape[2]} px")
    print(f"Shift range:  {shifts.min():+.1f} â†’ {shifts.max():+.1f} (mean {shifts.mean():+.1f})")
    if align_spread and scales.size:
        print(f"Scale range:  {scales.min():.3f} â†’ {scales.max():.3f} (mean {scales.mean():.3f})")
    if any(af_flags):
        af_shifts = [stat['applied_shift'] for stat in stats if stat.get('is_af')]
        if af_shifts:
            print(f"AF channels:  mean shift {np.mean(af_shifts):+.1f}")
    print()
    print("NEXT STEPS:")
    print("  1. Ã–ffnen Sie die Datei in Napari oder QuPath und prÃ¼fen Sie die Hintergrundlevel")
    print("  2. Vergleichen Sie AF-KanÃ¤le vor/nach der Normalisierung")

    globals()['hist_norm_stack'] = normalized
    globals()['hist_norm_stats'] = stats
    globals()['hist_norm_output'] = output_path
    globals()['IMG_PATH'] = str(output_path)
    globals()['STACK'] = None
    globals()['STACK_SRC'] = ''
    print(f'[CONFIG] IMG_PATH aktualisiert -> {output_path}')
    print('          (AF/ACE nutzen jetzt den histogramm-normalisierten Stack)')
    return normalized, stats, output_path





In [None]:
# ---------- Pfad / Defaults ----------
# ============================================================================
# >>> PIPELINE-INTEGRATION: Liest automatisch aus spillover/-Ordner <<<
# ============================================================================
SAMPLE_ID = "sample_004"
WORKSPACE_EXPORT_ROOT = Path(r"C:\Users\researcher\data\Epoxy_CyNif\Epoxy_CyNif\data\export")
BASE_EXPORT = WORKSPACE_EXPORT_ROOT / SAMPLE_ID

# STRIKT: Nur spillover/-Verzeichnis akzeptieren (kein Fallback!)
spillover_dir = BASE_EXPORT / "spillover"
spillover_file = spillover_dir / "fused_decon_spillover_corrected.ome.tif"

if not spillover_file.exists():
    raise FileNotFoundError(
        f"âŒ FEHLER: Spillover-korrigierter Stack nicht gefunden:\n"
        f"  Erwartet: {spillover_file}\n"
        f"  Bitte Part 2 (Spillover-Pipeline) ausfÃ¼hren."
    )

IMG_PATH = str(spillover_file)
print(f"[CONFIG] AF-Removal nutzt Spillover-Input: {IMG_PATH}")

# Globals setzen
globals()["WORKSPACE_EXPORT_ROOT"] = WORKSPACE_EXPORT_ROOT
globals()["BASE_EXPORT"] = BASE_EXPORT
globals()["SAMPLE_ID"] = SAMPLE_ID
globals()["IMG_PATH"] = IMG_PATH

SAVE_ROOT_FALLBACK = os.path.join(tempfile.gettempdir(), "MP_Previews")
STACK = None
STACK_SRC = ""

# ---------- Utils ----------
def _ensure_dir(p):
    os.makedirs(p, exist_ok=True); return p

def _basename_no_ome(path):
    return re.sub(r'\.ome\.tif$', '', os.path.basename(path), flags=re.I)

def _safe_root(path_hint):
    try:
        _ensure_dir(path_hint)
        tmp = os.path.join(path_hint, "_t.tmp")
        with open(tmp, "wb") as f: f.write(b"ok")
        os.remove(tmp)
        return path_hint
    except Exception:
        _ensure_dir(SAVE_ROOT_FALLBACK)
        return SAVE_ROOT_FALLBACK

def p_stretch(x, p1=1.0, p2=99.5):
    lo,hi = np.percentile(x, [p1,p2]); hi = max(hi, lo+1e-6)
    return np.clip((x-lo)/(hi-lo), 0, 1)

def write_png(fp, arr01):
    arr = (np.clip(arr01, 0, 1)*255).astype(np.uint8)
    Image.fromarray(arr).save(fp, format="PNG")

def _load_stack(path):
    assert os.path.exists(path), f"Datei nicht gefunden: {path}"
    if tiff is not None:
        arr = tiff.imread(path)
    else:
        import imageio.v2 as iio
        arr = iio.imread(path)
    if arr.ndim == 2:
        arr = arr[None, ...]
    elif arr.ndim == 4:
        if arr.shape[1] <= 16: arr = np.moveaxis(arr, 1, 0)
        arr = arr[:, 0, :, :]
    return arr.astype(np.float32)

# ---------- CH-Harmonisierung ----------
def channel_harmonize(stack, mode="p99"):
    m = str(mode).lower()
    st = stack.astype(np.float32, copy=True)
    if m in ("off", "none"):
        return st
    C = st.shape[0]
    if m == "p99":
        scales = [max(1e-6, np.percentile(st[c], 99.0)) for c in range(C)]
        ref = float(np.median(scales))
        for c in range(C):
            st[c] *= ref / scales[c]
        return st
    if m == "ace":
        params = ACE_PARAMS
        out = np.empty_like(st, dtype=np.float32)
        for c in range(C):
            out[c] = ace_local_equalize(
                st[c],
                radii=params.get('radii'),
                alpha=params.get('alpha', 6.0),
                iterations=params.get('iterations', 1),
                clip=params.get('clip'),
                radius=params.get('radius'),
                preserve_background=params.get('preserve_background', True),
                max_gain=params.get('max_gain', 3.0)
            )
        return out
    # default: IQR
    scales = []
    for c in range(C):
        q1, q3 = np.percentile(st[c], [25.0, 75.0])
        scales.append(max(1e-6, q3 - q1))
    ref = float(np.median(scales))
    for c in range(C):
        st[c] *= ref / scales[c]
    return st

# ---------- robuste Statistik ----------
def _mad(x):
    med = np.median(x)
    return 1.4826*np.median(np.abs(x - med))

# ---------- Huber-IRLS (robust) fÃ¼r 2 Donoren ----------
def huber_ridge_2d(C2, C3, Y, lam=1e-3, delta=1.5, iters=5):
    x1 = C2.ravel(); x2 = C3.ravel(); y = Y.ravel()
    X = np.stack([x1, x2], 1).astype(np.float32)
    y = y.astype(np.float32)
    XtX = X.T @ X + lam*np.eye(2, dtype=np.float32)
    Xty = X.T @ y
    w = np.linalg.lstsq(XtX, Xty, rcond=None)[0]
    for _ in range(max(1,iters)):
        r = y - X @ w
        s = _mad(r) + 1e-6
        t = delta * s
        a = np.abs(r)
        w_i = np.ones_like(a, dtype=np.float32)
        mask = (a > t)
        w_i[mask] = (t / a[mask]).astype(np.float32)
        XtWX = X.T @ (X*w_i[:,None]) + lam*np.eye(2, dtype=np.float32)
        XtWy = X.T @ (y*w_i)
        w = np.linalg.lstsq(XtWX, XtWy, rcond=None)[0]
    return w  # ENTFERNT: np.maximum(0.0, w) - Erlaube negative Gewichte wenn statistisch optimal

# ---------- Lokale Statistiken ----------
def local_mean_std(img, win=15):
    win = int(max(3, win))
    mu = ndi.uniform_filter(img, size=win, mode='reflect')
    mu2 = ndi.uniform_filter(img*img, size=win, mode='reflect')
    var = np.maximum(0.0, mu2 - mu*mu)
    return mu, np.sqrt(var + 1e-6)

def masked_local_mean_std(img, mask, win=15):
    win = int(max(3, win))
    k = np.ones((win, win), np.float32)
    inv = (~mask).astype(np.float32)
    sum_ = ndi.convolve(img * inv, k, mode='reflect')
    cnt  = ndi.convolve(inv,      k, mode='reflect')
    mu   = sum_ / np.maximum(1.0, cnt)
    sum2 = ndi.convolve((img * inv)**2, k, mode='reflect')
    var  = np.maximum(0.0, sum2/np.maximum(1.0, cnt) - mu*mu)
    return mu, np.sqrt(var + 1e-6)

def ring_mean_std(mask, img, band=5, win=15):
    band = int(max(1, band)); win = int(max(3, win))
    ring = morph.binary_dilation(mask, footprint=disk(band)) & (~mask)
    if not ring.any():
        return local_mean_std(img, win=win)
    vals = img[ring]
    mu_ring = np.median(vals); sd_ring = 1.4826*np.median(np.abs(vals - mu_ring))
    mu = np.full_like(img, float(mu_ring), dtype=np.float32)
    sd = np.full_like(img, float(sd_ring)+1e-6, dtype=np.float32)
    return mu, sd

# ---------- GenFill(exp): multiskaliger Grain um Anchor-Statistik ----------
def genfill_exp_fill(tgt, mask, mu_loc, sd_loc, anchor_bias_sd=0.20, feather_px=6,
                     beta=0.5, octaves=4, base_sigma=1.2, seed=123):
    rng = np.random.default_rng(int(seed))
    H,W = tgt.shape
    anchor = np.maximum(0.0, mu_loc - float(anchor_bias_sd)*sd_loc)
    anchor = np.minimum(anchor, tgt)
    noise = np.zeros_like(tgt, dtype=np.float32)
    sig = float(base_sigma)
    for k in range(int(max(1,octaves))):
        z = rng.standard_normal(size=(H,W)).astype(np.float32)
        nz = ndi.gaussian_filter(z, sigma=sig, mode='reflect')
        nz = nz / (np.std(nz) + 1e-6)
        noise += nz / (2.0**k)
        sig *= 1.7
    noise = noise / (np.std(noise) + 1e-6)
    grain = anchor + float(beta)*sd_loc*np.clip(noise, -2.5, 2.5)
    grain = np.minimum(grain, tgt)
    out = tgt.copy()
    if int(feather_px) > 0:
        di = ndi.distance_transform_edt(mask)
        w_blend = np.clip(di / float(feather_px), 0.0, 1.0)
        mix = (1.0 - w_blend)*out + w_blend*grain
        out[mask] = mix[mask]
    else:
        out[mask] = grain[mask]
    return out



def gaussian_background_estimate(tgt, mask, sigma=0.0):
    mask_bool = mask.astype(bool, copy=False)
    if not mask_bool.any():
        return tgt
    inv = (~mask_bool).astype(np.float32)
    if inv.sum() == 0:
        return tgt
    tgt_f = tgt.astype(np.float32, copy=False)
    if sigma is None or sigma <= 0:
        sigma = max(1.0, 0.01 * np.hypot(*tgt.shape))
    numer = ndi.gaussian_filter(tgt_f * inv, sigma=sigma, mode='reflect')
    denom = ndi.gaussian_filter(inv, sigma=sigma, mode='reflect')
    bg = np.where(denom > 1e-3, numer / (denom + 1e-6), tgt_f)
    return bg

def blended_fill(tgt, fill, mask, feather_px=0, softness=1.0):
    mask_bool = mask.astype(bool, copy=False)
    if not mask_bool.any():
        return tgt
    out = tgt.copy()
    if feather_px is None or feather_px <= 0:
        out[mask_bool] = fill[mask_bool]
        return out
    dist = ndi.distance_transform_edt(mask_bool)
    w = np.clip(dist / float(feather_px), 0.0, 1.0)
    if softness and softness > 0:
        w = np.power(w, 1.0 / float(softness))
    out[mask_bool] = w[mask_bool] * fill[mask_bool] + (1.0 - w[mask_bool]) * tgt[mask_bool]
    return out

# ---------- Epoxy_CyNif-AF auf ROI ----------
def Epoxy_CyNif_af_roi(
    stack, target, roi_xywh,
    ch_mode="p99",
    donor_sigma=1.6,
    kmad=5.5,
    mask_erode_px=0,
    mask_close_px=0,
    mask_expand_px=0,
    huber_delta=2.5,
    ridge_lambda=1e-4,
    cap_kappa=1.8,
    cap_win=15,
    closing_r=2,
    median_k=3,
    viz="Auto",
    pred_gain=1.35,
    cap_mode="anchor_strict",      # "local"|"ring"|"hybrid"|"anchor"|"anchor_strict"
    ring_band_px=5,
    donor_pow=1.10,
    anchor_bias_sd=0.05,
    feather_px=0,
    genfill_on=False,
    gen_beta=0.5,
    gen_octaves=4,
    gen_base_sigma=1.2,
    gen_seed=123
):
    st = channel_harmonize(stack, ch_mode)
    C,H,W = st.shape
    x,y,w,h = roi_xywh; x1,y1 = x+w, y+h
    def R(c): return st[int(np.clip(c,0,C-1)), y:y1, x:x1].astype(np.float32).copy()

    tgt = R(target); c2 = R(1); c3 = R(2)

    if donor_sigma>0: c2s=ndi.gaussian_filter(c2, donor_sigma); c3s=ndi.gaussian_filter(c3, donor_sigma)
    else: c2s, c3s = c2, c3
    if abs(donor_pow-1.0)>1e-3:
        c2s=np.power(np.maximum(0.0,c2s), donor_pow); c3s=np.power(np.maximum(0.0,c3s), donor_pow)
    dsum = c2s + c3s

    med = np.median(dsum); mad = _mad(dsum); thr = med + kmad*mad
    mask = (dsum > thr)
    if mask_erode_px>0:  mask = morph.binary_erosion(mask, footprint=disk(int(mask_erode_px)))
    if mask_close_px>0:  mask = morph.binary_closing(mask, footprint=disk(int(mask_close_px)))
    if mask_expand_px>0: mask = morph.binary_dilation(mask, footprint=disk(int(mask_expand_px)))

    if np.any(mask):
        wts = huber_ridge_2d(c2s[mask], c3s[mask], tgt[mask], lam=ridge_lambda, delta=huber_delta, iters=5)
        pred = wts[0]*c2s + wts[1]*c3s
        raw_sub = tgt - pred_gain*pred

        mu_loc, sd_loc   = masked_local_mean_std(tgt, mask, win=int(cap_win))
        mu_ring, sd_ring = ring_mean_std(mask, tgt, band=int(ring_band_px), win=int(cap_win))
        cap_local = mu_loc + cap_kappa*sd_loc
        cap_ring  = mu_ring + cap_kappa*sd_ring

        cm = str(cap_mode).lower()
        if cm in ("anchor","anchor_strict"):
            bias = float(anchor_bias_sd) if anchor_bias_sd is not None else 0.0
            anchor = np.maximum(0.0, mu_loc - bias*sd_loc)
            anchor = np.minimum(anchor, tgt)
            if genfill_on:
                clean = genfill_exp_fill(
                    tgt, mask, mu_loc, sd_loc,
                    anchor_bias_sd=float(anchor_bias_sd),
                    feather_px=int(feather_px),
                    beta=float(gen_beta),
                    octaves=int(gen_octaves),
                    base_sigma=float(gen_base_sigma),
                    seed=int(gen_seed)
                )
            else:
                base = np.maximum(raw_sub, 0.0)
                clean = tgt.copy()
                if int(feather_px) > 0:
                    di = ndi.distance_transform_edt(mask)
                    w_blend = np.clip(di/float(feather_px), 0.0, 1.0)
                    mix = (1.0 - w_blend)*base + w_blend*anchor
                    clean[mask] = np.minimum(mix[mask], tgt[mask])
                else:
                    # HARTER SCHNITT: Erweitere Maske um 2px und setze auf 0.0
                    mask_expanded = morph.binary_dilation(mask, footprint=disk(2))
                    clean[mask_expanded] = 0.0
        else:
            if cm == "ring":
                cap = cap_ring
            elif cm == "hybrid":
                cap = np.minimum(cap_local, cap_ring)
            elif cm == "local":
                cap = cap_local
            else:
                cap = np.minimum(cap_local, tgt)
            base = np.maximum(raw_sub, 0.0)
            clean = tgt.copy()
            clean[mask] = np.maximum(base[mask], cap[mask])
            clean = np.clip(clean, 0, None)
    else:
        wts=np.array([0.0,0.0],np.float32); pred=np.zeros_like(tgt); clean=tgt.copy()

    # Post-Processing: closing/median NUR auÃŸerhalb der erweiterten AF-Maske anwenden
    # um den harten Rand nicht wieder zu glÃ¤tten
    if np.any(mask):
        mask_expanded = morph.binary_dilation(mask, footprint=disk(2))
        temp_clean = clean.copy()
        
        if closing_r>0: 
            temp_clean = ndi.grey_closing(temp_clean, size=(int(closing_r), int(closing_r)))
        if median_k>1:  
            temp_clean = ndi.median_filter(temp_clean, size=int(median_k))
        
        # Ãœbernehme gefilterte Werte nur auÃŸerhalb der AF-Zone
        clean[~mask_expanded] = temp_clean[~mask_expanded]
    else:
        # Kein AF gefunden: normale Filter anwenden
        if closing_r>0: clean = ndi.grey_closing(clean, size=(int(closing_r), int(closing_r)))
        if median_k>1:  clean = ndi.median_filter(clean, size=int(median_k))

    if viz=="Auto": M=lambda z:p_stretch(z)
    else:
        lo,hi=np.percentile(tgt,[1.0,99.5]); hi=max(hi,lo+1e-6); M=lambda z:np.clip((z-lo)/(hi-lo),0,1)

    panel = dict(raw=M(tgt), c2=M(c2), c3=M(c3), dsum=M(dsum), mask=mask.astype(np.float32), pred=M(pred), clean=M(clean))
    meta  = dict(
        tgt_idx=int(target), ch_mode=str(ch_mode), viz=str(viz),
        donor_sigma=float(donor_sigma), kmad=float(kmad),
        mask_erode_px=int(mask_erode_px), mask_close_px=int(mask_close_px), mask_expand_px=int(mask_expand_px),
        huber_delta=float(huber_delta), ridge_lambda=float(ridge_lambda),
        cap_kappa=float(cap_kappa), cap_win=int(cap_win),
        closing_r=int(closing_r), median_k=int(median_k),
        pred_gain=float(pred_gain), cap_mode=str(cap_mode),
        ring_band_px=int(ring_band_px), donor_pow=float(donor_pow),
        anchor_bias_sd=float(anchor_bias_sd), feather_px=int(feather_px),
        genfill_on=bool(genfill_on), gen_beta=float(gen_beta), gen_octaves=int(gen_octaves),
        gen_base_sigma=float(gen_base_sigma), gen_seed=int(gen_seed),
        weights={'w_c2':float(wts[0]), 'w_c3':float(wts[1])},
        mask_pct=100.0*float(mask.sum())/max(1,mask.size)
    )
    return dict(panel=panel, meta=meta)

# ---------- ZIP-Speicher (robust: kÃ¼rzer + Fallback in %TEMP%) ----------
def _meta_fingerprint(meta: dict) -> str:
    sig = {k: meta.get(k) for k in [
        'tgt_idx','kmad','mask_erode_px','mask_close_px','mask_expand_px','donor_sigma','huber_delta','ridge_lambda',
        'cap_kappa','cap_win','pred_gain','cap_mode','ring_band_px','donor_pow','anchor_bias_sd','feather_px',
        'genfill_on','gen_beta','gen_octaves','gen_base_sigma'
    ]}
    s = json.dumps(sig, sort_keys=True).encode('utf-8')
    return hashlib.md5(s).hexdigest()[:8]

def save_preview_zip(panel, meta, img_path, roi_xywh, tag="Epoxy_CyNif_AF"):
    base_long = f"{_basename_no_ome(img_path)}_x{roi_xywh[0]}_y{roi_xywh[1]}_s{roi_xywh[2]}"
    root_hint = os.path.join(os.path.dirname(img_path), "MP_Previews")
    root = _safe_root(root_hint)
    base = re.sub(r'[^A-Za-z0-9_\-]', '_', base_long)
    if len(base) > 48: base = base[:48]
    prev_dir = _ensure_dir(os.path.join(root, base))
    finger   = _meta_fingerprint(meta)
    ts       = datetime.datetime.now().strftime("%H%M%S")
    fname    = f"{base}_C{meta['tgt_idx']}_{str(tag)[:14]}_{finger}_{ts}.zip"
    zpath    = os.path.join(prev_dir, fname)

    def _write_zip(to_path):
        tmpfiles = []
        try:
            for k in ["raw","c2","c3","dsum","mask","pred","clean"]:
                fp = os.path.join(prev_dir, f"{k}.png")
                write_png(fp, panel[k]); tmpfiles.append((k+".png", fp))
            mfp = os.path.join(prev_dir, "meta.json")
            with open(mfp, "w", encoding="utf-8") as f: json.dump(meta, f, indent=2)
            _ensure_dir(os.path.dirname(to_path))
            with zipfile.ZipFile(to_path, 'w', compression=zipfile.ZIP_DEFLATED) as Z:
                for arc,fp in tmpfiles: Z.write(fp, arcname=arc)
                Z.write(mfp, arcname="meta.json")
            return to_path
        finally:
            for _,fp in tmpfiles:
                try: os.remove(fp)
                except: pass
            try: os.remove(mfp)
            except: pass

    try:
        return _write_zip(zpath)
    except Exception:
        pass
    temp_root = _ensure_dir(os.path.join(tempfile.gettempdir(), "MP_Previews"))
    prev_dir2 = _ensure_dir(os.path.join(temp_root, base[:24]))
    zpath2    = os.path.join(prev_dir2, f"prev_{finger}_{ts}.zip")
    return _write_zip(zpath2)

# ---------- UI ----------
def _build_ui():
    global STACK, STACK_SRC, IMG_PATH, ACE_PARAMS

    w_path = widgets.Text(value=IMG_PATH, description="OME-TIF:", layout=widgets.Layout(width='100%'))
    w_load = widgets.Button(description="Load", button_style='info', tooltip="Bildstapel laden")
    w_over = widgets.Button(description="Show Overview", button_style='info', tooltip="Ãœbersicht & ROI (klicken)")

    w_tgt  = widgets.BoundedIntText(value=3, min=0, max=63, description="Target C (Index)")
    w_x    = widgets.BoundedIntText(value=1696, min=0, max=50000, description="x")
    w_y    = widgets.BoundedIntText(value=4528, min=0, max=50000, description="y")
    w_s    = widgets.BoundedIntText(value=384,  min=64, max=4096, description="size")

    w_ch   = widgets.Dropdown(options=["p99","IQR","ACE","Off"], value="ACE", description="CH (Harmonize)")
    w_viz  = widgets.Dropdown(options=["Auto","Fixed"], value="Auto", description="Viz")
    w_ace_alpha = widgets.FloatSlider(value=float(ACE_PARAMS.get('alpha', 6.0)), min=1.0, max=15.0, step=0.5, description="ACE Î±")
    w_ace_radius = widgets.IntSlider(value=int(ACE_PARAMS.get('radius', 0) or 0), min=0, max=512, step=16, description="ACE Radius", tooltip="0 = global")
    w_ace_gain = widgets.FloatSlider(value=float(ACE_PARAMS.get('max_gain', 3.0)), min=1.0, max=5.0, step=0.1, description="ACE gain cap")

    # Basis-Parameter
    w_sig  = widgets.FloatSlider(value=1.6, min=0.0, max=3.0, step=0.1, description="Donor blur (Ïƒ)")
    w_kmad = widgets.FloatSlider(value=5.5, min=2.0, max=12.0, step=0.5, description="AF detect (kÂ·MAD)")
    w_er   = widgets.IntSlider(value=0, min=0, max=5, step=1, description="Mask shrink (erode px)")
    w_cloM = widgets.IntSlider(value=0, min=0, max=7, step=1, description="Mask close (px)")
    w_dil  = widgets.IntSlider(value=1, min=0, max=12, step=1, description="Mask grow (dilate px)")

    w_hub  = widgets.FloatSlider(value=2.5, min=0.5, max=4.0, step=0.1, description="Robust fit (Huber Î´)")
    w_lam  = widgets.FloatLogSlider(value=1e-4, base=10, min=-6, max=-1, step=0.25, description="Weight L2 (Î»)")
    w_ck   = widgets.FloatSlider(value=1.8, min=0.0, max=6.0, step=0.1, description="Clamp softness (Îº)")
    w_cw   = widgets.IntSlider(value=15, min=5, max=31, step=2, description="Clamp window (px)")
    w_clos = widgets.IntSlider(value=2, min=0, max=9, step=1, description="Post smooth (closing r)")
    w_med  = widgets.IntSlider(value=3, min=1, max=7, step=2, description="Post noise (median k)")

    # Patches + Anchor
    w_gain = widgets.FloatSlider(value=1.35, min=1.0, max=2.2, step=0.05, description="Donor gain (Î³)")
    w_cmode= widgets.Dropdown(options=["local","hybrid","ring","anchor","anchor_strict"], value="anchor", description="Fill mode")
    w_ring = widgets.IntSlider(value=5, min=1, max=15, step=1, description="Ring width (px)")
    w_pow  = widgets.FloatSlider(value=1.10, min=0.8, max=1.5, step=0.02, description="Donor exponent (p)")
    w_bias = widgets.FloatSlider(value=0.15, min=0.00, max=0.50, step=0.01, description="Anchor bias (Ã—Ïƒ)")
    w_fth  = widgets.IntSlider(value=3, min=0, max=12, step=1, description="Feather (px)")

    # GenFill(exp)
    w_gf_on   = widgets.Checkbox(value=False, description="GenFill (exp) on")
    w_gf_beta = widgets.FloatSlider(value=0.50, min=0.0, max=1.0, step=0.05, description="GenFill strength (Î²)")
    w_gf_oct  = widgets.IntSlider(value=4, min=1, max=6, step=1, description="GenFill octaves")
    w_gf_sig  = widgets.FloatSlider(value=1.2, min=0.6, max=2.5, step=0.1, description="GenFill base Ïƒ")
    w_gf_seed = widgets.IntText(value=123, description="GenFill seed")

    def _update_ace_params(change=None):
        ACE_PARAMS["alpha"] = float(w_ace_alpha.value)
        ACE_PARAMS["radius"] = int(w_ace_radius.value) if int(w_ace_radius.value) > 0 else None
        ACE_PARAMS["max_gain"] = float(w_ace_gain.value)
        ACE_PARAMS["post_enabled"] = bool(w_ace_post.value)

    # --- NEU: Settings-Zeile laden ---
    w_txt = widgets.Text(
        value="",
        placeholder="kmad=4.5 grow=8 close=2 bias=0.35 feather=3 mode=anchor_strict gen=on gbeta=0.5 goct=4 gsig=1.2 gain=1.45",
        description="Settings:"
    )
    btn_apply = widgets.Button(description="Apply settings", button_style='warning', tooltip="Zeile parsen & Regler setzen")

    btn_run   = widgets.Button(description="RUN Preview", button_style='success')
    btn_zip   = widgets.Button(description="Save Preview ZIP", button_style='warning')

    out_over  = widgets.Output()
    out_panel = widgets.Output()
    out_info  = widgets.Output(layout=widgets.Layout(height='360px'))

    def _ensure_loaded():
        global STACK, STACK_SRC, IMG_PATH
        IMG_PATH = w_path.value
        if (STACK is None) or (STACK_SRC != IMG_PATH):
            with out_info: print("â€¢ Lade Stackâ€¦", IMG_PATH)
            st = _load_stack(IMG_PATH)
            STACK = st; STACK_SRC = IMG_PATH
            C,H,W = st.shape
            w_tgt.max = max(0, C-1)
            w_x.max   = max(0, W-64)
            w_y.max   = max(0, H-64)
            with out_info:
                print(f"âœ” Geladen. Shape (C,H,W)={STACK.shape}")
                print("Spickzettel fÃ¼r Settings-Zeile:")
                print("  kmad=â€¦ grow=â€¦ close=â€¦ erode=â€¦ bias=â€¦ feather=â€¦ mode=anchor|anchor_strict|local|hybrid|ring")
                print("  gen=on|off gbeta=â€¦ goct=â€¦ gsig=â€¦ gain=â€¦ ring=â€¦ cwin=â€¦ ck=â€¦ sig=â€¦ pow=â€¦")
                print("  ch=p99|IQR|ACE|Off viz=Auto|Fixed tgt=3 x=1696 y=4528 s=384")

    def on_over(ev=None):
        out_over.clear_output(True)
        try:
            _ensure_loaded()
            st = channel_harmonize(STACK, w_ch.value)
            img = st[int(w_tgt.value)]
            H,W = img.shape
            ds = max(1, int(np.ceil(max(H,W)/1100)))
            view = img[::ds, ::ds]
            fig, ax = plt.subplots(figsize=(22,7))
            ax.imshow(p_stretch(view), cmap='gray'); ax.axis('off')
            ax.set_title("Overview â€” Klick setzt ROI")
            s = int(w_s.value); x = int(w_x.value)//ds; y = int(w_y.value)//ds
            rect = Rectangle((x,y), s//ds, s//ds, fill=False, edgecolor='lime', linewidth=2)
            ax.add_patch(rect)
            def onclick(event):
                if not event.inaxes: return
                cx, cy = int(event.xdata)*ds, int(event.ydata)*ds
                S = int(w_s.value)
                nx = int(np.clip(cx - S//2, 0, max(0, W - S)))
                ny = int(np.clip(cy - S//2, 0, max(0, H - S)))
                w_x.value, w_y.value = nx, ny
                rect.set_xy((nx//ds, ny//ds)); fig.canvas.draw_idle()
            fig.canvas.mpl_connect('button_press_event', onclick)
            with out_over: display(fig); plt.close(fig)
        except Exception as e:
            with out_info: print("âœ– Overview-Fehler:", e); traceback.print_exc()

    # ---- Settings-Zeile parsen & anwenden ----
    def _parse_bool(v):
        return str(v).strip().lower() in ("1","true","on","yes","y")

    def on_apply(ev=None):
        line = w_txt.value.strip()
        if not line:
            with out_info: print("â„¹ Keine Settings-Zeile eingegeben.")
            return
        # Key-Map: alias -> (widget, type)
        m = {
            # ROI / Ziel
            "tgt": ("w_tgt","int"), "target": ("w_tgt","int"),
            "x": ("w_x","int"), "y": ("w_y","int"), "s": ("w_s","int"), "size": ("w_s","int"),
            # CH / Viz
            "ch": ("w_ch","str"), "harm": ("w_ch","str"),
            "viz": ("w_viz","str"),
            # Detection / Maske
            "kmad": ("w_kmad","float"),
            "erode": ("w_er","int"), "shrink": ("w_er","int"),
            "close": ("w_cloM","int"),
            "grow": ("w_dil","int"), "dilate": ("w_dil","int"),
            # Donor
            "sig": ("w_sig","float"), "sigma": ("w_sig","float"), "donor_sigma": ("w_sig","float"),
            "pow": ("w_pow","float"), "donor_pow": ("w_pow","float"),
            # Robust / Regular
            "huber": ("w_hub","float"), "delta": ("w_hub","float"),
            "lam": ("w_lam","float"), "lambda": ("w_lam","float"), "ridge": ("w_lam","float"),
            # Clamp
            "ck": ("w_ck","float"), "kappa": ("w_ck","float"), "cap_kappa": ("w_ck","float"),
            "cwin": ("w_cw","int"), "win": ("w_cw","int"), "cap_win": ("w_cw","int"),
            "closing": ("w_clos","int"), "post_close": ("w_clos","int"), "closing_r": ("w_clos","int"),
            "median": ("w_med","int"), "med": ("w_med","int"), "median_k": ("w_med","int"),
            # Fill / Anchor
            "gain": ("w_gain","float"), "pred_gain": ("w_gain","float"), "gamma": ("w_gain","float"),
            "mode": ("w_cmode","str"), "cap_mode": ("w_cmode","str"),
            "ring": ("w_ring","int"), "ring_band": ("w_ring","int"),
            "bias": ("w_bias","float"), "anchor_bias_sd": ("w_bias","float"),
            "feather": ("w_fth","int"), "feather_px": ("w_fth","int"),
            # GenFill
            "gen": ("w_gf_on","bool"), "genfill": ("w_gf_on","bool"),
            "gbeta": ("w_gf_beta","float"), "gen_beta": ("w_gf_beta","float"),
            "goct": ("w_gf_oct","int"), "gen_oct": ("w_gf_oct","int"),
            "gsig": ("w_gf_sig","float"), "gen_sigma": ("w_gf_sig","float"), "gen_base_sigma": ("w_gf_sig","float"),
            "gseed": ("w_gf_seed","int"), "seed": ("w_gf_seed","int"),
        }
        # Name -> Widget Objekt
        W = {
            "w_tgt": w_tgt, "w_x": w_x, "w_y": w_y, "w_s": w_s,
            "w_ch": w_ch, "w_viz": w_viz,
            "w_kmad": w_kmad, "w_er": w_er, "w_cloM": w_cloM, "w_dil": w_dil,
            "w_sig": w_sig, "w_pow": w_pow, "w_hub": w_hub, "w_lam": w_lam,
            "w_ck": w_ck, "w_cw": w_cw, "w_clos": w_clos, "w_med": w_med,
            "w_gain": w_gain, "w_cmode": w_cmode, "w_ring": w_ring,
            "w_bias": w_bias, "w_fth": w_fth,
            "w_gf_on": w_gf_on, "w_gf_beta": w_gf_beta, "w_gf_oct": w_gf_oct,
            "w_gf_sig": w_gf_sig, "w_gf_seed": w_gf_seed
        }
        # Tokenize
        changes = []
        tokens = [t for t in re.split(r'\s+', line) if t]
        for tok in tokens:
            if '=' not in tok: continue
            k,v = tok.split('=',1)
            k = k.strip().lower(); v = v.strip()
            if k not in m: 
                continue
            wname, typ = m[k]
            ww = W[wname]
            try:
                if typ == "int":
                    ww.value = int(float(v))
                elif typ == "float":
                    ww.value = float(v)
                elif typ == "bool":
                    ww.value = _parse_bool(v)
                else:
                    # normalize categorical to valid options
                    vv = v.strip()
                    if wname == "w_cmode":
                        vv = vv.lower()
                        if vv not in ("local","hybrid","ring","anchor","anchor_strict"):
                            continue
                    if wname == "w_ch":
                        vv = vv.upper() if vv.lower()=="iqr" else vv
                        if vv not in ("p99","IQR","ACE","Off"):
                            continue
                    if wname == "w_viz":
                        vv = vv.capitalize()
                        if vv not in ("Auto","Fixed"):
                            continue
                    ww.value = vv
                changes.append(f"{k}â†’{ww.value}")
            except Exception:
                pass
        with out_info:
            if changes:
                print("âœ” Settings angewendet:", ", ".join(changes))
            else:
                print("â„¹ Keine gÃ¼ltigen SchlÃ¼ssel gefunden (siehe Spickzettel).")

    def _run_once():
        _ensure_loaded()
        x,y,s = int(w_x.value), int(w_y.value), int(w_s.value)
        return Epoxy_CyNif_af_roi(
            STACK, int(w_tgt.value), (x,y,s,s),
            ch_mode=w_ch.value, donor_sigma=float(w_sig.value),
            kmad=float(w_kmad.value), mask_erode_px=int(w_er.value),
            mask_close_px=int(w_cloM.value), mask_expand_px=int(w_dil.value),
            huber_delta=float(w_hub.value), ridge_lambda=float(w_lam.value),
            cap_kappa=float(w_ck.value), cap_win=int(w_cw.value),
            closing_r=int(w_clos.value), median_k=int(w_med.value),
            viz=w_viz.value,
            pred_gain=float(w_gain.value), cap_mode=w_cmode.value,
            ring_band_px=int(w_ring.value), donor_pow=float(w_pow.value),
            anchor_bias_sd=float(w_bias.value), feather_px=int(w_fth.value),
            genfill_on=bool(w_gf_on.value), gen_beta=float(w_gf_beta.value),
            gen_octaves=int(w_gf_oct.value), gen_base_sigma=float(w_gf_sig.value),
            gen_seed=int(w_gf_seed.value)
        )

    def on_run(ev=None):
        out_panel.clear_output(True); out_info.clear_output(True)
        try:
            res = _run_once()
            fig, axs = plt.subplots(2, 3, figsize=(22, 14))
            axs[0,0].imshow(res["panel"]["raw"],   cmap='gray'); axs[0,0].set_title("Zoom RAW");   axs[0,0].axis('off')
            axs[0,1].imshow(res["panel"]["dsum"],  cmap='gray'); axs[0,1].set_title("Donor sum"); axs[0,1].axis('off')
            axs[0,2].imshow(res["panel"]["mask"],  cmap='gray'); axs[0,2].set_title("Donor mask");axs[0,2].axis('off')
            axs[1,0].imshow(res["panel"]["c2"],    cmap='gray'); axs[1,0].set_title("C2");        axs[1,0].axis('off')
            axs[1,1].imshow(res["panel"]["c3"],    cmap='gray'); axs[1,1].set_title("C3");        axs[1,1].axis('off')
            axs[1,2].imshow(res["panel"]["clean"], cmap='gray'); axs[1,2].set_title("Zoom CLEAN");axs[1,2].axis('off')
            with out_panel: display(fig); plt.close(fig)
            m = res["meta"]
            with out_info:
                print("âœ” Preview ok.")
                print(f"  Maskenanteil: {m['mask_pct']:.2f}% | w_c2={m['weights']['w_c2']:.4f} | w_c3={m['weights']['w_c3']:.4f}")
                print(f"  Mask ops: shrink={m['mask_erode_px']} | close={m['mask_close_px']} | grow={m['mask_expand_px']}")
                if m['cap_mode'] in ('anchor','anchor_strict'):
                    print(f"  Anchor: biasÃ—Ïƒ={m['anchor_bias_sd']:.2f} | feather={m['feather_px']} px")
                if m.get('genfill_on', False):
                    print(f"  GenFill(exp): Î²={m['gen_beta']:.2f}, octaves={m['gen_octaves']}, base Ïƒ={m['gen_base_sigma']:.2f}, seed={m['gen_seed']}")
        except Exception as e:
            with out_info: print("âœ– Fehler in Preview:", e); traceback.print_exc()

    def on_zip(ev=None):
        try:
            res = _run_once()
            x,y,s = int(w_x.value), int(w_y.value), int(w_s.value)
            zpath = save_preview_zip(res["panel"], res["meta"], IMG_PATH, (x,y,s), tag="Epoxy_CyNif_AF")
            with out_info:
                print("âœ” Preview-ZIP gespeichert:")
                print("  Ordner:", os.path.dirname(zpath))
                print("  Datei :", os.path.basename(zpath))
        except Exception as e:
            with out_info: print("âœ– Fehler beim Speichern:", e); traceback.print_exc()

    w_load.on_click(lambda ev: _ensure_loaded())
    w_over.on_click(on_over)
    btn_apply.on_click(on_apply)
    btn_run.on_click(on_run)
    btn_zip.on_click(on_zip)

    for _ctrl in (w_ace_alpha, w_ace_radius, w_ace_gain, w_ace_post):
        _ctrl.observe(_update_ace_params, names="value")
    _update_ace_params()

    # Layout
    info_workflow = widgets.HTML("""
    <div style=\"padding:6px 0;\">
      <b>Workflow:</b>
      <ol style=\"margin:4px 0 0 20px;\">
        <li>Histogramm-Normalisierung per <code>run_uniform_histogram_normalization_dynamic(...)</code> auf dem Roh-Stack ausfÃ¼hren.</li>
        <li>Den erzeugten <code>*_Histo_Norm.ome.tif</code>-Pfad in das Feld <b>OME-TIF</b> unten einfÃ¼gen (oder stehen lassen, falls bereits automatisch gesetzt).</li>
        <li>AF-Removal und optional ACE-Optimierung Ã¼ber die Regler starten.</li>
      </ol>
    </div>
    """)
    row_path = widgets.HBox([w_path, w_load, w_over])
    row_roi  = widgets.HBox([w_tgt, w_x, w_y, w_s, w_ch, w_viz, w_ace_post])

    row_ace  = widgets.HBox([w_ace_alpha, w_ace_radius, w_ace_gain])
    row_parA = widgets.HBox([w_sig, w_kmad, w_er, w_cloM, w_dil])
    row_parB = widgets.HBox([w_hub, w_lam, w_ck, w_cw, w_clos, w_med])
    row_patch= widgets.HBox([w_gain, w_cmode, w_ring, w_pow, w_bias, w_fth])
    row_gf   = widgets.HBox([w_gf_on, w_gf_beta, w_gf_oct, w_gf_sig, w_gf_seed])
    row_txt  = widgets.HBox([w_txt, btn_apply])
    row_act  = widgets.HBox([btn_run, btn_zip])

    out_over_placeholder = widgets.HTML("<hr><b>Overview (klickbar)</b>")
    out_panel_placeholder= widgets.HTML("<hr><b>Preview</b>")

    ui = widgets.VBox([
        info_workflow,
        row_path, row_roi,
        widgets.HTML("<b>ACE Normalisierung</b>"),
        row_ace,
        widgets.HTML("<hr><b>Settings-Zeile</b> (Key=Value, Leerzeichen-getrennt)"),
        row_txt,
        widgets.HTML("<hr><b>Epoxy_CyNif-AF Parameter</b>"),
        row_parA, row_parB,
        widgets.HTML("<b>Patches + Background-Anchor</b>"),
        row_patch,
        widgets.HTML("<b>GenFill (exp) â€” optional</b>"),
        row_gf,
        row_act,
        out_panel_placeholder, out_panel,
        out_over_placeholder,  out_over,
        widgets.HTML("<hr><b>Info</b>"),
        out_info
    ])
    display(ui)

# -------- Start UI --------
# _build_ui()  # UI deaktiviert fÃ¼r Batch-Lauf






In [None]:
# ---------- Marker-Referenz & AF-Prozess ----------
OME_NS = {'ome': 'http://www.openmicroscopy.org/Schemas/OME/2016-06'}

# MARKER_TABLE_PATH wird spÃ¤ter dynamisch aus BASE_EXPORT geladen (siehe weiter unten)

def load_marker_reference(csv_path: Path):
    include_rows = []
    with open(csv_path, newline='', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            if row.get('Include', '').strip().upper() == 'TRUE':
                include_rows.append(row)
    
    # Berechne theoretischen globalen Channel-Index: (cycle - 1) * 10 + channel_index
    # WICHTIG: Dieser Index ist NUR fÃ¼r Sortierung und cycle_to_donors!
    # Die tatsÃ¤chliche Stack-Position ist der Index nach Sortierung (0...95)
    def get_global_channel_idx(row):
        cycle = int(row['cycle'])
        
        # Robuste Spalten-Erkennung: 'channel index' oder 'No' oder 'index'
        if 'channel index' in row:
            ch_idx = int(row['channel index'])
        elif 'No' in row:
            ch_idx = int(row['No'])
        elif 'index' in row:
            ch_idx = int(row['index'])
        else:
            raise KeyError(f"Keine 'channel index'/'No'/'index' Spalte gefunden in: {list(row.keys())}")
        
        return (cycle - 1) * 10 + ch_idx
    
    # Sortiere nach globalem theoretischen Index
    include_rows.sort(key=get_global_channel_idx)
    
    # KRITISCH: stack_pos_to_cycle_ch speichert fÃ¼r jede Stack-Position (0...95) 
    # die ursprÃ¼nglichen cycle/channel-Werte
    stack_pos_to_cycle_ch = {}
    for stack_pos, row in enumerate(include_rows):
        cycle = int(row['cycle'])
        if 'channel index' in row:
            ch_idx = int(row['channel index'])
        elif 'No' in row:
            ch_idx = int(row['No'])
        elif 'index' in row:
            ch_idx = int(row['index'])
        else:
            ch_idx = 0
        stack_pos_to_cycle_ch[stack_pos] = (cycle, ch_idx)
    
    # Identifiziere AF1/AF2 Donor Stack-Positionen pro Cycle
    cycle_to_donors = {}
    for stack_pos, row in enumerate(include_rows):
        cycle = int(row['cycle'])
        marker = (row.get('Marker-Name') or '').strip().lower()
        
        if marker == 'af1':
            cycle_to_donors.setdefault(cycle, [None, None])[0] = stack_pos
        elif marker == 'af2':
            cycle_to_donors.setdefault(cycle, [None, None])[1] = stack_pos
    
    cycle_to_donors = {
        cycle: tuple(n for n in pair if n is not None)
        for cycle, pair in cycle_to_donors.items()
    }
    
    return include_rows, stack_pos_to_cycle_ch, cycle_to_donors


def is_base_marker(name: str) -> bool:
    if not name:
        return False
    n = name.strip().lower()
    if n in ('dapi', 'af1', 'af2'):
        return True
    if n.startswith('af') or 'autofluor' in n:
        return True
    if 'blank' in n:
        return True
    return False


def prepare_donor(array: np.ndarray, sigma: float, donor_pow: float) -> np.ndarray:
    out = array.astype(np.float32, copy=True)
    if sigma > 0:
        out = ndi.gaussian_filter(out, sigma)
    if abs(donor_pow - 1.0) > 1e-3:
        out = np.power(np.maximum(0.0, out), donor_pow)
    return out


def build_mask_from_sum(dsum: np.ndarray, params: dict) -> np.ndarray:
    med = np.median(dsum)
    mad = _mad(dsum)
    thr = med + params['kmad'] * mad
    mask = (dsum > thr)
    if params.get('mask_erode_px', 0) > 0:
        mask = morph.binary_erosion(mask, footprint=disk(int(params['mask_erode_px'])))
    if params.get('mask_close_px', 0) > 0:
        mask = morph.binary_closing(mask, footprint=disk(int(params['mask_close_px'])))
    if params.get('mask_expand_px', 0) > 0:
        mask = morph.binary_dilation(mask, footprint=disk(int(params['mask_expand_px'])))
    return mask


def huber_ridge_multi(donor_arrays, target, mask, lam=1e-4, delta=2.5, iters=5):
    if not donor_arrays:
        return np.zeros(0, dtype=np.float32)
    y = target[mask].astype(np.float32, copy=False)
    if y.size == 0:
        return np.zeros(len(donor_arrays), dtype=np.float32)
    X = np.stack([arr[mask].astype(np.float32, copy=False) for arr in donor_arrays], axis=1)
    lamI = lam * np.eye(X.shape[1], dtype=np.float32)
    XtX = X.T @ X + lamI
    Xty = X.T @ y
    w = np.linalg.lstsq(XtX, Xty, rcond=None)[0]
    for _ in range(max(1, iters)):
        residual = y - X @ w
        scale = _mad(residual) + 1e-6
        thr = float(delta) * scale
        weights = np.ones_like(residual, dtype=np.float32)
        heavy = np.abs(residual) > thr
        weights[heavy] = thr / np.maximum(np.abs(residual[heavy]), 1e-6)
        sqrt_w = np.sqrt(weights)
        Xw = X * sqrt_w[:, None]
        yw = y * sqrt_w
        XtX = Xw.T @ Xw + lamI
        Xty = Xw.T @ yw
        w = np.linalg.lstsq(XtX, Xty, rcond=None)[0]
    return np.maximum(0.0, w)


def parse_af_refs(raw: str) -> list[int]:
    if not raw:
        return []
    raw = raw.strip()
    if not raw or raw.lower().startswith('na'):
        return []
    tokens = raw.replace(',', ';').split(';')
    refs = []
    for token in tokens:
        token = token.strip()
        if token.isdigit():
            refs.append(int(token))
    return refs


def refine_af_mask(mask, pred, tgt, params):
    if mask is None or not np.any(mask):
        return mask
    mask = mask.astype(bool, copy=True)
    pred_masked = pred[mask]
    if pred_masked.size == 0:
        return mask
    med = np.median(pred_masked)
    mad = _mad(pred_masked) + 1e-6
    z = np.zeros_like(pred, dtype=np.float32)
    z[mask] = (pred[mask] - med) / mad
    sigma_thr = float(params.get('refine_sigma', 1.5))
    core = mask & (z > sigma_thr)

    ratio_min = float(params.get('refine_min_ratio', 0.4))
    if ratio_min > 0:
        ratio = np.zeros_like(pred, dtype=np.float32)
        ratio[mask] = pred[mask] / (tgt[mask] + 1e-6)
        core &= ratio >= ratio_min

    guard_ratio = float(params.get('refine_guard_ratio', 6.0))
    if guard_ratio > 0:
        guard = np.zeros_like(pred, dtype=np.float32)
        guard[mask] = (tgt[mask] + 1e-6) / (pred[mask] + 1e-6)
        core &= guard <= guard_ratio

    min_size = int(params.get('refine_min_size', 6))
    if min_size > 1:
        core = morph.remove_small_objects(core, min_size=min_size)
    max_size = int(params.get('refine_max_size', 0))
    if max_size > 0:
        lbl, n = ndi.label(core)
        if n > 0:
            counts = np.bincount(lbl.ravel())
            for idx in range(1, len(counts)):
                if counts[idx] > max_size:
                    core[lbl == idx] = False

    dil = int(params.get('refine_dilate', 1))
    ero = int(params.get('refine_erode', 0))
    if dil > 0:
        core = morph.binary_dilation(core, footprint=disk(dil))
    if ero > 0:
        core = morph.binary_erosion(core, footprint=disk(ero))
    return core


def build_global_seed(stack, donor_indices, params):
    if not donor_indices:
        return None
    accum = np.zeros(stack.shape[1:], dtype=np.float32)
    for idx in donor_indices:
        arr = prepare_donor(stack[idx], params['donor_sigma'], params['donor_pow'])
        accum = np.maximum(accum, arr)
    mask = build_mask_from_sum(accum, params)
    if not np.any(mask):
        return mask
    seed = morph.remove_small_objects(mask, min_size=int(max(4, params.get('refine_min_size', 6))))
    dil = int(params.get('refine_dilate', 1))
    if dil > 0:
        seed = morph.binary_dilation(seed, footprint=disk(dil))
    return seed


# ---------- Globaler AF-Lauf ----------
PARAMS = {
    'kmad': 5.5,
    'pred_gain': 1.35,  # ERHÃ–HT von 1.12: StÃ¤rkere AF-Korrektur
    'donor_sigma': 1.6,
    'donor_pow': 1.10,
    'mask_erode_px': 0,
    'mask_close_px': 3,
    'mask_expand_px': 1,
    'huber_delta': 3.0,  # ERHÃ–HT von 2.5: Weniger aggressive Outlier-Behandlung
    'ridge_lambda': 1e-6,  # REDUZIERT von 1e-4: Weniger DÃ¤mpfung der Gewichte
    'cap_mode': 'local',  # GEÃ„NDERT von 'anchor_strict': Weniger restriktiv, nutzt raw_sub direkt
    'cap_kappa': 2.5,  # ERHÃ–HT von 1.8: HÃ¶heres Cap fÃ¼r stÃ¤rkere Signale
    'cap_win': 15,
    'ring_band_px': 5,
    'anchor_bias_sd': 0.15,  # ERHÃ–HT von 0.05: HÃ¶herer Anchor-Wert falls anchor_strict wieder aktiviert
    'feather_px': 3,  # ERHÃ–HT von 0: Weichere ÃœbergÃ¤nge am Rand â†’ reduziert "breiten Rand"
    'genfill_on': False,
    'fill_sigma': 6.0,
    'blend_softness': 1.0,
    'gen_beta': 0.5,
    'gen_octaves': 4,
    'gen_base_sigma': 1.2,
    'gen_seed': 123,
    'closing_r': 0,
    'median_k': 1,
    'refine_sigma': 1.5,
    'refine_min_ratio': 0.4,
    'refine_guard_ratio': 6.0,
    'refine_min_size': 6,
    'refine_max_size': 0,
    'refine_dilate': 1,
    'refine_erode': 0,
}

# MARKER_TABLE_PATH dynamisch aus BASE_EXPORT laden
sample_token = ''.join(ch for ch in SAMPLE_ID if ch.isdigit()) or SAMPLE_ID
marker_candidates = list(BASE_EXPORT.glob(f'markers_{sample_token}.csv'))
if not marker_candidates:
    marker_candidates = list(BASE_EXPORT.glob('markers_*.csv'))
if not marker_candidates:
    raise FileNotFoundError(
        f"âŒ FEHLER: Keine Marker-CSV gefunden in: {BASE_EXPORT}\n"
        f"  Erwartet: markers_{sample_token}.csv"
    )
MARKER_TABLE_PATH = marker_candidates[0]
print(f"[CONFIG] Marker-CSV: {MARKER_TABLE_PATH}")

print("ðŸ”§ Starte AF-Entfernung fÃ¼r den kompletten Stack ...")
marker_rows, no_to_idx, cycle_to_donors = load_marker_reference(MARKER_TABLE_PATH)
print(f"  Marker-Referenzen: {len(marker_rows)} KanÃ¤le (Include=TRUE)")

with tiff.TiffFile(IMG_PATH) as tf:
    series = tf.series[0]
    axes = series.axes
    orig_dtype = series.dtype
    stack_raw = series.asarray()
    channel_info, px_sizes = parse_ome_channels(tf.ome_metadata)
    px_meta = sanitize_pixel_sizes(px_sizes)

axes_list = list(axes)
extra_dims = []
for dim, ax in enumerate(axes_list):
    if ax not in ('C', 'Y', 'X'):
        extra_dims.append(dim)

if extra_dims:
    for dim in sorted(extra_dims, reverse=True):
        if stack_raw.shape[dim] != 1:
            raise RuntimeError(f"Achse '{axes_list[dim]}' hat GrÃ¶ÃŸe {stack_raw.shape[dim]} (â‰ 1) â€“ kann nicht automatisch reduziert werden.")
        stack_raw = np.take(stack_raw, indices=0, axis=dim)
        axes_list.pop(dim)
axes_reduced = ''.join(axes_list)
perm = [axes_reduced.index('C'), axes_reduced.index('Y'), axes_reduced.index('X')]
stack_raw = np.transpose(stack_raw, perm)
axes = 'CYX'

# SPEICHER-OPTIMIERUNG: Behalte stack_raw als uint16 (13 GB statt 26 GB)
# Konvertiere nur einzelne Channels temporÃ¤r zu float32 bei Bedarf
stack = stack_raw  # Kein .astype(np.float32) mehr!
orig_dtype_for_processing = stack.dtype  # uint16
C, H, W = stack.shape
print(f"  Eingelesen: {C} KanÃ¤le | GrÃ¶ÃŸe: {W}Ã—{H} | dtype: {orig_dtype} | Achsen: {axes_reduced if extra_dims else axes}")
print(f"  âš ï¸ SPEICHER-MODUS: Behalte uint16 ({stack.nbytes / (1024**3):.1f} GB) statt float32-Konvertierung ({stack.nbytes * 2 / (1024**3):.1f} GB)")

if len(marker_rows) != C:
    raise RuntimeError(f"Marker-Referenz liefert {len(marker_rows)} KanÃ¤le, TIFF enthÃ¤lt {C}.")

channel_names = []
channel_meta_output = []
channel_donor_idx = []
skip_idx = set()
for idx, row in enumerate(marker_rows):
    name = (row.get('Marker-Name') or '').strip()
    fluor = (row.get('fluorochrome') or '').strip()
    if not name or name.upper() == 'NA':
        name = fluor or f'Channel_{idx:02d}'
    channel_names.append(name)
    meta_entry = {'Name': name}
    if fluor and fluor.upper() != 'NA':
        meta_entry['Fluor'] = fluor
    meta_entry['Cycle'] = row.get('cycle', '')
    channel_meta_output.append(meta_entry)

    cycle = int(row['cycle'])
    
    # AF_Ref wird nicht verwendet - wir verwenden cycle_to_donors
    # cycle_to_donors enthÃ¤lt bereits die korrekten Stack-Positionen fÃ¼r AF1/AF2
    donor_stack_positions = list(cycle_to_donors.get(cycle, []))
    
    # Entferne den aktuellen Channel selbst aus den Donors (falls vorhanden)
    donor_stack_positions = [d for d in donor_stack_positions if d != idx]
    
    if is_base_marker(name):
        skip_idx.add(idx)
    
    channel_donor_idx.append(donor_stack_positions)

unique_donors = sorted({d for pair in channel_donor_idx for d in pair[:2]})
global_seed = build_global_seed(stack, unique_donors, PARAMS) if unique_donors else None
if global_seed is not None and np.any(global_seed):
    print(f"  Globale AF-Seed: {global_seed.mean()*100:.2f}% der Pixel")
else:
    global_seed = None
    print('  Hinweis: Keine globale AF-Saat gebildet (keine Donors oder leer).')

# SPEICHER-OPTIMIERUNG: Erstelle processed erst spÃ¤ter channel-by-channel
# Statt stack.copy() (13 GB Duplikat) â†’ preallocate nur bei Bedarf
processed = np.empty_like(stack)  # Alloziert, aber noch nicht beschrieben
donor_cache = {}
mask_cache = {}
log_lines = []
for idx in range(C):
    name = channel_names[idx]
    donors_idx = channel_donor_idx[idx]

    if idx in skip_idx or not donors_idx:
        processed[idx] = stack[idx]
        reason = 'Basis/AF-Kanal' if idx in skip_idx else 'keine Referenz'
        msg = f"[{idx+1}/{C}] â†· Kanal {idx:02d} ({name}): {reason}"
        print(msg, flush=True)
        log_lines.append(msg)
        continue

    if len(donors_idx) < 2:
        processed[idx] = stack[idx]
        msg = f"[{idx+1}/{C}] â†· Kanal {idx:02d} ({name}): weniger als zwei Donors"
        print(msg, flush=True)
        log_lines.append(msg)
        continue

    donors_key = tuple(sorted(donors_idx[:2]))
    d1_idx, d2_idx = donors_key

    c2 = donor_cache.get(d1_idx)
    if c2 is None:
        c2 = prepare_donor(stack[d1_idx], PARAMS['donor_sigma'], PARAMS['donor_pow'])
        donor_cache[d1_idx] = c2
    c3 = donor_cache.get(d2_idx)
    if c3 is None:
        c3 = prepare_donor(stack[d2_idx], PARAMS['donor_sigma'], PARAMS['donor_pow'])
        donor_cache[d2_idx] = c3

    if donors_key in mask_cache:
        mask = mask_cache[donors_key]
    else:
        dsum = c2 + c3
        mask = build_mask_from_sum(dsum, PARAMS)
        mask_cache[donors_key] = mask

    if not np.any(mask):
        processed[idx] = stack[idx]
        msg = f"[{idx+1}/{C}] â†· Kanal {idx:02d} ({name}): keine AF-Maske"
        print(msg, flush=True)
        log_lines.append(msg)
        continue

    tgt = stack[idx]
    wts = huber_ridge_2d(c2[mask], c3[mask], tgt[mask], lam=PARAMS['ridge_lambda'], delta=PARAMS['huber_delta'], iters=5)
    pred = wts[0]*c2 + wts[1]*c3
    raw_sub = tgt - PARAMS['pred_gain'] * pred

    mask_refined = refine_af_mask(mask, pred, tgt, PARAMS)
    if global_seed is not None:
        mask_refined = mask_refined & global_seed
    if not np.any(mask_refined):
        processed[idx] = stack[idx]
        msg = f"[{idx+1}/{C}] â†· Kanal {idx:02d} ({name}): Maske entleert nach Refinement"
        print(msg, flush=True)
        log_lines.append(msg)
        continue

    mask = mask_refined
    wts = huber_ridge_2d(c2[mask], c3[mask], tgt[mask], lam=PARAMS['ridge_lambda'], delta=PARAMS['huber_delta'], iters=5)
    pred = wts[0]*c2 + wts[1]*c3
    raw_sub = tgt - PARAMS['pred_gain'] * pred

    mu_loc, sd_loc = masked_local_mean_std(tgt, mask, win=int(PARAMS['cap_win']))
    mu_ring, sd_ring = ring_mean_std(mask, tgt, band=int(PARAMS['ring_band_px']), win=int(PARAMS['cap_win']))
    cap_local = mu_loc + PARAMS['cap_kappa']*sd_loc
    cap_ring = mu_ring + PARAMS['cap_kappa']*sd_ring

    cm = str(PARAMS['cap_mode']).lower()
    
    # DEBUG: Zeige Parameter fÃ¼r diesen Channel
    if idx == 0:
        print(f"\nðŸ” DEBUG AF-REMOVAL PARAMETERS (Channel {idx}):")
        print(f"   cap_mode: {PARAMS['cap_mode']}")
        print(f"   feather_px: {PARAMS['feather_px']}")
        print(f"   genfill_on: {PARAMS['genfill_on']}")
        print(f"   mask_expand_px: {PARAMS['mask_expand_px']}")
        print(f"   anchor_bias_sd: {PARAMS['anchor_bias_sd']}\n")
    
    if cm in ('anchor', 'anchor_strict'):
        bias = float(PARAMS['anchor_bias_sd']) if PARAMS['anchor_bias_sd'] is not None else 0.0
        anchor = np.maximum(0.0, mu_loc - bias * sd_loc)
        anchor = np.minimum(anchor, tgt)
        
        if PARAMS['genfill_on']:
            clean = genfill_exp_fill(
                tgt, mask, mu_loc, sd_loc,
                anchor_bias_sd=float(PARAMS['anchor_bias_sd']),
                feather_px=int(PARAMS['feather_px']),
                beta=float(PARAMS['gen_beta']),
                octaves=int(PARAMS['gen_octaves']),
                base_sigma=float(PARAMS['gen_base_sigma']),
                seed=int(PARAMS['gen_seed'])
            )
        else:
            base = np.maximum(raw_sub, 0.0)
            clean = tgt.copy()
            if int(PARAMS['feather_px']) > 0:
                di = ndi.distance_transform_edt(mask)
                w_blend = np.clip(di / float(PARAMS['feather_px']), 0.0, 1.0)
                mix = (1.0 - w_blend)*base + w_blend*anchor
                clean[mask] = np.minimum(mix[mask], tgt[mask])
            else:
                clean[mask] = np.minimum(anchor[mask], tgt[mask])
    else:
        # NON-ANCHOR MODI: Nutze raw_sub direkt, mit Cap als Sicherheit
        base = np.maximum(raw_sub, 0.0)  # AF-korrigiertes Signal, clipped bei 0
        
        if cm == 'ring':
            cap = cap_ring
        elif cm == 'hybrid':
            cap = np.minimum(cap_local, cap_ring)
        elif cm == 'local':
            cap = cap_local
        else:
            cap = np.minimum(cap_local, tgt)

        cap = np.clip(cap, 0, tgt)
        
        # Verwende raw_sub direkt, aber mit Cap als obere Grenze fÃ¼r Safety
        clean = tgt.copy()
        clean[mask] = np.minimum(base[mask], cap[mask])  # GEÃ„NDERT: Nutze base (raw_sub) statt fill_candidate

    if PARAMS['closing_r'] > 0:
        clean = ndi.grey_closing(clean, size=(int(PARAMS['closing_r']), int(PARAMS['closing_r'])))
    if PARAMS['median_k'] > 1:
        clean = ndi.median_filter(clean, size=int(PARAMS['median_k']))

    processed[idx] = clean
    donors_label = ', '.join(f"{d:02d}:{channel_names[d]}" for d in donors_key)
    weight_str = ', '.join(f"w{j+1}={w:.4f}" for j, w in enumerate(wts))
    msg = f"[{idx+1}/{C}] âœ… Kanal {idx:02d} ({name}): Referenz â†’ {donors_label} | {weight_str}"
    print(msg, flush=True)
    log_lines.append(msg)
params = ACE_PARAMS
ace_post_widget = globals().get('w_ace_post')
ace_enabled = params.get('post_enabled', True)
if ace_post_widget is not None:
    ace_enabled = bool(getattr(ace_post_widget, 'value', False))
if ace_enabled:
    print('[ACE] Post-Normalisierung aktiv -> radii=%s, alpha=%.2f, iter=%d, radius=%s' % (
        params.get('radii'), params.get('alpha', 6.0), params.get('iterations', 1), params.get('radius')))
    total_channels = processed.shape[0]
    for _idx in range(total_channels):
        print(f"  [ACE] Kanal {_idx+1}/{total_channels} ...", end=' ', flush=True)
        processed[_idx] = ace_local_equalize(
            processed[_idx],
            radii=params.get('radii'),
            alpha=params.get('alpha', 6.0),
            iterations=params.get('iterations', 1),
            clip=params.get('clip'),
            radius=params.get('radius'),
            preserve_background=params.get('preserve_background', True),
            max_gain=params.get('max_gain', 3.0)
        )
        print('fertig', flush=True)
else:
    print('[ACE] Post-Normalisierung deaktiviert (post_enabled=False).')

if np.issubdtype(orig_dtype, np.integer):
    rng = np.iinfo(orig_dtype)
    output_array = np.clip(processed, rng.min, rng.max).astype(orig_dtype)
else:
    output_array = processed.astype(orig_dtype)

# OUTPUT IN AF_removal/-Ordner (dediziert, Pipeline-konform)
af_removal_dir = BASE_EXPORT / "AF_removal"
af_removal_dir.mkdir(exist_ok=True, parents=True)
output_path = af_removal_dir / "fused_decon_AF_cleaned.ome.tif"

metadata = {
    'axes': 'CYX',
    'DimensionOrder': 'XYCZT',
    'Channel': channel_meta_output,
    'SizeT': 1,
    'SizeZ': 1,
    'SignificantBits': int(np.iinfo(orig_dtype).bits if np.issubdtype(orig_dtype, np.integer) else 32)
}
metadata.update(px_meta)

print('--- KanalÃ¼bersicht ---')
success_count = sum('âœ…' in line for line in log_lines)
skip_count = sum('â†·' in line for line in log_lines)
print(f'  âœ… Gereinigt: {success_count}')
print(f'  â†· Ãœbersprungen: {skip_count}')

print(f"ðŸ’¾ Speichere vollstÃ¤ndigen Stack nach {output_path}")
tiff.imwrite(
    output_path,
    output_array,
    metadata=metadata,
    ome=True,
    photometric='minisblack',
    bigtiff=True,
    compression='zlib'
)

print(f"  Ãœbersprungene KanÃ¤le: {sorted(skip_idx)}")
print("  Fertig.")



















