
# MABe CTR-GCN Submission (Educational Scaffold)

This notebook is an **educational, fully annotated scaffold** for building a CTR-GCN-based solution for the
[MABe Mouse Behavior Detection](https://www.kaggle.com/competitions/MABe-mouse-behavior-detection) challenge.

It mirrors the structure of your `CTRGCN-model-baseline.py` file, but focuses on **clarity and learning**:
- Configuration and modes
- Data loading and batching
- Skeleton definition and adjacency
- CTR-GCN model classes (1/2/4 stream variants)
- Input preparation (coords, deltas, bones)
- Training, tuning, and submission hooks

Most heavy logic is left as **TODO blocks** so you can paste in or adapt your full implementations.



## 1. High-Level Pipeline

Three major phases in the CTR-GCN workflow:

1. **Offline Training / Validation** (local machine or cluster)
2. **Offline Hyperparameter Tuning** (optional, per-stream)
3. **Online Submission** (Kaggle: inference-only, using pre-trained models)

This notebook is primarily designed for **inference and education**, assuming heavy training
is done offline using your Python script.



## 2. Configuration Panel

Edit this cell to control run mode, stream mode, and root paths.


In [None]:

# ==== RUN / STREAM MODES =====================================================

# What should this notebook do?
# - "dev"    : tiny sanity-check on a small subset (once TODO blocks are filled)
# - "submit" : load pre-trained models and generate submission.csv on Kaggle
RUN_MODE = "dev"   # change to "submit" for actual Kaggle submission

# CTR-GCN stream mode:
# - "one"  : single stream with merged coords/delta/bone/bone_delta
# - "two"  : two streams (coords+bone, delta+bone_delta)
# - "four" : four independent streams
STREAM_MODE = "one"

# Paths (adjust for local vs Kaggle)
DATA_ROOT = "../input/MABe-mouse-behavior-detection"
MODEL_ROOT = "CTR-GCN-Models"
TUNING_ROOT = "tuning_results"



## 3. Imports


In [None]:

import os
import json
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F



## 4. CTR-GCN Configuration & Helpers


In [None]:

from dataclasses import dataclass

@dataclass
class CTRGCNConfig:
    """Configuration object for the CTR-GCN pipeline (simplified educational version)."""

    mode: str = "dev"
    stream_mode: str = "one"

    max_videos: int | None = 2
    max_batches: int | None = 5
    max_windows: int | None = 50

    use_delta: bool = True
    use_bone: bool = True
    use_bone_delta: bool = True

    in_channels_single_stream: int = 8
    in_channels_streamA: int = 4
    in_channels_streamB: int = 4
    in_channels_coords_only: int = 2
    in_channels_delta_only: int = 2
    in_channels_bone_only: int = 2
    in_channels_bone_delta_only: int = 2


def get_stream_mode_tag(cfg: CTRGCNConfig) -> str:
    mode = getattr(cfg, "stream_mode", "one")
    assert mode in {"one", "two", "four"}
    return mode


def get_stream_model_dir(cfg: CTRGCNConfig) -> str:
    root = MODEL_ROOT
    tag = get_stream_mode_tag(cfg)
    sub = {
        "one": "one_stream",
        "two": "two_stream",
        "four": "four_stream",
    }[tag]
    path = os.path.join(root, sub)
    os.makedirs(path, exist_ok=True)
    return path


In [None]:

def get_best_params_path_for_stream(cfg: CTRGCNConfig) -> str:
    os.makedirs(TUNING_ROOT, exist_ok=True)
    mode = getattr(cfg, "stream_mode", "one")
    assert mode in {"one", "two", "four"}
    return os.path.join(TUNING_ROOT, f"best_params_{mode}.csv")


def load_best_params_csv_for_config(config: CTRGCNConfig) -> dict | None:
    path = get_best_params_path_for_stream(config)
    if not os.path.exists(path):
        return None
    try:
        df = pd.read_csv(path)
        return df.to_dict(orient="records")[0]
    except Exception as e:
        print(f"Warning: Could not read {path}: {e}")
        return None



## 5. Mouse Skeleton: Joints & Adjacency


In [None]:

MASTER_MOUSE_JOINT_ORDER = [
    "nose",
    "head",
    "headpiece_topfrontleft",
    "headpiece_topfrontright",
    "headpiece_topbackleft",
    "headpiece_topbackright",
    "headpiece_bottomfrontleft",
    "headpiece_bottomfrontright",
    "headpiece_bottombackleft",
    "headpiece_bottombackright",
    "ear_left",
    "ear_right",
    "neck",
    "forepaw_left",
    "forepaw_right",
    "body_center",
    "lateral_left",
    "lateral_right",
    "spine_1",
    "spine_2",
    "hip_left",
    "hip_right",
    "hindpaw_left",
    "hindpaw_right",
    "tail_base",
    "tail_middle_1",
    "tail_middle_2",
    "tail_midpoint",
    "tail_tip",
]


def get_ordered_joints_and_adjacency(body_parts_tracked: list[str]) -> tuple[list[str], np.ndarray]:
    ordered_joints = [bp for bp in MASTER_MOUSE_JOINT_ORDER if bp in body_parts_tracked]
    V = len(ordered_joints)
    adjacency = np.zeros((V, V), dtype=np.float32)
    for i in range(V - 1):
        adjacency[i, i + 1] = 1.0
        adjacency[i + 1, i] = 1.0
    return ordered_joints, adjacency



## 6. CTR-GCN Model Definitions


In [None]:

def _normalize_adjacency_chain(adjacency: np.ndarray) -> np.ndarray:
    V = adjacency.shape[0]
    A = adjacency.astype(np.float32).copy()
    A += np.eye(V, dtype=np.float32)
    row_sum = A.sum(axis=1, keepdims=True)
    row_sum[row_sum == 0.0] = 1.0
    return A / row_sum


class GraphConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, adjacency: np.ndarray):
        super().__init__()
        A_norm = _normalize_adjacency_chain(adjacency)
        self.register_buffer("A", torch.from_numpy(A_norm))
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.einsum("ncvT,vw->ncwT", x, self.A)
        x = self.conv(x)
        return x


class STBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, adjacency: np.ndarray,
                 stride: int = 1, dropout: float = 0.1):
        super().__init__()
        self.gcn = GraphConv(in_channels, out_channels, adjacency)
        self.tcn = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels,
                kernel_size=(1, 3),
                padding=(0, 1),
                stride=(1, stride),
                bias=True,
            ),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(dropout),
        )
        if (in_channels != out_channels) or (stride != 1):
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=(1, stride), bias=False),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.residual = nn.Identity()
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = self.residual(x)
        x = self.gcn(x)
        x = self.tcn(x)
        x = x + res
        x = self.relu(x)
        return x


In [None]:

class CTRGCNMinimal(nn.Module):
    def __init__(self, in_channels: int, num_classes: int,
                 adjacency: np.ndarray, base_channels: int = 64,
                 num_blocks: int = 3, dropout: float = 0.1):
        super().__init__()
        channels = [base_channels] * num_blocks
        blocks = []
        last_c = in_channels
        for out_c in channels:
            blocks.append(
                STBlock(last_c, out_c, adjacency, stride=1, dropout=dropout)
            )
            last_c = out_c
        self.st_blocks = nn.ModuleList(blocks)
        self.fc = nn.Linear(last_c, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = x
        for block in self.st_blocks:
            out = block(out)
        out = out.mean(dim=(-2, -1))
        logits = self.fc(out)
        return logits


In [None]:

class CTRGCNTwoStream(nn.Module):
    def __init__(self, adjacency: np.ndarray,
                 in_channels_coords: int = 4,
                 in_channels_delta: int = 4,
                 base_channels: int = 64,
                 num_blocks: int = 3,
                 dropout: float = 0.1):
        super().__init__()
        self.stream_coords = CTRGCNMinimal(
            in_channels_coords, base_channels, adjacency,
            base_channels, num_blocks, dropout
        )
        self.stream_delta = CTRGCNMinimal(
            in_channels_delta, base_channels, adjacency,
            base_channels, num_blocks, dropout
        )
        self.fc = nn.Linear(base_channels, 1)

    def forward(self, coords_x: torch.Tensor, delta_x: torch.Tensor) -> torch.Tensor:
        feat_A = self.stream_coords(coords_x)
        feat_B = self.stream_delta(delta_x)
        fused = feat_A + feat_B
        logits = self.fc(fused)
        return logits


class CTRGCNFourStream(nn.Module):
    def __init__(self, adjacency: np.ndarray,
                 base_channels: int = 64,
                 dropout: float = 0.1,
                 num_blocks: int = 3):
        super().__init__()
        self.stream_coords = CTRGCNMinimal(2, base_channels, adjacency, base_channels, num_blocks, dropout)
        self.stream_delta = CTRGCNMinimal(2, base_channels, adjacency, base_channels, num_blocks, dropout)
        self.stream_bone = CTRGCNMinimal(2, base_channels, adjacency, base_channels, num_blocks, dropout)
        self.stream_bone_delta = CTRGCNMinimal(2, base_channels, adjacency, base_channels, num_blocks, dropout)
        self.fc = nn.Linear(base_channels, 1)

    def forward(self, coords_x, delta_x, bone_x, bone_delta_x):
        f1 = self.stream_coords(coords_x)
        f2 = self.stream_delta(delta_x)
        f3 = self.stream_bone(bone_x)
        f4 = self.stream_bone_delta(bone_delta_x)
        fused = f1 + f2 + f3 + f4
        return self.fc(fused)



## 7. Data Loading & Batch Generation (Placeholder)


In [None]:

def generate_mouse_data(dataset: pd.DataFrame,
                        traintest: str,
                        traintest_directory: str | None = None,
                        generate_single: bool = True,
                        generate_pair: bool = True,
                        config: CTRGCNConfig | None = None):
    """Generate batches of single-mouse or mouse-pair dataframes.

    TODO: Paste your full generate_mouse_data implementation from CTRGCN-model-baseline.py here.
    """
    raise NotImplementedError("Paste generate_mouse_data implementation here.")



## 8. Sliding Window Extraction & Input Preparation (Placeholders)


In [None]:

def create_sliding_windows(single_mouse_df: pd.DataFrame,
                           window: int = 90,
                           stride: int = 30):
    n_frames = len(single_mouse_df)
    frames = single_mouse_df.index.to_numpy()
    for start in range(0, n_frames - window + 1, stride):
        end = start + window
        window_df = single_mouse_df.iloc[start:end]
        frame_indices = frames[start:end]
        yield window_df, frame_indices


In [None]:

def prepare_ctr_gcn_input(single_mouse_df: pd.DataFrame,
                          ordered_joints: list[str],
                          config: CTRGCNConfig | None = None):
    """Convert a single-mouse DataFrame into CTR-GCN-ready tensors.

    TODO: Paste your full prepare_ctr_gcn_input implementation here,
          including normalization, deltas, bones, and stream-mode branching.
    """
    raise NotImplementedError("Paste prepare_ctr_gcn_input implementation here.")



## 9. Window Collection, Training & Inference (Placeholders)


In [None]:

def collect_ctr_gcn_windows(batches,
                            ordered_joints: list[str],
                            adjacency: np.ndarray,
                            config: CTRGCNConfig,
                            device: str = "cpu"):
    """Collect CTR-GCN windows and labels from batches.

    TODO: Paste or implement the helper that loops over batches,
          calls prepare_ctr_gcn_input, and aggregates tensors + labels.
    """
    raise NotImplementedError("Paste collect_ctr_gcn_windows implementation here.")


In [None]:

def train_ctr_gcn_models(batches,
                         ordered_joints: list[str],
                         adjacency: np.ndarray,
                         config: CTRGCNConfig,
                         device: str = "cpu"):
    """Train one CTR-GCN model per action and save weights to disk.

    TODO: Paste your train_ctr_gcn_models implementation here.
    """
    raise NotImplementedError("Paste train_ctr_gcn_models implementation here.")


def load_ctr_gcn_models(actions: list[str],
                        adjacency: np.ndarray,
                        config: CTRGCNConfig,
                        device: str = "cpu"):
    """Load CTR-GCN models for the given actions and current stream_mode.

    TODO: Paste your load_ctr_gcn_models implementation here.
    """
    raise NotImplementedError("Paste load_ctr_gcn_models implementation here.")


def submit_ctr_gcn(body_parts_tracked_str: str,
                   switch_tr: str,
                   model_dict: dict[str, nn.Module],
                   config: CTRGCNConfig,
                   device: str = "cpu") -> pd.DataFrame:
    """Run inference on test data and return a submission DataFrame.

    TODO: Paste your submit_ctr_gcn implementation here.
    """
    raise NotImplementedError("Paste submit_ctr_gcn implementation here.")



## 10. Dev Sanity Check (Optional)


In [None]:

if RUN_MODE == "dev":
    print("Dev sanity-check placeholder.")
    print("Once you've pasted implementations:")
    print(" - Load a tiny subset of train.csv")
    print(" - Build ordered_joints, adjacency")
    print(" - Generate a few batches via generate_mouse_data")
    print(" - Call prepare_ctr_gcn_input on a single batch")
    print(" - Instantiate CTRGCNMinimal and run a forward pass")
