# MABe Challenge - XGBoost training Notebook

üìù **Note:** Please note that comments and explanations are in Japanese. However, I've made an effort to write clear, self-explanatory code that should be accessible to non-Japanese speakers as well.

## inference and submission notebook: 
https://www.kaggle.com/code/hutch1221/mabe-starter-inference-ja/notebook

In [1]:
# !pip install -q --no-index --find-links=/kaggle/input/mabe-package xgboost==3.1.1

In [3]:
import datetime
import gc
import itertools
import json
import re
import sys
import time
import traceback
from collections import defaultdict
from pathlib import Path

import joblib
import lightgbm as lgb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import xgboost as xgb
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedGroupKFold
from tqdm.auto import tqdm

sys.path.append("../src/utils")
from metric import score
from data import sep, show_df, glob_walk, set_seed, save_config_yaml, dict_to_namespace


In [4]:
# const
# INPUT_DIR = Path("/kaggle/input/MABe-mouse-behavior-detection")
INPUT_DIR = Path("/mnt/nfs/home/hidebu/study/MABe-Challenge---Social-Action-Recognition-in-Mice/data/raw")
TRAIN_TRACKING_DIR = INPUT_DIR / "train_tracking"
TRAIN_ANNOTATION_DIR = INPUT_DIR / "train_annotation"
TEST_TRACKING_DIR = INPUT_DIR / "test_tracking"

WORKING_DIR = Path("/kaggle/working")

INDEX_COLS = [
    "video_id",
    "agent_mouse_id",
    "target_mouse_id",
    "video_frame",
]

BODY_PARTS = [
    "ear_left",
    "ear_right",
    "nose",
    "neck",
    "body_center",
    "lateral_left",
    "lateral_right",
    "hip_left",
    "hip_right",
    "tail_base",
    "tail_tip",
]

SELF_BEHAVIORS = [
    "biteobject",
    "climb",
    "dig",
    "exploreobject",
    "freeze",
    "genitalgroom",
    "huddle",
    "rear",
    "rest",
    "run",
    "selfgroom",
]

PAIR_BEHAVIORS = [
    "allogroom",
    "approach",
    "attack",
    "attemptmount",
    "avoid",
    "chase",
    "chaseattack",
    "defend",
    "disengage",
    "dominance",
    "dominancegroom",
    "dominancemount",
    "ejaculate",
    "escape",
    "flinch",
    "follow",
    "intromit",
    "mount",
    "reciprocalsniff",
    "shepherd",
    "sniff",
    "sniffbody",
    "sniffface",
    "sniffgenital",
    "submit",
    "tussle",
]

In [5]:
# read data
train_dataframe = pl.read_csv(INPUT_DIR / "train.csv")
show_df(train_dataframe)

(8789, 38)


lab_id,video_id,mouse1_strain,mouse1_color,mouse1_sex,mouse1_id,mouse1_age,mouse1_condition,mouse2_strain,mouse2_color,mouse2_sex,mouse2_id,mouse2_age,mouse2_condition,mouse3_strain,mouse3_color,mouse3_sex,mouse3_id,mouse3_age,mouse3_condition,mouse4_strain,mouse4_color,mouse4_sex,mouse4_id,mouse4_age,mouse4_condition,frames_per_second,video_duration_sec,pix_per_cm_approx,video_width_pix,video_height_pix,arena_width_cm,arena_height_cm,arena_shape,arena_type,body_parts_tracked,behaviors_labeled,tracking_method
str,i64,str,str,str,f64,str,str,str,str,str,f64,str,str,str,str,str,f64,str,str,str,str,str,f64,str,str,f64,f64,f64,i64,i64,f64,f64,str,str,str,str,str
"""AdaptableSnail""",44566106,"""CD-1 (ICR)""","""white""","""male""",10.0,"""8-12 weeks""","""wireless device""","""CD-1 (ICR)""","""white""","""male""",24.0,"""8-12 weeks""","""wireless device""","""CD-1 (ICR)""","""white""","""male""",38.0,"""8-12 weeks""","""wireless device""","""CD-1 (ICR)""","""white""","""male""",51.0,"""8-12 weeks""","""wireless device""",30.0,615.6,16.0,1228,1068,60.0,60.0,"""square""","""familiar""","""[""body_center"", ""ear_left"", ""e‚Ä¶","""[""mouse1,mouse2,approach"", ""mo‚Ä¶","""DeepLabCut"""
"""AdaptableSnail""",143861384,"""CD-1 (ICR)""","""white""","""male""",3.0,"""8-12 weeks""",,"""CD-1 (ICR)""","""white""","""male""",17.0,"""8-12 weeks""",,"""CD-1 (ICR)""","""white""","""male""",31.0,"""8-12 weeks""",,"""CD-1 (ICR)""","""white""","""male""",44.0,"""8-12 weeks""",,25.0,3599.0,9.7,968,608,60.0,60.0,"""square""","""familiar""","""[""body_center"", ""ear_left"", ""e‚Ä¶","""[""mouse1,mouse2,approach"", ""mo‚Ä¶","""DeepLabCut"""
"""AdaptableSnail""",209576908,"""CD-1 (ICR)""","""white""","""male""",7.0,"""8-12 weeks""",,"""CD-1 (ICR)""","""white""","""male""",21.0,"""8-12 weeks""",,"""CD-1 (ICR)""","""white""","""male""",35.0,"""8-12 weeks""",,"""CD-1 (ICR)""","""white""","""male""",48.0,"""8-12 weeks""",,30.0,615.2,16.0,1266,1100,60.0,60.0,"""square""","""familiar""","""[""body_center"", ""ear_left"", ""e‚Ä¶","""[""mouse1,mouse2,approach"", ""mo‚Ä¶","""DeepLabCut"""


In [6]:
# preprocess behavior labels
train_behavior_dataframe = (
    train_dataframe.filter(pl.col("behaviors_labeled").is_not_null())
    .select(
        pl.col("lab_id"),
        pl.col("video_id"),
        pl.col("behaviors_labeled").map_elements(eval, return_dtype=pl.List(pl.Utf8)).alias("behaviors_labeled_list"),
    )
    .explode("behaviors_labeled_list")
    .rename({"behaviors_labeled_list": "behaviors_labeled_element"})
    .select(
        pl.col("lab_id"),
        pl.col("video_id"),
        pl.col("behaviors_labeled_element").str.split(",").list[0].str.replace_all("'", "").alias("agent"),
        pl.col("behaviors_labeled_element").str.split(",").list[1].str.replace_all("'", "").alias("target"),
        pl.col("behaviors_labeled_element").str.split(",").list[2].str.replace_all("'", "").alias("behavior"),
    )
)

train_self_behavior_dataframe = train_behavior_dataframe.filter(pl.col("behavior").is_in(SELF_BEHAVIORS))
train_pair_behavior_dataframe = train_behavior_dataframe.filter(pl.col("behavior").is_in(PAIR_BEHAVIORS))
sep("train_self_behavior_dataframe")
show_df(train_self_behavior_dataframe)
sep("train_pair_behavior_dataframe")
show_df(train_pair_behavior_dataframe)

train_self_behavior_dataframe
(594, 5)


lab_id,video_id,agent,target,behavior
str,i64,str,str,str
"""AdaptableSnail""",44566106,"""mouse1""","""self""","""rear"""
"""AdaptableSnail""",44566106,"""mouse2""","""self""","""rear"""
"""AdaptableSnail""",44566106,"""mouse3""","""self""","""rear"""


train_pair_behavior_dataframe
(4325, 5)


lab_id,video_id,agent,target,behavior
str,i64,str,str,str
"""AdaptableSnail""",44566106,"""mouse1""","""mouse2""","""approach"""
"""AdaptableSnail""",44566106,"""mouse1""","""mouse2""","""attack"""
"""AdaptableSnail""",44566106,"""mouse1""","""mouse2""","""avoid"""


## ÁâπÂæ¥ÈáèÂä†Â∑•(self)
- agent„ÅÆ‰Ωì„ÅÆÂêÑÈÉ®‰ΩçÈñìË∑ùÈõ¢(cm)

  ‰∏ªË¶Å„Å™11ÈÉ®‰Ωç(ÂÆöÊï∞: BODY_PARTS)„ÅÆÁµÑ„ÅøÂêà„Çè„Åõ„ÅÆË∑ùÈõ¢
  
- ÈÉ®‰Ωç„ÅÆÊé®ÂÆöÈÄüÂ∫¶(cm/s)

  ear_left, ear_right, tail_base „ÅÆ 500, 1000, 2000, 3000msÈñì„ÅÆÊé®ÂÆöÈÄüÂ∫¶
  
- ‰º∏Èï∑Â∫¶

  (nose-tail_base„ÅÆË∑ùÈõ¢) / (ear_left-ear_right„ÅÆË∑ùÈõ¢)

- ‰ΩìËßíÂ∫¶(deg)
  
  nose-body_center, body_center-tail_base„ÄÄ„Åå„Å™„Åô„Éô„ÇØ„Éà„É´„ÅÆËßíÂ∫¶

In [7]:
%%writefile self_features.py

def make_self_features(
    metadata: dict,
    tracking: pl.DataFrame,
) -> pl.DataFrame:
    def body_parts_distance(body_part_1, body_part_2):
        # agent„ÅÆ‰Ωì„ÅÆÂêÑÈÉ®‰ΩçÈñìË∑ùÈõ¢(cm)
        assert body_part_1 in BODY_PARTS
        assert body_part_2 in BODY_PARTS
        return (
            (pl.col(f"agent_x_{body_part_1}") - pl.col(f"agent_x_{body_part_2}")).pow(2)
            + (pl.col(f"agent_y_{body_part_1}") - pl.col(f"agent_y_{body_part_2}")).pow(2)
        ).sqrt() / metadata["pix_per_cm_approx"]

    def body_part_speed(body_part, period_ms):
        # ÈÉ®‰Ωç„ÅÆÊé®ÂÆöÈÄüÂ∫¶(cm/s)
        assert body_part in BODY_PARTS
        window_frames = max(1, int(round(period_ms * metadata["frames_per_second"] / 1000.0)))
        return (
            ((pl.col(f"agent_x_{body_part}").diff()).pow(2) + (pl.col(f"agent_y_{body_part}").diff()).pow(2)).sqrt()
            / metadata["pix_per_cm_approx"]
            * metadata["frames_per_second"]
        ).rolling_mean(window_size=window_frames, center=True, min_samples=1)

    def elongation():
        # ‰º∏Èï∑Â∫¶
        d1 = body_parts_distance("nose", "tail_base")
        d2 = body_parts_distance("ear_left", "ear_right")
        return d1 / (d2 + 1e-06)

    def body_angle():
        # ‰ΩìËßíÂ∫¶(deg)
        v1x = pl.col("agent_x_nose") - pl.col("agent_x_body_center")
        v1y = pl.col("agent_y_nose") - pl.col("agent_y_body_center")
        v2x = pl.col("agent_x_tail_base") - pl.col("agent_x_body_center")
        v2y = pl.col("agent_y_tail_base") - pl.col("agent_y_body_center")
        return (v1x * v2x + v1y * v2y) / ((v1x.pow(2) + v1y.pow(2)).sqrt() * (v2x.pow(2) + v2y.pow(2)).sqrt() + 1e-06)

    n_mice = (
        (metadata["mouse1_strain"] is not None)
        + (metadata["mouse2_strain"] is not None)
        + (metadata["mouse3_strain"] is not None)
        + (metadata["mouse4_strain"] is not None)
    )
    start_frame = tracking.select(pl.col("video_frame").min()).item()
    end_frame = tracking.select(pl.col("video_frame").max()).item()

    result = []

    pivot = tracking.pivot(
        on=["bodypart"],
        index=["video_frame", "mouse_id"],
        values=["x", "y"],
    ).sort(["mouse_id", "video_frame"])
    pivot_trackings = {mouse_id: pivot.filter(pl.col("mouse_id") == mouse_id) for mouse_id in range(1, n_mice + 1)}

    for agent_mouse_id in range(1, n_mice + 1):
        result_element = pl.DataFrame(
            {
                "video_id": metadata["video_id"],
                "agent_mouse_id": agent_mouse_id,
                "target_mouse_id": -1,
                "video_frame": pl.arange(start_frame, end_frame + 1, eager=True),
            },
            schema={
                "video_id": pl.Int32,
                "agent_mouse_id": pl.Int8,
                "target_mouse_id": pl.Int8,
                "video_frame": pl.Int32,
            },
        )

        pivot = pivot_trackings[agent_mouse_id].select(
            pl.col("video_frame"),
            pl.exclude("video_frame").name.prefix("agent_"),
        )
        columns = pivot.columns
        pivot = pivot.with_columns(
            *[pl.lit(None).cast(pl.Float32).alias(f"agent_x_{bp}") for bp in BODY_PARTS if f"agent_x_{bp}" not in columns],
            *[pl.lit(None).cast(pl.Float32).alias(f"agent_y_{bp}") for bp in BODY_PARTS if f"agent_y_{bp}" not in columns],
        )

        features = pivot.with_columns(
            pl.lit(agent_mouse_id).alias("agent_mouse_id"),
            pl.lit(-1).alias("target_mouse_id"),
        ).select(
            pl.col("video_frame"),
            pl.col("agent_mouse_id"),
            pl.col("target_mouse_id"),
            *[
                body_parts_distance(body_part_1, body_part_2).alias(f"aa__{body_part_1}__{body_part_2}__distance")
                for body_part_1, body_part_2 in itertools.combinations(BODY_PARTS, 2)
            ],
            *[
                body_part_speed(body_part, period_ms).alias(f"agent__{body_part}__speed_{period_ms}ms")
                for body_part, period_ms in itertools.product(["ear_left", "ear_right", "tail_base"], [500, 1000, 2000, 3000])
            ],
            elongation().alias("agent__elongation"),
            body_angle().alias("agent__body_angle"),
        )

        result_element = result_element.join(
            features,
            on=["video_frame", "agent_mouse_id", "target_mouse_id"],
            how="left",
        )
        result.append(result_element)

    return pl.concat(result, how="vertical")

Overwriting self_features.py


## ÁâπÂæ¥ÈáèÂä†Â∑•(pair)
- agent-target „ÅÆ‰Ωì„ÅÆÂêÑÈÉ®‰ΩçÈñìË∑ùÈõ¢(cm)

  agent-target „ÅÆ‰∏ªË¶Å„Å™11ÈÉ®‰Ωç(ÂÆöÊï∞: BODY_PARTS)„ÅÆÁµÑ„ÅøÂêà„Çè„Åõ„ÅÆË∑ùÈõ¢
  
- agent, target „ÅÆÈÉ®‰Ωç„ÅÆÊé®ÂÆöÈÄüÂ∫¶(cm/s)

  ear_left, ear_right, tail_base „ÅÆ 500, 1000, 2000, 3000msÈñì„ÅÆÊé®ÂÆöÈÄüÂ∫¶
  
- agent, target „ÅÆ‰º∏Èï∑Â∫¶

  (nose-tail_base„ÅÆË∑ùÈõ¢) / (ear_left-ear_right„ÅÆË∑ùÈõ¢)

- agent, target „ÅÆ‰ΩìËßíÂ∫¶(deg)
  
  nose-body_center, body_center-tail_base„ÄÄ„Åå„Å™„Åô„Éô„ÇØ„Éà„É´„ÅÆËßíÂ∫¶

In [8]:
%%writefile pair_features.py

def make_pair_features(
    metadata: dict,
    tracking: pl.DataFrame,
) -> pl.DataFrame:
    def body_parts_distance(agent_or_target_1, body_part_1, agent_or_target_2, body_part_2):
        # agent-target„ÅÆ‰Ωì„ÅÆÂêÑÈÉ®‰ΩçÈñìË∑ùÈõ¢(cm)
        assert agent_or_target_1 == "agent" or agent_or_target_1 == "target"
        assert agent_or_target_2 == "agent" or agent_or_target_2 == "target"
        assert body_part_1 in BODY_PARTS
        assert body_part_2 in BODY_PARTS
        return (
            (pl.col(f"{agent_or_target_1}_x_{body_part_1}") - pl.col(f"{agent_or_target_2}_x_{body_part_2}")).pow(2)
            + (pl.col(f"{agent_or_target_1}_y_{body_part_1}") - pl.col(f"{agent_or_target_2}_y_{body_part_2}")).pow(2)
        ).sqrt() / metadata["pix_per_cm_approx"]

    def body_part_speed(agent_or_target, body_part, period_ms):
        # ÈÉ®‰Ωç„ÅÆÊé®ÂÆöÈÄüÂ∫¶(cm/s)
        assert agent_or_target == "agent" or agent_or_target == "target"
        assert body_part in BODY_PARTS
        window_frames = max(1, int(round(period_ms * metadata["frames_per_second"] / 1000.0)))
        return (
            (
                (pl.col(f"{agent_or_target}_x_{body_part}").diff()).pow(2)
                + (pl.col(f"{agent_or_target}_y_{body_part}").diff()).pow(2)
            ).sqrt()
            / metadata["pix_per_cm_approx"]
            * metadata["frames_per_second"]
        ).rolling_mean(window_size=window_frames, center=True)

    def elongation(agent_or_target):
        # ‰º∏Èï∑Â∫¶(cm)
        assert agent_or_target == "agent" or agent_or_target == "target"
        d1 = body_parts_distance(agent_or_target, "nose", agent_or_target, "tail_base")
        d2 = body_parts_distance(agent_or_target, "ear_left", agent_or_target, "ear_right")
        return d1 / (d2 + 1e-06)

    def body_angle(agent_or_target):
        # ‰ΩìËßíÂ∫¶(deg)
        assert agent_or_target == "agent" or agent_or_target == "target"
        v1x = pl.col(f"{agent_or_target}_x_nose") - pl.col(f"{agent_or_target}_x_body_center")
        v1y = pl.col(f"{agent_or_target}_y_nose") - pl.col(f"{agent_or_target}_y_body_center")
        v2x = pl.col(f"{agent_or_target}_x_tail_base") - pl.col(f"{agent_or_target}_x_body_center")
        v2y = pl.col(f"{agent_or_target}_y_tail_base") - pl.col(f"{agent_or_target}_y_body_center")
        return (v1x * v2x + v1y * v2y) / ((v1x.pow(2) + v1y.pow(2)).sqrt() * (v2x.pow(2) + v2y.pow(2)).sqrt() + 1e-06)

    def body_center_distance_rolling_agg(agg, period_ms):
        # Ë∑ùÈõ¢„ÅÆÁßªÂãïÈõÜË®àÁâπÂæ¥Èáè
        assert agg in ["mean", "std", "var", "min", "max"] # ÈõÜË®àÈñ¢Êï∞
        expr = body_parts_distance("agent", "body_center", "target", "body_center")
        window_frames = max(1, int(round(period_ms * metadata["frames_per_second"] / 1000.0)))

        if agg == "mean":
            return expr.rolling_mean(window_size=window_frames, center=True, min_samples=1)
        elif agg == "std":
            return expr.rolling_std(window_size=window_frames, center=True, min_samples=1)
        elif agg == "var":
            return expr.rolling_var(window_size=window_frames, center=True, min_samples=1)
        elif agg == "min":
            return expr.rolling_min(window_size=window_frames, center=True, min_samples=1)
        elif agg == "max":
            return expr.rolling_max(window_size=window_frames, center=True, min_samples=1)
        else:
            raise ValueError()

    n_mice = (
        (metadata["mouse1_strain"] is not None)
        + (metadata["mouse2_strain"] is not None)
        + (metadata["mouse3_strain"] is not None)
        + (metadata["mouse4_strain"] is not None)
    )
    start_frame = tracking.select(pl.col("video_frame").min()).item()
    end_frame = tracking.select(pl.col("video_frame").max()).item()

    result = []

    pivot = tracking.pivot(
        on=["bodypart"],
        index=["video_frame", "mouse_id"],
        values=["x", "y"],
    ).sort(["mouse_id", "video_frame"])
    pivot_trackings = {mouse_id: pivot.filter(pl.col("mouse_id") == mouse_id) for mouse_id in range(1, n_mice + 1)}

    for agent_mouse_id, target_mouse_id in itertools.permutations(range(1, n_mice + 1), 2):
        result_element = pl.DataFrame(
            {
                "video_id": metadata["video_id"],
                "agent_mouse_id": agent_mouse_id,
                "target_mouse_id": target_mouse_id,
                "video_frame": pl.arange(start_frame, end_frame + 1, eager=True),
            },
            schema={
                "video_id": pl.Int32,
                "agent_mouse_id": pl.Int8,
                "target_mouse_id": pl.Int8,
                "video_frame": pl.Int32,
            },
        )

        merged_pivot = (
            pivot_trackings[agent_mouse_id]
            .select(
                pl.col("video_frame"),
                pl.exclude("video_frame").name.prefix("agent_"),
            )
            .join(
                pivot_trackings[target_mouse_id].select(
                    pl.col("video_frame"),
                    pl.exclude("video_frame").name.prefix("target_"),
                ),
                on="video_frame",
                how="inner",
            )
        )
        columns = merged_pivot.columns
        merged_pivot = merged_pivot.with_columns(
            *[pl.lit(None).cast(pl.Float32).alias(f"agent_x_{bp}") for bp in BODY_PARTS if f"agent_x_{bp}" not in columns],
            *[pl.lit(None).cast(pl.Float32).alias(f"agent_y_{bp}") for bp in BODY_PARTS if f"agent_y_{bp}" not in columns],
            *[pl.lit(None).cast(pl.Float32).alias(f"target_x_{bp}") for bp in BODY_PARTS if f"target_x_{bp}" not in columns],
            *[pl.lit(None).cast(pl.Float32).alias(f"target_y_{bp}") for bp in BODY_PARTS if f"target_y_{bp}" not in columns],
        )

        features = merged_pivot.with_columns(
            pl.lit(agent_mouse_id).alias("agent_mouse_id"),
            pl.lit(target_mouse_id).alias("target_mouse_id"),
        ).select(
            pl.col("video_frame"),
            pl.col("agent_mouse_id"),
            pl.col("target_mouse_id"),
            *[
                body_parts_distance("agent", agent_body_part, "target", target_body_part).alias(
                    f"at__{agent_body_part}__{target_body_part}__distance"
                )
                for agent_body_part, target_body_part in itertools.product(BODY_PARTS, repeat=2)
            ],
            *[
                body_part_speed("agent", body_part, period_ms).alias(f"agent__{body_part}__speed_{period_ms}ms")
                for body_part, period_ms in itertools.product(["ear_left", "ear_right", "tail_base"], [500, 1000, 2000, 3000])
            ],
            *[
                body_part_speed("target", body_part, period_ms).alias(f"target__{body_part}__speed_{period_ms}ms")
                for body_part, period_ms in itertools.product(["ear_left", "ear_right", "tail_base"], [500, 1000, 2000, 3000])
            ],
            elongation("agent").alias("agent__elongation"),
            elongation("target").alias("target__elongation"),
            body_angle("agent").alias("agent__body_angle"),
            body_angle("target").alias("target__body_angle"),
        )

        result_element = result_element.join(
            features,
            on=["video_frame", "agent_mouse_id", "target_mouse_id"],
            how="left",
        )
        result.append(result_element)

    return pl.concat(result, how="vertical")

Overwriting pair_features.py


In [9]:
%run -i self_features.py
%run -i pair_features.py

def process_video(row):
    """Process a single video to extract self and pair features."""
    lab_id = row["lab_id"]
    video_id = row["video_id"]

    tracking_path = TRAIN_TRACKING_DIR / f"{lab_id}/{video_id}.parquet"
    tracking = pl.read_parquet(tracking_path)

    self_features = make_self_features(metadata=row, tracking=tracking)
    pair_features = make_pair_features(metadata=row, tracking=tracking)

    self_features.write_parquet(WORKING_DIR / "self_features" / f"{video_id}.parquet")
    pair_features.write_parquet(WORKING_DIR / "pair_features" / f"{video_id}.parquet")

    return video_id


# make data
(WORKING_DIR / "self_features").mkdir(exist_ok=True, parents=True)
(WORKING_DIR / "pair_features").mkdir(exist_ok=True, parents=True)

rows = list(train_dataframe.filter(pl.col("behaviors_labeled").is_not_null()).rows(named=True))
results = joblib.Parallel(n_jobs=-1, verbose=5)(joblib.delayed(process_video)(row) for row in rows)

print(f"Processed {len(results)} videos successfully")

del rows, results
gc.collect()

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 48 concurrent workers.
[Parallel(n_jobs=-1)]: Done  66 tasks      | elapsed:   11.2s
Exception ignored in: <function ResourceTracker.__del__ at 0x7fd86b430f40>
Traceback (most recent call last):
  File "/root/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py", line 77, in __del__
  File "/root/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py", line 86, in _stop
  File "/root/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py", line 111, in _stop_locked
ChildProcessError: [Errno 10] No child processes
Exception ignored in: <function ResourceTracker.__del__ at 0x7fdb2e8d4f40>
Traceback (most recent call last):
  File "/root/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py", line 77, in __del__
  File "/root/.l

Processed 848 videos successfully


<function ResourceTracker.__del__ at 0x7f85ab814f40>
Traceback (most recent call last):
  File "/root/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py", line 77, in __del__
  File "/root/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py", line 86, in _stop
  File "/root/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py", line 111, in _stop_locked
ChildProcessError: [Errno 10] No child processes


46

In [10]:
def tune_threshold(oof_action, y_action):
    thresholds = np.arange(0, 1.005, 0.005)
    scores = [f1_score(y_action, (oof_action >= th), zero_division=0) for th in thresholds]
    best_idx = np.argmax(scores)
    return thresholds[best_idx]

## Â≠¶Áøí„ÉªÊ§úË®º

lab, behavior ÊØé„Å´„É¢„Éá„É´(XGBoost)„Çí‰Ωú„Çã

„ÇØ„É≠„Çπ„Éê„É™„Éá„Éº„Ç∑„Éß„É≥(3fold)„Åßf1„Çπ„Ç≥„Ç¢„ÇíË®àÁÆó„Åó„Å™„Åå„ÇâÂ≠¶Áøí„Åô„Çã

In [11]:
def train_validate(lab_id: str, behavior: str, indices: pl.DataFrame, features: pl.DataFrame, labels: pl.Series):
    # ÁµêÊûú„Çí‰øùÂ≠ò„Åô„Çã„Éá„Ç£„É¨„ÇØ„Éà„É™„ÅÆ„Éë„Çπ„Çí‰ΩúÊàê
    result_dir = WORKING_DIR / "results" / lab_id / behavior
    # „Éá„Ç£„É¨„ÇØ„Éà„É™„ÅåÂ≠òÂú®„Åó„Å™„ÅÑÂ†¥Âêà„ÅØ‰ΩúÊàêÔºàË¶™„Éá„Ç£„É¨„ÇØ„Éà„É™„ÇÇÂê´„ÇÅ„Å¶Ôºâ
    result_dir.mkdir(exist_ok=True, parents=True)

    # „É©„Éô„É´„ÅÆÂêàË®à„Åå0„ÅÆÂ†¥ÂêàÔºàÊ≠£‰æã„Åå1„Å§„ÇÇ„Å™„ÅÑÂ†¥ÂêàÔºâ„ÅÆÂá¶ÁêÜ
    if labels.sum() == 0:
        # F1„Çπ„Ç≥„Ç¢„Çí0„Å®„Åó„Å¶‰øùÂ≠ò
        with open(result_dir / "f1.txt", "w") as f:
            f.write("0.0\n")
        # „Åô„Åπ„Å¶„ÅÆ‰∫àÊ∏¨ÂÄ§„Çí0„Å®„Åó„ÅüÁµêÊûú„Éá„Éº„Çø„Éï„É¨„Éº„É†„Çí‰ΩúÊàê
        oof_prediction_dataframe = indices.with_columns(
            pl.Series("fold", [-1] * len(labels), dtype=pl.Int8),  # „Éï„Ç©„Éº„É´„ÉâÁï™Âè∑Ôºà-1„ÅØÊú™‰ΩøÁî®„ÇíÊÑèÂë≥Ôºâ
            pl.Series("prediction", [0.0] * len(labels), dtype=pl.Float32),  # ‰∫àÊ∏¨Á¢∫Áéá
            pl.Series("predicted_label", [0] * len(labels), dtype=pl.Int8),  # ‰∫àÊ∏¨„É©„Éô„É´
        )
        # ÁµêÊûú„ÇíparquetÂΩ¢Âºè„Åß‰øùÂ≠ò
        oof_prediction_dataframe.write_parquet(result_dir / "oof_predictions.parquet")
        return 0.0

    # Out-of-Fold‰∫àÊ∏¨ÁµêÊûú„Çí‰øùÂ≠ò„Åô„Çã„Åü„ÇÅ„ÅÆÈÖçÂàó„ÇíÂàùÊúüÂåñ
    folds = np.ones(len(labels), dtype=np.int8) * -1  # ÂêÑ„Çµ„É≥„Éó„É´„ÅåÂ±û„Åô„Çã„Éï„Ç©„Éº„É´„ÉâÁï™Âè∑
    oof_predictions = np.zeros(len(labels), dtype=np.float32)  # ‰∫àÊ∏¨Á¢∫Áéá
    oof_prediction_labels = np.zeros(len(labels), dtype=np.int8)  # ‰∫àÊ∏¨„É©„Éô„É´Ôºà0„Åæ„Åü„ÅØ1Ôºâ

    # 3ÂàÜÂâ≤„ÅÆÂ±§Âåñ„Ç∞„É´„Éº„Éó‰∫§Â∑ÆÊ§úË®º„ÇíÂÆüË°å
    # StratifiedGroupKFold„ÅØ„ÄÅ„É©„Éô„É´„ÅÆÂàÜÂ∏É„Çí‰øù„Å°„Å§„Å§„ÄÅÂêå„Åò„Ç∞„É´„Éº„ÉóÔºàvideo_idÔºâ„ÅåË§áÊï∞„ÅÆ„Éï„Ç©„Éº„É´„Éâ„Å´ÂàÜ„Åã„Çå„Å™„ÅÑ„Çà„ÅÜ„Å´„Åô„Çã
    for fold, (train_idx, valid_idx) in enumerate(
        StratifiedGroupKFold(n_splits=3, shuffle=True, random_state=42).split(
            X=features,  # ÁâπÂæ¥Èáè
            y=labels,  # „É©„Éô„É´
            groups=indices.get_column("video_id"),  # „Ç∞„É´„Éº„ÉóÂåñ„ÅÆÂü∫Ê∫ñÔºàÂêå„ÅòÂãïÁîªID„ÅØÂêå„Åò„Éï„Ç©„Éº„É´„Éâ„Å´Ôºâ
        )
    ):
        # ÂêÑ„Éï„Ç©„Éº„É´„Éâ„ÅÆÁµêÊûú„Çí‰øùÂ≠ò„Åô„Çã„Éá„Ç£„É¨„ÇØ„Éà„É™„Çí‰ΩúÊàê
        result_dir_fold = result_dir / f"fold_{fold}"
        result_dir_fold.mkdir(exist_ok=True, parents=True)

        # Ë®ìÁ∑¥„Éá„Éº„Çø„Å®Ê§úË®º„Éá„Éº„Çø„Å´ÂàÜÂâ≤
        X_train = features[train_idx]  # Ë®ìÁ∑¥Áî®ÁâπÂæ¥Èáè
        y_train = labels[train_idx]  # Ë®ìÁ∑¥Áî®„É©„Éô„É´
        X_valid = features[valid_idx]  # Ê§úË®ºÁî®ÁâπÂæ¥Èáè
        y_valid = labels[valid_idx]  # Ê§úË®ºÁî®„É©„Éô„É´

        # „ÇØ„É©„Çπ‰∏çÂùáË°°„Å´ÂØæÂá¶„Åô„Çã„Åü„ÇÅ„ÅÆÈáç„Åø„ÇíË®àÁÆó
        # Ë≤†‰æã„ÅÆÊï∞ / Ê≠£‰æã„ÅÆÊï∞ = Ê≠£‰æã„Å´„Åã„Åë„ÇãÈáç„Åø
        scale_pos_weight = (len(y_train) - y_train.sum()) / y_train.sum()

        # XGBoost„ÅÆ„Éè„Ç§„Éë„Éº„Éë„É©„É°„Éº„Çø„ÇíË®≠ÂÆö
        params = {
            "objective": "binary:logistic",  # ‰∫åÂÄ§ÂàÜÈ°ûÂïèÈ°å
            "eval_metric": "logloss",  # Ë©ï‰æ°ÊåáÊ®ôÔºöÂØæÊï∞ÊêçÂ§±
            "device": "cpu",  # ‰ΩøÁî®„Éá„Éê„Ç§„Çπ
            "tree_method": "hist",  # „Éí„Çπ„Éà„Ç∞„É©„É†„Éô„Éº„Çπ„ÅÆÈ´òÈÄü„Å™„Ç¢„É´„Ç¥„É™„Ç∫„É†
            "learning_rate": 0.05,  # Â≠¶ÁøíÁéá
            "max_depth": 6,  # Êú®„ÅÆÊúÄÂ§ßÊ∑±„Åï
            "min_child_weight": 5,  # Â≠ê„Éé„Éº„Éâ„ÅÆÊúÄÂ∞èÈáç„Åø
            "subsample": 0.8,  # ÂêÑÊú®„Åß‰ΩøÁî®„Åô„Çã„Çµ„É≥„Éó„É´„ÅÆÂâ≤Âêà
            "colsample_bytree": 0.8,  # ÂêÑÊú®„Åß‰ΩøÁî®„Åô„ÇãÁâπÂæ¥Èáè„ÅÆÂâ≤Âêà
            "scale_pos_weight": scale_pos_weight,  # Ê≠£‰æã„ÅÆÈáç„Åø
            "max_bin": 64,  # „Éí„Çπ„Éà„Ç∞„É©„É†„ÅÆ„Éì„É≥Êï∞
            "seed": 42,  # ‰π±Êï∞„Ç∑„Éº„Éâ
        }
        
        # XGBoostÁî®„ÅÆ„Éá„Éº„ÇøË°åÂàó„Çí‰ΩúÊàêÔºàË®ìÁ∑¥„Éá„Éº„Çø„ÅØÈáèÂ≠êÂåñË°åÂàó„ÄÅÊ§úË®º„Éá„Éº„Çø„ÅØÈÄöÂ∏∏„ÅÆË°åÂàóÔºâ
        dtrain = xgb.QuantileDMatrix(X_train, label=y_train, feature_names=features.columns, max_bin=64)
        dvalid = xgb.DMatrix(X_valid, label=y_valid, feature_names=features.columns)

        # Ë©ï‰æ°ÁµêÊûú„Çí‰øùÂ≠ò„Åô„ÇãËæûÊõ∏
        evals_result = {}
        
        # Êó©ÊúüÁµÇ‰∫Ü„ÅÆ„Ç≥„Éº„É´„Éê„ÉÉ„ÇØ„ÇíË®≠ÂÆö
        # Ê§úË®º„Éá„Éº„Çø„ÅÆÂØæÊï∞ÊêçÂ§±„Åå10„É©„Ç¶„É≥„ÉâÊîπÂñÑ„Åó„Å™„ÅÑÂ†¥Âêà„ÄÅÂ≠¶Áøí„ÇíÂÅúÊ≠¢
        early_stopping_callback = xgb.callback.EarlyStopping(
            rounds=10,  # ÊîπÂñÑ„ÅåË¶ã„Çâ„Çå„Å™„ÅÑÈÄ£Á∂ö„É©„Ç¶„É≥„ÉâÊï∞
            metric_name="logloss",  # Áõ£Ë¶ñ„Åô„ÇãÊåáÊ®ô
            data_name="valid",  # Áõ£Ë¶ñ„Åô„Çã„Éá„Éº„Çø„Çª„ÉÉ„Éà
            maximize=False,  # Â∞è„Åï„ÅÑÊñπ„ÅåËâØ„ÅÑÊåáÊ®ô
            save_best=True,  # ÊúÄËâØ„ÅÆ„É¢„Éá„É´„Çí‰øùÂ≠ò
        )
        
        # „É¢„Éá„É´„ÅÆÂ≠¶Áøí„ÇíÂÆüË°å
        model = xgb.train(
            params,  # „Éè„Ç§„Éë„Éº„Éë„É©„É°„Éº„Çø
            dtrain=dtrain,  # Ë®ìÁ∑¥„Éá„Éº„Çø
            num_boost_round=250,  # ÊúÄÂ§ß„Éñ„Éº„Çπ„ÉÜ„Ç£„É≥„Ç∞„É©„Ç¶„É≥„ÉâÊï∞
            evals=[(dtrain, "train"), (dvalid, "valid")],  # Ë©ï‰æ°„Åô„Çã„Éá„Éº„Çø„Çª„ÉÉ„Éà
            callbacks=[early_stopping_callback],  # „Ç≥„Éº„É´„Éê„ÉÉ„ÇØ
            evals_result=evals_result,  # Ë©ï‰æ°ÁµêÊûú„ÅÆ‰øùÂ≠òÂÖà
            verbose_eval=0,  # „É≠„Ç∞Âá∫Âäõ„ÅÆÈ†ªÂ∫¶Ôºà0„ÅØÂá∫Âäõ„Å™„ÅóÔºâ
        )

        # Ê§úË®º„Éá„Éº„Çø„Å´ÂØæ„Åó„Å¶‰∫àÊ∏¨„ÇíÂÆüË°åÔºàÁ¢∫ÁéáÂÄ§„ÇíÂèñÂæóÔºâ
        fold_predictions = model.predict(dvalid)

        # F1„Çπ„Ç≥„Ç¢„ÇíÊúÄÂ§ßÂåñ„Åô„ÇãÊúÄÈÅ©„Å™ÈñæÂÄ§„ÇíË™øÊï¥
        threshold = tune_threshold(fold_predictions, y_valid)
        
        # Out-of-Fold‰∫àÊ∏¨ÁµêÊûú„Çí‰øùÂ≠ò
        folds[valid_idx] = fold  # „Éï„Ç©„Éº„É´„ÉâÁï™Âè∑
        oof_predictions[valid_idx] = fold_predictions  # ‰∫àÊ∏¨Á¢∫Áéá
        oof_prediction_labels[valid_idx] = (fold_predictions >= threshold).astype(np.int8)  # ÈñæÂÄ§„Åß‰∫åÂÄ§Âåñ

        # „Åì„ÅÆ„Éï„Ç©„Éº„É´„Éâ„ÅÆÁµêÊûú„Çí‰øùÂ≠ò
        # Â≠¶ÁøíÊ∏à„Åø„É¢„Éá„É´„Çí‰øùÂ≠ò
        model.save_model(result_dir_fold / "model.json")
        # ÊúÄÈÅ©„Å™ÈñæÂÄ§„Çí‰øùÂ≠ò
        with open(result_dir_fold / "threshold.txt", "w") as f:
            f.write(f"{threshold}\n")

        # ÁâπÂæ¥Èáè„ÅÆÈáçË¶ÅÂ∫¶„Çí„Éó„É≠„ÉÉ„ÉàÔºà‰∏ä‰Ωç20ÂÄã„ÄÅ„Ç≤„Ç§„É≥Âü∫Ê∫ñÔºâ
        xgb.plot_importance(model, max_num_features=20, importance_type="gain", values_format="{v:.2f}")
        plt.tight_layout()
        plt.savefig(result_dir_fold / "feature_importance.png")
        plt.close()

        # Â≠¶ÁøíÊõ≤Á∑öÔºàÂØæÊï∞ÊêçÂ§±„ÅÆÊé®ÁßªÔºâ„Çí„Éó„É≠„ÉÉ„Éà
        lgb.plot_metric(evals_result, metric="logloss")
        plt.tight_layout()
        plt.savefig(result_dir_fold / "metric.png")
        plt.close()

        # „É°„É¢„É™„ÇíËß£Êîæ
        gc.collect()

    # „Åô„Åπ„Å¶„ÅÆ„Éï„Ç©„Éº„É´„Éâ„ÅÆ‰∫àÊ∏¨ÁµêÊûú„Çí„Éá„Éº„Çø„Éï„É¨„Éº„É†„Å´„Åæ„Å®„ÇÅ„Çã
    oof_prediction_dataframe = indices.with_columns(
        pl.Series("fold", folds, dtype=pl.Int8),  # „Éï„Ç©„Éº„É´„ÉâÁï™Âè∑
        pl.Series("prediction", oof_predictions, dtype=pl.Float32),  # ‰∫àÊ∏¨Á¢∫Áéá
        pl.Series("predicted_label", oof_prediction_labels, dtype=pl.Int8),  # ‰∫àÊ∏¨„É©„Éô„É´
    )
    
    # ÂÖ®‰Ωì„ÅÆF1„Çπ„Ç≥„Ç¢„ÇíË®àÁÆó
    f1 = f1_score(labels, oof_prediction_labels, zero_division=0)
    # F1„Çπ„Ç≥„Ç¢„Çí„Éï„Ç°„Ç§„É´„Å´‰øùÂ≠ò
    with open(result_dir / "f1.txt", "w") as f:
        f.write(f"{f1}\n")

    # ‰∫àÊ∏¨ÁµêÊûú„Éá„Éº„Çø„Éï„É¨„Éº„É†„Çí‰øùÂ≠ò
    oof_prediction_dataframe.write_parquet(result_dir / "oof_predictions.parquet")

    # F1„Çπ„Ç≥„Ç¢„ÇíËøî„Åô
    return f1

## (self) lab-behaviorÊØé„Å´Â≠¶Áøí-Ê§úË®º„ÇíË°å„ÅÜ

In [None]:
groups = train_self_behavior_dataframe.group_by("lab_id", "behavior", maintain_order=True)
total_groups = len(list(groups))
start_time = time.perf_counter()

for idx, ((lab_id, behavior), group) in tqdm(enumerate(groups), total=total_groups):
    if idx == 0:
        tqdm.write(
            f"|{'LAB':^25}|{'BEHAVIOR':^15}|{'SAMPLES':^10}|{'POSITIVE':^10}|{'FEATURES':^10}|{'F1':^10}|{'ELAPSED TIME':^15}|",
            end="\n",
        )

    tqdm.write(f"|{lab_id:^25}|{behavior:^15}|", end="")
    index_list = []
    feature_list = []
    label_list = []

    for row in group.rows(named=True):
        video_id = row["video_id"]
        agent = row["agent"]

        agent_mouse_id = int(re.search(r"mouse(\d+)", agent).group(1))

        data = pl.scan_parquet(WORKING_DIR / "self_features" / f"{video_id}.parquet").filter(
            (pl.col("agent_mouse_id") == agent_mouse_id)
        )
        index = data.select(INDEX_COLS).collect(engine="streaming")
        feature = data.select(pl.exclude(INDEX_COLS)).collect(engine="streaming")

        # read annotation
        annotation_path = TRAIN_ANNOTATION_DIR / lab_id / f"{video_id}.parquet"
        if annotation_path.exists():
            annotation = (
                pl.scan_parquet(annotation_path)
                .filter((pl.col("action") == behavior) & (pl.col("agent_id") == agent_mouse_id))
                .collect()
            )
        else:
            annotation = pl.DataFrame(
                schema={
                    "agent_id": pl.Int8,
                    "target_id": pl.Int8,
                    "action": str,
                    "start_frame": pl.Int16,
                    "stop_frame": pl.Int16,
                }
            )

        label_frames = set()
        for annotation_row in annotation.rows(named=True):
            label_frames.update(range(annotation_row["start_frame"], annotation_row["stop_frame"]))
        label = index.select(pl.col("video_frame").is_in(label_frames).cast(pl.Int8).alias("label"))

        if label.get_column("label").sum() == 0:
            continue

        index_list.append(index)
        feature_list.append(feature)
        label_list.append(label.get_column("label"))

    if not index_list:
        elapsed_time = datetime.timedelta(seconds=int(time.perf_counter() - start_time))
        tqdm.write(f"{0:>10,}|{0:>10,}|{0:>10,}|{'-':>10}|{str(elapsed_time):>15}|", end="\n")
        continue

    indices = pl.concat(index_list, how="vertical")
    features = pl.concat(feature_list, how="vertical")
    labels = pl.concat(label_list, how="vertical")

    del index_list, feature_list, label_list
    gc.collect()

    tqdm.write(f"{len(indices):>10,}|{labels.sum():>10,}|{len(features.columns):>10,}|", end="")

    f1 = train_validate(lab_id, behavior, indices, features, labels)
    tqdm.write(f"{f1:>10.2f}|", end="")

    elapsed_time = datetime.timedelta(seconds=int(time.perf_counter() - start_time))
    tqdm.write(f"{str(elapsed_time):>15}|", end="\n")

    gc.collect()

  0%|          | 0/27 [00:00<?, ?it/s]

|           LAB           |   BEHAVIOR    | SAMPLES  | POSITIVE | FEATURES |    F1    | ELAPSED TIME  |
|     AdaptableSnail      |     rear      |   660,348|    85,313|        69|

Exception ignored in: <function ResourceTracker.__del__ at 0x7f29a4690f40>
Traceback (most recent call last):
  File "/root/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py", line 77, in __del__
  File "/root/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py", line 86, in _stop
  File "/root/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py", line 111, in _stop_locked
ChildProcessError: [Errno 10] No child processes
Exception ignored in: <function ResourceTracker.__del__ at 0x7f0e293e8f40>
Traceback (most recent call last):
  File "/root/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py", line 77, in __del__
  File "/root/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py", line 86, in _stop
  File "/root/

      0.62|        0:13:18|
|         CRIM13          |     rear      |   179,132|    12,042|        69|      0.36|        0:27:10|
|         CRIM13          |   selfgroom   |   205,533|    14,472|        69|      0.36|        0:39:01|
|      CalMS21_task1      | genitalgroom  |   102,445|     6,270|        69|      0.66|        0:45:57|
|       ElegantMink       |     rear      |         0|         0|         0|         -|        0:45:58|
|       ElegantMink       |   selfgroom   |         0|         0|         0|         -|        0:45:59|
|       GroovyShrew       |     rear      |   899,280|    50,768|        69|      0.53|        1:00:45|
|       GroovyShrew       |     rest      |   530,886|    87,573|        69|      0.68|        1:08:31|
|       GroovyShrew       |   selfgroom   |   877,773|    22,893|        69|      0.31|        1:21:08|
|       GroovyShrew       |     climb     |   295,943|     8,647|        69|      0.37|        1:30:57|
|       GroovyShrew       |      dig

## (pair) lab-behaviorÊØé„Å´Â≠¶Áøí-Ê§úË®º„ÇíË°å„ÅÜ

In [None]:
groups = train_pair_behavior_dataframe.group_by("lab_id", "behavior", maintain_order=True)
total_groups = len(list(groups))
start_time = time.perf_counter()

for idx, ((lab_id, behavior), group) in tqdm(enumerate(groups), total=total_groups):
    if idx == 0:
        tqdm.write(
            f"|{'LAB':^25}|{'BEHAVIOR':^15}|{'SAMPLES':^10}|{'POSITIVE':^10}|{'FEATURES':^10}|{'F1':^10}|{'ELAPSED TIME':^15}|",
            end="\n",
        )

    tqdm.write(f"|{lab_id:^25}|{behavior:^15}|", end="")
    index_list = []
    feature_list = []
    label_list = []

    for row in group.rows(named=True):
        video_id = row["video_id"]
        agent = row["agent"]
        target = row["target"]

        agent_mouse_id = int(re.search(r"mouse(\d+)", agent).group(1))
        target_mouse_id = int(re.search(r"mouse(\d+)", target).group(1))

        data = pl.scan_parquet(WORKING_DIR / "pair_features" / f"{video_id}.parquet").filter(
            (pl.col("agent_mouse_id") == agent_mouse_id) & (pl.col("target_mouse_id") == target_mouse_id)
        )
        index = data.select(INDEX_COLS).collect(engine="streaming")
        feature = data.select(pl.exclude(INDEX_COLS)).collect(engine="streaming")

        # read annotation
        annotation_path = TRAIN_ANNOTATION_DIR / lab_id / f"{video_id}.parquet"
        if annotation_path.exists():
            annotation = (
                pl.scan_parquet(annotation_path)
                .filter(
                    (pl.col("action") == behavior)
                    & (pl.col("agent_id") == agent_mouse_id)
                    & (pl.col("target_id") == target_mouse_id)
                )
                .collect()
            )
        else:
            annotation = pl.DataFrame(
                schema={
                    "agent_id": pl.Int8,
                    "target_id": pl.Int8,
                    "action": str,
                    "start_frame": pl.Int16,
                    "stop_frame": pl.Int16,
                }
            )

        label_frames = set()
        for annotation_row in annotation.rows(named=True):
            label_frames.update(range(annotation_row["start_frame"], annotation_row["stop_frame"]))
        label = index.select(pl.col("video_frame").is_in(label_frames).cast(pl.Int8).alias("label"))

        if label.get_column("label").sum() == 0:
            continue

        index_list.append(index)
        feature_list.append(feature)
        label_list.append(label.get_column("label"))

    if not index_list:
        elapsed_time = datetime.timedelta(seconds=int(time.perf_counter() - start_time))
        tqdm.write(f"{0:>10,}|{0:>10,}|{0:>10,}|{'-':>10}|{str(elapsed_time):>15}|", end="\n")
        continue

    indices = pl.concat(index_list, how="vertical")
    features = pl.concat(feature_list, how="vertical")
    labels = pl.concat(label_list, how="vertical")

    del index_list, feature_list, label_list
    gc.collect()

    tqdm.write(f"{len(indices):>10,}|{labels.sum():>10,}|{len(features.columns):>10,}|", end="")

    f1 = train_validate(lab_id, behavior, indices, features, labels)
    tqdm.write(f"{f1:>10.2f}|", end="")

    elapsed_time = datetime.timedelta(seconds=int(time.perf_counter() - start_time))
    tqdm.write(f"{str(elapsed_time):>15}|", end="\n")

    gc.collect()

In [None]:
%%writefile robustify.py

def robustify(submission: pl.DataFrame, dataset: pl.DataFrame, train_test: str = "train"):
    traintest_directory = INPUT_DIR / f"{train_test}_tracking"

    old_submission = submission.clone()
    submission = submission.filter(pl.col("start_frame") < pl.col("stop_frame"))
    if len(submission) != len(old_submission):
        print("ERROR: Dropped frames with start >= stop")

    old_submission = submission.clone()
    group_list = []
    for _, group in submission.group_by("video_id", "agent_id", "target_id"):
        group = group.sort("start_frame")
        mask = np.ones(len(group), dtype=bool)
        last_stop_frame = 0
        for i, row in enumerate(group.rows(named=True)):
            if row["start_frame"] < last_stop_frame:
                mask[i] = False
            else:
                last_stop_frame = row["stop_frame"]
        group_list.append(group.filter(pl.Series("mask", mask)))

    submission = pl.concat(group_list)

    if len(submission) != len(old_submission):
        print("ERROR: Dropped duplicate frames")

    s_list = []
    for row in dataset.rows(named=True):
        lab_id = row["lab_id"]
        video_id = row["video_id"]
        if row["behaviors_labeled"] is None:
            continue

        if video_id in submission.get_column("video_id").to_list():
            continue

        if isinstance(row["behaviors_labeled"], str):
            continue

        print(f"Video {video_id} has no predictions.")

        path = traintest_directory / f"/{lab_id}/{video_id}.parquet"
        vid = pd.read_parquet(path)

        vid_behaviors = json.loads(row["behaviors_labeled"])
        vid_behaviors = sorted(list({b.replace("'", "") for b in vid_behaviors}))
        vid_behaviors = [b.split(",") for b in vid_behaviors]
        vid_behaviors = pd.DataFrame(vid_behaviors, columns=["agent", "target", "action"])

        start_frame = vid.video_frame.min()
        stop_frame = vid.video_frame.max() + 1

        for (agent, target), actions in vid_behaviors.groupby(["agent", "target"]):
            batch_length = int(np.ceil((stop_frame - start_frame) / len(actions)))
            for i, action_row in enumerate(actions.itertuples(index=False)):
                batch_start = start_frame + i * batch_length
                batch_stop = min(batch_start + batch_length, stop_frame)
                s_list.append((video_id, agent, target, action_row["action"], batch_start, batch_stop))

    if len(s_list) > 0:
        submission = pd.concat(
            [
                submission,
                pd.DataFrame(s_list, columns=["video_id", "agent_id", "target_id", "action", "start_frame", "stop_frame"]),
            ]
        )
        print("ERROR: Filled empty videos")

    return submission

## Ê§úË®º„Éá„Éº„Çø„Å´ÂØæ„Åô„Çã‰∫àÊ∏¨ÂÄ§„ÇíÈõÜË®à

In [None]:
# „Ç∞„É´„Éº„Éó„Åî„Å®„ÅÆOut-of-Fold‰∫àÊ∏¨ÁµêÊûú„Çí‰øùÂ≠ò„Åô„Çã„É™„Çπ„Éà
group_oof_predictions = []

# „Éá„Éº„Çø„Çí lab_id, video_id, agent, target „Åß„Ç∞„É´„Éº„ÉóÂåñ
# maintain_order=True „ÅßÂÖÉ„ÅÆÈ†ÜÂ∫è„Çí‰øùÊåÅ
groups = train_behavior_dataframe.group_by("lab_id", "video_id", "agent", "target", maintain_order=True)

# ÂêÑ„Ç∞„É´„Éº„Éó„Å´ÂØæ„Åó„Å¶Âá¶ÁêÜ„ÇíÂÆüË°åÔºàÈÄ≤Êçó„Éê„Éº„ÇíË°®Á§∫Ôºâ
for (lab_id, video_id, agent, target), group in tqdm(groups, total=len(list(groups))):
    # agentÔºàË°åÂãï‰∏ª‰ΩìÔºâ„Åã„Çâ„Éû„Ç¶„ÇπID„ÇíÊäΩÂá∫
    # ‰æã: "mouse1" ‚Üí 1
    agent_mouse_id = int(re.search(r"mouse(\d+)", agent).group(1))
    
    # targetÔºàË°åÂãïÂØæË±°Ôºâ„Åã„Çâ„Éû„Ç¶„ÇπID„ÇíÊäΩÂá∫
    # "self"ÔºàËá™ÂàÜËá™Ë∫´Ôºâ„ÅÆÂ†¥Âêà„ÅØ -1„ÄÅ„Åù„Çå‰ª•Â§ñ„ÅØ„Éû„Ç¶„ÇπID„ÇíÂèñÂæó
    target_mouse_id = -1 if target == "self" else int(re.search(r"mouse(\d+)", target).group(1))

    # „Åì„ÅÆ„Ç∞„É´„Éº„Éó„ÅÆÂêÑË°åÂãï„ÅÆ‰∫àÊ∏¨ÁµêÊûú„Çí‰øùÂ≠ò„Åô„Çã„É™„Çπ„Éà
    prediction_dataframe_list = []

    # „Ç∞„É´„Éº„ÉóÂÜÖ„ÅÆÂêÑË°åÔºàÂêÑË°åÂãïÔºâ„ÇíÂá¶ÁêÜ
    for row in group.rows(named=True):
        behavior = row["behavior"]  # Ë°åÂãï„ÅÆÁ®ÆÈ°ûÔºà‰æã: "grooming", "sniffing"„Å™„Å©Ôºâ

        # „Åì„ÅÆË°åÂãï„ÅÆOOF‰∫àÊ∏¨ÁµêÊûú„Éï„Ç°„Ç§„É´„ÅÆ„Éë„Çπ„ÇíÊßãÁØâ
        oof_path = WORKING_DIR / "results" / lab_id / behavior / "oof_predictions.parquet"
        
        # „Éï„Ç°„Ç§„É´„ÅåÂ≠òÂú®„Åó„Å™„ÅÑÂ†¥Âêà„ÅØ„Çπ„Ç≠„ÉÉ„Éó
        if not oof_path.exists():
            continue

        # ‰∫àÊ∏¨ÁµêÊûú„ÇíË™≠„ÅøËæº„Åø„ÄÅË©≤ÂΩì„Åô„Çãvideo_id„ÄÅagent„ÄÅtarget„Åß„Éï„Ç£„É´„Çø„É™„É≥„Ç∞
        prediction = (
            pl.scan_parquet(oof_path)  # ÈÅÖÂª∂Ë™≠„ÅøËæº„ÅøÔºà„É°„É¢„É™ÂäπÁéáÁöÑÔºâ
            .filter(
                (pl.col("video_id") == video_id)  # ÂãïÁîªID„Åå‰∏ÄËá¥
                & (pl.col("agent_mouse_id") == agent_mouse_id)  # Ë°åÂãï‰∏ª‰Ωì„Åå‰∏ÄËá¥
                & (pl.col("target_mouse_id") == target_mouse_id)  # Ë°åÂãïÂØæË±°„Åå‰∏ÄËá¥
            )
            .select(
                *INDEX_COLS,  # „Ç§„É≥„Éá„ÉÉ„ÇØ„ÇπÂàó„ÇíÈÅ∏Êäû
                # ‰∫àÊ∏¨Á¢∫Áéá„Å®‰∫àÊ∏¨„É©„Éô„É´„ÇíÊéõ„ÅëÂêà„Çè„Åõ„Å¶„ÄÅ„Åì„ÅÆË°åÂãï„ÅÆ„Çπ„Ç≥„Ç¢„ÇíË®àÁÆó
                # ‰∫àÊ∏¨„É©„Éô„É´„Åå0„ÅÆÂ†¥Âêà„ÅØ„Çπ„Ç≥„Ç¢„ÇÇ0„Å´„Å™„Çã
                (pl.col("prediction") * pl.col("predicted_label")).alias(behavior)
            )
            .collect()  # ÂÆüÈöõ„Å´„Éá„Éº„Çø„ÇíË™≠„ÅøËæº„Çì„ÅßÂÆüË°å
        )

        # „Éï„Ç£„É´„ÇøÂæå„Å´Ë°å„Åå„Å™„ÅÑÂ†¥ÂêàÔºàË©≤ÂΩì„Éá„Éº„Çø„Åå„Å™„ÅÑÂ†¥ÂêàÔºâ„ÅØ„Çπ„Ç≠„ÉÉ„Éó
        if len(prediction) == 0:
            continue

        # „Åì„ÅÆË°åÂãï„ÅÆ‰∫àÊ∏¨ÁµêÊûú„Çí„É™„Çπ„Éà„Å´ËøΩÂä†
        prediction_dataframe_list.append(prediction)

    # „Åì„ÅÆ„Ç∞„É´„Éº„Éó„Åß‰∫àÊ∏¨ÁµêÊûú„Åå1„Å§„ÇÇ„Å™„ÅÑÂ†¥Âêà„ÅØ„Çπ„Ç≠„ÉÉ„Éó
    if not prediction_dataframe_list:
        continue

    # Ë§áÊï∞„ÅÆË°åÂãï„ÅÆ‰∫àÊ∏¨ÁµêÊûú„ÇíÊ®™ÊñπÂêë„Å´ÁµêÂêà
    # how="align"„Åß„ÄÅ„Ç§„É≥„Éá„ÉÉ„ÇØ„ÇπÂàó„ÇíÂü∫Ê∫ñ„Å´Êï¥Âàó„Åó„Å¶ÁµêÂêà
    prediction_dataframe = pl.concat(prediction_dataframe_list, how="align")

    # „Ç§„É≥„Éá„ÉÉ„ÇØ„ÇπÂàó‰ª•Â§ñ„ÅÆÂàóÂêçÔºàÂêÑË°åÂãïÂêçÔºâ„ÇíÂèñÂæó
    cols = prediction_dataframe.select(pl.exclude(INDEX_COLS)).columns
    
    # ÂêÑ„Éï„É¨„Éº„É†„ÅßÊúÄ„ÇÇÁ¢∫‰ø°Â∫¶„ÅÆÈ´ò„ÅÑË°åÂãï„ÇíÈÅ∏Êäû
    prediction_labels_dataframe = prediction_dataframe.with_columns(
        pl.struct(pl.exclude(INDEX_COLS))  # ÂÖ®Ë°åÂãï„ÅÆ„Çπ„Ç≥„Ç¢„ÇíÊßãÈÄ†‰Ωì„Å´„Åæ„Å®„ÇÅ„Çã
        .map_elements(
            # ÂêÑË°å„Å´ÂØæ„Åó„Å¶ÂÆüË°å„Åô„ÇãÈñ¢Êï∞
            lambda row: "none" if sum(row.values()) == 0  # ÂÖ®„Çπ„Ç≥„Ç¢„Åå0„Å™„Çâ"none"
                       else (cols[np.argmax(list(row.values()))]),  # ÊúÄÂ§ß„Çπ„Ç≥„Ç¢„ÅÆË°åÂãï„ÇíÈÅ∏Êäû
            return_dtype=pl.String,
        )
        .alias("prediction")  # Êñ∞„Åó„ÅÑÂàóÂêç„Çí"prediction"„Å®„Åô„Çã
    ).select(INDEX_COLS + ["prediction"])  # „Ç§„É≥„Éá„ÉÉ„ÇØ„ÇπÂàó„Å®‰∫àÊ∏¨Âàó„ÅÆ„Åø„ÇíÈÅ∏Êäû

    # ÈÄ£Á∂ö„Åô„ÇãÂêå„ÅòË°åÂãï„Çí„Åæ„Å®„ÇÅ„Å¶„ÄÅË°åÂãï„ÅÆÈñãÂßã„Å®ÁµÇ‰∫Ü„Éï„É¨„Éº„É†„ÇíÁâπÂÆö
    group_oof_prediction = (
        prediction_labels_dataframe
        .filter((pl.col("prediction") != pl.col("prediction").shift(1)))  # Ââç„ÅÆË°å„Å®Áï∞„Å™„ÇãË°åÂãï„ÅÆ„Åø„ÇíÊäΩÂá∫ÔºàÂ¢ÉÁïåÁÇπÔºâ
        .with_columns(pl.col("video_frame").shift(-1).alias("stop_frame"))  # Ê¨°„ÅÆÂ¢ÉÁïåÁÇπ„ÇíÁµÇ‰∫Ü„Éï„É¨„Éº„É†„Å®„Åô„Çã
        .filter(pl.col("prediction") != "none")  # "none"ÔºàË°åÂãï„Å™„ÅóÔºâ„ÇíÈô§Â§ñ
        .select(
            pl.col("video_id"),  # ÂãïÁîªID
            ("mouse" + pl.col("agent_mouse_id").cast(str)).alias("agent_id"),  # "mouse1"ÂΩ¢Âºè„Å´Â§âÊèõ
            # target_mouse_id„Åå-1„Å™„Çâ"self"„ÄÅ„Åù„Çå‰ª•Â§ñ„ÅØ"mouse2"ÂΩ¢Âºè„Å´Â§âÊèõ
            pl.when(pl.col("target_mouse_id") == -1)
            .then(pl.lit("self"))
            .otherwise("mouse" + pl.col("target_mouse_id").cast(str))
            .alias("target_id"),
            pl.col("prediction").alias("action"),  # Ë°åÂãïÂêç
            pl.col("video_frame").alias("start_frame"),  # ÈñãÂßã„Éï„É¨„Éº„É†
            pl.col("stop_frame"),  # ÁµÇ‰∫Ü„Éï„É¨„Éº„É†
        )
    )

    # „Åì„ÅÆ„Ç∞„É´„Éº„Éó„ÅÆ‰∫àÊ∏¨ÁµêÊûú„Çí„É™„Çπ„Éà„Å´ËøΩÂä†
    group_oof_predictions.append(group_oof_prediction)

%run -i robustify.py

oof_predictions = pl.concat(group_oof_predictions, how="vertical")
oof_predictions = robustify(oof_predictions, train_dataframe, train_test="train")
oof_predictions.with_row_index("row_id").write_csv(WORKING_DIR / "oof_predictions.csv")

## Ê§úË®º„Éá„Éº„Çø„Å´„Çà„Çã„Çπ„Ç≥„Ç¢„ÇíË®àÁÆó„Åô„Çã

In [None]:

def compute_validation_metrics(submission, verbose=True):
    """Compute and display validation metrics for single vs pair behaviors."""
    # solution_df
    dataset = pl.read_csv(INPUT_DIR / "train.csv").to_pandas()

    solution = []
    for _, row in dataset.iterrows():
        lab_id = row["lab_id"]
        if lab_id.startswith("MABe22"):
            continue

        video_id = row["video_id"]
        path = TRAIN_ANNOTATION_DIR / lab_id / f"{video_id}.parquet"
        try:
            annot = pd.read_parquet(path)
        except FileNotFoundError:
            continue

        annot["lab_id"] = lab_id
        annot["video_id"] = video_id
        annot["behaviors_labeled"] = row["behaviors_labeled"]
        annot["target_id"] = np.where(
            annot.target_id != annot.agent_id, annot["target_id"].apply(lambda s: f"mouse{s}"), "self"
        )
        annot["agent_id"] = annot["agent_id"].apply(lambda s: f"mouse{s}")
        solution.append(annot)

    solution = pd.concat(solution)

    try:
        # Separate single and pair behaviors
        submission_single = submission[submission["target_id"] == "self"].copy()
        submission_pair = submission[submission["target_id"] != "self"].copy()

        # Filter solution to match submission videos
        solution_videos = set(submission["video_id"].unique())
        solution = solution[solution["video_id"].isin(solution_videos)]

        if len(solution) == 0:
            return

        # Compute overall F1 score
        overall_f1 = score(solution, submission, "row_id", beta=1.0)
        print(f"\n{'=' * 60}")
        print("PERFORMANCE METRICS")
        print(f"{'=' * 60}")
        print(f"Overall F1 Score: {overall_f1:.4f}")
        print(f"Total predictions: {len(submission)}")
        print(f"  - Single behaviors: {len(submission_single)}")
        print(f"  - Pair behaviors: {len(submission_pair)}")

        # Compute per-action F1 scores using existing scoring function
        solution_pl = pl.DataFrame(solution)
        submission_pl = pl.DataFrame(submission)

        # Add label_key and prediction_key
        solution_pl = solution_pl.with_columns(
            pl.concat_str(
                [
                    pl.col("video_id").cast(pl.Utf8),
                    pl.col("agent_id").cast(pl.Utf8),
                    pl.col("target_id").cast(pl.Utf8),
                    pl.col("action"),
                ],
                separator="_",
            ).alias("label_key"),
        )
        submission_pl = submission_pl.with_columns(
            pl.concat_str(
                [
                    pl.col("video_id").cast(pl.Utf8),
                    pl.col("agent_id").cast(pl.Utf8),
                    pl.col("target_id").cast(pl.Utf8),
                    pl.col("action"),
                ],
                separator="_",
            ).alias("prediction_key"),
        )

        # Group by action and compute metrics
        action_stats = defaultdict(lambda: {"single": {"count": 0, "f1": 0.0}, "pair": {"count": 0, "f1": 0.0}})

        for lab in solution_pl["lab_id"].unique():
            lab_solution = solution_pl.filter(pl.col("lab_id") == lab).clone()
            lab_videos = set(lab_solution["video_id"].unique())
            lab_submission = submission_pl.filter(pl.col("video_id").is_in(lab_videos)).clone()

            # Compute per-action F1 using same logic as single_lab_f1
            label_frames = defaultdict(set)
            prediction_frames = defaultdict(set)

            for row in lab_solution.to_dicts():
                label_frames[row["label_key"]].update(range(row["start_frame"], row["stop_frame"]))

            for row in lab_submission.to_dicts():
                key = row["prediction_key"]
                prediction_frames[key].update(range(row["start_frame"], row["stop_frame"]))

            for key in set(list(label_frames.keys()) + list(prediction_frames.keys())):
                action = key.split("_")[-1]
                mode = "single" if "self" in key else "pair"

                pred_frames = prediction_frames.get(key, set())
                label_frames_set = label_frames.get(key, set())

                tp = len(pred_frames & label_frames_set)
                fn = len(label_frames_set - pred_frames)
                fp = len(pred_frames - label_frames_set)

                if tp + fn + fp > 0:
                    f1 = (1 + 1**2) * tp / ((1 + 1**2) * tp + 1**2 * fn + fp)
                    action_stats[action][mode]["count"] += 1
                    action_stats[action][mode]["f1"] += f1

        # Print per-action summary
        print("\nPer-Action Performance Summary:")
        print(f"{'-' * 60}")
        print(f"{'Action':<20} {'Mode':<10} {'Count':<10} {'Avg F1':<10}")
        print(f"{'-' * 60}")

        for action in sorted(action_stats.keys()):
            for mode in ["single", "pair"]:
                stats = action_stats[action][mode]
                if stats["count"] > 0:
                    avg_f1 = stats["f1"] / stats["count"]
                    print(f"{action:<20} {mode:<10} {stats['count']:<10} {avg_f1:<10.4f}")

        # Summary by mode
        single_actions = [a for a in action_stats.keys() if action_stats[a]["single"]["count"] > 0]
        pair_actions = [a for a in action_stats.keys() if action_stats[a]["pair"]["count"] > 0]

        if single_actions:
            single_avg_f1 = np.mean(
                [
                    action_stats[a]["single"]["f1"] / action_stats[a]["single"]["count"]
                    for a in single_actions
                    if action_stats[a]["single"]["count"] > 0
                ]
            )
            print(f"\nSingle behaviors: {len(single_actions)} actions, Avg F1: {single_avg_f1:.4f}")

        if pair_actions:
            pair_avg_f1 = np.mean(
                [
                    action_stats[a]["pair"]["f1"] / action_stats[a]["pair"]["count"]
                    for a in pair_actions
                    if action_stats[a]["pair"]["count"] > 0
                ]
            )
            print(f"Pair behaviors: {len(pair_actions)} actions, Avg F1: {pair_avg_f1:.4f}")

        print(f"{'=' * 60}\n")

    except Exception as e:
        if verbose:
            error_msg = str(e)
            if len(error_msg) > 200:
                error_msg = error_msg[:200] + "..."
            print(f"\nWarning: Could not compute validation metrics: {error_msg}")
            if verbose:
                print(f"Traceback: {traceback.format_exc()[:300]}")

compute_validation_metrics(submission=pd.read_csv(WORKING_DIR / "oof_predictions.csv"))