In [1]:
# Import libraries
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)
import polars as pl
from tqdm import tqdm

In [2]:
# Declare dataset configs
dataset_configs = {
    'train_metadata_path': r'..\datasets\train.csv',
    'train_tracking_dir_path': r'..\datasets\train_tracking',
    'train_annotation_dir_path': r'..\datasets\train_annotation',
    'test_metadata_path': r'..\datasets\test.csv',
    'test_tracking_dir_path': r'..\datasets\test_tracking',
    'sample_submission_file_path': r'..\datasets\sample_submission.csv'
}

In [3]:
# Read metadata
metadata_df = pd.read_csv(dataset_configs['train_metadata_path'])
metadata_df.head()

Unnamed: 0,lab_id,video_id,mouse1_strain,mouse1_color,mouse1_sex,mouse1_id,mouse1_age,mouse1_condition,mouse2_strain,mouse2_color,mouse2_sex,mouse2_id,mouse2_age,mouse2_condition,mouse3_strain,mouse3_color,mouse3_sex,mouse3_id,mouse3_age,mouse3_condition,mouse4_strain,mouse4_color,mouse4_sex,mouse4_id,mouse4_age,mouse4_condition,frames_per_second,video_duration_sec,pix_per_cm_approx,video_width_pix,video_height_pix,arena_width_cm,arena_height_cm,arena_shape,arena_type,body_parts_tracked,behaviors_labeled,tracking_method
0,AdaptableSnail,44566106,CD-1 (ICR),white,male,10.0,8-12 weeks,wireless device,CD-1 (ICR),white,male,24.0,8-12 weeks,wireless device,CD-1 (ICR),white,male,38.0,8-12 weeks,wireless device,CD-1 (ICR),white,male,51.0,8-12 weeks,wireless device,30.0,615.6,16.0,1228,1068,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""head...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut
1,AdaptableSnail,143861384,CD-1 (ICR),white,male,3.0,8-12 weeks,,CD-1 (ICR),white,male,17.0,8-12 weeks,,CD-1 (ICR),white,male,31.0,8-12 weeks,,CD-1 (ICR),white,male,44.0,8-12 weeks,,25.0,3599.0,9.7,968,608,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""late...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut
2,AdaptableSnail,209576908,CD-1 (ICR),white,male,7.0,8-12 weeks,,CD-1 (ICR),white,male,21.0,8-12 weeks,,CD-1 (ICR),white,male,35.0,8-12 weeks,,CD-1 (ICR),white,male,48.0,8-12 weeks,,30.0,615.2,16.0,1266,1100,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""late...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut
3,AdaptableSnail,278643799,CD-1 (ICR),white,male,11.0,8-12 weeks,wireless device,CD-1 (ICR),white,male,25.0,8-12 weeks,wireless device,CD-1 (ICR),white,male,39.0,8-12 weeks,wireless device,,,,,,,30.0,619.7,16.0,1224,1100,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""head...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut
4,AdaptableSnail,351967631,CD-1 (ICR),white,male,14.0,8-12 weeks,,CD-1 (ICR),white,male,28.0,8-12 weeks,,CD-1 (ICR),white,male,42.0,8-12 weeks,,,,,,8-12 weeks,,30.0,602.6,16.0,1204,1068,60.0,60.0,square,familiar,"[""body_center"", ""ear_left"", ""ear_right"", ""late...","[""mouse1,mouse2,approach"", ""mouse1,mouse2,atta...",DeepLabCut


In [4]:
# Get list of lab_id and its video_ids
lab_video_map = {}
for idx, row in tqdm(metadata_df.iterrows(), desc="Mapping lab_id to video_ids", total=metadata_df.shape[0]):
    lab_id, video_id = row['lab_id'],row['video_id']
    if not os.path.exists(os.path.join(dataset_configs['train_tracking_dir_path'], lab_id, f"{video_id}.parquet")) or not os.path.exists(os.path.join(dataset_configs['train_annotation_dir_path'], lab_id, f"{video_id}.parquet")):
        continue
    video_data = {
        'video_id': video_id,
        'video_width_pix': row['video_width_pix'],
        'video_height_pix': row['video_height_pix']

    }
    if lab_id not in lab_video_map:
        lab_video_map[lab_id] = [video_data]
    else:
        lab_video_map[lab_id].append(video_data)

with open(r'dumps/valid_lab_video_map.json', 'w') as f:
    json.dump(lab_video_map, f, indent=4)

Mapping lab_id to video_ids: 100%|██████████| 8789/8789 [00:01<00:00, 6071.16it/s]


In [5]:
# Enlist bodyparts
bodyparts = {
    'body_center',
    'ear_left',
    'ear_right',
    'forepaw_left',
    'forepaw_right',
    'head',
    'headpiece_bottombackleft',
    'headpiece_bottombackright',
    'headpiece_bottomfrontleft',
    'headpiece_bottomfrontright',
    'headpiece_topbackleft',
    'headpiece_topbackright',
    'headpiece_topfrontleft',
    'headpiece_topfrontright',
    'hindpaw_left',
    'hindpaw_right',
    'hip_left',
    'hip_right',
    'lateral_left',
    'lateral_right',
    'neck',
    'nose',
    'spine_1',
    'spine_2',
    'tail_base',
    'tail_middle_1',
    'tail_middle_2',
    'tail_midpoint',
    'tail_tip'
}

In [6]:
# # Load data for a particular agent and target
# def process_lab_videos(lab_id, video_data_list, agent_id, target_id, save_dir_path):
#     """Process all videos for a single lab_id. Used by thread workers."""
#     results_for_lab = []
#     for video_data in video_data_list:
#         results = []
#         video_id = video_data['video_id']
#         tracking_file_path = os.path.join(
#             dataset_configs['train_tracking_dir_path'], lab_id, f"{video_id}.parquet"
#         )
#         annotation_file_path = os.path.join(
#             dataset_configs['train_annotation_dir_path'], lab_id, f"{video_id}.parquet"
#         )
#         tracking_df = pd.read_parquet(tracking_file_path)
#         annotation_df = pd.read_parquet(annotation_file_path)
#         # Filter frames
#         filtered_tracking_df = tracking_df[
#             (tracking_df['mouse_id'] == agent_id) |
#             (tracking_df['mouse_id'] == target_id)
#         ]
#         filtered_annotation_df = annotation_df[
#             (annotation_df['agent_id'] == agent_id) &
#             (annotation_df['target_id'] == target_id)
#         ]
#         # Build rows
#         for _, ann_row in filtered_annotation_df.iterrows():
#             for frame in range(ann_row['start_frame'], ann_row['stop_frame'] + 1):
#                 curr_row = {
#                     'lab_id': lab_id,
#                     'video_id': video_id,
#                     'arena_width_cm': video_data['arena_width_cm'],
#                     'arena_height_cm': video_data['arena_height_cm'],
#                     'action': ann_row['action']
#                 }
#                 # Initialize NaNs
#                 for bodypart in bodyparts:
#                     curr_row[f'agent_{bodypart}_x'] = np.nan
#                     curr_row[f'agent_{bodypart}_y'] = np.nan
#                     curr_row[f'target_{bodypart}_x'] = np.nan
#                     curr_row[f'target_{bodypart}_y'] = np.nan
#                 # Fill tracking
#                 agent_frame = filtered_tracking_df[
#                     (filtered_tracking_df['video_frame'] == frame) &
#                     (filtered_tracking_df['mouse_id'] == agent_id)
#                 ]
#                 target_frame = filtered_tracking_df[
#                     (filtered_tracking_df['video_frame'] == frame) &
#                     (filtered_tracking_df['mouse_id'] == target_id)
#                 ]
#                 for trk_row in agent_frame.itertuples():
#                     curr_row[f'agent_{trk_row.bodypart}_x'] = trk_row.x
#                     curr_row[f'agent_{trk_row.bodypart}_y'] = trk_row.y
#                 for trk_row in target_frame.itertuples():
#                     curr_row[f'target_{trk_row.bodypart}_x'] = trk_row.x
#                     curr_row[f'target_{trk_row.bodypart}_y'] = trk_row.y
#                 results.append(curr_row)
#         # Save parquet
#         df = pd.DataFrame(results)
#         save_path = os.path.join(save_dir_path, f"{lab_id}_{video_id}.parquet")
#         df.to_parquet(save_path)
#     return lab_id  # Used only for progress reporting


# def filter_data(agent_id: str, target_id: str, lab_video_map: dict, save_dir_path: str, max_workers=8):
#     os.makedirs(save_dir_path, exist_ok=True)
#     tasks = []
#     with ThreadPoolExecutor(max_workers=max_workers) as executor:
#         pbar = tqdm(total=len(lab_video_map), desc=f"Processing agent {agent_id} -> {target_id}")
#         for lab_id, video_data_list in lab_video_map.items():
#             fut = executor.submit(
#                 process_lab_videos,
#                 lab_id,
#                 video_data_list,
#                 agent_id,
#                 target_id,
#                 save_dir_path
#             )
#             tasks.append(fut)
#         for fut in as_completed(tasks):
#             _ = fut.result()  # catch any exceptions
#             pbar.update(1)
#     pbar.close()

In [7]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import pandas as pd
import numpy as np
from tqdm import tqdm
import os


# =====================================================================
# Process one video (vectorized, fast, fragmentation-free)
# =====================================================================
def process_one_video(lab_id, video_data, agent_id, target_id, save_dir_path, bodyparts):

    video_id = video_data["video_id"]

    # -----------------------------
    # Load parquet files
    # -----------------------------
    tracking_df = pd.read_parquet(
        os.path.join(dataset_configs["train_tracking_dir_path"], lab_id, f"{video_id}.parquet")
    )
    annotation_df = pd.read_parquet(
        os.path.join(dataset_configs["train_annotation_dir_path"], lab_id, f"{video_id}.parquet")
    )

    # -----------------------------
    # Filter annotation rows
    # -----------------------------
    ann = annotation_df[
        (annotation_df.agent_id == agent_id) &
        (annotation_df.target_id == target_id)
    ].copy()

    if ann.empty:
        return video_id   # nothing to write

    # -----------------------------
    # Vectorized annotation → frames expansion
    # -----------------------------
    lengths = ann["stop_frame"] - ann["start_frame"] + 1

    full_ann = ann.loc[ann.index.repeat(lengths)].copy()

    full_ann["video_frame"] = np.concatenate([
        np.arange(s, e + 1)
        for s, e in zip(ann.start_frame, ann.stop_frame)
    ])

    full_ann["lab_id"] = lab_id
    full_ann["video_id"] = video_id
    full_ann["video_width_pix"] = video_data["video_width_pix"]
    full_ann["video_height_pix"] = video_data["video_height_pix"]

    # -----------------------------
    # Extract tracking for agent + target
    # -----------------------------
    trk = tracking_df[
        tracking_df.mouse_id.isin([agent_id, target_id])
    ][["video_frame", "mouse_id", "bodypart", "x", "y"]].copy()

    # Ensure safe string concatenation
    trk["mouse_id"] = trk["mouse_id"].astype(str)
    trk["bodypart"] = trk["bodypart"].astype(str)

    trk["col_x"] = trk["mouse_id"] + "_" + trk["bodypart"] + "_x"
    trk["col_y"] = trk["mouse_id"] + "_" + trk["bodypart"] + "_y"

    # Pivot to wide format
    wide_x = trk.pivot_table(index="video_frame", columns="col_x", values="x", aggfunc="first")
    wide_y = trk.pivot_table(index="video_frame", columns="col_y", values="y", aggfunc="first")

    wide = wide_x.join(wide_y, how="outer").reset_index()

    # -----------------------------
    # Merge annotation frames with tracking
    # -----------------------------
    merged = full_ann.merge(wide, on="video_frame", how="left")

    # -----------------------------
    # Rename tracking columns to agent_* / target_*
    # -----------------------------
    rename_map = {}
    for bp in bodyparts:
        rename_map[f"{agent_id}_{bp}_x"] = f"agent_{bp}_x"
        rename_map[f"{agent_id}_{bp}_y"] = f"agent_{bp}_y"
        rename_map[f"{target_id}_{bp}_x"] = f"target_{bp}_x"
        rename_map[f"{target_id}_{bp}_y"] = f"target_{bp}_y"

    merged.rename(columns=rename_map, inplace=True)

    # -----------------------------
    # Build final output column order
    # -----------------------------
    agent_cols = [f"agent_{bp}_{axis}" for bp in bodyparts for axis in ["x", "y"]]
    target_cols = [f"target_{bp}_{axis}" for bp in bodyparts for axis in ["x", "y"]]

    OUTPUT_COLUMNS = (
        ["lab_id", "video_id", "video_width_pix", "video_height_pix"]
        + agent_cols
        + target_cols
        + ["action"]
    )

    # -----------------------------------------------------
    # Add missing columns at once (avoid fragmentation)
    # -----------------------------------------------------
    missing_cols = [c for c in OUTPUT_COLUMNS if c not in merged.columns]

    if missing_cols:
        missing_df = pd.DataFrame(
            {c: np.nan for c in missing_cols},
            index=merged.index
        )
        merged = pd.concat([merged, missing_df], axis=1)

    # Reorder columns and defragment
    merged = merged[OUTPUT_COLUMNS].copy()

    # -----------------------------
    # Save result parquet
    # -----------------------------
    save_path = os.path.join(save_dir_path, f"{lab_id}_{video_id}.parquet")
    merged.to_parquet(save_path)

    return video_id



# =====================================================================
# Master controller — multithreaded over videos
# =====================================================================
def filter_data(
    agent_id,
    target_id,
    lab_video_map,
    save_dir_path,
    bodyparts,
    max_workers=8
):

    os.makedirs(save_dir_path, exist_ok=True)

    tasks = []
    total_videos = sum(len(v) for v in lab_video_map.values())

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        pbar = tqdm(total=total_videos, desc=f"Processing agent={agent_id} target={target_id}")

        for lab_id, video_list in lab_video_map.items():
            for video_data in video_list:
                fut = executor.submit(
                    process_one_video,
                    lab_id,
                    video_data,
                    agent_id,
                    target_id,
                    save_dir_path,
                    bodyparts
                )
                tasks.append(fut)

        for fut in as_completed(tasks):
            fut.result()     # raise if error
            pbar.update(1)

        pbar.close()

In [None]:
filter_data(
    agent_id=1,
    target_id=2,
    bodyparts=bodyparts,
    lab_video_map=lab_video_map,
    save_dir_path=os.path.join('dumps', 'data_split', 'agent_1_target_2'),
    max_workers=max(os.cpu_count() - 2, 1)
)

Processing agent=1 target=2:  21%|██        | 175/847 [01:37<02:23,  4.69it/s] 