In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
import csv
import zipfile
import optuna
import numpy as np
import tifffile as tiff
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.gridspec import GridSpec
from sklearn.model_selection import ParameterGrid

import sys
sys.path.append("..")
from dyn_signalrise import dyn_signalrise

parentdir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

In [None]:
def match_events(pred, gt, dt=2, dr=3.0):
    """
    pred, gt: arrays/list of (frame, x, y)
    returns tp, fp, fn and matched index pairs
    """
    pred = np.asarray(pred, float)
    gt   = np.asarray(gt, float)

    if len(pred) == 0:
        return 0, 0, len(gt), []
    if len(gt) == 0:
        return 0, len(pred), 0, []

    used_pred = set()
    matches = []

    # sort by time helps stability
    gt_idx = np.argsort(gt[:, 0])
    pred_idx = np.argsort(pred[:, 0])

    for i in gt_idx:
        t, x, y = gt[i]
        best_j = None
        best_cost = None

        for j in pred_idx:
            if j in used_pred:
                continue
            th, xh, yh = pred[j]
            if abs(th - t) > dt:
                continue
            d = np.hypot(xh - x, yh - y)
            if d > dr:
                continue
            cost = abs(th - t) + 0.01 * d  # mostly time, slightly space
            if best_cost is None or cost < best_cost:
                best_cost = cost
                best_j = j

        if best_j is not None:
            used_pred.add(best_j)
            matches.append((best_j, i))

    tp = len(matches)
    fp = len(pred) - tp
    fn = len(gt) - tp
    return tp, fp, fn, matches

def fbeta(tp, fp, fn, beta=2.0):
    b2 = beta * beta
    denom = (1 + b2) * tp + b2 * fn + fp
    return 0.0 if denom == 0 else (1 + b2) * tp / denom

def load_manual_events_from_zip(folder, file):
    roinames = zipfile.ZipFile(os.path.join(folder,file)).namelist()
    manual_events = []
    for roiname in roinames:
        frame = int(roiname.split('-')[0])
        y = int(roiname.split('-')[1].split('.')[0])
        x = int(roiname.split('-')[2].split('.')[0])
        manual_events.append((frame, x, y))
    manual_events.sort(key=lambda v: v[0])
    return manual_events

def run_dyn_pipeline_on_stack(conf_data, dyn_params, warmup_frames=10):
    """
    Runs dyn_signalrise sequentially across conf_data.
    Returns list of predicted events (frame, x, y).
    """
    tracks_all = None  # this is your exinfo state passed back in
    pred_events = []

    # we also need prev_frames for intensity ratio logic
    frames_appear = int(dyn_params["frames_appear"])

    for idx, img_conf in enumerate(conf_data):
        if idx <= warmup_frames:
            # build state gradually
            continue

        # dyn_signalrise expects prev_frames to include at least ~2*frames_appear frames
        start = max(0, idx - frames_appear * 2 - 2)
        prev_frames = conf_data[start:idx]

        coords_events, _, tracks_all, _ = dyn_signalrise(
            img_ch1=img_conf,
            prev_frames=prev_frames,
            binary_mask=None,
            exinfo=tracks_all,
            presetROIsize=None,
            # pipeline parameters:
            min_dist=dyn_params["min_dist"],
            num_peaks=dyn_params["num_peaks"],
            thresh_abs_lo=dyn_params["thresh_abs_lo"],
            thresh_abs_hi=dyn_params["thresh_abs_hi"],
            border_limit=dyn_params["border_limit"],
            memory_frames=dyn_params["memory_frames"],
            track_search_dist=dyn_params["track_search_dist"],
            frames_appear=dyn_params["frames_appear"],
            thresh_intincratio=dyn_params["thresh_intincratio"],
            thresh_intincratio_max=dyn_params["thresh_intincratio_max"],
            thresh_move_dist=dyn_params["thresh_move_dist"],
        )

        if coords_events is not None and coords_events.size > 0:
            # You used coords_events[0] and stored (idx, x, y)
            x = int(coords_events[0, 0])
            y = int(coords_events[0, 1])
            pred_events.append((idx, x, y))

    return pred_events

def dedupe_events_greedy(events, merge_dt=10, merge_dr=6.0, keep="earliest"):
    """
    Deduplicate event predictions by merging events that are close in time and space.
    Returns: deduped_events (same tuple format as input, but reduced)
    """
    if not events:
        return []

    events = list(events)
    # sort by time
    events.sort(key=lambda e: e[0])

    clusters = []  # each cluster: list of events

    for e in events:
        t, x, y = e[0], e[1], e[2]

        assigned = False
        # try to assign to an existing cluster (typically only last few matter)
        for c in reversed(clusters):
            # representative: last kept event in cluster (or centroid)
            rep = c[-1]
            tr, xr, yr = rep[0], rep[1], rep[2]

            if abs(t - tr) <= merge_dt and np.hypot(x - xr, y - yr) <= merge_dr:
                c.append(e)
                assigned = True
                break

            # because clusters are time-sorted, if we're too far in time, we can stop
            if t - tr > merge_dt:
                break

        if not assigned:
            clusters.append([e])

    # choose representative per cluster
    deduped = []
    for c in clusters:
        if keep == "earliest":
            rep = min(c, key=lambda e: e[0])
        elif keep == "latest":
            rep = max(c, key=lambda e: e[0])
        elif keep == "mean":
            # average frame,x,y; if you have score as 4th entry, ignore it
            arr = np.array([e[:3] for e in c], float)
            t_mean, x_mean, y_mean = arr.mean(axis=0)
            rep = (int(round(t_mean)), float(x_mean), float(y_mean))
        else:
            raise ValueError("keep must be 'earliest', 'latest', or 'mean'")
        deduped.append(rep)

    # keep time order
    deduped.sort(key=lambda e: e[0])
    return deduped

def evaluate_params(conf_data, manual_events, dyn_params, dt=30, dr=6.0, beta=2.0):
    pred_events = run_dyn_pipeline_on_stack(conf_data, dyn_params, warmup_frames=10)
    pred_events = dedupe_events_greedy(pred_events, merge_dt=dt, merge_dr=dr, keep='earliest')
    tp, fp, fn, matches = match_events(pred_events, manual_events, dt=dt, dr=dr)
    score = fbeta(tp, fp, fn, beta=beta)
    return score, tp, fp, fn, pred_events, matches

def load_conf_stack(folder, timelapse, analysis_stack_len):
    files_all = os.listdir(folder)
    files_conf = [f for f in files_all if "conftimelapse" in f and f.endswith(".tif")]
    files_conf.sort()
    file_conf = files_conf[timelapse]

    conf_data = tiff.imread(os.path.join(folder, file_conf)) - 2**15
    conf_data = conf_data[:analysis_stack_len]
    return conf_data, file_conf

def optimize_with_optuna(conf_data, manual_events, n_trials=100, dt=30, dr=6.0, beta=0.5, seed=0):
    """
    Bayesian optimization over dyn_signalrise parameters.
    """

    sampler = optuna.samplers.TPESampler(seed=seed)

    def objective(trial):
        # Suggest parameters (adjust ranges to your microscope / SNR)
        dyn_params = dict(
            min_dist=trial.suggest_float("min_dist", 1.0, 2.0),
            num_peaks=trial.suggest_int("num_peaks", 1000, 1000),
            thresh_abs_lo=trial.suggest_float("thresh_abs_lo", 0.4, 1.3),
            thresh_abs_hi=trial.suggest_float("thresh_abs_hi", 10.0, 35.0),
            border_limit=trial.suggest_int("border_limit", 4, 4),
            memory_frames=trial.suggest_int("memory_frames", 3, 10),
            track_search_dist=trial.suggest_int("track_search_dist", 4, 8),
            frames_appear=trial.suggest_int("frames_appear", 4, 10),
            thresh_intincratio=trial.suggest_float("thresh_intincratio", 1.05, 1.3),
            thresh_intincratio_max=trial.suggest_float("thresh_intincratio_max", 3.8, 4.5),
            thresh_move_dist=trial.suggest_float("thresh_move_dist", 0.9, 1.5),
        )

        # Hard constraints (prune invalid combos)
        if dyn_params["thresh_abs_hi"] <= dyn_params["thresh_abs_lo"]:
            raise optuna.exceptions.TrialPruned()
        if dyn_params["thresh_intincratio_max"] <= dyn_params["thresh_intincratio"]:
            raise optuna.exceptions.TrialPruned()

        score, tp, fp, fn, _, _ = evaluate_params(
            conf_data=conf_data,
            manual_events=manual_events,
            dyn_params=dyn_params,
            dt=dt, dr=dr, beta=beta
        )

        # Log useful diagnostics
        trial.set_user_attr("tp", int(tp))
        trial.set_user_attr("fp", int(fp))
        trial.set_user_attr("fn", int(fn))

        return score

    study = optuna.create_study(direction="maximize", sampler=sampler)
    study.optimize(objective, n_trials=n_trials)

    return study

In [None]:
folders = [os.path.join(parentdir, 'exampledata\\dyn1\\conftimelapse')]
folderidx = 0
folder = folders[folderidx]
timelapse = 0

In [None]:
analysis_stack_len = 200  # 251210: 0: 200, 1: 300, 2: 300; 251209: 0: 100, 1: 100, 2: 200, 3: 200, 4: 200
beta = 0.25  # >1: recall, <1: precision; good balance in this case: ~0.25, i.e. precision four times as important as recall (avoid false positives, as we have many events, but do not want to get many false positives)
dr = 6
dt = 40
n_trials = 100

conf_data, file_conf = load_conf_stack(folder, timelapse, analysis_stack_len)
print("Loaded:", file_conf, "frames:", conf_data.shape[0], "shape:", conf_data[0].shape)

allfiles = os.listdir(folder)
manualeventsfiles = [file for file in allfiles if 'manualevents' in file]
manual_events = load_manual_events_from_zip(folder, file=manualeventsfiles[timelapse])
print("Manual events:", len(manual_events))

# quick single evaluation (sanity check)
dyn_params = dict(
  min_dist= 1.5,
  num_peaks= 1000,
  thresh_abs_lo= 0.8,
  thresh_abs_hi= 35.0,
  border_limit= 4,
  memory_frames= 7,
  track_search_dist= 6,
  frames_appear= 6,
  thresh_intincratio= 1.1,
  thresh_intincratio_max= 5.0,
  thresh_move_dist= 1.3
)

score, tp, fp, fn, pred_events, matches = evaluate_params(
    conf_data, manual_events, dyn_params, dt=dt, dr=dr, beta=beta
)
print(f"Baseline F2={score:.4f}  TP={tp}  FP={fp}  FN={fn}")

# optimize
study = optimize_with_optuna(
    conf_data, manual_events,
    n_trials=n_trials, dt=dt, dr=dr, beta=beta, seed=0
)

print("\nBest score:", study.best_value)
print("Best params:")
for k, v in study.best_params.items():
    print(f"  {k}: {v}")

# evaluate best params once more (get events/matches)
best_params = study.best_params
score, tp, fp, fn, pred_events, matches = evaluate_params(
    conf_data, manual_events, best_params, dt=dt, dr=dr, beta=beta
)
print(f"\nBest F2={score:.4f}  TP={tp}  FP={fp}  FN={fn}")