In [7]:
import os, re
import numpy as np
import pandas as pd
import tifffile as tiff
import matplotlib.pyplot as plt
from scipy import ndimage as ndi
from scipy.signal import find_peaks, savgol_filter
from skimage import measure, morphology, filters

# =====================================================
# CONFIGURATION
# =====================================================
input_dir = "/Users/lokesh.pimpale/Desktop/work_dir/RK/untitled folder"
process_all = True
single_tif_name = "uniquename.tif"

frame_pad_left = 25
frame_pad_right = 10
std_thresh_factor = -0.5
min_size_px = 500
pixel_size_um = 3.14
frame_rate = 50.0
bin_size = 25
threshold_frac = 0.75  # intensity fraction for edge detection
edge_mode = "image"    # "image" (recommended) or "map"

save_png = True
show_plots = False


# =====================================================
# FILE GROUPING
# =====================================================
def group_tifs_by_basename(input_dir):
    tif_files = sorted(
        [os.path.join(input_dir, f) for f in os.listdir(input_dir)
         if f.lower().endswith(".tif")]
    )
    groups = {}
    for f in tif_files:
        m = re.match(r"^(.*?)(?:[-_]?file\d{3})?\.tif$", os.path.basename(f), re.IGNORECASE)
        if m:
            base = os.path.join(input_dir, m.group(1))
            groups.setdefault(base, []).append(f)
    for base, files in groups.items():
        files.sort(key=lambda x: int(re.search(r"file(\d{3})", x).group(1))
                   if re.search(r"file(\d{3})", x) else 0)
    print(f"📂 Found {len(groups)} grouped image sets.")
    return groups


def load_combined_stack(file_list):
    parts = []
    dtype = None
    for f in file_list:
        arr = tiff.imread(f)
        if arr.ndim == 2:
            arr = arr[np.newaxis, ...]
        parts.append(arr)
        if dtype is None:
            dtype = arr.dtype
    stack = np.concatenate(parts, axis=0).astype(dtype)
    print(f"🧩 Combined stack shape: {stack.shape}")
    return stack


# =====================================================
# HELPER FUNCTIONS
# =====================================================
def normalize01(a):
    a = a.astype(np.float32)
    mn, mx = np.nanmin(a), np.nanmax(a)
    return (a - mn) / (mx - mn + 1e-6)


def detect_events(data, smooth_win=15, prominence_factor=1.5, min_distance=5, baseline_drop=0.75):
    mean_trace = np.mean(data.reshape(data.shape[0], -1), axis=1)
    win = max(5, min(smooth_win, len(mean_trace)//2*2+1))
    smoothed = savgol_filter(mean_trace, window_length=win, polyorder=3)
    std = np.std(smoothed)
    peaks, _ = find_peaks(smoothed, prominence=prominence_factor * std, distance=min_distance)

    filtered = []
    if len(peaks):
        filtered.append(peaks[0])
        for i in range(1, len(peaks)):
            prev, cur = filtered[-1], peaks[i]
            valley = np.min(smoothed[prev:cur])
            thr = baseline_drop * min(smoothed[prev], smoothed[cur])
            if valley < thr:
                filtered.append(cur)

    d1 = np.gradient(smoothed)
    eps = 0.1 * np.std(d1)
    landmarks = []
    for p in filtered:
        left = None
        for i in range(p-1, 0, -1):
            if np.all(d1[i:i+3] > eps):
                left = i
                break
        if left is None:
            left = max(0, p-10)
        landmarks.append({"left": int(left), "peak": int(p)})
    return landmarks, mean_trace, smoothed, np.array(filtered, int)


def get_global_mask(data, min_size_px=500, relax=0.8):
    mip = np.max(data, axis=0)
    mip_norm = normalize01(mip)
    thr = filters.threshold_otsu(mip_norm) * relax
    binary = mip_norm > thr
    filled = ndi.binary_fill_holes(binary)
    cleaned = morphology.remove_small_objects(filled, min_size=min_size_px)
    labels = measure.label(cleaned)
    props = measure.regionprops(labels)
    if not props:
        return mip_norm, np.zeros_like(mip_norm, bool)
    largest = max(props, key=lambda r: r.area).label
    mask = ndi.binary_fill_holes(labels == largest)
    return mip_norm, mask


def analyze_event_std(event_data, global_mask, std_thresh_factor=0.0):
    pixel_std = np.std(event_data, axis=0)
    pixel_min = np.min(event_data, axis=0)
    pixel_max = np.max(event_data, axis=0)
    masked_std = pixel_std[global_mask]
    mean_std, std_std = np.mean(masked_std), np.std(masked_std)
    thr = mean_std + std_thresh_factor * std_std
    dyn_mask = global_mask & (pixel_std >= thr)

    frame_map = np.full_like(pixel_max, np.nan, np.float32)
    level = pixel_min + 0.75 * (pixel_max - pixel_min)
    for f in range(event_data.shape[0]):
        hit = (event_data[f] >= level) & dyn_mask & np.isnan(frame_map)
        frame_map[hit] = f
    return pixel_std, dyn_mask, frame_map, thr


def compute_leading_edges_from_image(event_data, dyn_mask, threshold_frac=0.75):
    n_frames, H, W = event_data.shape
    edges = []
    pixel_min = np.min(event_data, axis=0)
    pixel_max = np.max(event_data, axis=0)
    level = pixel_min + threshold_frac * (pixel_max - pixel_min)
    for f in range(n_frames):
        active = (event_data[f] >= level) & dyn_mask
        edge = active ^ ndi.binary_erosion(active)
        edges.append(edge)
    return edges


def compute_leading_edges_from_map(frame_map, dyn_mask):
    fmax = int(np.nanmax(frame_map))
    edges = []
    for f in range(fmax + 1):
        active = (frame_map <= f) & dyn_mask
        edge = active ^ ndi.binary_erosion(active)
        edges.append(edge)
    return edges


def leading_edge_centroids(edges):
    cents = []
    for e in edges:
        ysx = np.argwhere(e)
        if ysx.size == 0:
            cents.append((np.nan, np.nan))
        else:
            y, x = ysx.mean(axis=0)
            cents.append((x, y))
    return np.array(cents)


def compute_propagation_distance(frame_map, dyn_mask, pixel_size_um, frame_rate, bin_size):
    H, W = frame_map.shape
    nY, nX = H // bin_size, W // bin_size
    centers, times = [], []
    for by in range(nY):
        for bx in range(nX):
            y0, y1, x0, x1 = by*bin_size, (by+1)*bin_size, bx*bin_size, (bx+1)*bin_size
            region = frame_map[y0:y1, x0:x1]
            vals = region[dyn_mask[y0:y1, x0:x1]]
            vals = vals[np.isfinite(vals)]
            if len(vals) < 3:
                continue
            centers.append((x0+bin_size/2, y0+bin_size/2))
            times.append(np.nanmean(vals))
    if not centers:
        return np.nan, np.nan, np.nan, None, None
    centers, times = np.array(centers), np.array(times)
    i_start, i_end = np.argmin(times), np.argmax(times)
    (x1, y1), (x2, y2) = centers[i_start], centers[i_end]
    f1, f2 = times[i_start], times[i_end]
    dist_px = np.sqrt((x2-x1)**2 + (y2-y1)**2)
    dist_um = dist_px * pixel_size_um
    dt_s = (f2 - f1) / frame_rate
    vel_um_s = dist_um / dt_s if dt_s > 0 else np.nan
    return dist_um, dt_s, vel_um_s, (x1, y1), (x2, y2)


# =====================================================
# MAIN ANALYSIS
# =====================================================
def analyze_tif_stack(stack, base_name, output_root):
    print(f"\n🔹 Analyzing {base_name}")
    n_frames, H, W = stack.shape
    out_dir = os.path.join(output_root, base_name)
    os.makedirs(out_dir, exist_ok=True)

    landmarks, mean_trace, smoothed, peaks = detect_events(stack)
    if not landmarks:
        print("⚠️ No events detected.")
        return []

    mip_norm, global_mask = get_global_mask(stack, min_size_px=min_size_px)

    results = []
    for i, ev in enumerate(landmarks, 1):
        start = max(ev["left"] - frame_pad_left, 0)
        end = min(ev["peak"] + frame_pad_right, n_frames)
        event = stack[start:end]

        pixel_std, dyn_mask, frame_map, thr = analyze_event_std(event, global_mask, std_thresh_factor)
        edges = compute_leading_edges_from_image(event, dyn_mask) if edge_mode == "image" else compute_leading_edges_from_map(frame_map, dyn_mask)
        cents = leading_edge_centroids(edges)

        # compute propagation
        dist_um, dt_s, vel_um_s, p_start, p_end = compute_propagation_distance(frame_map, dyn_mask, pixel_size_um, frame_rate, bin_size)

        # save per-frame centroid CSV
        df_event = pd.DataFrame({
            "frame_index": np.arange(len(cents)),
            "centroid_x": cents[:, 0],
            "centroid_y": cents[:, 1],
        })
        if p_start and p_end:
            df_event["start_x"], df_event["start_y"] = p_start
            df_event["end_x"], df_event["end_y"] = p_end
            df_event["distance_um"] = dist_um
        csv_path = os.path.join(out_dir, f"{base_name}_event{i}_centroids.csv")
        df_event.to_csv(csv_path, index=False)

        # === FIGURE ===
        fig, axes = plt.subplots(1, 7, figsize=(36, 6))
        axes[0].imshow(mip_norm, cmap="gray"); axes[0].set_title("MIP"); axes[0].axis("off")
        axes[1].imshow(pixel_std, cmap="magma"); axes[1].imshow(global_mask, cmap="Greens", alpha=0.3)
        axes[1].set_title("STD + Global Mask"); axes[1].axis("off")
        axes[2].imshow(dyn_mask, cmap="Blues"); axes[2].set_title("Dynamic Mask"); axes[2].axis("off")
        im3 = axes[3].imshow(frame_map + start, cmap="turbo"); fig.colorbar(im3, ax=axes[3])
        axes[3].set_title("Frame Map"); axes[3].axis("off")

        # show edge traces on intensity
        axes[4].imshow(np.mean(event, axis=0), cmap="gray", alpha=0.6)
        axes[4].set_title("Leading Edges (true from image)"); axes[4].axis("off")
        idxs = np.linspace(0, len(edges)-1, 6, dtype=int)
        cols = plt.cm.plasma(np.linspace(0, 1, len(idxs)))
        for c, idx in zip(cols, idxs):
            cnts = measure.find_contours(edges[idx].astype(float), 0.5)
            for cnt in cnts:
                axes[4].plot(cnt[:,1], cnt[:,0], color=c, lw=1.8)

        # centroid trajectory
        axes[5].imshow(np.mean(event, axis=0), cmap="gray", alpha=0.6)
        axes[5].set_title("Centroid Path"); axes[5].axis("off")
        axes[5].plot(cents[:, 0], cents[:, 1], "w--", lw=2)
        axes[5].scatter(cents[:, 0], cents[:, 1], s=40, c=np.linspace(0,1,len(cents)), cmap="cool")
        axes[5].scatter(cents[0, 0], cents[0, 1], s=100, c="lime", edgecolors="k")
        axes[5].scatter(cents[-1, 0], cents[-1, 1], s=100, c="red", edgecolors="k")

        # distance QC plot
        axes[6].imshow(np.mean(event, axis=0), cmap="gray", alpha=0.6)
        axes[6].axis("off")
        if p_start and p_end:
            axes[6].plot([p_start[0], p_end[0]], [p_start[1], p_end[1]], "w--", lw=2)
            axes[6].scatter(*p_start, s=150, c="lime", edgecolors="k")
            axes[6].scatter(*p_end, s=150, c="red", edgecolors="k")
            axes[6].set_title(f"Start-End Bins\nDist={dist_um:.1f} µm\nVel={vel_um_s:.1f} µm/s")

        plt.suptitle(f"{base_name} — Event {i}", fontsize=15)
        plt.tight_layout()
        out_path = os.path.join(out_dir, f"{base_name}_event{i}_summary.png")
        plt.savefig(out_path, dpi=220)
        plt.close(fig)
        print(f"💾 Saved {out_path}")

        results.append({
            "file": base_name,
            "event_id": i,
            "propagation_distance_um": dist_um,
            "propagation_time_s": dt_s,
            "propagation_velocity_um_s": vel_um_s,
            "centroid_csv": os.path.basename(csv_path)
        })
    return results


# =====================================================
# MAIN RUN
# =====================================================
if __name__ == "__main__":
    all_results = []
    file_groups = group_tifs_by_basename(input_dir)

    if not process_all:
        file_groups = {os.path.splitext(single_tif_name)[0]: [os.path.join(input_dir, single_tif_name)]}

    for base, files in file_groups.items():
        try:
            stack = load_combined_stack(files)
            all_results.extend(analyze_tif_stack(stack, os.path.basename(base), input_dir))
        except Exception as e:
            print(f"❌ Error on {os.path.basename(base)}: {e}")

    if all_results:
        df = pd.DataFrame(all_results)
        csv_path = os.path.join(input_dir, "CalciumEventSummary_LeadingEdges.csv")
        df.to_csv(csv_path, index=False)
        print(f"\n✅ Exported summary CSV: {csv_path}")
    else:
        print("\n⚠️ No events processed.")

📂 Found 2 grouped image sets.
🧩 Combined stack shape: (3000, 1024, 1024)

🔹 Analyzing RK25DE14B-D21_Stream_C07_s1_t1_FITC
💾 Saved /Users/lokesh.pimpale/Desktop/work_dir/RK/untitled folder/RK25DE14B-D21_Stream_C07_s1_t1_FITC/RK25DE14B-D21_Stream_C07_s1_t1_FITC_event1_summary.png
💾 Saved /Users/lokesh.pimpale/Desktop/work_dir/RK/untitled folder/RK25DE14B-D21_Stream_C07_s1_t1_FITC/RK25DE14B-D21_Stream_C07_s1_t1_FITC_event2_summary.png
💾 Saved /Users/lokesh.pimpale/Desktop/work_dir/RK/untitled folder/RK25DE14B-D21_Stream_C07_s1_t1_FITC/RK25DE14B-D21_Stream_C07_s1_t1_FITC_event3_summary.png
🧩 Combined stack shape: (3000, 1024, 1024)

🔹 Analyzing RK25DE14B-D21_Stream_C08_s1_t1_FITC
💾 Saved /Users/lokesh.pimpale/Desktop/work_dir/RK/untitled folder/RK25DE14B-D21_Stream_C08_s1_t1_FITC/RK25DE14B-D21_Stream_C08_s1_t1_FITC_event1_summary.png
💾 Saved /Users/lokesh.pimpale/Desktop/work_dir/RK/untitled folder/RK25DE14B-D21_Stream_C08_s1_t1_FITC/RK25DE14B-D21_Stream_C08_s1_t1_FITC_event2_summary.png
💾