In [None]:
import json, argparse
import numpy as np
import torch
import torch.nn as nn
import cv2
import time 
import math

from common import (
    set_seed, find_episodes, load_annotations_json, video_path_for_episode,
    load_refs_for_episode, EmbeddingMatcher, build_template, YOLOProposals,
    l2_normalize, iou_xyxy, frame_to_boxes, TemporalHead, IoUHead,
)


def build_split(split_file):
    with open(split_file, "r", encoding="utf-8") as f:
        sp = json.load(f)
    return sp["train_ids"], sp["val_ids"]


def label_frame_candidates(gt_map, frame_idx, props):
    if frame_idx not in gt_map: return [0.0]*len(props), [0.0]*len(props)
    y = []; ious=[]
    for b in props:
        best = 0.0
        for g in gt_map[frame_idx]:
            best = max(best, iou_xyxy(b,g))
        y.append(1.0 if best >= 0.5 else 0.0)
        ious.append(best)
    return y, ious


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_root", required=True)
    ap.add_argument("--split_file", required=True)
    ap.add_argument("--epochs", type=int, default=5)
    ap.add_argument("--batch_size", type=int, default=128)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--freeze_backbone", action="store_true")
    ap.add_argument("--save_ckpt", default="./ckpts/heads_dev.pt")
    ap.add_argument("--imgsz", type=int, default=960)
    ap.add_argument("--max_props", type=int, default=200)
    ap.add_argument("--conf", type=float, default=0.005)
    ap.add_argument("--nms_iou", type=float, default=0.60)
    ap.add_argument("--temporal_T", type=int, default=15)
    ap.add_argument("--yolo_ckpt", type=str, default=None)
    ap.add_argument("--debug", action="store_true")
    args = ap.parse_args()

    set_seed(1337)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    Path(args.save_ckpt).parent.mkdir(parents=True, exist_ok=True)

    train_ids, val_ids = build_split(args.split_file)
    print(f"[SPLIT] train={len(train_ids)} val={len(val_ids)}")

    props_engine = YOLOProposals(conf=args.conf, iou=args.nms_iou, imgsz=args.imgsz,
                                 max_candidates=args.max_props, device=0,
                                 yolo_on=True, yolo_ckpt=args.yolo_ckpt, debug=args.debug)
    matcher = EmbeddingMatcher(out_dim=256).to(device).eval()

    temp_head = TemporalHead(in_dim=6, proj_dim=32, hidden=64).to(device)
    iou_head  = IoUHead(in_dim=6, hidden=64).to(device)

    opt = torch.optim.Adam(list(temp_head.parameters()) + list(iou_head.parameters()), lr=args.lr)
    bce = nn.BCELoss(); l1 = nn.L1Loss()

    gt_entries = load_annotations_json(args.data_root)

    def collect_samples(video_ids, max_frames_per_vid=400):
        X_temp = []; Y_temp = []
        X_iou  = []; Y_iou  = []
        for vid in video_ids:
            refs = load_refs_for_episode(args.data_root, vid)
            tmpl = build_template(matcher, refs, device=device, augs_per_ref=12, use_adapter=True, debug=args.debug)
            vpath = video_path_for_episode(args.data_root, vid)
            cap = cv2.VideoCapture(vpath)
            gt_map = frame_to_boxes(gt_entries, vid, key="annotations")
            last_geom = None
            seq_buf = []
            fidx = 0
            while True:
                ok, frame = cap.read()
                if not ok: break
                fidx += 1  # 1-based index
                boxes = props_engine(frame)
                crops = [frame[y1:y2, x1:x2] for (x1,y1,x2,y2) in boxes]
                t0=time.time()
                embs  = matcher.encode_np(crops, device)
                cos   = (embs @ tmpl.T).squeeze(1).detach().cpu().numpy() if embs.numel()>0 else np.zeros((0,), dtype=np.float32)
                if args.debug:
                    print(f"[EMBED] vid={vid} f={fidx} props={len(boxes)} embed_time={time.time()-t0:.3f}s")
                H,W = frame.shape[:2]
                feat_rows=[]; geoms=[]
                for b,s in zip(boxes, cos):
                    x1,y1,x2,y2 = b
                    cx=(x1+x2)/2.0; cy=(y1+y2)/2.0; w=max(1.0,x2-x1); h=max(1.0,y2-y1)
                    if last_geom is None:
                        dx=dy=ds=dh=dw=0.0
                    else:
                        lcx,lcy,lw,lh = last_geom
                        dx=(cx-lcx)/max(1.0,W); dy=(cy-lcy)/max(1.0,H)
                        ds=math.log(w/max(1.0,lw)); dh=math.log(h/max(1.0,lh))
                        dw=math.log((w/h)/max(1e-6,lw/lh))
                    feat_rows.append([float(s),dx,dy,ds,dh,dw]); geoms.append((cx,cy,w,h))
                y_bin, y_iou = label_frame_candidates(gt_map, fidx, boxes)
                if feat_rows:
                    top = max(range(len(feat_rows)), key=lambda i: feat_rows[i][0])
                    seq_buf.append(feat_rows[top])
                    if len(seq_buf) > args.temporal_T: seq_buf.pop(0)
                    for fr, yb, yi in zip(feat_rows, y_bin, y_iou):
                        X_iou.append(torch.tensor(fr, dtype=torch.float32))
                        Y_iou.append(torch.tensor([yi], dtype=torch.float32))
                        if len(seq_buf)>=3:
                            X_temp.append(torch.tensor(seq_buf, dtype=torch.float32))
                            Y_temp.append(torch.tensor([yb], dtype=torch.float32))
                    last_geom = geoms[top]
            cap.release()
        return X_temp, Y_temp, X_iou, Y_iou

    print("[COLLECT] train samples...")
    Xtemp_tr, Ytemp_tr, Xiou_tr, Yiou_tr = collect_samples(train_ids)
    print("[COLLECT] val samples...")
    Xtemp_va, Ytemp_va, Xiou_va, Yiou_va = collect_samples(val_ids)

    def batches(Xs, Ys, bs):
        idx = np.arange(len(Xs)); np.random.shuffle(idx)
        for i in range(0, len(idx), bs):
            sl = idx[i:i+bs]
            yield [Xs[j] for j in sl], [Ys[j] for j in sl]

    for ep in range(1, args.epochs+1):
        temp_head.train(); iou_head.train()
        loss_sum=0.0; nsteps=0
        for Xb, Yb in batches(Xtemp_tr, Ytemp_tr, args.batch_size):
            Xpad = torch.nn.utils.rnn.pad_sequence([x for x in Xb], batch_first=True).to(device)
            Ypad = torch.stack(Yb, dim=0).to(device)
            opt.zero_grad(); yhat = temp_head(Xpad); loss = bce(yhat, Ypad)
            loss.backward(); opt.step(); loss_sum += float(loss.item()); nsteps+=1
        for Xb, Yb in batches(Xiou_tr, Yiou_tr, args.batch_size):
            X = torch.stack(Xb, dim=0).to(device); Y = torch.stack(Yb, dim=0).to(device)
            opt.zero_grad(); yhat = iou_head(X); loss = l1(yhat, Y)
            loss.backward(); opt.step(); loss_sum += float(loss.item()); nsteps+=1
        temp_head.eval(); iou_head.eval()
        with torch.no_grad():
            def eval_head(Xs, Ys, head, is_temp=True):
                if not Xs: return 0.0
                acc=0.0; n=0
                if is_temp:
                    for i in range(0,len(Xs),args.batch_size):
                        X = torch.nn.utils.rnn.pad_sequence(Xs[i:i+args.batch_size], batch_first=True).to(device)
                        Y = torch.stack(Ys[i:i+args.batch_size], dim=0).to(device)
                        yhat = head(X); acc += float(((yhat>0.5)==(Y>0.5)).float().mean().cpu()); n+=1
                else:
                    for i in range(0,len(Xs),args.batch_size):
                        X = torch.stack(Xs[i:i+args.batch_size], dim=0).to(device)
                        Y = torch.stack(Ys[i:i+args.batch_size], dim=0).to(device)
                        yhat = head(X); acc += float(1.0 - torch.abs(yhat-Y).mean().cpu()); n+=1
                return acc/max(1,n)
            t_acc = eval_head(Xtemp_va, Ytemp_va, temp_head, True)
            q_acc = eval_head(Xiou_va,  Yiou_va,  iou_head,  False)
        print(f"[E{ep}] loss={loss_sum/max(1,nsteps):.4f}  temp_val_acc={t_acc:.3f}  iou_val(1-L1)={q_acc:.3f}")

    torch.save({"temporal_head": temp_head.state_dict(), "iou_head": iou_head.state_dict()}, args.save_ckpt)
    print(f"[SAVED] {args.save_ckpt}")

if __name__ == "__main__":
    from pathlib import Path
    main()


In [None]:
import json, argparse
import numpy as np
import torch
import cv2, math, time

from common import (
    set_seed, find_episodes, load_annotations_json, load_refs_for_episode,
    video_path_for_episode, YOLOProposals, EmbeddingMatcher, build_template,
    nms_xyxy, frame_to_boxes, TemporalHead, IoUHead, SingleTargetTracker,
    segmentize, st_iou_mean
)


def run_once(args, video_ids, ckpt=None, tau_high=0.55, tau_low=0.45):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    props = YOLOProposals(conf=args.conf, iou=args.nms_iou, imgsz=args.imgsz,
                          max_candidates=args.max_props, device=0,
                          yolo_on=not args.yolo_off, yolo_ckpt=args.yolo_ckpt, debug=args.debug)
    matcher = EmbeddingMatcher(out_dim=256).to(device).eval()
    temp_head = TemporalHead().to(device).eval()
    iou_head  = IoUHead().to(device).eval()
    if ckpt:
        sd = torch.load(ckpt, map_location=device)
        temp_head.load_state_dict(sd["temporal_head"]) ; iou_head.load_state_dict(sd["iou_head"])

    preds=[]
    for vid in video_ids:
        refs = load_refs_for_episode(args.data_root, vid)
        tmpl = build_template(matcher, refs, device=device, augs_per_ref=12, use_adapter=True, debug=args.debug)
        cap = cv2.VideoCapture(video_path_for_episode(args.data_root, vid))
        tracker = SingleTargetTracker(tau_high=tau_high, tau_low=tau_low,
                                      assoc_lambda=args.assoc_lambda, max_lost=max(10,3*args.frame_stride),
                                      min_commit=2, gap_fill=max(1,args.frame_stride-1),
                                      frame_stride=args.frame_stride, debug=args.debug)
        seq_buf=[]; last_geom=None; fidx=0
        while True:
            ok, frame = cap.read()
            if not ok: break
            fidx += 1  # 1-based
            if args.frame_stride>1 and ((fidx-1) % args.frame_stride)!=0:
                continue
            t0=time.time(); boxes = props(frame); t_prop=time.time()-t0
            crops = [frame[y1:y2, x1:x2] for (x1,y1,x2,y2) in boxes]
            t1=time.time(); embs  = matcher.encode_np(crops, device); t_emb=time.time()-t1
            cos   = (embs @ tmpl.T).squeeze(1).detach().cpu().numpy() if embs.numel()>0 else np.zeros((0,),dtype=np.float32)
            H,W = frame.shape[:2]; feat_rows=[]
            for b,s in zip(boxes, cos):
                x1,y1,x2,y2=b
                cx=(x1+x2)/2; cy=(y1+y2)/2; w=max(1.0,x2-x1); h=max(1.0,y2-y1)
                if last_geom is None:
                    dx=dy=ds=dh=dw=0.0
                else:
                    lcx,lcy,lw,lh = last_geom
                    dx=(cx-lcx)/max(1.0,W); dy=(cy-lcy)/max(1.0,H)
                    ds=math.log(w/max(1.0,lw)); dh=math.log(h/max(1.0,lh))
                    dw=math.log((w/h)/max(1e-6,lw/lh))
                feat_rows.append([float(s),dx,dy,ds,dh,dw])
            keep = nms_xyxy(boxes, cos, iou_thr=args.nms_final_iou)
            boxes = [boxes[i] for i in keep]; sims=[float(cos[i]) for i in keep]
            feats=[feat_rows[i] for i in keep]
            if feats:
                seq_buf.append(feats[0])
                if len(seq_buf)>args.temporal_T: seq_buf.pop(0)
                seq = torch.tensor(seq_buf, dtype=torch.float32, device=device).unsqueeze(0).repeat(len(feats),1,1)
                s_temp = temp_head(seq).squeeze(-1).detach().cpu().numpy()
                s_iou  = iou_head(torch.tensor(feats, dtype=torch.float32, device=device)).squeeze(-1).detach().cpu().numpy()
                fused = (args.alpha_fuse*np.array(sims) + (1-args.alpha_fuse)*s_temp) * 0.5 + 0.5*s_iou
            else:
                fused=[]
            tracker.update(fidx, boxes, fused)
            if boxes:
                bi = int(np.argmax(fused)); x1,y1,x2,y2 = boxes[bi]
                last_geom=((x1+x2)/2,(y1+y2)/2,max(1.0,x2-x1),max(1.0,y2-y1))
            if args.debug and (fidx % (5*args.frame_stride) == 1):
                print(f"[DBG] {vid} f={fidx} props={len(boxes)} t_prop={t_prop:.3f}s t_emb={t_emb:.3f}s")
        cap.release()
        dets = tracker.detections
        segs = segmentize(dets, max_gap=args.frame_stride)
        preds.append({"video_id": vid, "detections": segs})
        if args.debug:
            sample_frames = []
            for s in segs:
                for bb in s["bboxes"]:
                    sample_frames.append(bb["frame"]) ;
                    if len(sample_frames)>=5: break
                if len(sample_frames)>=5: break
            print(f"[DBG] {vid}: det_frames={sum(len(s['bboxes']) for s in segs)} sample={sample_frames}")
    return preds


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_root", required=True)
    ap.add_argument("--split_file", required=True)
    ap.add_argument("--ckpt", default="./ckpts/heads_dev.pt")
    ap.add_argument("--frame_stride", type=int, default=2)
    ap.add_argument("--imgsz", type=int, default=960)
    ap.add_argument("--max_props", type=int, default=200)
    ap.add_argument("--conf", type=float, default=0.005)
    ap.add_argument("--nms_iou", type=float, default=0.60)
    ap.add_argument("--nms_final_iou", type=float, default=0.5)
    ap.add_argument("--assoc_lambda", type=float, default=0.5)
    ap.add_argument("--alpha_fuse", type=float, default=0.5)
    ap.add_argument("--temporal_T", type=int, default=15)
    ap.add_argument("--yolo_ckpt", type=str, default=None)
    ap.add_argument("--yolo_off", action='store_true')
    ap.add_argument("--debug", action='store_true')
    args = ap.parse_args()

    with open(args.split_file,"r",encoding="utf-8") as f: sp=json.load(f)
    val_ids = sp["val_ids"]
    gt = load_annotations_json(args.data_root)

    # Debug GT coverage
    if args.debug and len(val_ids)>0:
        v0 = val_ids[0]
        gmap = frame_to_boxes(gt, v0, key="annotations")
        frames = sorted(gmap.keys())
        print(f"[DEBUG] {v0}: GT frames={len(frames)} min={frames[0] if frames else None} max={frames[-1] if frames else None}")

    best=(0.0, None)
    for th in [0.45,0.5,0.55]:
        for tl in [th-0.15, th-0.1, th-0.05]:
            preds = run_once(args, val_ids, ckpt=args.ckpt, tau_high=th, tau_low=tl)
            sc = st_iou_mean(gt, preds, val_ids)
            print(f"[VAL] tau_high={th:.2f} tau_low={tl:.2f}  ST-IoU={sc:.4f}")
            if sc > best[0]: best=(sc,(th,tl))

    print("\n[BEST]", best)
    cfg = {
        "frame_stride": args.frame_stride,
        "imgsz": args.imgsz,
        "max_props": args.max_props,
        "conf": args.conf,
        "nms_iou": args.nms_iou,
        "nms_final_iou": args.nms_final_iou,
        "assoc_lambda": args.assoc_lambda,
        "alpha_fuse": args.alpha_fuse,
        "temporal_T": args.temporal_T,
        "tau_high": best[1][0] if best[1] else 0.55,
        "tau_low":  best[1][1] if best[1] else 0.45
    }
    print("\n[CONFIG] Paste into configs/release.json:\n" + json.dumps(cfg, indent=2))

if __name__ == "__main__":
    main()


In [None]:
import json, argparse
import numpy as np
import torch
import cv2, math, time
from pathlib import Path

from common import (
    find_episodes, load_refs_for_episode, video_path_for_episode,
    YOLOProposals, EmbeddingMatcher, build_template, nms_xyxy,
    TemporalHead, IoUHead, SingleTargetTracker, segmentize
)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_root", required=True)
    ap.add_argument("--ckpt", required=True)
    ap.add_argument("--config", required=True)
    ap.add_argument("--out", default="./viz_test/predictions.json")
    ap.add_argument("--yolo_ckpt", type=str, default=None)
    ap.add_argument("--yolo_off", action='store_true')
    ap.add_argument("--debug", action='store_true')
    args = ap.parse_args()

    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    with open(args.config,"r",encoding="utf-8") as f: C = json.load(f)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    props = YOLOProposals(conf=C["conf"], iou=C["nms_iou"], imgsz=C["imgsz"],
                          max_candidates=C["max_props"], device=0,
                          yolo_on=not args.yolo_off, yolo_ckpt=args.yolo_ckpt, debug=args.debug)
    matcher = EmbeddingMatcher(out_dim=256).to(device).eval()
    temp_head = TemporalHead().to(device).eval()
    iou_head  = IoUHead().to(device).eval()
    sd = torch.load(args.ckpt, map_location=device)
    temp_head.load_state_dict(sd["temporal_head"]) ; iou_head.load_state_dict(sd["iou_head"])

    vids = find_episodes(args.data_root)
    preds=[]
    for vid in vids:
        print(f"[RUN] {vid}")
        refs = load_refs_for_episode(args.data_root, vid)
        tmpl = build_template(matcher, refs, device=device, augs_per_ref=12, use_adapter=True, debug=args.debug)
        cap = cv2.VideoCapture(video_path_for_episode(args.data_root, vid))
        tracker = SingleTargetTracker(
                tau_high=C["tau_high"], tau_low=C["tau_low"],
                assoc_lambda=C["assoc_lambda"],
                max_lost=int(C.get("max_lost", max(10, 3*C["frame_stride"]))),
                min_commit=int(C.get("min_commit", 2)),
                gap_fill=max(1, C["frame_stride"]-1),
                frame_stride=C["frame_stride"],
                debug=args.debug
            )
        seq_buf=[]; last_geom=None; fidx=0
        while True:
            ok, frame = cap.read()
            if not ok: break
            fidx += 1  # 1-based
            if C["frame_stride"]>1 and ((fidx-1) % C["frame_stride"])!=0:
                continue
            t0=time.time(); boxes = props(frame); t_prop=time.time()-t0
            crops = [frame[y1:y2, x1:x2] for (x1,y1,x2,y2) in boxes]
            t1=time.time(); embs  = matcher.encode_np(crops, device); t_emb=time.time()-t1
            cos   = (embs @ tmpl.T).squeeze(1).detach().cpu().numpy() if embs.numel()>0 else np.zeros((0,),dtype=np.float32)
            H,W = frame.shape[:2]; feat_rows=[]
            for b,s in zip(boxes, cos):
                x1,y1,x2,y2=b
                cx=(x1+x2)/2; cy=(y1+y2)/2; w=max(1.0,x2-x1); h=max(1.0,y2-y1)
                if last_geom is None:
                    dx=dy=ds=dh=dw=0.0
                else:
                    lcx,lcy,lw,lh = last_geom
                    dx=(cx-lcx)/max(1.0,W); dy=(cy-lcy)/max(1.0,H)
                    ds=math.log(w/max(1.0,lw)); dh=math.log(h/max(1.0,lh))
                    dw=math.log((w/h)/max(1e-6,lw/lh))
                feat_rows.append([float(s),dx,dy,ds,dh,dw])
            keep = nms_xyxy(boxes, cos, iou_thr=C["nms_final_iou"]) 
            boxes = [boxes[i] for i in keep]; sims=[float(cos[i]) for i in keep]
            feats=[feat_rows[i] for i in keep]
            if feats:
                seq_buf.append(feats[0])
                if len(seq_buf)>C["temporal_T"]: seq_buf.pop(0)
                seq = torch.tensor(seq_buf, dtype=torch.float32, device=device).unsqueeze(0).repeat(len(feats),1,1)
                s_temp = temp_head(seq).squeeze(-1).detach().cpu().numpy()
                s_iou  = iou_head(torch.tensor(feats, dtype=torch.float32, device=device)).squeeze(-1).detach().cpu().numpy()
                fused = (C["alpha_fuse"]*np.array(sims) + (1-C["alpha_fuse"]) * s_temp) * 0.5 + 0.5*s_iou
            else:
                fused=[]
                
            # ---- Tracker debug (every 20th kept frame) ----
            if isinstance(fused, list):
                top_score = -1.0 if len(fused) == 0 else float(np.max(fused))
            else:
                top_score = -1.0 if (fused is None or len(fused) == 0) else float(np.max(fused))
            if args.debug and (fidx % (20 * max(1, C["frame_stride"])) == 0):
                state = getattr(tracker, "state", "NA")
                curr_len = getattr(tracker, "curr_len", 0)
                print(
                    f"[TRK] vid={vid} f={fidx} top_s={top_score:.3f} "
                    f"state={state} tauH={C['tau_high']:.2f} tauL={C['tau_low']:.2f} "
                    f"len={curr_len}"
                )
            # -----------------------------------------------
            tracker.update(fidx, boxes, fused)
            if boxes:
                bi = int(np.argmax(fused)); x1,y1,x2,y2 = boxes[bi]
                last_geom=((x1+x2)/2,(y1+y2)/2,max(1.0,x2-x1),max(1.0,y2-y1))
            if args.debug and (fidx % (5*C["frame_stride"]) == 1):
                print(f"[DBG] {vid} f={fidx} props={len(boxes)} t_prop={t_prop:.3f}s t_emb={t_emb:.3f}s")
        cap.release()
        segs = segmentize(tracker.detections, max_gap=C["frame_stride"]) 
        preds.append({"video_id": vid, "detections": segs})
        out_path = Path(args.out); tmp = out_path.with_suffix(".tmp.json")
        cur = []
        if out_path.exists():
            try:
                with open(out_path,"r",encoding="utf-8") as f: cur=json.load(f)
            except Exception: cur=[]
        cur = [e for e in cur if e.get("video_id") != vid]
        cur.append({"video_id": vid, "detections": segs})
        with open(tmp,"w",encoding="utf-8") as f: json.dump(cur, f, indent=2)
        tmp.replace(out_path)
        print(f"[DONE] {vid}: {sum(len(s['bboxes']) for s in segs)} boxes, {len(segs)} segments")
    with open(args.out,"w",encoding="utf-8") as f: json.dump(preds, f, indent=2)
    print(f"[SAVE] {args.out}")

if __name__ == "__main__":
    main()
