In [None]:
import sys
import gc
import json
import warnings
import itertools
import numpy as np
import pandas as pd
import xgboost as xgb
import catboost as cb
import lightgbm as lgb
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Any, Optional
from scipy.ndimage import gaussian_filter1d

# --- Cấu hình & Đường dẫn ---
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)
warnings.filterwarnings('ignore', category=UserWarning)
np.seterr(invalid="ignore", divide="ignore")

INPUT_DIR = Path("/kaggle/input/MABe-mouse-behavior-detection")
TRAIN_TRACKING_DIR = INPUT_DIR / "train_tracking"
TEST_TRACKING_DIR = INPUT_DIR / "test_tracking"
WORKING_DIR = Path("/kaggle/working")
WORKING_DIR.mkdir(exist_ok=True)

# Định nghĩa các hành vi
SELF_BEHAVIORS = ["biteobject", "climb", "dig", "exploreobject", "freeze", "genitalgroom", "huddle", "rear", "rest", "run", "selfgroom"]
PAIR_BEHAVIORS = ["allogroom", "approach", "attack", "attemptmount", "avoid", "chase", "chaseattack", "defend", "disengage", "dominance", "dominancegroom", "dominancemount", "ejaculate", "escape", "flinch", "follow", "intromit", "mount", "reciprocalsniff", "shepherd", "sniff", "sniffbody", "sniffface", "sniffgenital", "submit", "tussle"]

# --- MAPPING---
# 4 Lab dùng bộ results-ensemble-optuna (0->3)
# Các Lab còn lại dùng results-xgb-fe
LAB_RESULT_DIR_MAP = {
    "AdaptableSnail":   Path("/kaggle/input/results-ensemble-optuna"),   # 0
    "BoisterousParrot": Path("/kaggle/input/results-ensemble-optuna1"),  # 1
    "NiftyGoldfinch":   Path("/kaggle/input/results-ensemble-optuna2"),  # 2
    "TranquilPanther":  Path("/kaggle/input/results-ensemble-optuna3"),  # 3
}
DEFAULT_RESULT_DIR = Path("/kaggle/input/results-xgb-fe")

TARGET_LABS = [
    "AdaptableSnail", "BoisterousParrot", "NiftyGoldfinch", "TranquilPanther",
    "ElegantMink", "GroovyShrew", "JovialSwallow", "PleasantMeerkat", "SparklingTapir"
]

# =============================================================================
# 1. CORE: FEATURE EXTRACTOR
# =============================================================================
# =============================================================================
# 1. CORE: FEATURE EXTRACTOR (FULL VERSION)
# =============================================================================
@dataclass
class AgentContext:
    idx: pd.Index
    pos: np.ndarray        # [F, 2] cm
    vel: np.ndarray        # [F, 2] cm/s
    speed: np.ndarray      # [F, 1] cm/s
    acc: np.ndarray        # [F, 2] cm/s^2
    cx: pd.Series          # Series tọa độ X
    cy: pd.Series          # Series tọa độ Y
    speed_series: pd.Series
    raw_df: Optional[pd.DataFrame] = None 

class FeatureExtractor:
    def __init__(self, fps: float, pix_per_cm: float, smooth_sigma: float = 1.0, use_pairwise: bool = True):
        self.cfg = FeatureConfig(float(fps), float(pix_per_cm), smooth_sigma, use_pairwise)
        
        self.feature_registry = {
            "kinematics": self._feat_basic_kinematics,
            "multiscale": self._feat_multiscale,
            "long_range": self._feat_long_range,
            "cumulative": self._feat_cumulative,
            "curvature": self._feat_curvature,
            "speed_asym": self._feat_speed_asym,
            "gauss_shift": self._feat_gauss_shift,
            "avoid": self._feat_avoidance_trajectory,
            "pose": self._feat_pose_shape,
            "pairwise": self._feat_pairwise,
            "follow": self._feat_follow_pattern,
            "short": self._feat_shortburst_social,   # (cho attack/chase)
            "climb": self._feat_climb,               # (cho ElegantMink/GroovyShrew)
            "ejac": self._feat_ejaculate_temporal, 
            "submit": self._feat_submission_temporal,
            "atk_sniff": self._feat_attack_sniff     # (IoU, jerk)
        }

    # --- Helpers ---
    def _scale(self, n_frames_30fps: int) -> int:
        return max(1, int(round(n_frames_30fps * self.cfg.fps / 30.0)))

    def _to_cm(self, arr):
        return arr / self.cfg.pix_per_cm

    def _smooth(self, x):
        if self.cfg.smooth_sigma is None or x.shape[0] < 3: return x
        if np.all(np.isnan(x)): return x
        return gaussian_filter1d(x, sigma=self.cfg.smooth_sigma, axis=0, mode="nearest")

    def _forward_fill_nan(self, pos):
        if np.all(np.isnan(pos)): return np.zeros_like(pos)
        mask = np.any(~np.isnan(pos), axis=1)
        if not mask.any(): return np.zeros_like(pos)
        pos_ffill = pos.copy()
        df_temp = pd.DataFrame(pos_ffill)
        return df_temp.ffill().bfill().to_numpy() # ffill then bfill for edges

    def _speed_series(self, cx: pd.Series, cy: pd.Series) -> pd.Series:
        dx, dy = cx.diff(), cy.diff()
        return (np.hypot(dx, dy).fillna(0.0) * self.cfg.fps).astype("float32")

    def _roll_future_mean(self, s: pd.Series, w: int, min_p: int = 1) -> pd.Series:
        return s.iloc[::-1].rolling(w, min_periods=min_p).mean().iloc[::-1]

    def _roll_future_var(self, s: pd.Series, w: int, min_p: int = 2) -> pd.Series:
        return s.iloc[::-1].rolling(w, min_periods=min_p).var().iloc[::-1]

    def _extract_part(self, ctx: AgentContext, part: str) -> Optional[np.ndarray]:
        if ctx.raw_df is None: return None
        if part not in ctx.raw_df.columns.get_level_values(0): return None
        try:
            sub_df = ctx.raw_df.xs(part, axis=1, level=0)[["x", "y"]].reindex(ctx.idx)
        except KeyError: return None
        
        raw = sub_df.to_numpy()
        raw = self._forward_fill_nan(raw)
        cm = self._to_cm(raw.astype(np.float32))
        return self._smooth(cm)

    def _extract_parts_dict(self, ctx: AgentContext, parts: List[str]) -> Dict[str, Optional[np.ndarray]]:
        out = {}
        for p in parts:
            out[p] = self._extract_part(ctx, p)
        return out

    # --- Core Logic ---
    def _compute_kinematics(self, pos_px: np.ndarray):
        pos_cm = self._smooth(self._to_cm(self._forward_fill_nan(pos_px).astype(np.float32)))
        dt = 1.0 / self.cfg.fps
        vel = np.zeros_like(pos_cm, dtype=np.float32)
        vel[1:] = (pos_cm[1:] - pos_cm[:-1]) / dt
        speed = np.linalg.norm(vel, axis=1, keepdims=True).astype(np.float32)
        acc = np.zeros_like(pos_cm, dtype=np.float32)
        acc[1:] = (vel[1:] - vel[:-1]) / dt
        return pos_cm.astype(np.float32), vel, speed, acc

    def _build_context(self, frames, pos_px, mouse_df=None) -> AgentContext:
        p, v, s, a = self._compute_kinematics(pos_px)
        idx = pd.Index(frames, name="frame")
        return AgentContext(idx, p, v, s, a, 
                            pd.Series(p[:,0], index=idx), pd.Series(p[:,1], index=idx), 
                            pd.Series(s[:,0], index=idx), mouse_df)

    # =========================================================================
    # FULL FEATURES IMPLEMENTATION
    # =========================================================================

    def _feat_basic_kinematics(self, ctx: AgentContext, **kwargs) -> Dict:
        return {
            "a_x": ctx.pos[:, 0], "a_y": ctx.pos[:, 1],
            "a_vx": ctx.vel[:, 0], "a_vy": ctx.vel[:, 1],
            "a_speed": ctx.speed[:, 0],
            "a_ax": ctx.acc[:, 0], "a_ay": ctx.acc[:, 1]
        }

    def _feat_multiscale(self, ctx: AgentContext, **kwargs) -> Dict:
        feats = {}
        for scale in [10, 40, 160]:
            ws = self._scale(scale)
            if len(ctx.speed_series) >= ws:
                r = ctx.speed_series.rolling(ws, min_periods=max(1, ws//4), center=True)
                feats[f"sp_m{scale}"] = r.mean().astype("float32")
                feats[f"sp_s{scale}"] = r.std().astype("float32")
        if "sp_m10" in feats and "sp_m160" in feats:
            feats["sp_ratio"] = feats["sp_m10"] / (feats["sp_m160"] + 1e-6)
        return feats

    def _feat_long_range(self, ctx: AgentContext, **kwargs) -> Dict:
        feats = {}
        for w in [120, 240]:
            ws = self._scale(w)
            feats[f"x_ml{w}"] = ctx.cx.rolling(ws, min_periods=max(5, ws//6), center=True).mean()
            feats[f"y_ml{w}"] = ctx.cy.rolling(ws, min_periods=max(5, ws//6), center=True).mean()
        for w in [60, 120]:
            ws = self._scale(w)
            feats[f"sp_pct{w}"] = ctx.speed_series.rolling(ws, min_periods=max(5, ws//6), center=True).rank(pct=True)
        return feats

    def _feat_cumulative(self, ctx: AgentContext, **kwargs) -> Dict:
        L = max(1, self._scale(180))
        step = np.hypot(ctx.cx.diff(), ctx.cy.diff()).fillna(0.0)
        return {"path_cum180": step.rolling(2*L+1, min_periods=max(5, L//6), center=True).sum().fillna(0.0).astype("float32")}

    def _feat_curvature(self, ctx: AgentContext, **kwargs) -> Dict:
        vx, vy = ctx.vel[:, 0], ctx.vel[:, 1]
        ax, ay = ctx.acc[:, 0], ctx.acc[:, 1]
        v_mag = np.maximum(np.hypot(vx, vy), 0.1/self.cfg.fps)
        curv = (vx*ay - vy*ax) / (v_mag**3)
        curv = np.clip(curv * (v_mag > 2.0), -2.0, 2.0)
        s_curv = pd.Series(np.abs(curv), index=ctx.idx)
        feats = {}
        for w in [30, 60]:
            feats[f"curv_mean_{w}"] = s_curv.rolling(self._scale(w), min_periods=1).mean()
        
        # Turn rate
        angle = np.arctan2(vy, vx)
        angle_change = np.abs(pd.Series(angle, index=ctx.idx).diff().fillna(0.0))
        angle_change = np.where(angle_change > np.pi, 2*np.pi - angle_change, angle_change)
        ws = self._scale(30)
        feats["turn_rate_30"] = pd.Series(angle_change * (v_mag > 0.5), index=ctx.idx).rolling(ws, min_periods=1).sum()
        return feats

    def _feat_speed_asym(self, ctx: AgentContext, **kwargs) -> Dict:
        w = max(3, self._scale(30))
        v = ctx.speed_series
        return {"spd_asym_1s": (self._roll_future_mean(v, w) - v.rolling(w, min_periods=1).mean()).fillna(0.0)}

    def _feat_gauss_shift(self, ctx: AgentContext, **kwargs) -> Dict:
        w = max(5, self._scale(30))
        v = ctx.speed_series
        mu_p, va_p = v.rolling(w, min_periods=1).mean(), v.rolling(w, min_periods=1).var().clip(1e-6)
        mu_f, va_f = self._roll_future_mean(v, w), self._roll_future_var(v, w).clip(1e-6)
        kl = 0.5 * ((va_p/va_f) + ((mu_f-mu_p)**2)/va_f - 1 + np.log(va_f/va_p) + (va_f/va_p) + ((mu_p-mu_f)**2)/va_p - 1 + np.log(va_p/va_f))
        return {"spd_symkl_1s": kl.replace([np.inf, -np.inf], np.nan).fillna(0.0)}

    def _feat_avoidance_trajectory(self, ctx: AgentContext, target_ctx: AgentContext=None, **kwargs) -> Dict:
        if target_ctx is None: return {}
        rel_vec = target_ctx.pos - ctx.pos
        ang_target = np.arctan2(rel_vec[:,1], rel_vec[:,0])
        ang_self = np.arctan2(ctx.vel[:,1], ctx.vel[:,0])
        diff = np.abs(ang_target - ang_self)
        diff = np.minimum(diff, 2*np.pi - diff)
        
        feats = {
            "heading_rel_cos": pd.Series(np.cos(diff), index=ctx.idx, dtype="float32"),
            "heading_rel_abs": pd.Series(diff, index=ctx.idx, dtype="float32")
        }
        s_dist = pd.Series(np.linalg.norm(rel_vec, axis=1), index=ctx.idx)
        for w in [15, 30]:
            ws = self._scale(w)
            feats[f"dist_gain_{w}f"] = (s_dist.shift(-ws) - s_dist).fillna(0.0).astype("float32")
        return feats

    def _feat_pose_shape(self, ctx: AgentContext, **kwargs) -> Dict:
        feats = {}
        def zero(): return pd.Series(0.0, index=ctx.idx, dtype="float32")
        
        target_parts = ["nose", "neck", "body_center", "tail_base", "ear_left", "ear_right", "lateral_left", "lateral_right", "hip_left", "hip_right", "head"]
        parts = self._extract_parts_dict(ctx, target_parts)

        def dist(k1, k2):
            p1, p2 = parts.get(k1), parts.get(k2)
            if p1 is None or p2 is None: return zero()
            return pd.Series(np.linalg.norm(p1 - p2, axis=1), index=ctx.idx, dtype="float32")

        def vel(k, n_frames):
            p_pos = parts.get(k)
            if p_pos is None: return zero()
            s_x, s_y = pd.Series(p_pos[:,0], index=ctx.idx), pd.Series(p_pos[:,1], index=ctx.idx)
            raw = self._speed_series(s_x, s_y)
            return raw.rolling(self._scale(n_frames), min_periods=1, center=True).mean().astype("float32")

        # Distances
        feats["aa_nose_tailbase_dist"] = dist("nose", "tail_base")
        feats["aa_nose_bodycenter_dist"] = dist("nose", "body_center")
        feats["a_body_width"] = dist("lateral_left", "lateral_right")
        if feats["a_body_width"].sum() == 0: # Fallback if laterals missing
            feats["a_body_width"] = dist("hip_left", "hip_right")

        # Angle
        if parts.get("nose") is not None and parts.get("tail_base") is not None and parts.get("body_center") is not None:
            v1 = parts["nose"] - parts["body_center"]
            v2 = parts["tail_base"] - parts["body_center"]
            dot = np.sum(v1*v2, axis=1)
            mag = np.linalg.norm(v1, axis=1) * np.linalg.norm(v2, axis=1)
            feats["a_body_angle"] = np.clip(dot/(mag+1e-6), -1.0, 1.0).astype("float32")
        else:
            feats["a_body_angle"] = zero()

        # Part Velocities
        for p in ["nose", "tail_base", "ear_right", "head"]:
            if parts.get(p) is not None:
                for t in [15, 30, 60, 90]: # 500ms, 1s, 2s, 3s
                    feats[f"a_{p}_vel_{int(t*1000/30)}ms"] = vel(p, t)
        
        return feats

    def _feat_pairwise(self, ctx: AgentContext, target_ctx: AgentContext=None, **kwargs) -> Dict:
        if target_ctx is None: return {}
        feats = {}
        idx = ctx.idx
        def zero(): return pd.Series(0.0, index=idx, dtype="float32")

        rel_vec = target_ctx.pos - ctx.pos
        dist = np.linalg.norm(rel_vec, axis=1)
        feats["rel_dist"] = pd.Series(dist, index=idx, dtype="float32")

        # Detailed distances
        my_parts = self._extract_parts_dict(ctx, ["nose", "head", "tail_base"])
        t_parts = self._extract_parts_dict(target_ctx, ["nose", "head", "tail_base", "body_center", "ear_left", "ear_right", "lateral_left", "lateral_right", "neck", "hip_left", "hip_right"])

        def dist_ab(pa, pb):
            if pa is None or pb is None: return zero()
            return pd.Series(np.linalg.norm(pa - pb, axis=1), index=idx, dtype="float32")

        mn = my_parts.get("nose") if my_parts.get("nose") is not None else my_parts.get("head")
        
        feats["dist_nose_nose"] = dist_ab(mn, t_parts.get("nose"))
        feats["dist_nose_tail"] = dist_ab(mn, t_parts.get("tail_base"))
        feats["dist_nose_body"] = dist_ab(mn, t_parts.get("body_center"))
        feats["dist_nose_el"] = dist_ab(mn, t_parts.get("ear_left"))
        feats["dist_nose_er"] = dist_ab(mn, t_parts.get("ear_right"))
        feats["dist_nose_neck"] = dist_ab(mn, t_parts.get("neck"))
        feats["dist_nose_hip_l"] = dist_ab(mn, t_parts.get("hip_left"))
        feats["dist_nose_hip_r"] = dist_ab(mn, t_parts.get("hip_right"))

        # Angles
        u_vec = rel_vec / (np.where(dist==0, 1e-6, dist)[:,None])
        a_along = np.sum(ctx.vel * u_vec, axis=1)
        t_along = np.sum(target_ctx.vel * (-u_vec), axis=1)
        
        feats["approach_speed_agent"] = pd.Series(a_along, index=idx, dtype="float32")
        feats["approach_speed_target"] = pd.Series(t_along, index=idx, dtype="float32")
        feats["approach_speed_rel"] = pd.Series(np.sum((ctx.vel - target_ctx.vel)*u_vec, axis=1), index=idx, dtype="float32")
        feats["lateral_speed_agent"] = pd.Series(np.linalg.norm(ctx.vel - (a_along[:,None]*u_vec), axis=1), index=idx, dtype="float32")

        # Body cosine
        def get_vec(parts_dict):
            h = parts_dict.get("nose") or parts_dict.get("head")
            t = parts_dict.get("tail_base") or parts_dict.get("body_center")
            if h is not None and t is not None: return h - t
            return None
        
        va, vt = get_vec(my_parts), get_vec(t_parts)
        if va is not None and vt is not None:
            dot = np.sum(va * vt, axis=1)
            mag = np.linalg.norm(va, axis=1) * np.linalg.norm(vt, axis=1)
            feats["body_cosine"] = pd.Series(np.clip(dot/(mag+1e-6), -1, 1), index=idx, dtype="float32")
        else:
            feats["body_cosine"] = zero()
            
        if va is not None:
            dot_g = np.sum(va * rel_vec, axis=1)
            mag_g = np.linalg.norm(va, axis=1) * dist
            feats["gaze_cosine"] = pd.Series(np.clip(dot_g/(mag_g+1e-6), -1, 1), index=idx, dtype="float32")
        else:
            feats["gaze_cosine"] = zero()

        return feats

    def _feat_shortburst_social(self, ctx: AgentContext, target_ctx: AgentContext=None, **kwargs) -> Dict:
        if target_ctx is None: return {}
        feats = {}
        idx = ctx.idx
        
        rel_vec = target_ctx.pos - ctx.pos
        dist = np.linalg.norm(rel_vec, axis=1)
        u_vec = rel_vec / (np.where(dist==0, 1e-6, dist)[:,None])
        a_along = np.sum(ctx.vel * u_vec, axis=1)
        rel_along = np.sum((ctx.vel - target_ctx.vel) * u_vec, axis=1)
        
        # heading cos
        parts = self._extract_parts_dict(ctx, ["nose", "tail_base", "body_center"])
        head = parts.get("nose")
        tail = parts.get("tail_base") or parts.get("body_center")
        heading_cos = np.zeros(len(idx), dtype="float32")
        if head is not None and tail is not None:
            body = head - tail
            dot = np.sum(body * rel_vec, axis=1)
            mag = np.linalg.norm(body, axis=1) * dist
            heading_cos = np.clip(dot/(mag+1e-6), -1, 1)
        
        s_along = pd.Series(a_along, index=idx)
        s_rel = pd.Series(rel_along, index=idx)
        s_dist = pd.Series(dist, index=idx)
        s_head = pd.Series(heading_cos, index=idx)
        
        for w30 in [10, 20, 30]:
            ws = self._scale(w30)
            min_p = max(1, ws//3)
            # Attack
            feats[f"sb_att_approach_mean_{w30}"] = s_along.rolling(ws, min_periods=min_p).mean()
            feats[f"sb_att_rel_along_mean_{w30}"] = s_rel.rolling(ws, min_periods=min_p).mean()
            feats[f"sb_att_dist_delta_{w30}"] = (s_dist - s_dist.shift(ws)).fillna(0.0)
            # Chase
            feats[f"sb_chase_speed_agent_mean_{w30}"] = ctx.speed_series.rolling(ws, min_periods=min_p).mean()
            feats[f"sb_chase_speed_target_mean_{w30}"] = target_ctx.speed_series.rolling(ws, min_periods=min_p).mean()
            feats[f"sb_chase_dist_mean_{w30}"] = s_dist.rolling(ws, min_periods=min_p).mean()
            # Escape
            feats[f"sb_esc_heading_cos_mean_{w30}"] = s_head.rolling(ws, min_periods=min_p).mean()
            feats[f"sb_esc_dist_gain_{w30}"] = (s_dist.shift(-ws) - s_dist).fillna(0.0)
            
        return {k: v.astype("float32").fillna(0.0) for k,v in feats.items()}

    def _feat_follow_pattern(self, ctx: AgentContext, target_ctx: AgentContext=None, **kwargs) -> Dict:
        if target_ctx is None: return {}
        feats = {}
        idx = ctx.idx
        dist = pd.Series(np.linalg.norm(target_ctx.pos - ctx.pos, axis=1), index=idx)
        
        # Cosines
        parts_a = self._extract_parts_dict(ctx, ["nose", "tail_base", "body_center"])
        parts_t = self._extract_parts_dict(target_ctx, ["nose", "tail_base", "body_center"])
        
        def bvec(p):
            h = p.get("nose"); t = p.get("tail_base") or p.get("body_center")
            return (h-t) if (h is not None and t is not None) else None
        
        va, vt = bvec(parts_a), bvec(parts_t)
        cos_body = np.zeros(len(idx), dtype="float32")
        if va is not None and vt is not None:
            cos_body = np.clip(np.sum(va*vt, axis=1)/(np.linalg.norm(va, axis=1)*np.linalg.norm(vt, axis=1)+1e-6), -1, 1)
            
        dot_v = np.sum(ctx.vel * target_ctx.vel, axis=1)
        mag_v = ctx.speed[:,0] * target_ctx.speed[:,0]
        cos_vel = np.zeros_like(dot_v)
        mask = mag_v > 1e-3
        cos_vel[mask] = np.clip(dot_v[mask]/mag_v[mask], -1, 1)
        
        s_cb = pd.Series(cos_body, index=idx)
        s_cv = pd.Series(cos_vel, index=idx)
        
        for w30 in [15, 30, 60]:
            ws = self._scale(w30)
            min_p = max(1, ws//3)
            feats[f"follow_dist_mean_{w30}"] = dist.rolling(ws, min_periods=min_p).mean()
            feats[f"follow_dist_std_{w30}"] = dist.rolling(ws, min_periods=min_p).std()
            feats[f"follow_cos_body_mean_{w30}"] = s_cb.rolling(ws, min_periods=min_p).mean()
            feats[f"follow_cos_vel_mean_{w30}"] = s_cv.rolling(ws, min_periods=min_p).mean()
            feats[f"follow_speed_agent_mean_{w30}"] = ctx.speed_series.rolling(ws, min_periods=min_p).mean()
            feats[f"follow_speed_target_mean_{w30}"] = target_ctx.speed_series.rolling(ws, min_periods=min_p).mean()
            
        return {k: v.fillna(0.0).astype("float32") for k,v in feats.items()}

    def _feat_climb(self, ctx: AgentContext, **kwargs) -> Dict:
        # Hỗ trợ ElegantMink (Rect) và GroovyShrew
        # Hardcode arena size based on labs usually seen in this comp
        W, H = 33.0, 19.0 # Standard box size often used
        
        feats = {}
        idx = ctx.idx
        parts = self._extract_parts_dict(ctx, ["nose", "head", "body_center"])
        head = parts.get("nose") or parts.get("head") or ctx.pos
        
        cx, cy = head[:,0], head[:,1]
        
        d_l, d_r = cx, W - cx
        d_b, d_t = cy, H - cy
        d_all = np.stack([d_l, d_r, d_b, d_t], axis=1)
        dist_wall = np.min(d_all, axis=1)
        wall_idx = np.argmin(d_all, axis=1) # 0:L, 1:R, 2:B, 3:T
        
        s_dw = pd.Series(dist_wall, index=idx, dtype="float32")
        feats["climb_dist_wall"] = s_dw
        
        # Normal vector
        nx, ny = np.zeros_like(cx), np.zeros_like(cy)
        nx[wall_idx==0] = 1; nx[wall_idx==1] = -1
        ny[wall_idx==2] = 1; ny[wall_idx==3] = -1
        
        vx, vy = ctx.vel[:,0], ctx.vel[:,1]
        v_norm = vx*nx + vy*ny
        v_tan = np.sqrt((vx - v_norm*nx)**2 + (vy - v_norm*ny)**2)
        
        feats["climb_normal_vel"] = pd.Series(v_norm, index=idx, dtype="float32")
        feats["climb_tangent_vel"] = pd.Series(v_tan, index=idx, dtype="float32")
        
        # Stick score
        ws = self._scale(15)
        appr = -s_dw.diff().fillna(0.0).rolling(ws, min_periods=1).mean()
        feats["climb_approach_speed_wall"] = appr
        
        near = (dist_wall < 3.0).astype(float)
        stick = near * (1.0 / (1.0 + np.abs(v_norm))) * (v_tan > 0.5)
        feats["climb_wall_stick_score"] = pd.Series(stick, index=idx, dtype="float32")
        
        return feats

    def _feat_ejaculate_temporal(self, ctx: AgentContext, target_ctx: AgentContext=None, **kwargs) -> Dict:
        if target_ctx is None: return {}
        idx = ctx.idx
        parts_a = self._extract_parts_dict(ctx, ["nose", "body_center", "tail_base"])
        parts_t = self._extract_parts_dict(target_ctx, ["body_center", "tail_base"])
        
        def dist(p1, p2):
            if p1 is None or p2 is None: return pd.Series(0.0, index=idx)
            return pd.Series(np.linalg.norm(p1-p2, axis=1), index=idx)
        
        abc = parts_a.get("body_center") or parts_a.get("tail_base") or ctx.pos
        tbc = parts_t.get("body_center") or parts_t.get("tail_base") or target_ctx.pos
        t_tail = parts_t.get("tail_base") or tbc
        
        d_body = dist(abc, tbc)
        d_gen = dist(abc, t_tail)
        d_nose_gen = dist(parts_a.get("nose"), t_tail)
        
        feats = {
            "ejac_dist_body": d_body,
            "ejac_dist_gen_body": d_gen,
            "ejac_dist_gen_nose": d_nose_gen
        }
        
        # Prox
        prox = np.exp(-d_body.to_numpy()/5.0) * (1.0/(1.0+d_gen.to_numpy()))
        
        # Memory
        v = ctx.speed_series
        close = (d_body < 5.0).astype(float)
        mem = (v*close).rolling(self._scale(90), min_periods=1).max().fillna(0.0)
        feats["ejac_activity_memory_3s"] = mem
        
        is_still = (v < 1.5).astype(float)
        feats["ejac_is_still"] = pd.Series(is_still, index=idx)
        feats["ejac_static_score"] = pd.Series(is_still * prox * mem, index=idx)
        
        return {k: v.astype("float32") for k,v in feats.items()}

    def _feat_submission_temporal(self, ctx: AgentContext, target_ctx: AgentContext=None, **kwargs) -> Dict:
        if target_ctx is None: return {}
        idx = ctx.idx
        rel = target_ctx.pos - ctx.pos
        dist = np.linalg.norm(rel, axis=1)
        dist_s = pd.Series(dist, index=idx).replace(0, 1e-6)
        
        dot = np.sum(target_ctx.vel * (-rel), axis=1)
        threat = (dot / dist_s).clip(lower=0) * (dist_s < 15.0).astype(float)
        
        mem = threat.rolling(self._scale(90), min_periods=1).max().fillna(0.0)
        
        is_still = (ctx.speed_series < 1.0).astype(float)
        
        parts = self._extract_parts_dict(ctx, ["nose", "tail_base"])
        compact = pd.Series(0.0, index=idx)
        if parts.get("nose") is not None and parts.get("tail_base") is not None:
            slen = np.linalg.norm(parts["nose"] - parts["tail_base"], axis=1)
            compact = pd.Series((slen < 8.0).astype(float), index=idx)
            
        return {
            "fear_memory_3s": mem.astype("float32"),
            "static_submit_prob": (is_still * compact * mem).astype("float32")
        }

    def _feat_attack_sniff(self, ctx: AgentContext, target_ctx: AgentContext=None, **kwargs) -> Dict:
        if target_ctx is None: return {}
        idx = ctx.idx
        feats = {}
        
        # Violence
        ws = self._scale(15)
        mp = max(1, ws//3)
        va = ctx.speed_series.rolling(ws, min_periods=mp).std().fillna(0.0)
        vt = target_ctx.speed_series.rolling(ws, min_periods=mp).std().fillna(0.0)
        feats["as_speed_std_sum_05"] = (va + vt).astype("float32")
        
        # Jerk (Turn)
        ang = np.arctan2(ctx.vel[:,1], ctx.vel[:,0])
        diff = np.abs(np.diff(ang, prepend=ang[0]))
        diff = np.where(diff > np.pi, 2*np.pi - diff, diff)
        feats["as_a_turn_jerk_05"] = pd.Series(diff, index=idx).rolling(ws, min_periods=mp).sum().fillna(0.0).astype("float32")
        
        # BBox IoU (Quan trọng cho attack)
        def get_bbox(c, p_list):
            arrs = []
            parts = self._extract_parts_dict(c, p_list)
            for k in p_list:
                if parts.get(k) is not None: arrs.append(parts[k])
            if not arrs: return None
            st = np.stack(arrs, axis=1)
            return np.stack([np.nanmin(st[...,0],1), np.nanmin(st[...,1],1), np.nanmax(st[...,0],1), np.nanmax(st[...,1],1)], axis=1)
            
        plist = ["nose", "tail_base", "neck", "hip_left", "hip_right", "head"]
        b1 = get_bbox(ctx, plist)
        b2 = get_bbox(target_ctx, plist)
        
        iou = np.zeros(len(idx), dtype="float32")
        if b1 is not None and b2 is not None:
            xA = np.maximum(b1[:,0], b2[:,0])
            yA = np.maximum(b1[:,1], b2[:,1])
            xB = np.minimum(b1[:,2], b2[:,2])
            yB = np.minimum(b1[:,3], b2[:,3])
            inter = np.maximum(0, xB-xA) * np.maximum(0, yB-yA)
            a1 = (b1[:,2]-b1[:,0])*(b1[:,3]-b1[:,1])
            a2 = (b2[:,2]-b2[:,0])*(b2[:,3]-b2[:,1])
            iou = inter / (a1 + a2 - inter + 1e-6)
            
        feats["as_body_iou"] = pd.Series(iou, index=idx, dtype="float32")
        return feats

    def _feat_attack_defend(self, ctx: AgentContext, target_ctx: AgentContext=None, **kwargs) -> Dict:
        # Alias or subset of attack_sniff logic often used in Robustify
        return self._feat_attack_sniff(ctx, target_ctx)

    def build_pose_tensor(self, tracking: pd.DataFrame):
        tracking = tracking.sort_values("video_frame")
        frames = np.sort(tracking["video_frame"].unique())
        
        pvid = tracking.pivot(index="video_frame", columns=["mouse_id", "bodypart"], values=["x", "y"])
        pvid = pvid.reorder_levels([1, 2, 0], axis=1).sort_index(axis=1).astype("float32")
        
        mouse_ids = list(pvid.columns.get_level_values(0).unique())
        pos = np.full((len(frames), len(mouse_ids), 2), np.nan, dtype=np.float32)
        per_mouse_df = {}
        
        for i, mid in enumerate(mouse_ids):
            single = pvid[mid]
            per_mouse_df[mid] = single
            
            # Ưu tiên body_center, nếu không có thì lấy trung bình
            if "body_center" in single.columns.get_level_values(0):
                cx = single["body_center"]["x"]
                cy = single["body_center"]["y"]
            else:
                cx = single.xs("x", level=1, axis=1).mean(axis=1)
                cy = single.xs("y", level=1, axis=1).mean(axis=1)
                
            pos[:, i, 0] = cx.reindex(frames).values
            pos[:, i, 1] = cy.reindex(frames).values
            
        return frames, mouse_ids, pos, per_mouse_df

    def extract_agent_target(self, frames, mouse_ids, pos, agent_id, target_id, per_mouse_df=None):
        try:
            aid_idx = mouse_ids.index(agent_id)
        except ValueError:
            return pd.DataFrame()

        ctx_agent = self._build_context(frames, pos[:, aid_idx, :], per_mouse_df.get(agent_id) if per_mouse_df else None)
        
        ctx_target = None
        if self.cfg.use_pairwise and target_id in mouse_ids:
            tid_idx = mouse_ids.index(target_id)
            ctx_target = self._build_context(frames, pos[:, tid_idx, :], per_mouse_df.get(target_id) if per_mouse_df else None)

        all_data = {}
        for name, func in self.feature_registry.items():
            # Chạy tất cả feature, bỏ qua lỗi (do thiếu part) để code chạy bền vững
            try:
                out = func(ctx_agent, target_ctx=ctx_target)
                all_data.update(out)
            except Exception:
                pass

        df_out = pd.DataFrame(all_data, index=ctx_agent.idx)
        # Final cleanup: fill inf/nan và sort columns
        return df_out.replace([np.inf, -np.inf], np.nan).fillna(0.0).reindex(sorted(df_out.columns), axis=1)

# =============================================================================
# 2. INFERENCE LOOP
# =============================================================================

def predict_lab_video(lab_id, video_id, test_meta, result_dir):
    # 1. Load Data
    try:
        tr_path = TEST_TRACKING_DIR / str(lab_id) / f"{video_id}.parquet"
        if not tr_path.exists(): return []
        tracking = pd.read_parquet(tr_path)
    except: return []

    # 2. Setup Feature Extractor
    row = test_meta[test_meta["video_id"] == video_id].iloc[0]
    fps = float(row["frames_per_second"])
    pix = float(row["pix_per_cm_approx"]) if row["pix_per_cm_approx"] > 0 else 1.0
    
    fe = FeatureExtractor(fps=fps, pix_per_cm=pix)
    frames, mouse_ids, pos, per_mouse_df = fe.build_pose_tensor(tracking)
    
    # 3. Xác định behaviors cần predict
    lab_res_path = result_dir / lab_id
    if not lab_res_path.exists(): return []
    
    available_bhvs = [p.name for p in lab_res_path.iterdir() if p.is_dir()]
    self_b = [b for b in available_bhvs if b in SELF_BEHAVIORS]
    pair_b = [b for b in available_bhvs if b in PAIR_BEHAVIORS]
    
    preds = []
    
    # Helper: Load Models (Cache nếu cần, ở đây load thẳng cho đơn giản)
    def run_inference(feat_df, behavior):
        bhv_dir = lab_res_path / behavior
        # Load weights/threshold
        ws, th = {}, 0.5
        if (bhv_dir / "ensemble_params.json").exists():
            with open(bhv_dir / "ensemble_params.json") as f: 
                d = json.load(f)
                ws, th = d.get("weights", {}), d.get("threshold", 0.5)
        
        # Load models
        models = []
        for fd in sorted(bhv_dir.glob("fold_*")):
            m_set = {}
            # XGB
            if (fd/"model_xgb.json").exists():
                b = xgb.Booster()
                b.load_model(str(fd/"model_xgb.json"))
                m_set["xgb"] = b
            elif (fd/"model.json").exists(): # Fallback naming
                b = xgb.Booster()
                b.load_model(str(fd/"model.json"))
                m_set["xgb"] = b
            
            # Cat/LGB (nếu có)
            if (fd/"model_cat.cbm").exists():
                c = cb.CatBoostClassifier()
                c.load_model(str(fd/"model_cat.cbm"))
                m_set["cat"] = c
            if (fd/"model_lgb.txt").exists():
                l = lgb.Booster(model_file=str(fd/"model_lgb.txt"))
                m_set["lgb"] = l
            
            if m_set: models.append(m_set)
            
        if not models: return None
        
        # Prepare Data
        # Quan trọng: Chỉ lấy các cột model cần
        req_cols = []
        if "xgb" in models[0]: req_cols = models[0]["xgb"].feature_names
        
        if not req_cols: req_cols = feat_df.columns.tolist()
        
        # Tạo X đúng thứ tự cột
        X = pd.DataFrame(0.0, index=feat_df.index, columns=req_cols, dtype=np.float32)
        common = list(set(req_cols) & set(feat_df.columns))
        X[common] = feat_df[common]
        
        dtest = xgb.DMatrix(X, feature_names=req_cols)
        
        # Predict & Ensemble
        final_prob = np.zeros(len(X), dtype=np.float32)
        for m in models:
            p = 0
            if "xgb" in m: p += m["xgb"].predict(dtest) * ws.get("xgb", 1.0)
            if "cat" in m: p += m["cat"].predict_proba(X)[:,1] * ws.get("cat", 0.33)
            if "lgb" in m: p += m["lgb"].predict(X) * ws.get("lgb", 0.33)
            
            # Normalize weight sum if needed, here assuming simple avg or dominated by xgb
            final_prob += p
            
        final_prob /= len(models)
        return (final_prob >= th).astype("int8")

    def fid(mid): return str(mid) if str(mid).startswith("mouse") else f"mouse{mid}"

    # --- SELF LOOP ---
    for m in mouse_ids:
        fdf = fe.extract_features(frames, mouse_ids, pos, per_mouse_df, m, m)
        if fdf.empty: continue
        for bhv in self_b:
            mask = run_inference(fdf, bhv)
            if mask is None: continue
            
            # RLE to segments
            # (Simple 1-pass)
            on = False
            start = 0
            for i, v in enumerate(mask):
                if v and not on:
                    on = True; start = frames[i]
                elif not v and on:
                    on = False
                    preds.append([video_id, fid(m), "self", bhv, int(start), int(frames[i])])
            if on: preds.append([video_id, fid(m), "self", bhv, int(start), int(frames[-1])+1])

    # --- PAIR LOOP ---
    if len(mouse_ids) > 1:
        for a, t in itertools.permutations(mouse_ids, 2):
            fdf = fe.extract_features(frames, mouse_ids, pos, per_mouse_df, a, t)
            if fdf.empty: continue
            for bhv in pair_b:
                mask = run_inference(fdf, bhv)
                if mask is None: continue
                
                on = False
                start = 0
                for i, v in enumerate(mask):
                    if v and not on:
                        on = True; start = frames[i]
                    elif not v and on:
                        on = False
                        preds.append([video_id, fid(a), fid(t), bhv, int(start), int(frames[i])])
                if on: preds.append([video_id, fid(a), fid(t), bhv, int(start), int(frames[-1])+1])
                
    return preds

# =============================================================================
# 3. MAIN RUN
# =============================================================================
if __name__ == "__main__":
    print("=== STARTING INFERENCE (CORRECTED MAPPING) ===")
    
    # Cleanup
    for p in WORKING_DIR.iterdir():
        if p.is_file() and p.suffix != ".csv": p.unlink()
        elif p.is_dir(): import shutil; shutil.rmtree(p)
        
    test_meta = pd.read_csv(INPUT_DIR / "test.csv")
    all_rows = []
    
    for lab in TARGET_LABS:
        print(f"\n>> Processing Lab: {lab}")
        # Lấy dir theo map, nếu không có trong map thì dùng default
        res_dir = LAB_RESULT_DIR_MAP.get(lab, DEFAULT_RESULT_DIR)
        print(f"   Using models from: {res_dir}")
        
        lab_meta = test_meta[test_meta["lab_id"] == lab]
        if lab_meta.empty: continue
        
        for vid in sorted(lab_meta["video_id"].unique()):
            try:
                p = predict_lab_video(lab, vid, test_meta, res_dir)
                all_rows.extend(p)
            except Exception as e:
                print(f"   Err {vid}: {e}")
            gc.collect()
            
    # Save
    cols = ["video_id", "agent_id", "target_id", "action", "start_frame", "stop_frame"]
    if all_rows:
        sub = pd.DataFrame(all_rows, columns=cols)
        sub = sub.sort_values(["video_id", "agent_id", "target_id", "action", "start_frame"])
        sub.insert(0, "row_id", np.arange(len(sub)))
    else:
        sub = pd.DataFrame(columns=["row_id"] + cols)
        
    sub.to_csv(WORKING_DIR / "submission.csv", index=False)
    print(f"\nDone! Saved {len(sub)} rows to submission.csv")