In [1]:
# ==========================================
# CELL 1: DATA PREPARATION & PLATINUM FEATURE ENGINEERING
# ==========================================
import os
import gc
import json
import hashlib
import numpy as np
import pandas as pd
import polars as pl
from tqdm.auto import tqdm
import warnings
import itertools
import shutil

warnings.filterwarnings('ignore')

# C·∫§U H√åNH
DATA_PATH = '/kaggle/input/MABe-mouse-behavior-detection'
OUTPUT_PATH = './processed_data'

# X√≥a d·ªØ li·ªáu c≈© ƒë·ªÉ ƒë·∫£m b·∫£o features m·ªõi ƒë∆∞·ª£c t·∫°o ra
if os.path.exists(OUTPUT_PATH):
    shutil.rmtree(OUTPUT_PATH)
os.makedirs(OUTPUT_PATH, exist_ok=True)

# --- DANH S√ÅCH FEATURES (PLATINUM) ---
SINGLE_FEATURES = [
    'velocity_m1', 'accel_m1', 'v_long_m1', 'v_lat_m1',
    'curvature_m1', 'turn_rate_m1', 
    'elongation_m1', 'vel_m1_std_micro', 'grooming_score_m1',
    'vel_m1_mean_short', 'curvature_m1_mean_short', 'vel_m1_mean_long'
]

# Features cho Pair Behavior 
PAIR_FEATURES = SINGLE_FEATURES + [
    'distance', 'velocity_m2', 'accel_m2',
    'v_long_m2', 'v_lat_m2',
    'curvature_m2', 'turn_rate_m2',
    'speed_ratio_m1',
    'nose1_to_tail2', 'nose1_to_nose2', 'facing_angle_m1',
    'elongation_m2', 'spine_alignment',
    'vel_m2_std_micro',
    'dist_mean_short', 'vel_m2_mean_short', 'facing_mean_short'
]

def calculate_features_polars(df_pl, fps):
    # 1. ƒê·ªãnh nghƒ©a Window Size theo FPS
    W_MICRO = max(1, int(round(0.10 * fps)))  # ~100ms (Jitter)
    W_SHORT = max(1, int(round(0.33 * fps)))  # ~330ms (Action)
    W_LONG  = max(1, int(round(1.00 * fps)))  # ~1s (Context)
    
    df_pl = df_pl.sort("frame")
    
    # =================================================================
    # [CRITICAL UPGRADE] PRE-SMOOTHING COORDINATES
    # =================================================================
    coord_cols = [c for c in df_pl.columns if c.endswith('_x') or c.endswith('_y')]
    smooth_exprs = [
        pl.col(c).rolling_mean(window_size=5, center=True).fill_null(strategy="forward").alias(f"{c}_smooth") 
        for c in coord_cols
    ]
    df_pl = df_pl.with_columns(smooth_exprs)
    
    m1_cx, m1_cy = pl.col("mouse1_body_center_x_smooth"), pl.col("mouse1_body_center_y_smooth")
    m2_cx, m2_cy = pl.col("mouse2_body_center_x_smooth"), pl.col("mouse2_body_center_y_smooth")
    
    m1_nx, m1_ny = pl.col("mouse1_nose_x_smooth"), pl.col("mouse1_nose_y_smooth")
    m2_nx, m2_ny = pl.col("mouse2_nose_x_smooth"), pl.col("mouse2_nose_y_smooth")
    m1_tx, m1_ty = pl.col("mouse1_tail_base_x_smooth"), pl.col("mouse1_tail_base_y_smooth")
    m2_tx, m2_ty = pl.col("mouse2_tail_base_x_smooth"), pl.col("mouse2_tail_base_y_smooth")

    # =================================================================
    # GIAI ƒêO·∫†N 1: C√ÅC ƒê·∫†I L∆Ø·ª¢NG C∆† B·∫¢N (D√πng bi·∫øn ƒë√£ Smooth)
    # =================================================================
    exprs_basic = [
        # Distance
        (((m1_cx - m2_cx)**2 + (m1_cy - m2_cy)**2).sqrt()).fill_null(0).alias("distance"),
        
        # Velocity Magnitude (Speed) 
        (((m1_cx.diff().fill_null(0))**2 + (m1_cy.diff().fill_null(0))**2).sqrt() * fps).alias("velocity_m1"),
        (((m2_cx.diff().fill_null(0))**2 + (m2_cy.diff().fill_null(0))**2).sqrt() * fps).alias("velocity_m2"),
          
        # Social Distances
        (((m1_nx - m2_tx)**2 + (m1_ny - m2_ty)**2).sqrt()).fill_null(0).alias("nose1_to_tail2"),
        (((m1_nx - m2_nx)**2 + (m1_ny - m2_ny)**2).sqrt()).fill_null(0).alias("nose1_to_nose2"),
    ]
    df_pl = df_pl.with_columns(exprs_basic)

    # =================================================================
    # GIAI ƒêO·∫†N 2: VECTOR & HELPER
    # =================================================================
    df_pl = df_pl.with_columns([
        # Vector V·∫≠n t·ªëc (Velocity Vector)
        (m1_cx.diff().fill_null(0) * fps).alias("vx1"),
        (m1_cy.diff().fill_null(0) * fps).alias("vy1"),
        (m2_cx.diff().fill_null(0) * fps).alias("vx2"),
        (m2_cy.diff().fill_null(0) * fps).alias("vy2"),
        
        # Vector X∆∞∆°ng S·ªëng (Spine Vector)
        (m1_nx - m1_tx).alias("spine1_x"),
        (m1_ny - m1_ty).alias("spine1_y"),
        (m2_nx - m2_tx).alias("spine2_x"),
        (m2_ny - m2_ty).alias("spine2_y"),

        # Relative Nose Movement
        (m1_nx - m1_cx).alias("nose1_rel_x"),
        (m1_ny - m1_cy).alias("nose1_rel_y"),
    ])
    
    # T√≠nh Gia t·ªëc (Acceleration) t·ª´ V·∫≠n t·ªëc
    df_pl = df_pl.with_columns([
        pl.col("vx1").diff().fill_null(0).alias("ax1"),
        pl.col("vy1").diff().fill_null(0).alias("ay1"),
        pl.col("vx2").diff().fill_null(0).alias("ax2"),
        pl.col("vy2").diff().fill_null(0).alias("ay2")
    ])

    # ƒê·ªô d√†i spine (ƒë·ªÉ chu·∫©n h√≥a vector)
    df_pl = df_pl.with_columns([
        ((pl.col("spine1_x")**2 + pl.col("spine1_y")**2).sqrt() + 1e-6).alias("len1"),
        ((pl.col("spine2_x")**2 + pl.col("spine2_y")**2).sqrt() + 1e-6).alias("len2"),
    ])
    
    # Vector ƒë∆°n v·ªã h∆∞·ªõng c∆° th·ªÉ
    df_pl = df_pl.with_columns([
        (pl.col("spine1_x") / pl.col("len1")).alias("u1_x"),
        (pl.col("spine1_y") / pl.col("len1")).alias("u1_y"),
        (pl.col("spine2_x") / pl.col("len2")).alias("u2_x"),
        (pl.col("spine2_y") / pl.col("len2")).alias("u2_y"),
    ])

    # Chi·∫øu v·∫≠n t·ªëc (Projection)
    df_pl = df_pl.with_columns([
        (pl.col("vx1") * pl.col("u1_x") + pl.col("vy1") * pl.col("u1_y")).alias("v_long_m1"),
        (pl.col("vx1") * pl.col("u1_y") - pl.col("vy1") * pl.col("u1_x")).alias("v_lat_m1"),
        (pl.col("vx2") * pl.col("u2_x") + pl.col("vy2") * pl.col("u2_y")).alias("v_long_m2"),
        (pl.col("vx2") * pl.col("u2_y") - pl.col("vy2") * pl.col("u2_x")).alias("v_lat_m2"),
    ])
    
    # =================================================================
    # GIAI ƒêO·∫†N 3: PLATINUM FEATURES
    # =================================================================
    norm_spine1 = pl.col("len1")
    norm_spine2 = pl.col("len2")
    dot_spine = pl.col("spine1_x") * pl.col("spine2_x") + pl.col("spine1_y") * pl.col("spine2_y")
    
    # Micro-Motion (Grooming) 
    nose_speed_rel = ((pl.col("nose1_rel_x").diff().fill_null(0))**2 + 
                      (pl.col("nose1_rel_y").diff().fill_null(0))**2).sqrt() * fps

    # Facing Angle
    v1_x = m1_nx - m1_cx
    v1_y = m1_ny - m1_cy
    v12_x = m2_cx - m1_cx
    v12_y = m2_cy - m1_cy
    dot_face = v1_x * v12_x + v1_y * v12_y
    norm_face1 = (v1_x**2 + v1_y**2).sqrt()
    norm_face12 = (v12_x**2 + v12_y**2).sqrt()

    df_pl = df_pl.with_columns([
        # Acceleration Magnitude
        (pl.col("v_long_m1").diff().fill_null(0) * fps).alias("accel_m1"), 
        (pl.col("v_long_m2").diff().fill_null(0) * fps).alias("accel_m2"),
        
        # Spine Alignment
        (dot_spine / (norm_spine1 * norm_spine2 + 1e-6)).fill_nan(0).alias("spine_alignment"),
        
        # Elongation
        norm_spine1.alias("elongation_m1"),
        norm_spine2.alias("elongation_m2"),
        
        # Grooming Score
        (nose_speed_rel / (pl.col("velocity_m1") + 0.5)).fill_nan(0).alias("grooming_score_m1"),

        # Facing Angle
        (dot_face / (norm_face1 * norm_face12 + 1e-6)).fill_nan(0).fill_null(0).alias('facing_angle_m1'),
        
        # Curvature 
        ( (pl.col("vx1")*pl.col("ay1") - pl.col("vy1")*pl.col("ax1")).abs() / 
          ((pl.col("vx1")**2 + pl.col("vy1")**2 + 1e-6)**1.5) ).fill_nan(0).alias("curvature_m1"),
        ( (pl.col("vx2")*pl.col("ay2") - pl.col("vy2")*pl.col("ax2")).abs() / 
          ((pl.col("vx2")**2 + pl.col("vy2")**2 + 1e-6)**1.5) ).fill_nan(0).alias("curvature_m2"),

        # Turn Rate & Speed Ratio
        pl.col("v_lat_m1").abs().alias("turn_rate_m1"),
        pl.col("v_lat_m2").abs().alias("turn_rate_m2"),
        (pl.col("velocity_m1") / (pl.col("velocity_m1").rolling_mean(W_LONG).fill_null(0) + 1e-6)).alias("speed_ratio_m1"),
    ])
    
    # =================================================================
    # GIAI ƒêO·∫†N 4: CONTEXT & JITTER
    # =================================================================
    roll_exprs = [
        pl.col("velocity_m1").rolling_mean(W_LONG).fill_null(0).alias("vel_m1_mean_long"),
        
        # Context (Short Term)
        pl.col("distance").rolling_mean(W_SHORT).fill_null(0).alias("dist_mean_short"),
        pl.col("velocity_m1").rolling_mean(W_SHORT).fill_null(0).alias("vel_m1_mean_short"),
        pl.col("velocity_m2").rolling_mean(W_SHORT).fill_null(0).alias("vel_m2_mean_short"),
        pl.col("facing_angle_m1").rolling_mean(W_SHORT).fill_null(0).alias("facing_mean_short"),
        pl.col("curvature_m1").rolling_mean(W_SHORT).fill_null(0).alias("curvature_m1_mean_short"),
        
        # Jitter (Micro Term) - Grooming/Attack
        pl.col("velocity_m1").rolling_std(W_MICRO).fill_null(0).alias("vel_m1_std_micro"),
        pl.col("velocity_m2").rolling_std(W_MICRO).fill_null(0).alias("vel_m2_std_micro"),
    ]
    
    df_pl = df_pl.with_columns(roll_exprs)
    return df_pl

def standardize_tracking_data(df_wide, mouse_ids):
    # ƒêi·ªÅn khuy·∫øt body_center, nose, tail_base n·∫øu thi·∫øu
    for m_id in mouse_ids:
        prefix = f"mouse{m_id}_"
        center_x = f"{prefix}body_center_x"; center_y = f"{prefix}body_center_y"
        
        if center_x not in df_wide.columns:
            x_cols = [c for c in df_wide.columns if c.startswith(prefix) and c.endswith('_x')]
            y_cols = [c for c in df_wide.columns if c.startswith(prefix) and c.endswith('_y')]
            if x_cols:
                df_wide[center_x] = df_wide[x_cols].mean(axis=1)
                df_wide[center_y] = df_wide[y_cols].mean(axis=1)
            else:
                df_wide[center_x] = 0.0; df_wide[center_y] = 0.0
                
        for part in ['nose', 'tail_base']:
            px = f"{prefix}{part}_x"; py = f"{prefix}{part}_y"
            if px not in df_wide.columns:
                df_wide[px] = df_wide[center_x]; df_wide[py] = df_wide[center_y]
    
    df_wide = df_wide.interpolate(limit=5)
    df_wide = df_wide.ffill().bfill()
    return df_wide

def get_group_id(body_parts_list):
    s = json.dumps(sorted(body_parts_list))
    return hashlib.md5(s.encode()).hexdigest()[:6]

# Main Process
def process_train_data_grouped():
    print("üî• [ETL] B·∫Øt ƒë·∫ßu x·ª≠ l√Ω d·ªØ li·ªáu theo nh√≥m (PLATINUM - FULL RELOAD)...")
    train_meta = pd.read_csv(f'{DATA_PATH}/train.csv')
    train_meta = train_meta[~train_meta['lab_id'].str.startswith('MABe22_')]
    
    train_meta['bp_json'] = train_meta['body_parts_tracked'].apply(lambda x: json.loads(x))
    train_meta['group_id'] = train_meta['bp_json'].apply(get_group_id)
    
    group_mapping = train_meta[['group_id', 'body_parts_tracked']].drop_duplicates().set_index('group_id').to_dict()['body_parts_tracked']
    with open(f'{OUTPUT_PATH}/group_mapping.json', 'w') as f:
        json.dump(group_mapping, f)

    #train_meta theo nh√≥m body_part kh√°c nhau
    unique_groups = train_meta['group_id'].unique()
    
    for grp in unique_groups:
        grp_meta = train_meta[train_meta['group_id'] == grp]
        print(f"‚öôÔ∏è Processing Group {grp} ({len(grp_meta)} videos)...")
        
        #all_chunks = []
        single_chunks = []
        pair_chunks = []
        
        for idx, row in tqdm(grp_meta.iterrows(), total=len(grp_meta), desc=f"Group {grp}"):
            vid = row['video_id']; lab = row['lab_id']
            
            fps = row.get('frames_per_second', 30.0)
            if pd.isna(fps) or fps <= 0:
                fps = 30.0
                
            t_path = f'{DATA_PATH}/train_tracking/{lab}/{vid}.parquet'
            a_path = f'{DATA_PATH}/train_annotation/{lab}/{vid}.parquet'
            if not os.path.exists(t_path) or not os.path.exists(a_path): continue
            
            try:
                track_df = pd.read_parquet(t_path)
                track_df['col_name'] = 'mouse' + track_df['mouse_id'].astype(str) + '_' + track_df['bodypart']
                px = track_df.pivot(index='video_frame', columns='col_name', values='x')
                py = track_df.pivot(index='video_frame', columns='col_name', values='y')
                px.columns = [c + '_x' for c in px.columns]; py.columns = [c + '_y' for c in py.columns]
                df_wide = pd.concat([px, py], axis=1).sort_index()
                
                mouse_ids = sorted(list(set([int(c.split('_')[0].replace('mouse', '')) for c in px.columns if 'mouse' in c])))
                df_wide = standardize_tracking_data(df_wide, mouse_ids)

                pix_per_cm = row['pix_per_cm_approx'] if row['pix_per_cm_approx'] > 0 else 1.0
                df_wide = df_wide / pix_per_cm
                df_wide = df_wide.reset_index().rename(columns={'video_frame': 'frame'})
                
                annot_df = pd.read_parquet(a_path)
                pl_wide = pl.from_pandas(df_wide)
                
                for m1, m2 in itertools.permutations(mouse_ids, 2):
                    col_map = {}
                    for c in pl_wide.columns:
                        if f'mouse{m1}_' in c: col_map[c] = c.replace(f'mouse{m1}_', 'mouse1_')
                        elif f'mouse{m2}_' in c: col_map[c] = c.replace(f'mouse{m2}_', 'mouse2_')
                    
                    pair_pl = pl_wide.select(['frame'] + list(col_map.keys())).rename(col_map)
                    # --- FEATURE CALCULATION ---
                    pair_pl = calculate_features_polars(pair_pl, fps=fps)
                    
                    pair_annot = annot_df[(annot_df['agent_id'] == m1) & (annot_df['target_id'] == m2)][['start_frame', 'stop_frame', 'action']]
                
                    pair_pandas = pair_pl.select(PAIR_FEATURES).to_pandas()
                    pair_pandas['label'] = 'other'
                    pair_pandas['frame'] = pair_pl['frame'].to_numpy()
                    
                    for _, r in pair_annot.iterrows():
                        mask = (pair_pandas['frame'] >= r['start_frame']) & (pair_pandas['frame'] < r['stop_frame'])
                        pair_pandas.loc[mask, 'label'] = r['action']
                    
                    pair_pandas['video_id'] = vid
                    pair_chunks.append(pair_pandas)
                    
                for m1 in mouse_ids:
                    # Map m1 -> mouse1. T·∫°o dummy mouse2 ƒë·ªÉ h√†m t√≠nh to√°n kh√¥ng l·ªói (distance s·∫Ω = 0)
                    cols_m1 = [c for c in pl_wide.columns if f'mouse{m1}_' in c]
                    base_pl = pl_wide.select(['frame'] + cols_m1)
                    rename_dict = {c: c.replace(f'mouse{m1}_', 'mouse1_') for c in cols_m1}
                    base_pl = base_pl.rename(rename_dict)
                    
                    # Dummy mouse2 columns (copy t·ª´ mouse1)
                    dummy_exprs = [pl.col(c).alias(c.replace('mouse1_', 'mouse2_')) for c in base_pl.columns if 'mouse1_' in c]
                    single_pl = base_pl.with_columns(dummy_exprs)
                    
                    single_pl = calculate_features_polars(single_pl, fps=fps)
                    
                    # Filter annotation: L·∫•y c√°c h√†nh vi t·ª± th√¢n (agent == target)
                    single_annot = annot_df[(annot_df['agent_id'] == m1) & (annot_df['target_id'] == m1)][['start_frame', 'stop_frame', 'action']]
                    
                    # Select SINGLE_FEATURES (Lo·∫°i b·ªè c√°c feature t∆∞∆°ng t√°c r√°c)
                    single_pandas = single_pl.select(SINGLE_FEATURES).to_pandas()
                    single_pandas['label'] = 'other'
                    single_pandas['frame'] = single_pl['frame'].to_numpy()
                    
                    for _, r in single_annot.iterrows():
                        mask = (single_pandas['frame'] >= r['start_frame']) & (single_pandas['frame'] < r['stop_frame'])
                        single_pandas.loc[mask, 'label'] = r['action']
                    
                    single_pandas['video_id'] = vid
                    single_chunks.append(single_pandas)

            except Exception as e:
                print(f"Err {vid}: {e}")
                continue

        if len(pair_chunks) > 0:
            full_pair = pd.concat(pair_chunks, ignore_index=True)
            # Downcast float64 -> float32 ƒë·ªÉ gi·∫£m dung l∆∞·ª£ng
            cols = [c for c in full_pair.columns if full_pair[c].dtype == 'float64']
            full_pair[cols] = full_pair[cols].astype(np.float32)
            
            full_pair.to_parquet(f'{OUTPUT_PATH}/train_pair_group_{grp}.parquet', index=False)
            print(f"‚úÖ Saved PAIR Group {grp}: {full_pair.shape}")
            del full_pair

        if len(single_chunks) > 0:
            full_single = pd.concat(single_chunks, ignore_index=True)
            cols = [c for c in full_single.columns if full_single[c].dtype == 'float64']
            full_single[cols] = full_single[cols].astype(np.float32)
            
            full_single.to_parquet(f'{OUTPUT_PATH}/train_single_group_{grp}.parquet', index=False)
            print(f"‚úÖ Saved SINGLE Group {grp}: {full_single.shape}")
            del full_single
    
        del pair_chunks, single_chunks
        gc.collect()

if not os.path.exists(f'{OUTPUT_PATH}/group_mapping.json'):
    process_train_data_grouped()
else:
    # Logic ƒë·ªÉ ch·∫°y l·∫°i n·∫øu mu·ªën override d·ªØ li·ªáu c≈©
    print("‚ö†Ô∏è Data exists. Deleting and reprocessing to update features...")
    process_train_data_grouped()

üî• [ETL] B·∫Øt ƒë·∫ßu x·ª≠ l√Ω d·ªØ li·ªáu theo nh√≥m (PLATINUM - FULL RELOAD)...
‚öôÔ∏è Processing Group 31269b (7 videos)...


Group 31269b:   0%|          | 0/7 [00:00<?, ?it/s]

‚úÖ Saved PAIR Group 31269b: (1744248, 32)
‚úÖ Saved SINGLE Group 31269b: (689252, 15)
‚öôÔ∏è Processing Group bb01ae (10 videos)...


Group bb01ae:   0%|          | 0/10 [00:00<?, ?it/s]

‚úÖ Saved PAIR Group bb01ae: (5881764, 32)
‚úÖ Saved SINGLE Group bb01ae: (2042368, 15)
‚öôÔ∏è Processing Group 501ce1 (19 videos)...


Group 501ce1:   0%|          | 0/19 [00:00<?, ?it/s]

‚úÖ Saved PAIR Group 501ce1: (10212910, 32)
‚úÖ Saved SINGLE Group 501ce1: (10212910, 15)
‚öôÔ∏è Processing Group 7eae46 (634 videos)...


Group 7eae46:   0%|          | 0/634 [00:00<?, ?it/s]

‚úÖ Saved PAIR Group 7eae46: (23086736, 32)
‚úÖ Saved SINGLE Group 7eae46: (23086736, 15)
‚öôÔ∏è Processing Group 879ca7 (42 videos)...


Group 879ca7:   0%|          | 0/42 [00:00<?, ?it/s]

‚úÖ Saved PAIR Group 879ca7: (2534176, 32)
‚úÖ Saved SINGLE Group 879ca7: (2534176, 15)
‚öôÔ∏è Processing Group 2fe7d0 (17 videos)...


Group 2fe7d0:   0%|          | 0/17 [00:00<?, ?it/s]

‚úÖ Saved PAIR Group 2fe7d0: (899134, 32)
‚úÖ Saved SINGLE Group 2fe7d0: (899134, 15)
‚öôÔ∏è Processing Group a912ef (24 videos)...


Group a912ef:   0%|          | 0/24 [00:00<?, ?it/s]

‚úÖ Saved PAIR Group a912ef: (1774618, 32)
‚úÖ Saved SINGLE Group a912ef: (1774618, 15)
‚öôÔ∏è Processing Group 42815e (89 videos)...


Group 42815e:   0%|          | 0/89 [00:00<?, ?it/s]

‚úÖ Saved PAIR Group 42815e: (1849144, 32)
‚úÖ Saved SINGLE Group 42815e: (1849144, 15)
‚öôÔ∏è Processing Group 1d241d (21 videos)...


Group 1d241d:   0%|          | 0/21 [00:00<?, ?it/s]

‚úÖ Saved PAIR Group 1d241d: (628714, 32)
‚úÖ Saved SINGLE Group 1d241d: (628714, 15)
