In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from pathlib import Path
from tqdm import tqdm
import sys

In [None]:
def save_mouse_sequences():
    # Define the input paths to read from and output paths to read to
    annotated_dir=Path("/kaggle/input/MABe-mouse-behavior-detection/train_annotation")
    tracking_dir=Path("/kaggle/input/MABe-mouse-behavior-detection/train_tracking")
    annotated_output=Path("kaggle/working/annotated")
    tracking_output=Path("kaggle/working/tracking")
    annotated_output.mkdir(parents=True, exist_ok=True)
    tracking_output.mkdir(parents=True, exist_ok=True)

    # Recursively list all the parquet files 
    tracking_files = list(tracking_dir.rglob("*.parquet"))
    annotation_files = list(annotated_dir.rglob("*.parquet"))
    # Creates a dictionary of the annotated files with the name as the key and the Path as the value
    annotation_map = {f.name: f for f in annotation_files} 
    # Collect all metadata columns first
    all_columns = set([
        "video_frame", "mouse_id", "action", "agent_id", "target_id"
    ])
    # Create a set for each unique bodypart seen
    bodyparts_seen = set()
    # loop over every tracking file and fill the body parts set with the unique entries from the "bodypart" column
    for track_file in tqdm(tracking_files, desc="Detecting bodyparts"):
        try:
            df = pd.read_parquet(track_file)
            if "bodypart" not in df.columns or df["bodypart"].empty:
                tqdm.write(f"No bodypart column present in {track_file}")
                continue
            bodyparts_seen.update(df["bodypart"].unique())
        except Exception as e:
            tqdm.write(f"Error reading {track_file}: {e}")
    # Add x/y columns for each bodypart to the all columns set
    for bp in bodyparts_seen:
        all_columns.add(f"{bp}_x")
        all_columns.add(f"{bp}_y")

    # Loop over the tracking files a second time, this time for information merging and formatting
    for track_file in tqdm(tracking_files, desc="Processing videos"):
        video_id = track_file.stem
        tracking_df = pd.read_parquet(track_file).sort_values("video_frame").reset_index(drop=True)
        tracking_df["action"] = None
        tracking_df["agent_id"] = None
        tracking_df["target_id"] = None
        # Apply annotations if available
        annot_file = annotation_map.get(track_file.name)
        is_annotated = False # Flag for sorting the labelled VS unlabelled files in the output
        if annot_file and annot_file.exists():
            annotation_df = pd.read_parquet(annot_file)
            for _, row in annotation_df.iterrows():
                mask = (tracking_df.video_frame >= row.start_frame) & (tracking_df.video_frame <= row.stop_frame)
                tracking_df.loc[mask, "action"] = row.action
                tracking_df.loc[mask, "agent_id"] = row.agent_id
                tracking_df.loc[mask, "target_id"] = row.target_id
            is_annotated = True # Change the flag to indicate the tracking data is labelled 

        try:
            # Check required columns
            required_index = ["video_frame", "mouse_id",]
            required_values = ["x", "y"]
            missing_index_cols = [c for c in required_index if c not in tracking_df.columns]
            missing_value_cols = [c for c in required_values + ["bodypart"] if c not in tracking_df.columns]

            if missing_index_cols or missing_value_cols:
                tqdm.write(f"Missing columns in {track_file}:")
                if missing_index_cols:
                    tqdm.write(f"  Index columns missing: {missing_index_cols}")
                if missing_value_cols:
                    tqdm.write(f"  Value/bodypart columns missing: {missing_value_cols}")
                tqdm.write("Cannot pivot this file. Exiting to troubleshoot.")
                sys.exit(1)

            # Pivot
            pivoted = tracking_df.pivot_table(
                index=required_index,
                columns="bodypart",
                values=["x", "y"],
                aggfunc="first"
            )

            # Check if pivot is empty
            if pivoted.empty:
                tqdm.write(f"Pivot resulted in empty DataFrame for {track_file}.")
                tqdm.write("Tracking DF info:")
                tqdm.write(f"  Shape: {tracking_df.shape}")
                tqdm.write(f"  Columns: {tracking_df.columns.tolist()}")
                tqdm.write(f"  Head:\n{tracking_df.head()}")
                sys.exit(1)  # exit to troubleshoot

            # Flatten multiindex columns
            pivoted.columns = [f"{bp}_{coord}" for coord, bp in pivoted.columns]
            pivoted = pivoted.reset_index()

        except Exception as e:
            tqdm.write(f"Error pivoting {track_file}: {e}")
            sys.exit(1)
        
        # Ensure uniform columns
        missing_cols = all_columns - set(pivoted.columns)
        for c in missing_cols:
            pivoted[c] = pd.NA
        pivoted = pivoted[list(all_columns)]  # enforce column order


        if is_annotated:
            output_dir = annotated_output
        else:
            output_dir = tracking_output
        # Save, overwrite if exists
        out_path = Path(output_dir) / f"{video_id}.parquet"
        try:
            pivoted.to_parquet(out_path, index=False)
            print(f"✅ Saved {out_path} with shape {pivoted.shape}")
        except Exception as e:
            print(f"⚠️ Could not save {out_path}: {e}")

    print(f"\n✅ All processed files saved to {output_dir}")


(save_mouse_sequences())

In [None]:
'''
Code snippet to get a sample structure of the two different data-sets
'''

annotated_example_path = Path("/kaggle/input/MABe-mouse-behavior-detection/train_annotation/CRIM13/1763467574.parquet")
tracking_example_path = Path("/kaggle/input/MABe-mouse-behavior-detection/train_tracking/CRIM13/1763467574.parquet")

labeled_df = pd.read_parquet(annotated_example_path)
print(labeled_df.head(20),"\n","\n")
track_df = pd.read_parquet(tracking_example_path)
print(track_df.head(20))