In [1]:
import sys
import os
from pathlib import Path
import shutil

# ============================================================
# Import core packages FIRST (before chdir!)
# ============================================================
import gc
import re
import ast
import itertools
import numpy as np
import polars as pl
import pandas as pd
import json
import xgboost as xgb
from tqdm.auto import tqdm

print("✅ Core packages imported")

# ============================================================
# NOW setup paths
# ============================================================
INPUT_DIR = Path("/kaggle/input/MABe-mouse-behavior-detection")
TEST_TRACKING_DIR = INPUT_DIR / "test_tracking"
MODELS_DATASET = Path("/kaggle/input/mabe-trained-models-xgboost-sampled-300k/kaggle_upload_v14/kaggle_upload")

WORKING_DIR = Path("/kaggle/working")
WORKING_DIR.mkdir(parents=True, exist_ok=True)

SELF_FEATURE_DIR = WORKING_DIR / "self_features"
PAIR_FEATURE_DIR = WORKING_DIR / "pair_features"
SELF_FEATURE_DIR.mkdir(parents=True, exist_ok=True)
PAIR_FEATURE_DIR.mkdir(parents=True, exist_ok=True)

MODELS_DIR = MODELS_DATASET / "results"

# ============================================================
# Copy helper scripts
# ============================================================
print("Copying helper scripts...")
for script in ["self_features.py", "pair_features.py", "robustify.py"]:
    src = MODELS_DATASET / script
    dst = WORKING_DIR / script
    if src.exists():
        shutil.copy(src, dst)
        print(f"  ✅ {script}")
    else:
        raise FileNotFoundError(f"Helper script not found: {src}")

# ============================================================
# Load helper scripts (now safe because polars already imported)
# ============================================================
print("Loading helper functions...")
# Add working dir to path temporarily
sys.path.insert(0, str(WORKING_DIR))

with open(WORKING_DIR / "self_features.py") as f:
    exec(f.read(), globals())
with open(WORKING_DIR / "pair_features.py") as f:
    exec(f.read(), globals())
with open(WORKING_DIR / "robustify.py") as f:
    exec(f.read(), globals())

print("✅ Helper functions loaded")

INDEX_COLS = [
    "video_id",
    "agent_mouse_id",
    "target_mouse_id",
    "video_frame",
]

BODY_PARTS = [
    "ear_left",
    "ear_right",
    "nose",
    "neck",
    "body_center",
    "lateral_left",
    "lateral_right",
    "hip_left",
    "hip_right",
    "tail_base",
    "tail_tip",
]

SELF_BEHAVIORS = [
    "biteobject",
    "climb",
    "dig",
    "exploreobject",
    "freeze",
    "genitalgroom",
    "huddle",
    "rear",
    "rest",
    "run",
    "selfgroom",
]

PAIR_BEHAVIORS = [
    "allogroom",
    "approach",
    "attack",
    "attemptmount",
    "avoid",
    "chase",
    "chaseattack",
    "defend",
    "disengage",
    "dominance",
    "dominancegroom",
    "dominancemount",
    "ejaculate",
    "escape",
    "flinch",
    "follow",
    "intromit",
    "mount",
    "reciprocalsniff",
    "shepherd",
    "sniff",
    "sniffbody",
    "sniffface",
    "sniffgenital",
    "submit",
    "tussle",
]

# ============================================================
# Helper functions
# ============================================================

def merge_close_segments(submission_df, max_gap=3, verbose=False):
    
    # Sort by grouping keys and start_frame
    df_sorted = submission_df.sort([
        "video_id", 
        "agent_id", 
        "target_id", 
        "action", 
        "start_frame"
    ])
    
    # Group by (video, agent, target, action)
    grouped = df_sorted.group_by(
        ["video_id", "agent_id", "target_id", "action"],
        maintain_order=True
    )
    
    merged_results = []
    total_merges = 0
    
    for key, group in grouped:
        segments = group.to_dicts()
        
        if len(segments) == 0:
            continue
        
        # Skip if only 1 segment (nothing to merge)
        if len(segments) == 1:
            merged_results.extend(segments)
            continue
        
        # Merge logic
        merged_segments = []
        current = segments[0].copy()
        
        for i in range(1, len(segments)):
            next_seg = segments[i]
            gap = next_seg["start_frame"] - current["stop_frame"]
            
            if gap <= max_gap:
                # Merge: extend current segment to cover next one
                current["stop_frame"] = next_seg["stop_frame"]
                total_merges += 1
            else:
                # No merge: save current and start new
                merged_segments.append(current)
                current = next_seg.copy()
        
        # Don't forget the last segment
        merged_segments.append(current)
        merged_results.extend(merged_segments)
    
    # Convert back to DataFrame
    if not merged_results:
        return submission_df.head(0)  # Empty with same schema
    
    result_df = pl.DataFrame(merged_results)
    
    # Ensure correct column order
    result_df = result_df.select([
        "video_id",
        "agent_id", 
        "target_id",
        "action",
        "start_frame",
        "stop_frame"
    ])
    
    if verbose:
        reduction_pct = 100 * (len(submission_df) - len(result_df)) / len(submission_df)
        print(f"Segment merging: {total_merges} merges, "
              f"{len(submission_df)} → {len(result_df)} segments "
              f"({reduction_pct:.1f}% reduction)")
    
    return result_df


def parse_behaviors_column(behaviors_str: str):
    """
    behaviors_labeled is stored as a Python like list of tuples.
    Use ast.literal_eval for safety instead of eval.

    Example:
      "[('mouse1','mouse2','sniff'), ('mouse2','mouse1','sniff')]"
    """
    if behaviors_str is None:
        return []
    return ast.literal_eval(behaviors_str)


def build_behavior_dataframe(test_df: pl.DataFrame) -> pl.DataFrame:
    """
    Expand behaviors_labeled into one row per (lab, video, agent, target, behavior).
    """
    behavior_df = (
        test_df
        .filter(pl.col("behaviors_labeled").is_not_null())
        .select(["lab_id", "video_id", "behaviors_labeled"])
        .with_columns(
            pl.col("behaviors_labeled")
            .map_elements(
                parse_behaviors_column,
                return_dtype=pl.List(pl.Utf8),
            )
            .alias("behaviors_labeled_list")
        )
        .explode("behaviors_labeled_list")
        .rename({"behaviors_labeled_list": "behaviors_labeled_element"})
        .with_columns(
            pl.col("behaviors_labeled_element").str.split(",").list.get(0)
            .str.replace_all("[()' ]", "")
            .alias("agent"),
            pl.col("behaviors_labeled_element").str.split(",").list.get(1)
            .str.replace_all("[()' ]", "")
            .alias("target"),
            pl.col("behaviors_labeled_element").str.split(",").list.get(2)
            .str.replace_all("[()' ]", "")
            .alias("behavior"),
        )
        .select(["lab_id", "video_id", "agent", "target", "behavior"])
    )
    return behavior_df


def extract_mouse_id(mouse_str: str) -> int:
    """
    Convert 'mouse1' -> 1, 'mouse2' -> 2, 'self' -> -1.
    """
    if mouse_str == "self":
        return -1
    m = re.search(r"mouse(\d+)", mouse_str)
    if m:
        return int(m.group(1))
    raise ValueError(f"Unexpected mouse id format: {mouse_str}")


def load_features_for_group(lab_id, video_id, agent, target):
    """
    Load per frame features for a given (lab, video, agent, target) group.
    Returns:
      index_df   - DataFrame with INDEX_COLS
      feature_df - DataFrame with feature columns only
    """
    agent_mouse_id = extract_mouse_id(agent)
    target_mouse_id = extract_mouse_id(target)

    if target == "self":
        feature_path = SELF_FEATURE_DIR / f"{video_id}.parquet"
        scan = pl.scan_parquet(feature_path).filter(
            pl.col("agent_mouse_id") == agent_mouse_id
        )
    else:
        feature_path = PAIR_FEATURE_DIR / f"{video_id}.parquet"
        scan = pl.scan_parquet(feature_path).filter(
            (pl.col("agent_mouse_id") == agent_mouse_id)
            & (pl.col("target_mouse_id") == target_mouse_id)
        )

    full_df = scan.collect()
    if full_df.height == 0:
        return full_df, full_df

    index_df = full_df.select(INDEX_COLS)
    feature_df = full_df.select(pl.exclude(INDEX_COLS))
    return index_df, feature_df


def load_models_for_behavior(lab_id: str, behavior: str):
    """
    Load all fold models and thresholds for a given (lab, behavior).
    Returns list of (model, threshold).
    """
    behavior_dir = MODELS_DIR / lab_id / behavior
    fold_dirs = sorted(behavior_dir.glob("fold_*"))
    models = []
    for fold_dir in fold_dirs:
        model_file = fold_dir / "model.json"
        thr_file = fold_dir / "threshold.txt"
        if not model_file.exists() or not thr_file.exists():
            continue
        with open(thr_file, "r") as f:
            threshold = float(f.read().strip())
        model = xgb.Booster(model_file=str(model_file))
        models.append((model, threshold))
    return models


def predict_for_group(
    lab_id: str,
    video_id: int,
    agent: str,
    target: str,
    group_behaviors: pl.DataFrame,
):
    """
    Run inference for one group of (lab_id, video_id, agent, target).

    Improvements:
      - Aggregate folds per behavior into a single score column
        (mean of thresholded probabilities).
      - Pick best behavior per frame using those aggregated scores.
    """
    index_df, feature_df = load_features_for_group(lab_id, video_id, agent, target)

    if feature_df.height == 0:
        return None

    # Create XGBoost DMatrix once per group and reuse across behaviors
    dtest = xgb.DMatrix(feature_df.to_pandas(), feature_names=feature_df.columns)

    prediction_df = index_df.clone()
    used_cols = []

    # Unique behaviors for this group
    unique_behaviors = (
        group_behaviors.select("behavior").unique()["behavior"].to_list()
    )

    for behavior in unique_behaviors:
        models = load_models_for_behavior(lab_id, behavior)
        if not models:
            # No trained model for this (lab, behavior) in the starter models
            continue

        # Aggregate over folds: mean of thresholded probabilities
        agg_scores = np.zeros(feature_df.height, dtype=np.float32)

        for model, threshold in models:
            probs = model.predict(dtest)
            labels = (probs >= threshold).astype(np.int8)
            agg_scores += probs * labels

        agg_scores /= max(len(models), 1)

        col_name = behavior
        prediction_df = prediction_df.with_columns(
            pl.Series(name=col_name, values=agg_scores)
        )
        used_cols.append(col_name)

    if not used_cols:
        return None

    # Pick best behavior per frame (over behaviors only)
    cols = used_cols

    prediction_labels_df = (
        prediction_df
        .with_columns(
            pl.struct(pl.col(cols))
            .map_elements(
                lambda row: (
                    "none"
                    if sum(row.values()) == 0
                    else cols[int(np.argmax(list(row.values())))]
                ),
                return_dtype=pl.String,
            )
            .alias("prediction")
        )
        .select(INDEX_COLS + ["prediction"])
    )

    # Convert per frame labels into time segments
    agent_mouse_id = extract_mouse_id(agent)
    target_mouse_id = extract_mouse_id(target)

    group_submission = (
        prediction_labels_df
        .filter(pl.col("prediction") != pl.col("prediction").shift(1))
        .with_columns(
            pl.col("video_frame").shift(-1).alias("stop_frame")
        )
        .filter(pl.col("prediction") != "none")
        .select(
            pl.col("video_id"),
            (pl.lit("mouse") + pl.lit(agent_mouse_id).cast(pl.Utf8)).alias("agent_id"),
            pl.when(pl.lit(target_mouse_id) == -1)
            .then(pl.lit("self"))
            .otherwise(pl.lit("mouse") + pl.lit(target_mouse_id).cast(pl.Utf8))
            .alias("target_id"),
            pl.col("prediction").alias("action"),
            pl.col("video_frame").alias("start_frame"),
            pl.col("stop_frame"),
        )
    )

    return group_submission

# ============================================================
# 1. Load metadata and build behavior table
# ============================================================
print("Loading test metadata...")
test_df = pl.read_csv(INPUT_DIR / "test.csv")

print("Building behavior table from behaviors_labeled...")
behavior_df = build_behavior_dataframe(test_df)

groups = list(
    behavior_df.group_by("lab_id", "video_id", "agent", "target", maintain_order=True)
)
print(f"Number of (lab, video, agent, target) groups: {len(groups)}")

# ============================================================
# 2. Pre compute features for all videos
# ============================================================
print("Generating self and pair features for all test videos...")

rows = test_df.rows(named=True)

for row in tqdm(rows, total=len(rows)):
    lab_id = row["lab_id"]
    video_id = row["video_id"]

    tracking_path = TEST_TRACKING_DIR / f"{lab_id}/{video_id}.parquet"
    tracking = pl.read_parquet(tracking_path)

    self_feat = make_self_features(metadata=row, tracking=tracking)
    pair_feat = make_pair_features(metadata=row, tracking=tracking)

    self_feat.write_parquet(SELF_FEATURE_DIR / f"{video_id}.parquet")
    pair_feat.write_parquet(PAIR_FEATURE_DIR / f"{video_id}.parquet")

    del self_feat, pair_feat, tracking
    gc.collect()

# ============================================================
# 3. Inference by group and segment construction
# ============================================================
print("Running inference and building group submissions...")

group_submissions = []

for (lab_id, video_id, agent, target), group in tqdm(groups, total=len(groups)):
    group_submission = predict_for_group(
        lab_id=lab_id,
        video_id=video_id,
        agent=agent,
        target=target,
        group_behaviors=group,
    )

    if group_submission is not None and group_submission.height > 0:
        group_submissions.append(group_submission)

if not group_submissions:
    raise RuntimeError(
        "No submissions were generated. "
        "Check that starter models exist under /kaggle/working/results."
    )

submission = pl.concat(group_submissions, how="vertical").sort(
    "video_id",
    "agent_id",
    "target_id",
    "action",
    "start_frame",
    "stop_frame",
)

print("Initial submission rows:", submission.height)

# ============================================================
# 4. Robustify and final clean up
# ============================================================
print("Running robustify on submission...")
submission = robustify(submission, test_df, train_test="test")

# Keep only valid intervals
submission = submission.filter(pl.col("start_frame") < pl.col("stop_frame"))

# Drop ultra short segments (likely noise)
submission = submission.with_columns(
    (pl.col("stop_frame") - pl.col("start_frame")).alias("duration")
).filter(pl.col("duration") >= 2).drop("duration")

print("Rows after robustify, validity check and duration filter:", submission.height)

submission = merge_close_segments(submission, max_gap=3, verbose=True)

print("Running robustify round 2 on submission...")
submission = robustify(submission, test_df, train_test="test")
submission = submission.filter(pl.col("start_frame") < pl.col("stop_frame"))
submission = submission.with_columns(
    (pl.col("stop_frame") - pl.col("start_frame")).alias("duration")
).filter(pl.col("duration") >= 2).drop("duration")
print("Rows after robustify, validity check and duration filter round 2:", submission.height)

# Add row_id and save as submission.csv
final_submission = submission.with_row_index("row_id")
final_path = WORKING_DIR / "submission.csv"
final_submission.write_csv(final_path)

print("Saved submission to:", final_path)
print(final_submission.head(10))

✅ Core packages imported
Copying helper scripts...
  ✅ self_features.py
  ✅ pair_features.py
  ✅ robustify.py
Loading helper functions...
✅ Helper functions loaded
Loading test metadata...
Building behavior table from behaviors_labeled...
Number of (lab, video, agent, target) groups: 16
Generating self and pair features for all test videos...


  0%|          | 0/1 [00:00<?, ?it/s]

Running inference and building group submissions...


  0%|          | 0/16 [00:00<?, ?it/s]

Initial submission rows: 2990
Running robustify on submission...
ERROR: Dropped frames with start >= stop
Rows after robustify, validity check and duration filter: 2032
Segment merging: 585 merges, 2032 → 1447 segments (28.8% reduction)
Running robustify round 2 on submission...
ERROR: Dropped duplicate frames
Rows after robustify, validity check and duration filter round 2: 1435
Saved submission to: /kaggle/working/submission.csv
shape: (10, 7)
┌────────┬───────────┬──────────┬───────────┬──────────┬─────────────┬────────────┐
│ row_id ┆ video_id  ┆ agent_id ┆ target_id ┆ action   ┆ start_frame ┆ stop_frame │
│ ---    ┆ ---       ┆ ---      ┆ ---       ┆ ---      ┆ ---         ┆ ---        │
│ u32    ┆ i64       ┆ str      ┆ str       ┆ str      ┆ i64         ┆ i64        │
╞════════╪═══════════╪══════════╪═══════════╪══════════╪═════════════╪════════════╡
│ 0      ┆ 438887472 ┆ mouse4   ┆ mouse1    ┆ avoid    ┆ 205         ┆ 210        │
│ 1      ┆ 438887472 ┆ mouse4   ┆ mouse1    ┆ 