<a href="https://colab.research.google.com/github/elenancalima/Troph_Min_5to2/blob/main/Simple_troph_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
# === Cell 1: minimal, Py3.12-safe, no SciPy/skimage ===
!pip -q uninstall -y opencv-python opencv-python-headless numpy scipy scikit-image imageio >/dev/null 2>&1 || true

!pip -q install --upgrade --only-binary=:all: \
  numpy==2.0.2 \
  opencv-python-headless==4.10.0.84 \
  imageio==2.34.1

import numpy as np, cv2, imageio
from imageio import v3 as iio

print("OK:",
      "numpy", np.__version__,
      "opencv", cv2.__version__,
      "imageio", imageio.__version__)

# tiny smoke test (ensures read/write + resize work)
import os
arr = np.zeros((10,10), np.uint8)
cv2.resize(arr, (5,5))
iio.imwrite("/content/_smoke.png", arr)
print("Smoke test: wrote /content/_smoke.png")


[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tsfresh 0.21.1 requires scipy>=1.14.0; python_version >= "3.10", which is not installed.
stumpy 1.13.0 requires scipy>=1.10, which is not installed.
pymc 5.25.1 requires scipy>=1.4.1, which is not installed.
yellowbrick 1.5 requires scipy>=1.0.0, which is not installed.
xarray-einstats 0.9.1 requires scipy>=1.11, which is not installed.
mlxtend 0.23.4 requires scipy>=1.2.1, which is not installed.
arviz 0.22.0 requires scipy>=1.11.0, which is not installed.
shap 0.49.1 requires scipy, which is not installed.
libpysal 4.13.0 requires scipy>=1.8, which is not installed.
hyperopt 0.2.7 requires scipy, which is not installed.
dopamine-rl 4.1.2 requires opencv-python>=3.4.8.29, which is not installed.
jaxlib 0.5.3 requires scipy>=1.11.1, which is not installed.
matplotlib-venn 1.1.2 requires scipy, which is not in

In [22]:
# === Cell 2 (final): Detector with grayscale front_body_mask + cone brightness + direct 170x170 draw ===
import os, glob, math, json, re
import numpy as np
import cv2
from imageio import v3 as iio

PARAMS = dict(
    scale_factor=5,            # 850 -> 170
    min_duration=10,           # frames for stationarity check
    iou_stationary_tol=0.10,   # mean IoU >= 1 - tol

    # brightness change gating inside abdomen masks (composite red channel)
    brightnessThreshold=160,   # R channel threshold (0-255)
    brightPropDelta=0.15,      # |Δ| of smoothed red>thr fraction
    smoothingWindow=5,

    # head-direction via cone brightness on GRAYSCALE front_body_mask
    coneAngle_deg=30,
    coneLength=60,
    rays_per_cone=9,

    abdomenToHeadLength=25,
    headRadius=10,
    min_overlap_pairs=2,
    iou_match_threshold=0.30,
)

# ---------- helpers ----------
def ensure_dirs(*ps):
    for p in ps: os.makedirs(p, exist_ok=True)

def read_rgb(p):
    im = iio.imread(p)
    if im.ndim != 3 or im.shape[2] != 3:
        raise ValueError(f"expect RGB: {p}")
    return im

def to_bgr(img_rgb: np.ndarray) -> np.ndarray:
    return np.ascontiguousarray(cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR))

def read_gray(p):
    im = iio.imread(p)
    if im.ndim == 2:
        g = im
    elif im.ndim == 3:
        g = cv2.cvtColor(to_bgr(im), cv2.COLOR_BGR2GRAY)  # robust if stored RGB-gray
    else:
        raise ValueError(f"expect gray or RGB: {p}")
    if g.dtype != np.uint8:
        g = np.clip(g, 0, 255).astype(np.uint8)
    return g

def read_bin(p):
    im = iio.imread(p)
    if im.ndim == 3: im = im[...,0]
    return (im > 127).astype(np.uint8)

def iou(a, b):
    inter = np.bitwise_and(a, b).sum()
    if inter == 0: return 0.0
    union = np.bitwise_or(a, b).sum()
    return float(inter) / float(union)

def clamp(y, x, H, W):
    return max(0, min(H-1, int(round(y)))), max(0, min(W-1, int(round(x))))

def find_nearest_boundary(mask01, cy, cx):
    cnts, _ = cv2.findContours((mask01*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    if not cnts: return cy, cx
    best_d2, by, bx = 1e18, cy, cx
    for c in cnts:
        pts = c.reshape(-1,2)  # (x,y)
        dy = pts[:,1]-cy; dx = pts[:,0]-cx
        d2 = dx*dx + dy*dy
        j = int(np.argmin(d2))
        if d2[j] < best_d2:
            best_d2 = float(d2[j]); by, bx = int(pts[j,1]), int(pts[j,0])
    return by, bx

def region_axis_unit(mask01):
    m = cv2.moments((mask01*255).astype(np.uint8), binaryImage=True)
    if m['m00'] == 0: return (0.0, 1.0)
    mu20, mu02, mu11 = m['mu20'], m['mu02'], m['mu11']
    cov = np.array([[mu20, mu11],[mu11, mu02]], dtype=np.float64)
    w,v = np.linalg.eigh(cov)
    vmaj = v[:,1]
    n = np.hypot(vmaj[0], vmaj[1]) + 1e-9
    vx, vy = vmaj[0]/n, vmaj[1]/n
    return (vy, vx)  # (dy, dx)

def axis_phi_deg_from_unit(dy, dx):
    return math.degrees(math.atan2(dy, dx))  # y-down image

def connected_components(mask01):
    num, labels, stats, cents = cv2.connectedComponentsWithStats((mask01>0).astype(np.uint8), connectivity=8)
    return num, labels, stats, cents

def centroid_from_mask(mask01):
    m = cv2.moments((mask01*255).astype(np.uint8), binaryImage=True)
    if m['m00'] == 0: return None
    return (m['m01']/m['m00'], m['m10']/m['m00'])  # (cy,cx)

def draw_mask_contours(canvas_bgr, mask01, color, thickness=1):
    cnts, _ = cv2.findContours((mask01*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(canvas_bgr, cnts, -1, color, thickness)

def colorize_labels(labels):
    H, W = labels.shape
    out = np.zeros((H,W,3), np.uint8)
    rng = np.random.default_rng(0)
    unique = [u for u in np.unique(labels) if u!=0]
    lut = {0:(0,0,0)}
    for u in unique:
        lut[u] = tuple(int(x) for x in rng.integers(50, 255, size=3))
    return out

# --- map full-res point to small canvas (draw directly at 170x170) ---
def downscale_pt(y, x, H, W, h_small, w_small):
    sy = int(y * h_small / H)
    sx = int(x * w_small / W)
    if sy < 0: sy = 0
    if sx < 0: sx = 0
    if sy >= h_small: sy = h_small - 1
    if sx >= w_small: sx = w_small - 1
    return sy, sx

# --- filename alignment helpers ---
_num_re = re.compile(r'(\d+)')

def index_pngs_by_basename(folder: str):
    paths = glob.glob(os.path.join(folder, "*.png")) + glob.glob(os.path.join(folder, "*.PNG"))
    return {os.path.basename(p): p for p in paths}

import os, re

def natural_key(s: str):
    """Split a string into digit and non-digit chunks so 'img_2' < 'img_10'."""
    return [int(t) if t.isdigit() else t.lower() for t in re.findall(r'\d+|\D+', s)]

def _list_images_sorted(dir_path, exts=(".png", ".jpg", ".jpeg", ".bmp")):
    """List images in a dir and sort by natural (numeric-aware) filename order."""
    if not os.path.isdir(dir_path):
        raise FileNotFoundError(f"Directory not found: {dir_path}")
    files = [
        os.path.join(dir_path, f)
        for f in os.listdir(dir_path)
        if os.path.splitext(f)[1].lower() in exts
    ]
    files.sort(key=lambda p: natural_key(os.path.basename(p)))
    return files

def align_triplet(vid_dir: str, abd_dir: str, fb_dir: str):
    """
    New behavior:
      • Treat each series independently: sort each folder by natural order.
      • Lengths may differ: truncate all to the shortest count.
      • Basenames for outputs come from the video sequence (after truncation).
    Returns:
      vid_list, abd_list, fb_list, base_list
    """
    vid_list = _list_images_sorted(vid_dir)
    abd_list = _list_images_sorted(abd_dir)
    fb_list  = _list_images_sorted(fb_dir)

    counts = (len(vid_list), len(abd_list), len(fb_list))
    n = min(counts)
    if n == 0:
        raise FileNotFoundError(
            "No images to align.\n"
            f"  input_vid: {len(vid_list)} in {vid_dir}\n"
            f"  abdomen_mask: {len(abd_list)} in {abd_dir}\n"
            f"  front_body_mask: {len(fb_list)} in {fb_dir}"
        )
    if len(set(counts)) != 1:
        print(f"[warn] counts differ (vid, abd, fb) = {counts}. Using first {n} frames of each by sorted order.")

    vid_list = vid_list[:n]
    abd_list = abd_list[:n]
    fb_list  = fb_list[:n]

    # Keep output basenames tied to the video sequence for consistency.
    base_list = [os.path.splitext(os.path.basename(p))[0] for p in vid_list]

    return vid_list, abd_list, fb_list, base_list

# ---- cone brightness scorer on grayscale front_body_mask ----
def cone_brightness_score(gray, center_yx, axis_phi_deg, half_deg, length_px, rays):
    """
    Integrates grayscale intensity within a cone fan.
    Returns (mean_intensity, rays_lines) for debug drawing.
    """
    H, W = gray.shape
    cy, cx = center_yx
    total = 0.0
    count = 0
    rays_lines = []
    for a in np.linspace(-half_deg, +half_deg, rays):
        ang = math.radians(axis_phi_deg + a)
        ux, uy = math.cos(ang), math.sin(ang)  # y-down
        seg = []
        for r in range(1, length_px+1):
            y = int(round(cy + uy*r)); x = int(round(cx + ux*r))
            if 0 <= y < H and 0 <= x < W:
                seg.append((x, y))
                total += float(gray[y, x])
                count += 1
            else:
                break
        if seg: rays_lines.append(seg)
    mean_intensity = total / (count + 1e-6)
    return mean_intensity, rays_lines

# ---------- main ----------
def process_video_root(video_root, P=PARAMS, debug=False, debug_frame=30):
    H=W=850
    h_small, w_small = H//P["scale_factor"], W//P["scale_factor"]

    vid_dir = os.path.join(video_root, "input_vid")
    abd_dir = os.path.join(video_root, "abdomen_mask")
    fb_dir  = os.path.join(video_root, "front_body_mask")
    out_pt  = os.path.join(video_root, "simple_troph_point_heatmap")
    out_ln  = os.path.join(video_root, "simple_participant_line_heatmap")
    ensure_dirs(out_pt, out_ln)

    vid_list, abd_list, fb_list, base_list = align_triplet(vid_dir, abd_dir, fb_dir)
    n = len(base_list)

    dbg_dir = os.path.join(video_root, "_debug_troph")
    if debug: ensure_dirs(dbg_dir)

    next_id = 1
    prev_regions = {}  # id -> mask01
    tracks = {}        # id -> {'iou_hist': list, 'red_prop_hist': list}

    for t in range(n):
        comp = read_rgb(vid_list[t])
        red  = comp[...,0]                 # for abdomen brightness gating only
        abd  = read_bin(abd_list[t])
        fb_g = read_gray(fb_list[t])       # grayscale front_body_mask

        # 01 inputs overlay (fb visualized via Otsu threshold just for display)
        if debug and t==debug_frame:
            vis = to_bgr(comp)
            draw_mask_contours(vis, abd, (0,255,0), 1)   # green abdomen
            if fb_g.std() < 1e-6:
                fb_bin = (fb_g > int(fb_g.mean())).astype(np.uint8)
            else:
                _, fb_bin8 = cv2.threshold(fb_g, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
                fb_bin = (fb_bin8 > 0).astype(np.uint8)
            draw_mask_contours(vis, fb_bin, (255,0,0), 1)
            cv2.imwrite(os.path.join(dbg_dir, "01_inputs.png"), vis)

        # label abdomens
        num, labels, stats, cents = connected_components(abd)
        cur = []
        for lab in range(1, num):
            area = int(stats[lab, cv2.CC_STAT_AREA])
            if area < 10: continue
            mask_l = (labels==lab).astype(np.uint8)
            cy, cx = cents[lab][1], cents[lab][0]
            cur.append(dict(mask=mask_l, centroid=(cy, cx)))

        if debug and t==debug_frame:
            colored = colorize_labels(labels)
            cv2.imwrite(os.path.join(dbg_dir, "02_labels.png"), to_bgr(colored))

        # assign to prev by IoU
        assignments, used = {}, set()
        for i, R in enumerate(cur):
            best = (0.0, None)
            for tid, pm in prev_regions.items():
                if tid in used: continue
                s = iou(R["mask"], pm)
                if s > best[0]: best = (s, tid)
            assignments[i] = best[1] if best[0] >= P["iou_match_threshold"] else None

        # update tracks with IoU + red proportion
        for i, R in enumerate(cur):
            tid = assignments[i]
            if tid is None:
                tid = next_id; next_id += 1
                tracks[tid] = dict(iou_hist=[], red_prop_hist=[])
            if tid in prev_regions:
                tracks[tid]["iou_hist"].append(iou(R["mask"], prev_regions[tid]))
            m = R["mask"].astype(bool)
            prop = float((red[m] >= P["brightnessThreshold"]).mean()) if m.sum()>0 else 0.0
            tracks[tid]["red_prop_hist"].append(prop)
            R["track_id"] = tid

        prev_regions = {R["track_id"]: R["mask"] for R in cur}

        # candidate regions (stationary + brightness change)
        candidates, cand_dbg = [], []
        for R in cur:
            tid = R["track_id"]
            iou_hist = tracks[tid]["iou_hist"]
            rp_hist  = tracks[tid]["red_prop_hist"]
            stationary = (len(iou_hist) >= max(1, P["min_duration"]-1) and
                          np.mean(iou_hist[-(P["min_duration"]-1):] or [0.0]) >= (1.0 - P["iou_stationary_tol"]))
            changing = False
            curm = prevm = None
            if len(rp_hist) >= P["smoothingWindow"] + 1:
                curm  = float(np.mean(rp_hist[-P["smoothingWindow"]:]))
                prevm = float(np.mean(rp_hist[-P["smoothingWindow"]-1:-1]))
                changing = abs(curm - prevm) >= P["brightPropDelta"]
            if stationary and changing:
                candidates.append(R)
            cand_dbg.append(dict(
                frame=t, tid=tid,
                iou_mean=float(np.mean(iou_hist[-(P["min_duration"]-1):] or [0.0])),
                rp_cur=curm, rp_prev=prevm,
                rp_delta=(None if curm is None or prevm is None else curm-prevm),
                stationary=bool(stationary), changing=bool(changing)
            ))

        if debug and t==debug_frame:
            vis = to_bgr(comp)
            for R in candidates:
                draw_mask_contours(vis, R["mask"], (0,255,255), 2)
                cy,cx = map(int, R["centroid"])
                cv2.putText(vis, f"id{R['track_id']}", (cx+5, cy-5),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,255), 1, cv2.LINE_AA)
            cv2.imwrite(os.path.join(dbg_dir, "03_candidates.png"), vis)
            with open(os.path.join(dbg_dir, "debug_log.txt"), "w") as f:
                for row in cand_dbg: f.write(json.dumps(row)+"\n")

        # head points from axis + cone BRIGHTNESS on grayscale fb_g
        head_points, per_tid_circle = [], {}
        if debug and t==debug_frame:
            vis_axes = to_bgr(comp)

        for R in candidates:
            dy_u, dx_u = region_axis_unit(R["mask"])
            phi = axis_phi_deg_from_unit(dy_u, dx_u)
            cy, cx = R["centroid"]

            Fscore, rays_f = cone_brightness_score(fb_g, (cy, cx), phi,
                                                   P["coneAngle_deg"], P["coneLength"], P["rays_per_cone"])
            Bscore, rays_b = cone_brightness_score(fb_g, (cy, cx), phi+180.0,
                                                   P["coneAngle_deg"], P["coneLength"], P["rays_per_cone"])

            ratio = (Fscore - Bscore) / (Fscore + Bscore + 1e-6)  # normalized contrast [-1,1]
            sign = +1.0 if ratio >= 0.0 else -1.0

            hy = cy + sign * dy_u * P["abdomenToHeadLength"]
            hx = cx + sign * dx_u * P["abdomenToHeadLength"]
            hy, hx = clamp(hy, hx, H, W)
            head_points.append((hy, hx, R["track_id"]))

            if debug and t==debug_frame:
                p1 = (int(cx - dx_u*35), int(cy - dy_u*35))
                p2 = (int(cx + dx_u*35), int(cy + dy_u*35))
                cv2.arrowedLine(vis_axes, p1, p2, (255,255,255), 1, tipLength=0.2)
                for seg in rays_f:
                    for i in range(1,len(seg)): cv2.line(vis_axes, seg[i-1], seg[i], (0,0,255), 1)
                for seg in rays_b:
                    for i in range(1,len(seg)): cv2.line(vis_axes, seg[i-1], seg[i], (255,255,0), 1)
                cv2.putText(vis_axes, f"F:{Fscore:.1f} B:{Bscore:.1f} r:{ratio:.3f}",
                            (int(cx)+6, int(cy)+14), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,200,255), 1, cv2.LINE_AA)

        if debug and t==debug_frame:
            cv2.imwrite(os.path.join(dbg_dir, "04_axes_cones.png"), vis_axes)

        # --- outputs: DRAW DIRECTLY ON 170×170 CANVASES ---
        small_point = np.zeros((h_small, w_small), np.uint8)
        small_line  = np.zeros((h_small, w_small), np.uint8)

        # for full-res overlay
        if debug and t==debug_frame:
            vis_final = to_bgr(comp).copy()

        intersections = np.zeros((H, W), np.uint8)
        if head_points:
            # build head circle overlaps
            head_sum = np.zeros((H, W), np.uint16)
            for hy, hx, tid in head_points:
                m = np.zeros((H, W), np.uint8)
                cv2.circle(m, (hx, hy), P["headRadius"], 1, thickness=-1)
                per_tid_circle[tid] = m
                head_sum += m.astype(np.uint16)
            intersections = (head_sum >= P["min_overlap_pairs"]).astype(np.uint8)

        if intersections.any():
            num_cc, labs = cv2.connectedComponents(intersections, connectivity=8)
            for lab in range(1, num_cc):
                comp_mask = (labs==lab).astype(np.uint8)
                c = centroid_from_mask(comp_mask)
                if c is None:
                    continue

                cy, cx = clamp(c[0], c[1], H, W)

                # draw point on SMALL canvas
                sy, sx = downscale_pt(cy, cx, H, W, h_small, w_small)
                small_point[sy, sx] = 255

                # which abdomens participate (full-res overlap test)
                tids = [tid for tid,m in per_tid_circle.items() if (comp_mask & (m>0)).any()]
                for tid in tids:
                    amask = prev_regions.get(tid, None)
                    if amask is None:
                        continue
                    by, bx = find_nearest_boundary(amask, cy, cx)

                    # scale both endpoints to SMALL and draw the line there
                    sby, sbx = downscale_pt(by, bx, H, W, h_small, w_small)
                    cv2.line(small_line, (sx, sy), (sbx, sby), 255, 1)

                    if debug and t==debug_frame:
                        cv2.circle(vis_final, (cx, cy), 3, (0,255,0), -1)
                        cv2.line(vis_final, (cx, cy), (bx, by), (0,255,0), 2)

        if debug and t==debug_frame:
            cv2.imwrite(os.path.join(dbg_dir, "05_headpoints_circles.png"), to_bgr(comp))
            cv2.imwrite(os.path.join(dbg_dir, "06_outputs_overlay.png"), vis_final)

        # save SMALL images directly (no resizing step)
        base = base_list[t] + ".png" # Add the .png extension here
        iio.imwrite(os.path.join(out_pt, base), small_point)
        iio.imwrite(os.path.join(out_ln, base), small_line)

        if (t+1) % 25 == 0 or t == n-1:
            print(f"[{t+1}/{n}]")

    print("Done.")

In [23]:
# === Cell 3 (revised): Fake data generator that CLEANS input folders (new names) ===
import os, math, shutil
import numpy as np, cv2
from imageio import v3 as iio

def _ellipse_mask(H, W, cx, cy, a, b, angle_deg):
    m = np.zeros((H, W), np.uint8)
    cv2.ellipse(m, (int(cx), int(cy)), (int(a), int(b)), float(angle_deg), 0, 360, 255, -1)
    return m

def _triangle_wedge(H, W, cx, cy, axis_deg, dir_sign, length=60, base=30):
    ux, uy = math.cos(math.radians(axis_deg)), -math.sin(math.radians(axis_deg))
    px, py = cx + dir_sign*ux*length, cy + dir_sign*uy*length
    vx, vy = -uy, ux
    b = base/2.0
    blx, bly = cx - vx*b, cy - vy*b
    brx, bry = cx + vx*b, cy + vy*b
    poly = np.array([[blx, bly], [brx, bry], [px, py]], dtype=np.float32)
    poly[:, 0] = np.clip(np.round(poly[:, 0]), 0, W-1)
    poly[:, 1] = np.clip(np.round(poly[:, 1]), 0, H-1)
    m = np.zeros((H, W), np.uint8)
    cv2.fillConvexPoly(m, poly.astype(np.int32), 255)
    return m

def _reset_dir(d):
    if os.path.isdir(d):
        shutil.rmtree(d)
    os.makedirs(d, exist_ok=True)

def generate_fake_troph_data(
    video_root: str,
    n_frames: int = 60,
    seed: int = 0,
    prefix: str = "frame_",
    start_index: int = 1,
    pad: int = 5,
):
    rng = np.random.default_rng(seed)
    H, W = 850, 850

    # UPDATED folder names
    out_vid = os.path.join(video_root, "input_vid")            # was 'composite_frames'
    out_abd = os.path.join(video_root, "abdomen_mask")
    out_fb  = os.path.join(video_root, "front_body_mask")      # was 'body_front_mask'

    # CLEAN the three input folders, then recreate them
    _reset_dir(out_vid)
    _reset_dir(out_abd)
    _reset_dir(out_fb)

    # scene: 5 abdomens (2 stationary facing, 1 stationary quiet, 2 moving)
    A = [
        dict(cx=375.0, cy=420.0, a=28, b=18, axis_deg=0.0,   dir_sign=+1, kind="stationary_feed"),
        dict(cx=435.0, cy=420.0, a=28, b=18, axis_deg=180.0, dir_sign=+1, kind="stationary_feed"),
        dict(cx=200.0, cy=220.0, a=26, b=16, axis_deg=45.0,  dir_sign=+1, kind="stationary_quiet"),
        dict(cx=620.0, cy=320.0, a=26, b=16, axis_deg=75.0,  dir_sign=+1, kind="moving"),
        dict(cx=180.0, cy=600.0, a=30, b=20, axis_deg=135.0, dir_sign=+1, kind="moving"),
    ]
    vel = {3: (+0.8, +0.5), 4: (+0.6, -0.7)}  # indices into A

    noise = rng.random((H, W), dtype=np.float32)
    steps = [0.10, 0.32, 0.55, 0.30]  # ensures detectable Δ in red proportion
    step_len = max(5, n_frames // len(steps))

    for t in range(n_frames):
        comp_rgb = np.full((H, W, 3), 235, np.uint8)  # RGB light background
        abd_mask_all = np.zeros((H, W), np.uint8)
        fb_mask_all  = np.zeros((H, W), np.uint8)
        red_hi_all   = np.zeros((H, W), dtype=bool)

        k = min(t // step_len, len(steps)-1)
        p_feed, p_quiet = steps[k], 0.10

        for i, a in enumerate(A):
            if a["kind"] == "moving":
                dx, dy = vel[i]
                a["cx"] += dx; a["cy"] += dy
                if not (80 < a["cx"] < W-80): vel[i] = (-dx, dy); a["cx"] = np.clip(a["cx"], 80, W-80)
                if not (80 < a["cy"] < H-80): vel[i] = (dx, -dy); a["cy"] = np.clip(a["cy"], 80, H-80)

            cx, cy = a["cx"], a["cy"]
            m_abd = _ellipse_mask(H, W, cx, cy, a["a"], a["b"], a["axis_deg"])
            m_fb  = _triangle_wedge(H, W, cx, cy, a["axis_deg"], a["dir_sign"], length=60, base=30)

            abd_mask_all |= (m_abd > 0).astype(np.uint8)
            fb_mask_all  |= (m_fb  > 0).astype(np.uint8)

            p = p_feed if a["kind"] == "stationary_feed" else p_quiet
            red_hi_all |= (noise < p) & (m_abd > 0)

        # Compose RGB: dark body, base red inside abdomen, extra bright-red "feeding" pixels
        comp_rgb[abd_mask_all > 0] = (60, 60, 60)             # dark body
        comp_rgb[..., 0][abd_mask_all > 0] = 120              # base red inside abdomen
        comp_rgb[..., 0][red_hi_all] = 220                    # brighter red where "feeding"

        base = f"{prefix}{start_index + t:0{pad}d}.png"
        iio.imwrite(os.path.join(out_vid, base), comp_rgb)
        iio.imwrite(os.path.join(out_abd, base), (abd_mask_all*255).astype(np.uint8))
        iio.imwrite(os.path.join(out_fb,  base), (fb_mask_all*255).astype(np.uint8))

        if (t+1) % 20 == 0 or t == n_frames-1:
            print(f"[fake {t+1}/{n_frames}] wrote {base}")

    print("Fake data ready:",
          f"\n  input_vid:        {len(os.listdir(out_vid))} PNGs",
          f"\n  abdomen_mask:     {len(os.listdir(out_abd))} PNGs",
          f"\n  front_body_mask:  {len(os.listdir(out_fb))} PNGs")


In [24]:
video_root = "/content/fake_video_root"
generate_fake_troph_data(video_root, n_frames=60, seed=0)

[fake 20/60] wrote frame_00020.png
[fake 40/60] wrote frame_00040.png
[fake 60/60] wrote frame_00060.png
Fake data ready: 
  input_vid:        60 PNGs 
  abdomen_mask:     60 PNGs 
  front_body_mask:  60 PNGs


In [25]:

PARAMS = dict(
    scale_factor=5,            # 850 -> 170
    min_duration=3,
    iou_stationary_tol=0.10,   # mean IoU >= 1 - tol
    brightnessThreshold=160,   # R channel threshold (0-255)
    brightPropDelta=0.02,      # abs diff of smoothed red>thr fraction
    smoothingWindow=3,
    coneAngle_deg=45,
    cone_bg_weight=0.05,
    coneLength=80,
    rays_per_cone=9,
    abdomenToHeadLength=45,
    headRadius=30,
    min_overlap_pairs=2,
    iou_match_threshold=0.4,
)


video_root = "/content/fake_video_root"
# Run with debug visuals for frame 30
process_video_root(video_root, PARAMS, debug=True, debug_frame=30)


[25/60]
[50/60]
[60/60]
Done.


In [26]:
# === Cell 6B: Detect video_root + sanity check (run even if you skipped 6A) ===
import os, glob
from google.colab import drive

# --- config (repeat here so this cell is standalone) ---
videoName   = "tmpvidroot3"
GDRIVE_ROOT = "/content/drive/MyDrive"  # or '/content/drive/Shared drives/<YourDrive>'

# Ensure Drive is mounted (idempotent)
drive.mount("/content/drive", force_remount=False)

dest_dir = os.path.join(GDRIVE_ROOT, f"MainConnection_VidRoots/{videoName}/")

REQUIRED = ["input_vid", "abdomen_mask", "front_body_mask"]

def has_required(root):
    return all(os.path.isdir(os.path.join(root, r)) for r in REQUIRED)

# Find candidate video_root(s) that contain the required folders
candidates = []
if has_required(dest_dir):
    candidates.append(dest_dir)

top_level = [p for p in os.listdir(dest_dir) if os.path.isdir(os.path.join(dest_dir, p))]
for name in top_level:
    root = os.path.join(dest_dir, name)
    if has_required(root):
        candidates.append(root)

if not candidates:
    # one more level deep
    for name in top_level:
        lvl1 = os.path.join(dest_dir, name)
        if not os.path.isdir(lvl1):
            continue
        for name2 in os.listdir(lvl1):
            root = os.path.join(lvl1, name2)
            if os.path.isdir(root) and has_required(root):
                candidates.append(root)

if not candidates:
    print("❌ Could not find a folder that contains:", REQUIRED)
    print("Top-level entries in dest_dir:", top_level[:20])
    raise RuntimeError("video_root not found.")

# Choose the shortest path (usually the intended root)
video_root = sorted(candidates, key=lambda p: len(p))[0]
print("✅ video_root detected:", video_root)

def list_bases(folder):
    files = glob.glob(os.path.join(folder, "*.png")) + glob.glob(os.path.join(folder, "*.PNG"))
    return sorted([os.path.basename(f) for f in files])

vid_b = set(list_bases(os.path.join(video_root, "input_vid")))
abd_b = set(list_bases(os.path.join(video_root, "abdomen_mask")))
fb_b  = set(list_bases(os.path.join(video_root, "front_body_mask")))
common = sorted(list(vid_b & abd_b & fb_b))[:10]

print(f"Frames: input_vid={len(vid_b)}  abdomen={len(abd_b)}  front_body={len(fb_b)}  common={len(vid_b & abd_b & fb_b)}")
print("Sample common basenames:", common)

print("\nUse this in the detector cell:")
print(f"video_root = r\"{video_root}\"")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ video_root detected: /content/drive/MyDrive/MainConnection_VidRoots/tmpvidroot3/
Frames: input_vid=132  abdomen=132  front_body=132  common=0
Sample common basenames: []

Use this in the detector cell:
video_root = r"/content/drive/MyDrive/MainConnection_VidRoots/tmpvidroot3/"


In [27]:
PARAMS = dict(
    scale_factor=5,            # 850 -> 170
    min_duration=4,
    iou_stationary_tol=0.60,   # mean IoU >= 1 - tol
    brightnessThreshold=30,   # R channel threshold (0-255)
    brightPropDelta=0.00,      # abs diff of smoothed red>thr fraction
    smoothingWindow=3,
    coneAngle_deg=35,
    coneLength=80,
    rays_per_cone=9,
    cone_bg_weight=0.05,       # Add the missing key with a default value
    abdomenToHeadLength=50,
    headRadius=15,
    min_overlap_pairs=2,
    iou_match_threshold=0.01,
)

process_video_root(video_root, PARAMS, debug=True, debug_frame=3)

[25/132]
[50/132]
[75/132]
[100/132]
[125/132]
[132/132]
Done.


Run 6A only when unzipping is needed

In [None]:
# === Cell 6A: Mount Drive + UNZIP (run only if you need to extract) ===
import os, zipfile
from google.colab import drive

# --- config ---
videoName   = "FoodLimit2.5D_gR0033_feeding"
GDRIVE_ROOT = "/content/drive/MyDrive"  # or '/content/drive/Shared drives/<YourDrive>'
ZIP_REL     = f"MainConnection_VidRoots/{videoName}/{videoName}_tmpvidroot2.zip"

# --- mount ---
drive.mount("/content/drive", force_remount=False)

zip_path = os.path.join(GDRIVE_ROOT, ZIP_REL)
dest_dir = os.path.dirname(zip_path)

print(f"Zip path: {zip_path}")
print(f"Dest dir: {dest_dir}")

if not os.path.isfile(zip_path):
    print("❌ Zip not found. Listing parent folder:")
    parent = os.path.dirname(zip_path)
    if os.path.isdir(parent):
        print("Contents of", parent, "→", os.listdir(parent)[:20])
    else:
        print("Parent folder does not exist:", parent)
    raise FileNotFoundError(zip_path)

def _is_within_directory(directory, target):
    directory = os.path.abspath(directory)
    target    = os.path.abspath(target)
    return os.path.commonpath([directory]) == os.path.commonpath([directory, target])

with zipfile.ZipFile(zip_path, 'r') as zf:
    for member in zf.namelist():
        target_path = os.path.join(dest_dir, member)
        if not _is_within_directory(dest_dir, target_path):
            raise RuntimeError(f"Blocked unsafe path in zip: {member}")
    print("Unzipping...")
    zf.extractall(dest_dir)
    print("Unzip complete.")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Zip path: /content/drive/MyDrive/MainConnection_VidRoots/FoodLimit2.5D_gR0033_feeding/FoodLimit2.5D_gR0033_feeding_tmpvidroot2.zip
Dest dir: /content/drive/MyDrive/MainConnection_VidRoots/FoodLimit2.5D_gR0033_feeding
Unzipping...


KeyboardInterrupt: 