In [1]:
import datetime
import gc
import itertools
import json
import re
import sys
import time
import traceback
from collections import defaultdict
from pathlib import Path

import joblib
import lightgbm as lgb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import xgboost as xgb
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedGroupKFold
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
#from metric import score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# const
INPUT_DIR = Path("D:/Kaggle_Competition/data")
TRAIN_TRACKING_DIR = INPUT_DIR / "train_tracking"
TRAIN_ANNOTATION_DIR = INPUT_DIR / "train_annotation"
TEST_TRACKING_DIR = INPUT_DIR / "test_tracking"

WORKING_DIR = Path("D:/Kaggle_Competition/notebooks")

INDEX_COLS = [
    "video_id",
    "agent_mouse_id",
    "target_mouse_id",
    "video_frame",
]

BODY_PARTS = [
    "ear_left",
    "ear_right",
    "nose",
    "neck",
    "body_center",
    "lateral_left",
    "lateral_right",
    "hip_left",
    "hip_right",
    "tail_base",
    "tail_tip",
]

SELF_BEHAVIORS = [
    "biteobject",
    "climb",
    "dig",
    "exploreobject",
    "freeze",
    "genitalgroom",
    "huddle",
    "rear",
    "rest",
    "run",
    "selfgroom",
]

PAIR_BEHAVIORS = [
    "allogroom",
    "approach",
    "attack",
    "attemptmount",
    "avoid",
    "chase",
    "chaseattack",
    "defend",
    "disengage",
    "dominance",
    "dominancegroom",
    "dominancemount",
    "ejaculate",
    "escape",
    "flinch",
    "follow",
    "intromit",
    "mount",
    "reciprocalsniff",
    "shepherd",
    "sniff",
    "sniffbody",
    "sniffface",
    "sniffgenital",
    "submit",
    "tussle",
]

In [3]:
# read data
train_dataframe = pl.read_csv(INPUT_DIR / "train.csv")

In [4]:
# preprocess behavior labels
train_behavior_dataframe = (
    train_dataframe.filter(pl.col("behaviors_labeled").is_not_null())
    .select(
        pl.col("lab_id"),
        pl.col("video_id"),
        pl.col("behaviors_labeled").map_elements(eval, return_dtype=pl.List(pl.Utf8)).alias("behaviors_labeled_list"),
    )
    .explode("behaviors_labeled_list")
    .rename({"behaviors_labeled_list": "behaviors_labeled_element"})
    .select(
        pl.col("lab_id"),
        pl.col("video_id"),
        pl.col("behaviors_labeled_element").str.split(",").list[0].str.replace_all("'", "").alias("agent"),
        pl.col("behaviors_labeled_element").str.split(",").list[1].str.replace_all("'", "").alias("target"),
        pl.col("behaviors_labeled_element").str.split(",").list[2].str.replace_all("'", "").alias("behavior"),
    )
)

train_self_behavior_dataframe = train_behavior_dataframe.filter(pl.col("behavior").is_in(SELF_BEHAVIORS))
train_pair_behavior_dataframe = train_behavior_dataframe.filter(pl.col("behavior").is_in(PAIR_BEHAVIORS))

In [5]:
%%writefile self_features.py

def make_self_features(
    metadata: dict,
    tracking: pl.DataFrame,
) -> pl.DataFrame:
    fps = float(metadata.get("frames_per_second", 30.0))
    if fps <= 0: fps = 30.0
    
    pix_per_cm = float(metadata.get("pix_per_cm_approx", 1.0))
    if pix_per_cm <= 0: pix_per_cm = 1.0
    
    W_MICRO = max(1, int(round(0.1 * fps))) 
    W_SHORT = max(1, int(round(0.5 * fps))) 
    W_LONG = max(1, int(round(1.0 * fps))) 
    W_MACRO = max(1, int(round(3.0 * fps))) 

    # --- 2. HELPER FUNCTIONS ---
    def body_parts_distance(body_part_1, body_part_2):
        assert body_part_1 in BODY_PARTS
        assert body_part_2 in BODY_PARTS
        return (
            (pl.col(f"agent_x_{body_part_1}") - pl.col(f"agent_x_{body_part_2}")).pow(2)
            + (pl.col(f"agent_y_{body_part_1}") - pl.col(f"agent_y_{body_part_2}")).pow(2)
        ).sqrt() / pix_per_cm

    # Hàm tính Features Vật Lý
    def calculate_velocity_features(body_part):
        dx = pl.col(f"agent_x_{body_part}").diff().fill_null(0)
        dy = pl.col(f"agent_y_{body_part}").diff().fill_null(0)
        inst_speed = (dx.pow(2) + dy.pow(2)).sqrt() / pix_per_cm * fps
        
        features = []
        
        features.append(inst_speed.rolling_mean(W_MICRO, center=True, min_samples=1).alias(f"agent__{body_part}__speed_micro"))
        features.append(inst_speed.rolling_mean(W_SHORT, center=True, min_samples=1).alias(f"agent__{body_part}__speed_short"))
        features.append(inst_speed.rolling_mean(W_LONG, center=True, min_samples=1).alias(f"agent__{body_part}__speed_long"))
        features.append(inst_speed.rolling_mean(W_MACRO, center=True, min_samples=1).alias(f"agent__{body_part}__speed_macro"))

        speed_short = inst_speed.rolling_mean(W_SHORT, center=True, min_samples=1)
        accel = speed_short.diff().fill_null(0) * fps
        features.append(accel.alias(f"agent__{body_part}__accel"))
        
        jitter = inst_speed.rolling_std(W_MICRO, center=True, min_samples=1).fill_null(0)
        features.append(jitter.alias(f"agent__{body_part}__jitter"))
        
        return features

    def elongation():
        d1 = body_parts_distance("nose", "tail_base")
        d2 = body_parts_distance("ear_left", "ear_right")
        return d1 / (d2 + 1e-06)

    def body_angle():
        v1x = pl.col("agent_x_nose") - pl.col("agent_x_body_center")
        v1y = pl.col("agent_y_nose") - pl.col("agent_y_body_center")
        v2x = pl.col("agent_x_tail_base") - pl.col("agent_x_body_center")
        v2y = pl.col("agent_y_tail_base") - pl.col("agent_y_body_center")
        return (v1x * v2x + v1y * v2y) / ((v1x.pow(2) + v1y.pow(2)).sqrt() * (v2x.pow(2) + v2y.pow(2)).sqrt() + 1e-06)

    def longitudinal_lateral_velocity():
        """Calculate longitudinal and lateral velocity components.
        
        Returns:
            v_long: velocity along body axis (forward/backward)
            v_lat: velocity perpendicular to axis (sideways)
        """
        # 1. Define body axis vector
        axis_x = pl.col("agent_x_nose") - pl.col("agent_x_tail_base")
        axis_y = pl.col("agent_y_nose") - pl.col("agent_y_tail_base")
        axis_len = (axis_x.pow(2) + axis_y.pow(2)).sqrt() + 1e-6
        
        # 2. Unit vector
        u_x = axis_x / axis_len
        u_y = axis_y / axis_len
        
        # 3. Velocity vector
        v_x = pl.col("agent_x_tail_base").diff().fill_null(0)
        v_y = pl.col("agent_y_tail_base").diff().fill_null(0)
        
        # 4. Projection
        v_long = (v_x * u_x + v_y * u_y) / pix_per_cm * fps
        v_lat = (v_x * u_y - v_y * u_x) / pix_per_cm * fps
        
        return v_long, v_lat

    def groom_features():
        """Calculate grooming-specific features.
        
        Returns:
            head_body_decouple, nose_rad_std, head_orient_jitter (if tail_base available)
        """
        # 1. Head-body decoupling
        nose_dx = pl.col("agent_x_nose").diff().fill_null(0)
        nose_dy = pl.col("agent_y_nose").diff().fill_null(0)
        nose_speed = (nose_dx.pow(2) + nose_dy.pow(2)).sqrt() / pix_per_cm * fps
        
        body_dx = pl.col("agent_x_body_center").diff().fill_null(0)
        body_dy = pl.col("agent_y_body_center").diff().fill_null(0)
        body_speed = (body_dx.pow(2) + body_dy.pow(2)).sqrt() / pix_per_cm * fps
        
        head_body_ratio = (nose_speed / (body_speed + 1e-3)).clip(0, 10)
        decouple = head_body_ratio.rolling_median(W_SHORT, center=True, min_samples=1)
        
        # 2. Nose radius variation
        nose_to_body_dist = (
            (pl.col("agent_x_nose") - pl.col("agent_x_body_center")).pow(2) +
            (pl.col("agent_y_nose") - pl.col("agent_y_body_center")).pow(2)
        ).sqrt() / pix_per_cm
        
        nose_rad_std = nose_to_body_dist.rolling_std(W_SHORT, center=True, min_samples=1).fill_null(0)
        
        return decouple, nose_rad_std

    def head_orientation_jitter():
        """Calculate head rotation jitter (requires tail_base)."""
        head_angle = pl.arctan2(
            pl.col("agent_y_nose") - pl.col("agent_y_tail_base"),
            pl.col("agent_x_nose") - pl.col("agent_x_tail_base")
        )
        angle_change_rate = head_angle.diff().abs().fill_null(0)
        return angle_change_rate.rolling_mean(W_SHORT, center=True, min_samples=1)

    def wall_distance():
        """Calculate distance to nearest wall."""
        arena_width_cm = float(metadata.get("arena_width_cm", 60.0))
        arena_height_cm = float(metadata.get("arena_height_cm", 60.0))
        
        x_cm = pl.col("agent_x_nose").forward_fill().backward_fill() / pix_per_cm
        y_cm = pl.col("agent_y_nose").forward_fill().backward_fill() / pix_per_cm
        
        return pl.min_horizontal([
            x_cm,
            arena_width_cm - x_cm,
            y_cm,
            arena_height_cm - y_cm
        ])

    # --- 3. DATA PREPARATION ---
    n_mice = (
        (metadata["mouse1_strain"] is not None)
        + (metadata["mouse2_strain"] is not None)
        + (metadata["mouse3_strain"] is not None)
        + (metadata["mouse4_strain"] is not None)
    )
    start_frame = tracking.select(pl.col("video_frame").min()).item()
    end_frame = tracking.select(pl.col("video_frame").max()).item()

    result = []

    pivot = tracking.pivot(
        on=["bodypart"],
        index=["video_frame", "mouse_id"],
        values=["x", "y"],
    ).sort(["mouse_id", "video_frame"])
    
    pivot_trackings = {mouse_id: pivot.filter(pl.col("mouse_id") == mouse_id) for mouse_id in range(1, n_mice + 1)}

    # --- 4. MAIN LOOP ---
    for agent_mouse_id in range(1, n_mice + 1):
        result_element = pl.DataFrame(
            {
                "video_id": metadata["video_id"],
                "agent_mouse_id": agent_mouse_id,
                "target_mouse_id": -1,
                "video_frame": pl.arange(start_frame, end_frame + 1, eager=True),
            },
            schema={
                "video_id": pl.Int32,
                "agent_mouse_id": pl.Int8,
                "target_mouse_id": pl.Int8,
                "video_frame": pl.Int32,
            },
        )

        pivot_agent = pivot_trackings[agent_mouse_id].select(
            pl.col("video_frame"),
            pl.exclude("video_frame").name.prefix("agent_"),
        )
        
        columns = pivot_agent.columns
        pivot_agent = pivot_agent.with_columns(
            *[pl.lit(None).cast(pl.Float32).alias(f"agent_x_{bp}") for bp in BODY_PARTS if f"agent_x_{bp}" not in columns],
            *[pl.lit(None).cast(pl.Float32).alias(f"agent_y_{bp}") for bp in BODY_PARTS if f"agent_y_{bp}" not in columns],
        )
        
        # SIGNAL SMOOTHING
        smooth_exprs = []
        for col_name in pivot_agent.columns:
            if col_name.startswith("agent_x_") or col_name.startswith("agent_y_"):
                 smooth_exprs.append(pl.col(col_name).rolling_mean(W_MICRO, center=True, min_samples=1).alias(col_name))
        
        pivot_agent = pivot_agent.with_columns(smooth_exprs)

        # --- 6. FEATURE CALCULATION ---
        feature_exprs = [
            pl.lit(agent_mouse_id).alias("agent_mouse_id"),
            pl.lit(-1).alias("target_mouse_id"),
        ]
        
        # A. Distances
        feature_exprs.extend([
            body_parts_distance(bp1, bp2).alias(f"aa__{bp1}__{bp2}__distance")
            for bp1, bp2 in itertools.combinations(BODY_PARTS, 2)
        ])
        
        # B. Velocity, Acceleration, Jitter
        for bp in ["ear_left", "ear_right", "tail_base"]:
            feature_exprs.extend(calculate_velocity_features(bp))
        
        # Longitudinal/Lateral velocity components
        v_long, v_lat = longitudinal_lateral_velocity()
        feature_exprs.append(v_long.rolling_mean(W_SHORT, center=True, min_samples=1).alias("agent__velocity_long"))
        feature_exprs.append(v_lat.rolling_mean(W_SHORT, center=True, min_samples=1).alias("agent__velocity_lat"))

        # C. Geometry
        feature_exprs.append(elongation().alias("agent__elongation"))
        feature_exprs.append(body_angle().alias("agent__body_angle"))

        # D. Angular Velocity
        v_x_center = pl.col("agent_x_body_center").diff().fill_null(0)
        v_y_center = pl.col("agent_y_body_center").diff().fill_null(0)
        heading_angle = pl.arctan2(v_y_center, v_x_center)
        
        angular_vel = heading_angle.diff().fill_null(0) * fps
        feature_exprs.append(
            angular_vel.abs().rolling_mean(W_SHORT, center=True, min_samples=1)
            .alias("agent__angular_velocity")
        )

        # E. Groom Micro-Features
        head_body_decouple, nose_rad_std = groom_features()
        feature_exprs.append(head_body_decouple.alias("agent__head_body_decouple"))
        feature_exprs.append(nose_rad_std.alias("agent__nose_rad_std"))
        
        feature_exprs.append(head_orientation_jitter().alias("agent__head_orient_jitter"))

        # F. Wall Distance
        feature_exprs.append(wall_distance().alias("agent__wall_distance"))

        # Execute
        features = pivot_agent.select(
            pl.col("video_frame"),
            *feature_exprs
        )

        result_element = result_element.join(
            features,
            on=["video_frame", "agent_mouse_id", "target_mouse_id"],
            how="left",
        )
        result.append(result_element)

    return pl.concat(result, how="vertical")



Overwriting self_features.py


In [6]:
%%writefile pair_features.py

def make_pair_features(
    metadata: dict,
    tracking: pl.DataFrame,
) -> pl.DataFrame:
    fps = float(metadata.get("frames_per_second", 30.0))
    if fps <= 0: fps = 30.0
    
    pix_per_cm = float(metadata.get("pix_per_cm_approx", 1.0))
    if pix_per_cm <= 0: pix_per_cm = 1.0
    
    W_MICRO = max(1, int(round(0.1 * fps))) 
    W_SHORT = max(1, int(round(0.5 * fps)))
    W_LONG = max(1, int(round(1.0 * fps)))
    W_MACRO = max(1, int(round(3.0 * fps)))

    # --- 2. HELPER FUNCTIONS ---
    def body_parts_distance(agent_or_target_1, body_part_1, agent_or_target_2, body_part_2):
        assert agent_or_target_1 in ["agent", "target"]
        assert agent_or_target_2 in ["agent", "target"]
        assert body_part_1 in BODY_PARTS
        assert body_part_2 in BODY_PARTS
        return (
            (pl.col(f"{agent_or_target_1}_x_{body_part_1}") - pl.col(f"{agent_or_target_2}_x_{body_part_2}")).pow(2)
            + (pl.col(f"{agent_or_target_1}_y_{body_part_1}") - pl.col(f"{agent_or_target_2}_y_{body_part_2}")).pow(2)
        ).sqrt() / pix_per_cm

    def calculate_velocity_features(agent_or_target, body_part):
        assert agent_or_target in ["agent", "target"]
        assert body_part in BODY_PARTS
        dx = pl.col(f"{agent_or_target}_x_{body_part}").diff().fill_null(0)
        dy = pl.col(f"{agent_or_target}_y_{body_part}").diff().fill_null(0)
        inst_speed = (dx.pow(2) + dy.pow(2)).sqrt() / pix_per_cm * fps
        
        features = []
        features.append(inst_speed.rolling_mean(W_MICRO, center=True, min_samples=1).alias(f"{agent_or_target}__{body_part}__speed_micro"))
        features.append(inst_speed.rolling_mean(W_SHORT, center=True, min_samples=1).alias(f"{agent_or_target}__{body_part}__speed_short"))
        features.append(inst_speed.rolling_mean(W_LONG, center=True, min_samples=1).alias(f"{agent_or_target}__{body_part}__speed_long"))
        features.append(inst_speed.rolling_mean(W_MACRO, center=True, min_samples=1).alias(f"{agent_or_target}__{body_part}__speed_macro"))
        
        speed_short = inst_speed.rolling_mean(W_SHORT, center=True, min_samples=1)
        accel = speed_short.diff().fill_null(0) * fps
        features.append(accel.alias(f"{agent_or_target}__{body_part}__accel"))
        
        jitter = inst_speed.rolling_std(W_MICRO, center=True, min_samples=1).fill_null(0)
        features.append(jitter.alias(f"{agent_or_target}__{body_part}__jitter"))
        return features

    def elongation(agent_or_target):
        assert agent_or_target in ["agent", "target"]
        d1 = body_parts_distance(agent_or_target, "nose", agent_or_target, "tail_base")
        d2 = body_parts_distance(agent_or_target, "ear_left", agent_or_target, "ear_right")
        return d1 / (d2 + 1e-06)

    def body_angle(agent_or_target):
        assert agent_or_target in ["agent", "target"]
        v1x = pl.col(f"{agent_or_target}_x_nose") - pl.col(f"{agent_or_target}_x_body_center")
        v1y = pl.col(f"{agent_or_target}_y_nose") - pl.col(f"{agent_or_target}_y_body_center")
        v2x = pl.col(f"{agent_or_target}_x_tail_base") - pl.col(f"{agent_or_target}_x_body_center")
        v2y = pl.col(f"{agent_or_target}_y_tail_base") - pl.col(f"{agent_or_target}_y_body_center")
        return (v1x * v2x + v1y * v2y) / ((v1x.pow(2) + v1y.pow(2)).sqrt() * (v2x.pow(2) + v2y.pow(2)).sqrt() + 1e-06)

    def velocity_components(agent_or_target):
        """Calculate longitudinal and lateral velocity components.
        
        Returns:
            v_long: velocity along nose-tail axis (forward/backward)
            v_lat: velocity perpendicular to axis (sideways)
        """
        assert agent_or_target in ["agent", "target"]
        
        # 1. Body axis vector (nose to tail)
        axis_x = pl.col(f"{agent_or_target}_x_nose") - pl.col(f"{agent_or_target}_x_tail_base")
        axis_y = pl.col(f"{agent_or_target}_y_nose") - pl.col(f"{agent_or_target}_y_tail_base")
        axis_len = (axis_x.pow(2) + axis_y.pow(2)).sqrt() + 1e-6
        
        # 2. Unit vector along axis
        u_x = axis_x / axis_len
        u_y = axis_y / axis_len
        
        # 3. Velocity vector (tail_base displacement)
        v_x = pl.col(f"{agent_or_target}_x_tail_base").diff().fill_null(0)
        v_y = pl.col(f"{agent_or_target}_y_tail_base").diff().fill_null(0)
        
        # 4. Project velocity onto axis (longitudinal) and perpendicular (lateral)
        v_long = (v_x * u_x + v_y * u_y) / pix_per_cm * fps
        v_lat = (v_x * u_y - v_y * u_x) / pix_per_cm * fps
        
        return v_long, v_lat

    def facing_score():
        """Calculate how much agent is facing toward target."""
        v_head_x = pl.col("agent_x_nose") - pl.col("agent_x_neck")
        v_head_y = pl.col("agent_y_nose") - pl.col("agent_y_neck")
        v_to_target_x = pl.col("target_x_body_center") - pl.col("agent_x_neck")
        v_to_target_y = pl.col("target_y_body_center") - pl.col("agent_y_neck")
        
        dot_face = v_head_x * v_to_target_x + v_head_y * v_to_target_y
        norm_head = (v_head_x.pow(2) + v_head_y.pow(2)).sqrt()
        norm_target = (v_to_target_x.pow(2) + v_to_target_y.pow(2)).sqrt()
        
        return dot_face / (norm_head * norm_target + 1e-6)

    def spine_alignment():
        """Calculate spine alignment between agent and target."""
        v_spine_a_x = pl.col("agent_x_neck") - pl.col("agent_x_tail_base")
        v_spine_a_y = pl.col("agent_y_neck") - pl.col("agent_y_tail_base")
        v_spine_t_x = pl.col("target_x_neck") - pl.col("target_x_tail_base")
        v_spine_t_y = pl.col("target_y_neck") - pl.col("target_y_tail_base")
        
        dot_spine = v_spine_a_x * v_spine_t_x + v_spine_a_y * v_spine_t_y
        norm_spine_a = (v_spine_a_x.pow(2) + v_spine_a_y.pow(2)).sqrt()
        norm_spine_t = (v_spine_t_x.pow(2) + v_spine_t_y.pow(2)).sqrt()
        
        return dot_spine / (norm_spine_a * norm_spine_t + 1e-6)

    def speed_ratio():
        """Calculate agent vs target speed ratio."""
        dx_a = pl.col("agent_x_body_center").diff().fill_null(0)
        dy_a = pl.col("agent_y_body_center").diff().fill_null(0)
        speed_a = ((dx_a.pow(2) + dy_a.pow(2)).sqrt() / pix_per_cm * fps).rolling_mean(W_SHORT, center=True, min_samples=1)

        dx_t = pl.col("target_x_body_center").diff().fill_null(0)
        dy_t = pl.col("target_y_body_center").diff().fill_null(0)
        speed_t = ((dx_t.pow(2) + dy_t.pow(2)).sqrt() / pix_per_cm * fps).rolling_mean(W_SHORT, center=True, min_samples=1)

        return (speed_a - speed_t) / (speed_a + speed_t + 1e-6)

    def pursuit_alignment():
        """Calculate how well agent's velocity aligns with direction to target."""
        head_to_target_x = pl.col("target_x_body_center") - pl.col("agent_x_body_center")
        head_to_target_y = pl.col("target_y_body_center") - pl.col("agent_y_body_center")
        
        agent_vel_x = pl.col("agent_x_body_center").diff().fill_null(0)
        agent_vel_y = pl.col("agent_y_body_center").diff().fill_null(0)
        
        pursuit_dot = agent_vel_x * head_to_target_x + agent_vel_y * head_to_target_y
        pursuit_norm_vel = (agent_vel_x.pow(2) + agent_vel_y.pow(2)).sqrt()
        pursuit_norm_target = (head_to_target_x.pow(2) + head_to_target_y.pow(2)).sqrt()
        
        alignment = pursuit_dot / (pursuit_norm_vel * pursuit_norm_target + 1e-6)
        return alignment.rolling_mean(W_SHORT, center=True, min_samples=1)

    def distance_change_features():
        """Calculate distance change rate and variance."""
        distance = body_parts_distance("agent", "body_center", "target", "body_center")
        distance_change_rate = distance.diff().fill_null(0) * fps
        
        rate_mean = distance_change_rate.rolling_mean(W_SHORT, center=True, min_samples=1)
        rate_var = distance_change_rate.rolling_std(W_MICRO, center=True, min_samples=1)
        
        return rate_mean, rate_var

    def height_differential():
        """Calculate vertical height difference between agent and target."""
        height_diff = (pl.col("agent_y_body_center") - pl.col("target_y_body_center")) / pix_per_cm
        
        return (
            height_diff,
            height_diff.abs(),
            (height_diff > 2.0).cast(pl.Float32)
        )

    # --- 3. DATA PREPARATION ---
    n_mice = (
        (metadata["mouse1_strain"] is not None)
        + (metadata["mouse2_strain"] is not None)
        + (metadata["mouse3_strain"] is not None)
        + (metadata["mouse4_strain"] is not None)
    )
    start_frame = tracking.select(pl.col("video_frame").min()).item()
    end_frame = tracking.select(pl.col("video_frame").max()).item()

    result = []

    pivot = tracking.pivot(
        on=["bodypart"],
        index=["video_frame", "mouse_id"],
        values=["x", "y"],
    ).sort(["mouse_id", "video_frame"])
    
    pivot_trackings = {mouse_id: pivot.filter(pl.col("mouse_id") == mouse_id) for mouse_id in range(1, n_mice + 1)}

    # --- 4. MAIN LOOP ---
    for agent_mouse_id, target_mouse_id in itertools.permutations(range(1, n_mice + 1), 2):
        result_element = pl.DataFrame(
            {
                "video_id": metadata["video_id"],
                "agent_mouse_id": agent_mouse_id,
                "target_mouse_id": target_mouse_id,
                "video_frame": pl.arange(start_frame, end_frame + 1, eager=True),
            },
            schema={
                "video_id": pl.Int32,
                "agent_mouse_id": pl.Int8,
                "target_mouse_id": pl.Int8,
                "video_frame": pl.Int32,
            },
        )

        merged_pivot = (
            pivot_trackings[agent_mouse_id]
            .select(
                pl.col("video_frame"),
                pl.exclude("video_frame").name.prefix("agent_"),
            )
            .join(
                pivot_trackings[target_mouse_id].select(
                    pl.col("video_frame"),
                    pl.exclude("video_frame").name.prefix("target_"),
                ),
                on="video_frame",
                how="inner",
            )
        )
        
        columns = merged_pivot.columns
        merged_pivot = merged_pivot.with_columns(
            *[pl.lit(None).cast(pl.Float32).alias(f"agent_x_{bp}") for bp in BODY_PARTS if f"agent_x_{bp}" not in columns],
            *[pl.lit(None).cast(pl.Float32).alias(f"agent_y_{bp}") for bp in BODY_PARTS if f"agent_y_{bp}" not in columns],
            *[pl.lit(None).cast(pl.Float32).alias(f"target_x_{bp}") for bp in BODY_PARTS if f"target_x_{bp}" not in columns],
            *[pl.lit(None).cast(pl.Float32).alias(f"target_y_{bp}") for bp in BODY_PARTS if f"target_y_{bp}" not in columns],
        )

        smooth_exprs = []
        for col_name in merged_pivot.columns:
            if "_x_" in col_name or "_y_" in col_name:
                 smooth_exprs.append(pl.col(col_name).rolling_mean(W_MICRO, center=True, min_samples=1).alias(col_name))
        
        merged_pivot = merged_pivot.with_columns(smooth_exprs)

        # --- 6. FEATURE CALCULATION ---
        feature_exprs = [
            pl.lit(agent_mouse_id).alias("agent_mouse_id"),
            pl.lit(target_mouse_id).alias("target_mouse_id"),
        ]

        # A. Pair Distances
        feature_exprs.extend([
            body_parts_distance("agent", agent_body_part, "target", target_body_part).alias(
                f"at__{agent_body_part}__{target_body_part}__distance"
            )
            for agent_body_part, target_body_part in itertools.product(BODY_PARTS, repeat=2)
        ])

        # B. Velocity
        for bp in ["ear_left", "ear_right", "tail_base"]:
            feature_exprs.extend(calculate_velocity_features("agent", bp))
            feature_exprs.extend(calculate_velocity_features("target", bp))
        
        # Longitudinal/Lateral velocity components
        agent_v_long, agent_v_lat = velocity_components("agent")
        feature_exprs.append(
            agent_v_long.rolling_mean(W_SHORT, center=True, min_samples=1)
            .alias("agent__velocity_long")
        )
        feature_exprs.append(
            agent_v_lat.rolling_mean(W_SHORT, center=True, min_samples=1)
            .alias("agent__velocity_lat")
        )
        
        target_v_long, target_v_lat = velocity_components("target")
        feature_exprs.append(
            target_v_long.rolling_mean(W_SHORT, center=True, min_samples=1)
            .alias("target__velocity_long")
        )
        feature_exprs.append(
            target_v_lat.rolling_mean(W_SHORT, center=True, min_samples=1)
            .alias("target__velocity_lat")
        )
            
        # C. Geometry
        feature_exprs.append(elongation("agent").alias("agent__elongation"))
        feature_exprs.append(elongation("target").alias("target__elongation"))
        feature_exprs.append(body_angle("agent").alias("agent__body_angle"))
        feature_exprs.append(body_angle("target").alias("target__body_angle"))
        
        # D. Interaction Geometry
        feature_exprs.append(facing_score().alias("interaction__facing_score"))
        feature_exprs.append(spine_alignment().alias("interaction__spine_alignment"))
        feature_exprs.append(speed_ratio().alias("interaction__speed_ratio"))
        feature_exprs.append(pursuit_alignment().alias("interaction__pursuit_alignment"))
        
        # Distance change features
        dist_change_mean, dist_change_var = distance_change_features()
        feature_exprs.extend([
            dist_change_mean.alias("interaction__distance_change_rate"),
            dist_change_var.alias("interaction__distance_change_variance")
        ])
        
        # Height differential features
        height_diff, height_diff_abs, agent_above = height_differential()
        feature_exprs.extend([
            height_diff.alias("interaction__height_diff"),
            height_diff_abs.alias("interaction__height_diff_abs"),
            agent_above.alias("interaction__agent_above")
        ])

        # Execute
        features = merged_pivot.select(
            pl.col("video_frame"),
            *feature_exprs
        )

        result_element = result_element.join(
            features,
            on=["video_frame", "agent_mouse_id", "target_mouse_id"],
            how="left",
        )
        result.append(result_element)

    return pl.concat(result, how="vertical")


Overwriting pair_features.py


In [7]:
%run -i self_features.py
%run -i pair_features.py

def process_video(row):
    """Process a single video to extract self and pair features."""
    lab_id = row["lab_id"]
    video_id = row["video_id"]
    tracking_path = TRAIN_TRACKING_DIR / f"{lab_id}/{video_id}.parquet"
    if not tracking_path.exists(): return None # Safety check
    try:
        tracking = pl.read_parquet(tracking_path)
        # 1. Feature Generation
        self_features = make_self_features(metadata=row, tracking=tracking)
        pair_features = make_pair_features(metadata=row, tracking=tracking)
        # 2. OPTIMIZATION: Downcast Float64 -> Float32
        float_cols_self = [c for c, t in zip(self_features.columns, self_features.dtypes) if t == pl.Float64]
        if float_cols_self:
            self_features = self_features.with_columns([pl.col(c).cast(pl.Float32) for c in float_cols_self])
            
        float_cols_pair = [c for c, t in zip(pair_features.columns, pair_features.dtypes) if t == pl.Float64]
        if float_cols_pair:
            pair_features = pair_features.with_columns([pl.col(c).cast(pl.Float32) for c in float_cols_pair])
        # 3. Write to Disk
        self_features.write_parquet(WORKING_DIR / "self_features" / f"{video_id}.parquet")
        pair_features.write_parquet(WORKING_DIR / "pair_features" / f"{video_id}.parquet")
        
        # 4. Explicit Memory Release
        del tracking, self_features, pair_features
        # (Garbage collector sẽ lo phần còn lại)
    except Exception as e:
        print(f"Error {video_id}: {e}")
        return None
    return video_id


(WORKING_DIR / "self_features").mkdir(exist_ok=True)
(WORKING_DIR / "pair_features").mkdir(exist_ok=True)

rows = list(train_dataframe.filter(pl.col("behaviors_labeled").is_not_null()).rows(named=True))
results = joblib.Parallel(n_jobs=6, verbose=5)(joblib.delayed(process_video)(row) for row in rows)

print(f"Processed {len(results)} videos successfully")

del rows, results
gc.collect()

[Parallel(n_jobs=6)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=6)]: Done   6 tasks      | elapsed:    4.0s
[Parallel(n_jobs=6)]: Done  60 tasks      | elapsed:   42.0s
[Parallel(n_jobs=6)]: Done 150 tasks      | elapsed:   49.7s
[Parallel(n_jobs=6)]: Done 276 tasks      | elapsed:   59.5s
[Parallel(n_jobs=6)]: Done 438 tasks      | elapsed:  1.1min
[Parallel(n_jobs=6)]: Done 636 tasks      | elapsed:  1.4min
[Parallel(n_jobs=6)]: Done 848 out of 848 | elapsed:  1.8min finished


Processed 848 videos successfully


42

In [5]:
def tune_threshold(oof_action, y_action):
    """Find optimal threshold for F1 score."""
    thresholds = np.arange(0, 1.005, 0.005)
    scores = [f1_score(y_action, (oof_action >= th), zero_division=0) for th in thresholds]
    best_idx = np.argmax(scores)
    return thresholds[best_idx]

def get_adaptive_params(lab_id, behavior):
    """Get params based on current F1"""
    f1_file = WORKING_DIR / "results" / lab_id / behavior / "f1.txt"
    current_f1 = 0.0
    if f1_file.exists():
        with open(f1_file) as f:
            current_f1 = float(f.read().strip())
    
    if current_f1 < 0.25:
        max_depth = 6
        config_type = "WEAK"
    else:
        max_depth = 6
        config_type = "STRONG"
    
    return max_depth, config_type, current_f1

def train_validate(lab_id: str, behavior: str, indices: pl.DataFrame, features: pl.DataFrame, labels: pl.Series):
    """Train and validate with memory cleanup (EXACT same logic as reference)."""
    result_dir = WORKING_DIR / "results" / lab_id / behavior
    result_dir.mkdir(exist_ok=True, parents=True)
    
    # Convert to pandas
    features = features.to_pandas().astype(np.float32)
    labels = labels.to_numpy()

    MAX_SAMPLES = 300000
    if len(labels) > MAX_SAMPLES:
        print(f"⚠️  Sampling {MAX_SAMPLES:,} from {len(labels):,}")
        idx = np.random.choice(len(labels), MAX_SAMPLES, replace=False)
        idx.sort()  # Keep temporal order
        indices = indices[idx]
        features = features.iloc[idx]
        labels = labels[idx]
        gc.collect()
    
    # Check for no positive samples
    if labels.sum() == 0:
        with open(result_dir / "f1.txt", "w") as f:
            f.write("0.0\n")
        oof_prediction_dataframe = indices.with_columns(
            pl.Series("fold", [-1] * len(labels), dtype=pl.Int8),
            pl.Series("prediction", [0.0] * len(labels), dtype=pl.Float32),
            pl.Series("predicted_label", [0] * len(labels), dtype=pl.Int8),
        )
        oof_prediction_dataframe.write_parquet(result_dir / "oof_predictions.parquet")
        return 0.0
    
    # Initialize OOF arrays
    folds = np.ones(len(labels), dtype=np.int8) * -1
    oof_predictions = np.zeros(len(labels), dtype=np.float32)
    oof_prediction_labels = np.zeros(len(labels), dtype=np.int8)

    max_depth, config_type, current_f1 = get_adaptive_params(lab_id, behavior)
    print(f"  Config: {config_type} (F1={current_f1:.3f}, max_depth={max_depth})")
    
    # Cross-validation
    for fold, (train_idx, valid_idx) in enumerate(
        StratifiedGroupKFold(n_splits=3, shuffle=True, random_state=42).split(
            X=features, y=labels, groups=indices.get_column("video_id")
        )
    ):
        result_dir_fold = result_dir / f"fold_{fold}"
        result_dir_fold.mkdir(exist_ok=True, parents=True)
        
        # Split data
        X_train = features.iloc[train_idx]
        y_train = labels[train_idx]
        X_valid = features.iloc[valid_idx]
        y_valid = labels[valid_idx]
        
        # Compute scale_pos_weight
        scale_pos_weight = (len(y_train) - y_train.sum()) / y_train.sum()
        
        # XGBoost params with adaptive max_depth
        params = {
            "objective": "binary:logistic",
            "eval_metric": "logloss",
            "device": "cpu",
            "tree_method": "hist",
            "learning_rate": 0.05,
            "max_depth": max_depth, 
            "min_child_weight": 5,
            "subsample": 0.8,
            "colsample_bytree": 0.8,
            "scale_pos_weight": scale_pos_weight,
            "max_bin": 64,
            "seed": 42,
        }
        
        dtrain = xgb.DMatrix(X_train, label=y_train, feature_names=features.columns.tolist())
        dvalid = xgb.DMatrix(X_valid, label=y_valid, feature_names=features.columns.tolist())
        
        # Cleanup after DMatrix creation
        del X_train, y_train
        gc.collect()
        
        # Training
        evals_result = {}
        early_stopping_callback = xgb.callback.EarlyStopping(
            rounds=10, metric_name="logloss", data_name="valid", maximize=False, save_best=True
        )
        
        model = xgb.train(
            params, dtrain=dtrain, num_boost_round=200,
            evals=[(dtrain, "train"), (dvalid, "valid")],
            callbacks=[early_stopping_callback],
            evals_result=evals_result, verbose_eval=0
        )
        
        # Cleanup dtrain after training
        del dtrain
        gc.collect()
        
        # Predict
        fold_predictions = model.predict(dvalid)
        threshold = tune_threshold(fold_predictions, y_valid)
        
        # Save OOF
        folds[valid_idx] = fold
        oof_predictions[valid_idx] = fold_predictions
        oof_prediction_labels[valid_idx] = (fold_predictions >= threshold).astype(np.int8)
        
        # Save model
        model.save_model(result_dir_fold / "model.json")
        with open(result_dir_fold / "threshold.txt", "w") as f:
            f.write(f"{threshold}\n")
        
        # Plot importance
        xgb.plot_importance(model, max_num_features=20, importance_type="gain", values_format="{v:.2f}")
        plt.tight_layout()
        plt.savefig(result_dir_fold / "feature_importance.png")
        plt.close()
        
        # Plot metric
        lgb.plot_metric(evals_result, metric="logloss")
        plt.tight_layout()
        plt.savefig(result_dir_fold / "metric.png")
        plt.close()
        
        # Cleanup everything from this fold
        del X_valid, y_valid, dvalid, model, fold_predictions, evals_result
        gc.collect()
    
    # Calculate F1
    oof_prediction_dataframe = indices.with_columns(
        pl.Series("fold", folds, dtype=pl.Int8),
        pl.Series("prediction", oof_predictions, dtype=pl.Float32),
        pl.Series("predicted_label", oof_prediction_labels, dtype=pl.Int8),
    )
    
    f1 = f1_score(labels, oof_prediction_labels, zero_division=0)
    with open(result_dir / "f1.txt", "w") as f:
        f.write(f"{f1}\n")
    
    oof_prediction_dataframe.write_parquet(result_dir / "oof_predictions.parquet")
    
    # Final cleanup
    del features, labels, folds, oof_predictions, oof_prediction_labels, oof_prediction_dataframe
    gc.collect()
    
    return f1


In [9]:
groups = train_self_behavior_dataframe.group_by("lab_id", "behavior", maintain_order=True)
total_groups = len(list(groups))
start_time = time.perf_counter()

for idx, ((lab_id, behavior), group) in tqdm(enumerate(groups), total=total_groups):
    if idx == 0:
        tqdm.write(
            f"|{'LAB':^25}|{'BEHAVIOR':^15}|{'SAMPLES':^10}|{'POSITIVE':^10}|{'FEATURES':^10}|{'F1':^10}|{'ELAPSED TIME':^15}|",
            end="\n",
        )

    tqdm.write(f"|{lab_id:^25}|{behavior:^15}|", end="")
    index_list = []
    feature_list = []
    label_list = []

    for row in group.rows(named=True):
        video_id = row["video_id"]
        agent = row["agent"]

        agent_mouse_id = int(re.search(r"mouse(\d+)", agent).group(1))

        data = pl.scan_parquet(WORKING_DIR / "self_features" / f"{video_id}.parquet").filter(
            (pl.col("agent_mouse_id") == agent_mouse_id)
        )
        index = data.select(INDEX_COLS).collect(engine="streaming")
        feature = data.select(pl.exclude(INDEX_COLS)).collect(engine="streaming")

        # read annotation
        annotation_path = TRAIN_ANNOTATION_DIR / lab_id / f"{video_id}.parquet"
        if annotation_path.exists():
            annotation = (
                pl.scan_parquet(annotation_path)
                .filter((pl.col("action") == behavior) & (pl.col("agent_id") == agent_mouse_id))
                .collect()
            )
        else:
            annotation = pl.DataFrame(
                schema={
                    "agent_id": pl.Int8,
                    "target_id": pl.Int8,
                    "action": str,
                    "start_frame": pl.Int16,
                    "stop_frame": pl.Int16,
                }
            )

        label_frames = set()
        for annotation_row in annotation.rows(named=True):
            label_frames.update(range(annotation_row["start_frame"], annotation_row["stop_frame"]))
        label = index.select(pl.col("video_frame").is_in(label_frames).cast(pl.Int8).alias("label"))

        if label.get_column("label").sum() == 0:
            continue

        index_list.append(index)
        feature_list.append(feature)
        label_list.append(label.get_column("label"))

    if not index_list:
        elapsed_time = datetime.timedelta(seconds=int(time.perf_counter() - start_time))
        tqdm.write(f"{0:>10,}|{0:>10,}|{0:>10,}|{'-':>10}|{str(elapsed_time):>15}|", end="\n")
        continue

    indices = pl.concat(index_list, how="vertical")
    features = pl.concat(feature_list, how="vertical")
    labels = pl.concat(label_list, how="vertical")

    del index_list, feature_list, label_list
    gc.collect()

    tqdm.write(f"{len(indices):>10,}|{labels.sum():>10,}|{len(features.columns):>10,}|", end="")

    f1 = train_validate(lab_id, behavior, indices, features, labels)
    tqdm.write(f"{f1:>10.2f}|", end="")

    del indices, features, labels
    gc.collect()

    elapsed_time = datetime.timedelta(seconds=int(time.perf_counter() - start_time))
    tqdm.write(f"{str(elapsed_time):>15}|", end="\n")

    gc.collect()

                                      

  0%|          | 0/27 [00:00<?, ?it/s]

|           LAB           |   BEHAVIOR    | SAMPLES  | POSITIVE | FEATURES |    F1    | ELAPSED TIME  |
|     AdaptableSnail      |     rear      |

  0%|          | 0/27 [00:01<?, ?it/s]

   660,348|    85,313|        82|⚠️  Sampling 300,000 from 660,348
  Config: STRONG (F1=0.632, max_depth=6)


  4%|▎         | 1/27 [00:14<06:06, 14.09s/it]

      0.64|        0:00:14|
|         CRIM13          |     rear      |

  4%|▎         | 1/27 [00:14<06:06, 14.09s/it]

   179,132|    12,042|        82|  Config: STRONG (F1=0.412, max_depth=6)


  4%|▎         | 1/27 [00:23<06:06, 14.09s/it]

      0.41|        0:00:23|


  7%|▋         | 2/27 [00:23<04:39, 11.16s/it]

|         CRIM13          |   selfgroom   |

  7%|▋         | 2/27 [00:23<04:39, 11.16s/it]

   205,533|    14,472|        82|  Config: STRONG (F1=0.349, max_depth=6)


 11%|█         | 3/27 [00:32<04:11, 10.48s/it]

      0.35|        0:00:32|


 11%|█         | 3/27 [00:32<04:11, 10.48s/it]

|      CalMS21_task1      | genitalgroom  |

 11%|█         | 3/27 [00:33<04:11, 10.48s/it]

   102,445|     6,270|        82|  Config: STRONG (F1=0.692, max_depth=6)


 11%|█         | 3/27 [00:39<04:11, 10.48s/it]

      0.69|        0:00:39|


 15%|█▍        | 4/27 [00:39<03:24,  8.88s/it]

|       ElegantMink       |     rear      |

 19%|█▊        | 5/27 [00:39<02:08,  5.85s/it]

         0|         0|         0|         -|        0:00:39|
|       ElegantMink       |   selfgroom   |         0|         0|         0|         -|        0:00:39|
|       GroovyShrew       |     rear      |

 19%|█▊        | 5/27 [00:40<02:08,  5.85s/it]

   899,280|    50,768|        82|⚠️  Sampling 300,000 from 899,280
  Config: STRONG (F1=0.569, max_depth=6)


 26%|██▌       | 7/27 [00:51<01:56,  5.80s/it]

      0.57|        0:00:51|


 26%|██▌       | 7/27 [00:51<01:56,  5.80s/it]

|       GroovyShrew       |     rest      |

 26%|██▌       | 7/27 [00:51<01:56,  5.80s/it]

   530,886|    87,573|        82|⚠️  Sampling 300,000 from 530,886
  Config: STRONG (F1=0.700, max_depth=6)


 30%|██▉       | 8/27 [00:59<02:04,  6.53s/it]

      0.70|        0:00:59|


 30%|██▉       | 8/27 [00:59<02:04,  6.53s/it]

|       GroovyShrew       |   selfgroom   |

 30%|██▉       | 8/27 [01:00<02:04,  6.53s/it]

   877,773|    22,893|        82|⚠️  Sampling 300,000 from 877,773
  Config: STRONG (F1=0.335, max_depth=6)


 30%|██▉       | 8/27 [01:11<02:04,  6.53s/it]

      0.34|        0:01:11|


 33%|███▎      | 9/27 [01:11<02:23,  7.95s/it]

|       GroovyShrew       |     climb     |

 33%|███▎      | 9/27 [01:11<02:23,  7.95s/it]

   295,943|     8,647|        82|  Config: STRONG (F1=0.339, max_depth=6)


 33%|███▎      | 9/27 [01:22<02:23,  7.95s/it]

      0.34|        0:01:22|


 37%|███▋      | 10/27 [01:22<02:29,  8.80s/it]

|       GroovyShrew       |      dig      |

 37%|███▋      | 10/27 [01:22<02:29,  8.80s/it]

   771,922|    31,267|        82|⚠️  Sampling 300,000 from 771,922
  Config: STRONG (F1=0.437, max_depth=6)


 41%|████      | 11/27 [01:34<02:35,  9.71s/it]

      0.45|        0:01:34|


 41%|████      | 11/27 [01:34<02:35,  9.71s/it]

|       GroovyShrew       |      run      |

 41%|████      | 11/27 [01:34<02:35,  9.71s/it]

   413,942|     1,732|        82|⚠️  Sampling 300,000 from 413,942
  Config: WEAK (F1=0.191, max_depth=6)


 41%|████      | 11/27 [01:45<02:35,  9.71s/it]

      0.19|

 41%|████      | 11/27 [01:45<02:35,  9.71s/it]

        0:01:45|


 44%|████▍     | 12/27 [01:45<02:32, 10.15s/it]

|   InvincibleJellyfish   |      dig      |

 44%|████▍     | 12/27 [01:46<02:32, 10.15s/it]

   188,949|     6,768|        82|  Config: STRONG (F1=0.288, max_depth=6)


 44%|████▍     | 12/27 [01:55<02:32, 10.15s/it]

      0.29|

 48%|████▊     | 13/27 [01:56<02:23, 10.22s/it]

        0:01:55|


 48%|████▊     | 13/27 [01:56<02:23, 10.22s/it]

|   InvincibleJellyfish   |   selfgroom   |

 48%|████▊     | 13/27 [01:56<02:23, 10.22s/it]

   308,326|     2,791|        82|⚠️  Sampling 300,000 from 308,326
  Config: WEAK (F1=0.198, max_depth=6)


 48%|████▊     | 13/27 [02:08<02:23, 10.22s/it]

      0.18|

 52%|█████▏    | 14/27 [02:09<02:24, 11.11s/it]

        0:02:09|


 52%|█████▏    | 14/27 [02:09<02:24, 11.11s/it]

|       LyricalHare       |    freeze     |

 52%|█████▏    | 14/27 [02:09<02:24, 11.11s/it]

   329,777|    31,660|        82|⚠️  Sampling 300,000 from 329,777
  Config: STRONG (F1=0.524, max_depth=6)


 52%|█████▏    | 14/27 [02:19<02:24, 11.11s/it]

      0.52|

 52%|█████▏    | 14/27 [02:19<02:24, 11.11s/it]

        0:02:19|


 56%|█████▌    | 15/27 [02:19<02:11, 10.95s/it]

|       LyricalHare       |     rear      |

 56%|█████▌    | 15/27 [02:20<02:11, 10.95s/it]

   255,767|    18,953|        82|  Config: STRONG (F1=0.534, max_depth=6)


 56%|█████▌    | 15/27 [02:30<02:11, 10.95s/it]

      0.53|

 59%|█████▉    | 16/27 [02:30<01:59, 10.91s/it]

        0:02:30|


 59%|█████▉    | 16/27 [02:30<01:59, 10.91s/it]

|     NiftyGoldfinch      |  biteobject   |

 59%|█████▉    | 16/27 [02:31<01:59, 10.91s/it]

   558,309|     2,326|        82|⚠️  Sampling 300,000 from 558,309
  Config: WEAK (F1=0.042, max_depth=6)


 59%|█████▉    | 16/27 [02:43<01:59, 10.91s/it]

      0.05|

 59%|█████▉    | 16/27 [02:43<01:59, 10.91s/it]

        0:02:43|


 63%|██████▎   | 17/27 [02:44<01:56, 11.65s/it]

|     NiftyGoldfinch      |     climb     |

 63%|██████▎   | 17/27 [02:44<01:56, 11.65s/it]

   602,654|    51,687|        82|⚠️  Sampling 300,000 from 602,654
  Config: STRONG (F1=0.583, max_depth=6)


 63%|██████▎   | 17/27 [02:57<01:56, 11.65s/it]

      0.58|

 67%|██████▋   | 18/27 [02:57<01:49, 12.20s/it]

        0:02:57|


 67%|██████▋   | 18/27 [02:57<01:49, 12.20s/it]

|     NiftyGoldfinch      |      dig      |

 67%|██████▋   | 18/27 [02:57<01:49, 12.20s/it]

   656,612|    40,735|        82|⚠️  Sampling 300,000 from 656,612
  Config: STRONG (F1=0.557, max_depth=6)


 67%|██████▋   | 18/27 [03:10<01:49, 12.20s/it]

      0.56|

 70%|███████   | 19/27 [03:11<01:41, 12.70s/it]

        0:03:11|


 70%|███████   | 19/27 [03:11<01:41, 12.70s/it]

|     NiftyGoldfinch      | exploreobject |

 70%|███████   | 19/27 [03:11<01:41, 12.70s/it]

   558,859|     3,678|        82|⚠️  Sampling 300,000 from 558,859
  Config: WEAK (F1=0.107, max_depth=6)


 70%|███████   | 19/27 [03:24<01:41, 12.70s/it]

      0.12|

 70%|███████   | 19/27 [03:24<01:41, 12.70s/it]

        0:03:24|


 74%|███████▍  | 20/27 [03:25<01:31, 13.03s/it]

|     NiftyGoldfinch      |     rear      |

 74%|███████▍  | 20/27 [03:25<01:31, 13.03s/it]

   602,308|    40,444|        82|⚠️  Sampling 300,000 from 602,308
  Config: STRONG (F1=0.438, max_depth=6)


 74%|███████▍  | 20/27 [03:39<01:31, 13.03s/it]

      0.44|

 74%|███████▍  | 20/27 [03:39<01:31, 13.03s/it]

        0:03:39|


 78%|███████▊  | 21/27 [03:39<01:20, 13.43s/it]

|     NiftyGoldfinch      |   selfgroom   |

 78%|███████▊  | 21/27 [03:40<01:20, 13.43s/it]

   708,496|    34,621|        82|⚠️  Sampling 300,000 from 708,496
  Config: STRONG (F1=0.466, max_depth=6)


 78%|███████▊  | 21/27 [03:53<01:20, 13.43s/it]

      0.47|

 81%|████████▏ | 22/27 [03:53<01:08, 13.67s/it]

        0:03:53|


 81%|████████▏ | 22/27 [03:53<01:08, 13.67s/it]

|     TranquilPanther     |     rear      |

 81%|████████▏ | 22/27 [03:54<01:08, 13.67s/it]

 1,234,586|    23,369|        82|⚠️  Sampling 300,000 from 1,234,586
  Config: WEAK (F1=0.193, max_depth=6)


 81%|████████▏ | 22/27 [04:09<01:08, 13.67s/it]

      0.21|

 85%|████████▌ | 23/27 [04:09<00:57, 14.36s/it]

        0:04:09|


 85%|████████▌ | 23/27 [04:09<00:57, 14.36s/it]

|     TranquilPanther     |   selfgroom   |

 85%|████████▌ | 23/27 [04:10<00:57, 14.36s/it]

 1,021,199|     6,984|        82|⚠️  Sampling 300,000 from 1,021,199
  Config: WEAK (F1=0.141, max_depth=6)


 85%|████████▌ | 23/27 [04:24<00:57, 14.36s/it]

      0.15|

 89%|████████▉ | 24/27 [04:25<00:44, 14.69s/it]

        0:04:24|


 89%|████████▉ | 24/27 [04:25<00:44, 14.69s/it]

|      UppityFerret       |    huddle     |

 89%|████████▉ | 24/27 [04:25<00:44, 14.69s/it]

   164,371|    24,148|        82|  Config: STRONG (F1=0.627, max_depth=6)


 89%|████████▉ | 24/27 [04:36<00:44, 14.69s/it]

      0.63|

 89%|████████▉ | 24/27 [04:37<00:44, 14.69s/it]

        0:04:37|


 93%|█████████▎| 25/27 [04:37<00:27, 13.98s/it]

|      UppityFerret       |     rear      |

100%|██████████| 27/27 [04:37<00:00, 10.29s/it]

         0|         0|         0|         -|        0:04:37|
|      UppityFerret       |   selfgroom   |         0|         0|         0|         -|        0:04:37|





In [6]:
groups = train_pair_behavior_dataframe.group_by("lab_id", "behavior", maintain_order=True)
total_groups = len(list(groups))
start_time = time.perf_counter()

for idx, ((lab_id, behavior), group) in tqdm(enumerate(groups), total=total_groups):
    if idx == 0:
        tqdm.write(
            f"|{'LAB':^25}|{'BEHAVIOR':^15}|{'SAMPLES':^10}|{'POSITIVE':^10}|{'FEATURES':^10}|{'F1':^10}|{'ELAPSED TIME':^15}|",
            end="\n",
        )

    tqdm.write(f"|{lab_id:^25}|{behavior:^15}|", end="")
    index_list = []
    feature_list = []
    label_list = []

    for row in group.rows(named=True):
        video_id = row["video_id"]
        agent = row["agent"]
        target = row["target"]

        agent_mouse_id = int(re.search(r"mouse(\d+)", agent).group(1))
        target_mouse_id = int(re.search(r"mouse(\d+)", target).group(1))

        data = pl.scan_parquet(WORKING_DIR / "pair_features" / f"{video_id}.parquet").filter(
            (pl.col("agent_mouse_id") == agent_mouse_id) & (pl.col("target_mouse_id") == target_mouse_id)
        )
        index = data.select(INDEX_COLS).collect(engine="streaming")
        feature = data.select(pl.exclude(INDEX_COLS)).collect(engine="streaming")

        # read annotation
        annotation_path = TRAIN_ANNOTATION_DIR / lab_id / f"{video_id}.parquet"
        if annotation_path.exists():
            annotation = (
                pl.scan_parquet(annotation_path)
                .filter(
                    (pl.col("action") == behavior)
                    & (pl.col("agent_id") == agent_mouse_id)
                    & (pl.col("target_id") == target_mouse_id)
                )
                .collect()
            )
        else:
            annotation = pl.DataFrame(
                schema={
                    "agent_id": pl.Int8,
                    "target_id": pl.Int8,
                    "action": str,
                    "start_frame": pl.Int16,
                    "stop_frame": pl.Int16,
                }
            )

        label_frames = set()
        for annotation_row in annotation.rows(named=True):
            label_frames.update(range(annotation_row["start_frame"], annotation_row["stop_frame"]))
        label = index.select(pl.col("video_frame").is_in(label_frames).cast(pl.Int8).alias("label"))

        if label.get_column("label").sum() == 0:
            continue

        index_list.append(index)
        feature_list.append(feature)
        label_list.append(label.get_column("label"))

    if not index_list:
        elapsed_time = datetime.timedelta(seconds=int(time.perf_counter() - start_time))
        tqdm.write(f"{0:>10,}|{0:>10,}|{0:>10,}|{'-':>10}|{str(elapsed_time):>15}|", end="\n")
        continue

    indices = pl.concat(index_list, how="vertical")
    features = pl.concat(feature_list, how="vertical")
    labels = pl.concat(label_list, how="vertical")

    del index_list, feature_list, label_list
    gc.collect()

    tqdm.write(f"{len(indices):>10,}|{labels.sum():>10,}|{len(features.columns):>10,}|", end="")

    f1 = train_validate(lab_id, behavior, indices, features, labels)
    tqdm.write(f"{f1:>10.2f}|", end="")

    del indices, features, labels
    gc.collect()

    elapsed_time = datetime.timedelta(seconds=int(time.perf_counter() - start_time))
    tqdm.write(f"{str(elapsed_time):>15}|", end="\n")

    gc.collect()

 87%|████████▋ | 90/104 [30:02<06:34, 28.20s/it]

      0.68|

 88%|████████▊ | 91/104 [30:04<05:55, 27.38s/it]

        0:30:03|


 88%|████████▊ | 91/104 [30:04<05:55, 27.38s/it]

|     SparklingTapir      |    defend     |

 88%|████████▊ | 91/104 [30:06<05:55, 27.38s/it]

   198,020|     9,960|       174|  Config: STRONG (F1=0.603, max_depth=6)


 88%|████████▊ | 91/104 [30:24<05:55, 27.38s/it]

      0.61|

 88%|████████▊ | 92/104 [30:27<05:09, 25.83s/it]

        0:30:26|


 88%|████████▊ | 92/104 [30:27<05:09, 25.83s/it]

|     SparklingTapir      |    escape     |

 88%|████████▊ | 92/104 [30:28<05:09, 25.83s/it]

   180,020|     5,640|       174|  Config: STRONG (F1=0.751, max_depth=6)


 88%|████████▊ | 92/104 [30:47<05:09, 25.83s/it]

      0.76|

 89%|████████▉ | 93/104 [30:49<04:34, 24.94s/it]

        0:30:48|


 89%|████████▉ | 93/104 [30:49<04:34, 24.94s/it]

|     SparklingTapir      |     mount     |

 89%|████████▉ | 93/104 [30:51<04:34, 24.94s/it]

   162,000|     9,327|       174|  Config: STRONG (F1=0.839, max_depth=6)


 89%|████████▉ | 93/104 [31:09<04:34, 24.94s/it]

      0.84|

 89%|████████▉ | 93/104 [31:10<04:34, 24.94s/it]

        0:31:10|


 90%|█████████ | 94/104 [31:11<04:00, 24.00s/it]

|     SparklingTapir      | sniffgenital  |         0|         0|         0|         -|        0:31:11|
|     TranquilPanther     |   intromit    |

 90%|█████████ | 94/104 [31:13<04:00, 24.00s/it]

 1,061,791|    51,856|       174|⚠️  Sampling 300,000 from 1,061,791
  Config: STRONG (F1=0.568, max_depth=6)


 90%|█████████ | 94/104 [31:36<04:00, 24.00s/it]

      0.56|

 90%|█████████ | 94/104 [31:37<04:00, 24.00s/it]

        0:31:37|


 92%|█████████▏| 96/104 [31:38<02:33, 19.19s/it]

|     TranquilPanther     |     mount     |

 92%|█████████▏| 96/104 [31:40<02:33, 19.19s/it]

 1,061,791|    27,340|       174|⚠️  Sampling 300,000 from 1,061,791
  Config: STRONG (F1=0.454, max_depth=6)


 92%|█████████▏| 96/104 [32:06<02:33, 19.19s/it]

      0.46|

 93%|█████████▎| 97/104 [32:08<02:32, 21.76s/it]

        0:32:07|


 93%|█████████▎| 97/104 [32:08<02:32, 21.76s/it]

|     TranquilPanther     |     sniff     |

 93%|█████████▎| 97/104 [32:10<02:32, 21.76s/it]

 1,234,586|    31,352|       174|⚠️  Sampling 300,000 from 1,234,586
  Config: STRONG (F1=0.471, max_depth=6)


 93%|█████████▎| 97/104 [32:35<02:32, 21.76s/it]

      0.47|

 94%|█████████▍| 98/104 [32:37<02:22, 23.79s/it]

        0:32:36|


 94%|█████████▍| 98/104 [32:37<02:22, 23.79s/it]

|     TranquilPanther     | sniffgenital  |

 94%|█████████▍| 98/104 [32:39<02:22, 23.79s/it]

 1,183,089|    21,942|       174|⚠️  Sampling 300,000 from 1,183,089
  Config: STRONG (F1=0.494, max_depth=6)


 94%|█████████▍| 98/104 [33:05<02:22, 23.79s/it]

      0.50|

 95%|█████████▌| 99/104 [33:07<02:06, 25.30s/it]

        0:33:06|


 95%|█████████▌| 99/104 [33:07<02:06, 25.30s/it]

|      UppityFerret       |reciprocalsniff|

 95%|█████████▌| 99/104 [33:08<02:06, 25.30s/it]

   328,742|    17,384|       174|⚠️  Sampling 300,000 from 328,742
  Config: STRONG (F1=0.622, max_depth=6)


 95%|█████████▌| 99/104 [33:36<02:06, 25.30s/it]

      0.62|

 96%|█████████▌| 100/104 [33:39<01:48, 27.10s/it]

        0:33:37|


 96%|█████████▌| 100/104 [33:39<01:48, 27.10s/it]

|      UppityFerret       |     sniff     |

 97%|█████████▋| 101/104 [33:39<00:58, 19.47s/it]

         0|         0|         0|         -|        0:33:39|
|      UppityFerret       | sniffgenital  |

 97%|█████████▋| 101/104 [33:41<00:58, 19.47s/it]

   613,716|    39,944|       174|⚠️  Sampling 300,000 from 613,716
  Config: STRONG (F1=0.503, max_depth=6)


 97%|█████████▋| 101/104 [34:09<00:58, 19.47s/it]

      0.51|

 98%|█████████▊| 102/104 [34:11<00:46, 23.12s/it]

        0:34:10|


 99%|█████████▉| 103/104 [34:11<00:16, 16.38s/it]

|      UppityFerret       |   intromit    |         0|         0|         0|         -|        0:34:11|
|      UppityFerret       |     mount     |

100%|██████████| 104/104 [34:11<00:00, 19.73s/it]

         0|         0|         0|         -|        0:34:11|





In [None]:
%%writefile robustify.py

def robustify(submission: pl.DataFrame, dataset: pl.DataFrame, train_test: str = "train"):
    traintest_directory = INPUT_DIR / f"{train_test}_tracking"

    old_submission = submission.clone()
    submission = submission.filter(pl.col("start_frame") < pl.col("stop_frame"))
    if len(submission) != len(old_submission):
        print("ERROR: Dropped frames with start >= stop")

    old_submission = submission.clone()
    group_list = []
    for _, group in submission.group_by("video_id", "agent_id", "target_id"):
        group = group.sort("start_frame")
        mask = np.ones(len(group), dtype=bool)
        last_stop_frame = 0
        for i, row in enumerate(group.rows(named=True)):
            if row["start_frame"] < last_stop_frame:
                mask[i] = False
            else:
                last_stop_frame = row["stop_frame"]
        group_list.append(group.filter(pl.Series("mask", mask)))

    submission = pl.concat(group_list)

    if len(submission) != len(old_submission):
        print("ERROR: Dropped duplicate frames")

    s_list = []
    for row in dataset.rows(named=True):
        lab_id = row["lab_id"]
        video_id = row["video_id"]
        if row["behaviors_labeled"] is None:
            continue

        if video_id in submission.get_column("video_id").to_list():
            continue

        if isinstance(row["behaviors_labeled"], str):
            continue

        print(f"Video {video_id} has no predictions.")

        path = traintest_directory / f"/{lab_id}/{video_id}.parquet"
        vid = pd.read_parquet(path)

        vid_behaviors = json.loads(row["behaviors_labeled"])
        vid_behaviors = sorted(list({b.replace("'", "") for b in vid_behaviors}))
        vid_behaviors = [b.split(",") for b in vid_behaviors]
        vid_behaviors = pd.DataFrame(vid_behaviors, columns=["agent", "target", "action"])

        start_frame = vid.video_frame.min()
        stop_frame = vid.video_frame.max() + 1

        for (agent, target), actions in vid_behaviors.groupby(["agent", "target"]):
            batch_length = int(np.ceil((stop_frame - start_frame) / len(actions)))
            for i, action_row in enumerate(actions.itertuples(index=False)):
                batch_start = start_frame + i * batch_length
                batch_stop = min(batch_start + batch_length, stop_frame)
                s_list.append((video_id, agent, target, action_row["action"], batch_start, batch_stop))

    if len(s_list) > 0:
        submission = pd.concat(
            [
                submission,
                pd.DataFrame(s_list, columns=["video_id", "agent_id", "target_id", "action", "start_frame", "stop_frame"]),
            ]
        )
        print("ERROR: Filled empty videos")

    return submission