# MABe Mouse Behavior Detection — EDA

Data structure:
- `train.csv` / `test.csv`: per-video metadata (one row per video)
- `train_tracking/{lab_id}/{video_id}.parquet`: per-frame keypoint coordinates (video_frame, mouse_id, bodypart, x, y)
- `train_annotation/{lab_id}/{video_id}.parquet`: behavior annotations (agent_id, target_id, action, start_frame, stop_frame)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json, os, glob
from collections import Counter
from pathlib import Path

sns.set_theme(style='whitegrid', palette='muted', font_scale=1.1)
plt.rcParams.update({'figure.dpi': 100, 'figure.figsize': (12, 5)})

DATA_DIR = Path('../data')
train = pd.read_csv(DATA_DIR / 'train.csv')
test  = pd.read_csv(DATA_DIR / 'test.csv')
print(f'Train videos: {len(train)},  Test videos: {len(test)}')
train.head(3)

## 1. Laboratory (Lab) Distribution

In [None]:
lab_counts = train['lab_id'].value_counts().reset_index()
lab_counts.columns = ['lab_id', 'count']

fig, axes = plt.subplots(1, 2, figsize=(16, 5))

sns.barplot(data=lab_counts, y='lab_id', x='count', ax=axes[0], color='steelblue')
axes[0].set_title('Train: Number of Videos per Lab')
axes[0].set_xlabel('Number of Videos')
axes[0].set_ylabel('')

# Which labs appear only in train
train_only = set(train['lab_id'].unique()) - set(test['lab_id'].unique())
test_only  = set(test['lab_id'].unique()) - set(train['lab_id'].unique())
shared     = set(train['lab_id'].unique()) & set(test['lab_id'].unique())

summary = pd.DataFrame({
    'train_labs': [len(train['lab_id'].unique())],
    'test_labs':  [len(test['lab_id'].unique())],
    'shared':     [len(shared)],
    'train_only': [len(train_only)],
})
axes[1].axis('off')
axes[1].table(cellText=summary.values, colLabels=summary.columns, loc='center', cellLoc='center')
axes[1].set_title('Lab Overlap Between Train and Test')

plt.tight_layout()
plt.show()
print(f'Train-only labs: {sorted(train_only)}')

## 2. Video Metadata: FPS / Duration / Resolution / Mouse Count

In [None]:
train['n_mice'] = 4 - train[['mouse1_strain','mouse2_strain','mouse3_strain','mouse4_strain']].isna().sum(axis=1)

fig, axes = plt.subplots(2, 2, figsize=(14, 9))

# FPS
fps_counts = train['frames_per_second'].value_counts().sort_index().reset_index()
fps_counts.columns = ['fps', 'count']
sns.barplot(data=fps_counts, x='fps', y='count', ax=axes[0,0], color='coral')
axes[0,0].set_title('Frame Rate (FPS) Distribution')
axes[0,0].set_ylabel('Number of Videos')
axes[0,0].set_xlabel('FPS')

# Duration
sns.histplot(train['video_duration_sec'], bins=50, ax=axes[0,1], color='teal', edgecolor='white')
axes[0,1].set_title('Video Duration (seconds) Distribution')
axes[0,1].set_xlabel('Seconds')

# n_mice
mice_counts = train['n_mice'].value_counts().sort_index().reset_index()
mice_counts.columns = ['n_mice', 'count']
sns.barplot(data=mice_counts, x='n_mice', y='count', ax=axes[1,0], color='mediumpurple')
axes[1,0].set_title('Number of Mice per Video')
axes[1,0].set_xlabel('Mice')
axes[1,0].set_ylabel('Count')

# Arena shape
arena_counts = train['arena_shape'].value_counts().reset_index()
arena_counts.columns = ['arena_shape', 'count']
sns.barplot(data=arena_counts, x='arena_shape', y='count', ax=axes[1,1], color='goldenrod')
axes[1,1].set_title('Arena Shape')
axes[1,1].set_xlabel('')
axes[1,1].set_ylabel('Count')

plt.tight_layout()
plt.show()

print('\n--- Numerical Column Statistics ---')
train[['frames_per_second','video_duration_sec','pix_per_cm_approx','video_width_pix','video_height_pix']].describe().round(1)

## 3. Body Part Tracking Schemes (Vary Significantly Across Labs)

In [None]:
bp_groups = train.groupby('body_parts_tracked')['lab_id'].apply(lambda x: sorted(x.unique())).reset_index()
bp_groups.columns = ['body_parts_tracked', 'labs']
bp_groups['n_parts'] = bp_groups['body_parts_tracked'].apply(lambda x: len(json.loads(x)))
bp_groups['n_labs']  = bp_groups['labs'].apply(len)
bp_groups['n_videos'] = bp_groups['body_parts_tracked'].map(train['body_parts_tracked'].value_counts())

display_df = bp_groups[['n_parts','n_labs','n_videos','labs']].sort_values('n_videos', ascending=False)
display_df

## 4. Behavior Annotation Analysis (Key Insight: Extreme Class Imbalance)

In [None]:
# Parse the behaviors_labeled field to extract all action categories
all_actions = []
for _, row in train.iterrows():
    bl = row['behaviors_labeled']
    if pd.isna(bl):
        continue
    items = json.loads(bl)
    for item in items:
        parts = item.replace("'", "").split(',')
        if len(parts) == 3:
            all_actions.append(parts[2])  # action name

action_counts = pd.Series(Counter(all_actions)).sort_values(ascending=True).reset_index()
action_counts.columns = ['action', 'count']

fig, ax = plt.subplots(figsize=(10, 8))
sns.barplot(data=action_counts, y='action', x='count', ax=ax, color='steelblue')
ax.set_title('Behavior Category Frequency in behaviors_labeled (Cumulative Across Videos)')
ax.set_xlabel('Occurrence Count')
ax.set_ylabel('')
plt.tight_layout()
plt.show()

print(f'Total behavior categories: {len(action_counts)}')

## 5. Deep Dive into Annotations: Event Duration & Per-Category Frame Distribution

In [None]:
# Sample annotation files for statistics (avoid reading all files which is slow)
anno_files = glob.glob(str(DATA_DIR / 'train_annotation' / '*' / '*.parquet'))
print(f'Total annotation files: {len(anno_files)}')

anno_list = []
for f in anno_files:
    df = pd.read_parquet(f)
    df['lab_id'] = Path(f).parent.name
    df['video_id'] = Path(f).stem
    anno_list.append(df)

anno = pd.concat(anno_list, ignore_index=True)
anno['duration_frames'] = anno['stop_frame'] - anno['start_frame']
print(f'Total annotation events: {len(anno):,}')
anno.head()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Number of events per action
event_df = anno['action'].value_counts().sort_values(ascending=True).reset_index()
event_df.columns = ['action', 'count']
sns.barplot(data=event_df, y='action', x='count', ax=axes[0], color='darkorange')
axes[0].set_title('Number of Annotation Events per Action')
axes[0].set_xlabel('Event Count')
axes[0].set_ylabel('')

# Total frames per action (reflects actual training data volume)
frames_df = anno.groupby('action')['duration_frames'].sum().sort_values(ascending=True).reset_index()
frames_df.columns = ['action', 'total_frames']
sns.barplot(data=frames_df, y='action', x='total_frames', ax=axes[1], color='seagreen')
axes[1].set_title('Total Annotated Frames per Action')
axes[1].set_xlabel('Frame Count')
axes[1].set_ylabel('')

plt.tight_layout()
plt.show()

In [None]:
# Event duration distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

sns.histplot(anno['duration_frames'].clip(upper=500), bins=80, ax=axes[0], color='slategray', edgecolor='white')
axes[0].set_title('Event Duration Distribution (frames, clipped at 500)')
axes[0].set_xlabel('Frames')
axes[0].axvline(anno['duration_frames'].median(), color='red', ls='--', label=f"median={anno['duration_frames'].median():.0f}")
axes[0].legend()

# Median event duration per action
median_dur = anno.groupby('action')['duration_frames'].median().sort_values(ascending=True).reset_index()
median_dur.columns = ['action', 'median_frames']
sns.barplot(data=median_dur, y='action', x='median_frames', ax=axes[1], color='indianred')
axes[1].set_title('Median Event Duration per Action (frames)')
axes[1].set_xlabel('Frames')
axes[1].set_ylabel('')

plt.tight_layout()
plt.show()

print('--- Event Duration Statistics ---')
anno['duration_frames'].describe().round(1)

## 6. Behavior Category Distribution Across Labs (Heatmap)

In [None]:
# Lab × Action event count pivot
lab_action = anno.groupby(['lab_id', 'action']).size().reset_index(name='count')
pivot = lab_action.pivot(index='lab_id', columns='action', values='count').fillna(0).astype(int)

# Keep only the most frequent actions
top_actions = anno['action'].value_counts().head(20).index
pivot_top = pivot[[c for c in top_actions if c in pivot.columns]]

fig, ax = plt.subplots(figsize=(16, 8))
sns.heatmap(np.log1p(pivot_top), ax=ax, cmap='YlOrRd', linewidths=0.5,
            cbar_kws={'label': 'log(count+1)'}, xticklabels=True, yticklabels=True)
ax.set_title('Lab × Action Event Count Heatmap (log scale)')
ax.set_xlabel('Action')
ax.set_ylabel('Lab')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

## 7. Tracking Data Visualization: Example Trajectories

In [None]:
# Select a small video to visualize trajectories
sample_lab = 'BoisterousParrot'
sample_vid = os.listdir(DATA_DIR / 'train_tracking' / sample_lab)[0].replace('.parquet', '')

trk = pd.read_parquet(DATA_DIR / 'train_tracking' / sample_lab / f'{sample_vid}.parquet')
print(f'Tracking shape: {trk.shape}, columns: {trk.columns.tolist()}')
print(f'Frame range: {trk.video_frame.min()} ~ {trk.video_frame.max()}, Mouse IDs: {sorted(trk.mouse_id.unique())}')
print(f'Body parts: {sorted(trk.bodypart.unique())}')
trk.head()

In [None]:
# Plot body_center trajectories (first 3000 frames)
bc = trk[(trk.bodypart == 'body_center') & (trk.video_frame < 3000)].copy()
bc['mouse_id'] = bc['mouse_id'].astype(str)

palette = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3']

fig, ax = plt.subplots(figsize=(8, 8))
sns.lineplot(data=bc, x='x', y='y', hue='mouse_id', ax=ax,
             palette=palette[:bc['mouse_id'].nunique()],
             alpha=0.5, linewidth=0.5, legend='brief', sort=False, estimator=None)

# Mark starting positions
for i, mid in enumerate(sorted(bc['mouse_id'].unique())):
    m = bc[bc.mouse_id == mid]
    ax.scatter(m.x.iloc[0], m.y.iloc[0], marker='o', s=60, color=palette[i % len(palette)], zorder=5)

ax.set_title(f'{sample_lab} / {sample_vid} — body_center Trajectory (first 3000 frames)')
ax.set_xlabel('x (px)'); ax.set_ylabel('y (px)')
ax.invert_yaxis(); ax.set_aspect('equal')
plt.tight_layout()
plt.show()

## 8. Tracking Data Quality: NaN Rates / Coordinate Ranges

In [None]:
# Sample a few labs to check NaN rates
sample_labs = ['BoisterousParrot', 'CautiousGiraffe', 'SparklingTapir', 'ElegantMink']
nan_stats = []

for lab in sample_labs:
    lab_dir = DATA_DIR / 'train_tracking' / lab
    files = list(lab_dir.glob('*.parquet'))[:3]  # Take first 3 files per lab
    for f in files:
        df = pd.read_parquet(f)
        n = len(df)
        nan_x = df['x'].isna().sum() / n * 100
        nan_y = df['y'].isna().sum() / n * 100
        nan_stats.append({'lab': lab, 'video': f.stem, 'rows': n, 'nan_x%': round(nan_x, 2), 'nan_y%': round(nan_y, 2)})

nan_df = pd.DataFrame(nan_stats)
nan_df

## 9. Interaction Behaviors vs Self Behaviors

In [None]:
self_actions = {'biteobject','climb','dig','exploreobject','freeze',
                'genitalgroom','huddle','rear','rest','run','selfgroom'}

anno['action_type'] = anno['action'].apply(lambda x: 'self' if x in self_actions else 'interaction')

pal = sns.color_palette(['#66c2a5','#fc8d62'])

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Event count by type
type_counts = anno['action_type'].value_counts().reset_index()
type_counts.columns = ['action_type', 'count']
sns.barplot(data=type_counts, x='action_type', y='count', ax=axes[0], palette=pal)
axes[0].set_title('Event Count: Self vs Interaction')
axes[0].set_xlabel('')
axes[0].set_ylabel('Number of Events')
for p in axes[0].patches:
    axes[0].annotate(f'{int(p.get_height()):,}', (p.get_x() + p.get_width()/2., p.get_height()),
                     ha='center', va='bottom', fontsize=10)

# Total frames by type
type_frames = anno.groupby('action_type')['duration_frames'].sum().reset_index()
type_frames.columns = ['action_type', 'total_frames']
sns.barplot(data=type_frames, x='action_type', y='total_frames', ax=axes[1], palette=pal)
axes[1].set_title('Total Frames: Self vs Interaction')
axes[1].set_xlabel('')
axes[1].set_ylabel('Total Frames')
for p in axes[1].patches:
    axes[1].annotate(f'{int(p.get_height()):,}', (p.get_x() + p.get_width()/2., p.get_height()),
                     ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

## 10. Train vs Test Comparison

In [None]:
test['n_mice'] = 4 - test[['mouse1_strain','mouse2_strain','mouse3_strain','mouse4_strain']].isna().sum(axis=1)

compare = pd.DataFrame({
    'metric': ['Number of Videos', 'Number of Labs', 'Mean Duration (s)', 'Median Duration (s)',
               'FPS Values', 'Mean Mice per Video', 'Tracking Schemes'],
    'train': [
        len(train),
        train['lab_id'].nunique(),
        f"{train['video_duration_sec'].mean():.1f}",
        f"{train['video_duration_sec'].median():.1f}",
        str(sorted(train['frames_per_second'].dropna().unique())),
        f"{train['n_mice'].mean():.2f}",
        train['body_parts_tracked'].nunique(),
    ],
    'test': [
        len(test),
        test['lab_id'].nunique(),
        f"{test['video_duration_sec'].mean():.1f}",
        f"{test['video_duration_sec'].median():.1f}",
        str(sorted(test['frames_per_second'].dropna().unique())),
        f"{test['n_mice'].mean():.2f}",
        test['body_parts_tracked'].nunique(),
    ],
})
compare

## Key Findings Summary

1. **High lab heterogeneity**: 15+ laboratories with different FPS (25/30), different body part tracking schemes (4–18 keypoints), and different arena shapes
2. **Severe class imbalance**: `sniff` / `approach` are high-frequency, while `shepherd` / `ejaculate` are extremely rare
3. **Wide variation in event duration**: from a few frames to thousands of frames; median is around a few dozen frames
4. **Train-only labs**: CalMS21, CRIM13, MABe22, and other historical datasets appear only in the training set with no corresponding test labels
5. **Interaction vs self behaviors**: interaction-type behaviors dominate in both variety and total count