# 06 — Form Classification: Baseline vs 3D CNN vs ST-GCN

Binary form classification (correct vs incorrect) using three approaches:
1. **Logistic Regression** baseline on per-video angle statistics
2. **3D CNN (R3D-18)** fine-tuned on raw video frames
3. **ST-GCN** on skeleton keypoint sequences

All evaluated on the same 5-fold stratified cross-validation splits.

In [None]:
import sys
sys.path.insert(0, "..")

import json
import logging
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
from pathlib import Path
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from src.classification.stgcn import PushUpSTGCN
from src.classification.video_classifier import PushUpVideoClassifier
from src.classification.datasets import PushUpSkeletonDataset, PushUpVideoDataset
from src.classification.train_classifier import run_kfold_cv

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

# Load manifest
with open("../data/processed/keypoints/manifest.json") as f:
    manifest = json.load(f)

# Build video_ids and labels
video_ids = sorted(manifest.keys())
labels = [0 if manifest[v]["label"] == "correct" else 1 for v in video_ids]
labels_arr = np.array(labels)

print(f"Total videos: {len(video_ids)}")
print(f"  Correct: {sum(1 for l in labels if l == 0)}")
print(f"  Incorrect: {sum(1 for l in labels if l == 1)}")

# Device selection
if torch.cuda.is_available():
    device_str = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device_str = "mps"
else:
    device_str = "cpu"
print(f"Using device: {device_str}")

# Common CV config
RANDOM_STATE = 42
N_SPLITS = 5
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)

## Section 1: Baseline — Logistic Regression

Uses 16 per-video angle statistics (mean/min/max/range for elbow, back, hip, knee angles) from `feature_summary.csv`.

In [None]:
# Load pre-computed feature summary
feat_df = pd.read_csv("../data/processed/features/feature_summary.csv")
feat_df = feat_df.set_index("video_id")

# Build feature matrix with 16 features (mean/min/max/range for 4 angles)
angle_names = ["elbow", "back", "hip", "knee"]
stat_names = ["mean", "min", "max"]

# Existing columns: {angle}_{stat}
feature_cols = []
for angle in angle_names:
    for stat in stat_names:
        col = f"{angle}_{stat}"
        feature_cols.append(col)

# Add range features
for angle in angle_names:
    range_col = f"{angle}_range"
    feat_df[range_col] = feat_df[f"{angle}_max"] - feat_df[f"{angle}_min"]
    feature_cols.append(range_col)

# Align with video_ids order
X_baseline = feat_df.loc[video_ids, feature_cols].values.astype(np.float32)
y_baseline = labels_arr

print(f"Feature matrix: {X_baseline.shape}")
print(f"Features: {feature_cols}")

# Cross-validated predictions using the SAME splits
pipeline = Pipeline([
    ("scaler", StandardScaler()),
    ("lr", LogisticRegression(max_iter=1000, random_state=RANDOM_STATE)),
])

baseline_preds = cross_val_predict(pipeline, X_baseline, y_baseline, cv=skf)
baseline_acc = accuracy_score(y_baseline, baseline_preds)

# Per-fold accuracy
baseline_fold_accs = []
for fold, (train_idx, val_idx) in enumerate(skf.split(X_baseline, y_baseline)):
    fold_acc = accuracy_score(y_baseline[val_idx], baseline_preds[val_idx])
    baseline_fold_accs.append(fold_acc)
    print(f"  Fold {fold}: accuracy={fold_acc:.4f} (n={len(val_idx)})")

print(f"\nBaseline overall accuracy: {baseline_acc:.4f}")
print(f"Per-fold mean: {np.mean(baseline_fold_accs):.4f} +/- {np.std(baseline_fold_accs):.4f}")
print()
print(classification_report(y_baseline, baseline_preds, target_names=["correct", "incorrect"]))

## Section 2: 3D CNN (R3D-18)

Pretrained on Kinetics-400, frozen backbone, fine-tune FC layer only (1,026 trainable params).

In [None]:
VIDEO_DIR = Path("../data/raw/kaggle_pushups")

def make_r3d_model():
    return PushUpVideoClassifier(freeze_backbone=True, num_classes=2)

def make_video_dataset(ids):
    return PushUpVideoDataset(
        manifest=manifest,
        video_dir=VIDEO_DIR,
        video_ids=ids,
        n_frames=16,
    )

print("Training R3D-18 with 5-fold stratified CV...")
print(f"  Trainable params: {sum(p.numel() for p in make_r3d_model().parameters() if p.requires_grad)}")

r3d_results = run_kfold_cv(
    model_factory=make_r3d_model,
    dataset_factory=make_video_dataset,
    video_ids=video_ids,
    labels=labels,
    n_splits=N_SPLITS,
    n_epochs=30,
    batch_size=8,
    lr=1e-3,
    patience=10,
    device_str=device_str,
    random_state=RANDOM_STATE,
)

r3d_fold_accs = [f["val_accuracy"] for f in r3d_results["fold_results"]]
print(f"\nR3D-18 per-fold accuracies: {[f'{a:.4f}' for a in r3d_fold_accs]}")
print(f"Mean: {np.mean(r3d_fold_accs):.4f} +/- {np.std(r3d_fold_accs):.4f}")

## Section 3: ST-GCN

Spatial-Temporal Graph Convolutional Network on torso-normalized skeleton sequences (~245K params).

In [None]:
KEYPOINT_DIR = Path("../data/processed/keypoints/yolo")

def make_stgcn_model():
    return PushUpSTGCN(in_channels=2, num_classes=2, dropout=0.2)

def make_skeleton_dataset(ids):
    return PushUpSkeletonDataset(
        manifest=manifest,
        keypoint_dir=KEYPOINT_DIR,
        video_ids=ids,
        max_frames=150,
        normalize=True,
    )

print("Training ST-GCN with 5-fold stratified CV...")
n_params = sum(p.numel() for p in make_stgcn_model().parameters() if p.requires_grad)
print(f"  Trainable params: {n_params:,}")

stgcn_results = run_kfold_cv(
    model_factory=make_stgcn_model,
    dataset_factory=make_skeleton_dataset,
    video_ids=video_ids,
    labels=labels,
    n_splits=N_SPLITS,
    n_epochs=50,
    batch_size=8,
    lr=1e-3,
    patience=15,
    device_str=device_str,
    random_state=RANDOM_STATE,
)

stgcn_fold_accs = [f["val_accuracy"] for f in stgcn_results["fold_results"]]
print(f"\nST-GCN per-fold accuracies: {[f'{a:.4f}' for a in stgcn_fold_accs]}")
print(f"Mean: {np.mean(stgcn_fold_accs):.4f} +/- {np.std(stgcn_fold_accs):.4f}")

## Section 4: Comparison

Per-fold accuracy table and bar chart with error bars across all three methods.

In [None]:
# Per-fold accuracy table
comparison_data = {
    "Fold": list(range(N_SPLITS)),
    "Baseline (LR)": baseline_fold_accs,
    "R3D-18": r3d_fold_accs,
    "ST-GCN": stgcn_fold_accs,
}
df_comparison = pd.DataFrame(comparison_data)

# Add mean row
mean_row = pd.DataFrame([{
    "Fold": "Mean",
    "Baseline (LR)": np.mean(baseline_fold_accs),
    "R3D-18": np.mean(r3d_fold_accs),
    "ST-GCN": np.mean(stgcn_fold_accs),
}])
df_display = pd.concat([df_comparison, mean_row], ignore_index=True)
print(df_display.to_string(index=False))

# Bar chart with error bars
methods = ["Baseline (LR)", "R3D-18", "ST-GCN"]
means = [np.mean(baseline_fold_accs), np.mean(r3d_fold_accs), np.mean(stgcn_fold_accs)]
stds = [np.std(baseline_fold_accs), np.std(r3d_fold_accs), np.std(stgcn_fold_accs)]
colors = ["#4C72B0", "#DD8452", "#55A868"]

fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(methods, means, yerr=stds, capsize=8, color=colors, alpha=0.85, edgecolor="black")

# Add value labels on bars
for bar, mean, std in zip(bars, means, stds):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + std + 0.01,
            f"{mean:.1%}", ha="center", va="bottom", fontweight="bold")

ax.set_ylabel("Accuracy")
ax.set_title("Form Classification: 5-Fold CV Accuracy Comparison")
ax.set_ylim(0, 1.15)
ax.axhline(0.5, color="gray", linestyle="--", alpha=0.5, label="Chance")
ax.legend()
plt.tight_layout()
plt.savefig("../outputs/figures/06_accuracy_comparison.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# Confusion matrices side-by-side
fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))
class_names = ["correct", "incorrect"]

# Baseline
cm_baseline = confusion_matrix(y_baseline, baseline_preds)
sns.heatmap(cm_baseline, annot=True, fmt="d", cmap="Blues",
            xticklabels=class_names, yticklabels=class_names, ax=axes[0])
axes[0].set_title(f"Baseline (LR) — {baseline_acc:.1%}")
axes[0].set_xlabel("Predicted")
axes[0].set_ylabel("True")

# R3D-18
r3d_true = [r3d_results["per_video_true"][v] for v in video_ids]
r3d_pred = [r3d_results["per_video_preds"][v] for v in video_ids]
cm_r3d = confusion_matrix(r3d_true, r3d_pred)
r3d_acc = accuracy_score(r3d_true, r3d_pred)
sns.heatmap(cm_r3d, annot=True, fmt="d", cmap="Oranges",
            xticklabels=class_names, yticklabels=class_names, ax=axes[1])
axes[1].set_title(f"R3D-18 — {r3d_acc:.1%}")
axes[1].set_xlabel("Predicted")
axes[1].set_ylabel("True")

# ST-GCN
stgcn_true = [stgcn_results["per_video_true"][v] for v in video_ids]
stgcn_pred = [stgcn_results["per_video_preds"][v] for v in video_ids]
cm_stgcn = confusion_matrix(stgcn_true, stgcn_pred)
stgcn_acc = accuracy_score(stgcn_true, stgcn_pred)
sns.heatmap(cm_stgcn, annot=True, fmt="d", cmap="Greens",
            xticklabels=class_names, yticklabels=class_names, ax=axes[2])
axes[2].set_title(f"ST-GCN — {stgcn_acc:.1%}")
axes[2].set_xlabel("Predicted")
axes[2].set_ylabel("True")

plt.tight_layout()
plt.savefig("../outputs/figures/06_confusion_matrices.png", dpi=150, bbox_inches="tight")
plt.show()

## Section 5: Error Analysis

Per-video predictions — identify hard videos and cross-reference with angle statistics.

In [None]:
# Per-video prediction summary
results_rows = []
for vid_id in video_ids:
    true_label = 0 if manifest[vid_id]["label"] == "correct" else 1
    row = {
        "video_id": vid_id,
        "true_label": manifest[vid_id]["label"],
        "baseline_pred": "correct" if baseline_preds[video_ids.index(vid_id)] == 0 else "incorrect",
        "r3d_pred": "correct" if r3d_results["per_video_preds"][vid_id] == 0 else "incorrect",
        "stgcn_pred": "correct" if stgcn_results["per_video_preds"][vid_id] == 0 else "incorrect",
    }
    row["baseline_correct"] = row["baseline_pred"] == row["true_label"]
    row["r3d_correct"] = row["r3d_pred"] == row["true_label"]
    row["stgcn_correct"] = row["stgcn_pred"] == row["true_label"]
    row["n_correct_models"] = sum([row["baseline_correct"], row["r3d_correct"], row["stgcn_correct"]])
    results_rows.append(row)

df_results = pd.DataFrame(results_rows)

# Hard videos: misclassified by at least 2 methods
hard_videos = df_results[df_results["n_correct_models"] <= 1]
print(f"Hard videos (misclassified by >= 2 methods): {len(hard_videos)}")
if len(hard_videos) > 0:
    print(hard_videos[["video_id", "true_label", "baseline_pred", "r3d_pred", "stgcn_pred"]].to_string(index=False))

# Cross-reference with angle stats for hard videos
if len(hard_videos) > 0:
    print("\nAngle statistics for hard videos:")
    hard_ids = hard_videos["video_id"].tolist()
    hard_feats = feat_df.loc[hard_ids, ["elbow_mean", "elbow_min", "back_mean", "hip_mean"]]
    print(hard_feats.to_string())

    # Duration info
    print("\nDuration info:")
    for vid_id in hard_ids:
        dur = manifest[vid_id].get("duration_s", "?")
        n_frames = manifest[vid_id].get("n_frames", "?")
        print(f"  {vid_id}: {dur}s, {n_frames} frames")

# Agreement analysis
all_agree = (df_results["n_correct_models"] == 3).sum()
two_agree = (df_results["n_correct_models"] >= 2).sum()
print(f"\nAll 3 methods correct: {all_agree}/{len(video_ids)} ({all_agree/len(video_ids):.1%})")
print(f"Majority (>= 2) correct: {two_agree}/{len(video_ids)} ({two_agree/len(video_ids):.1%})")

## Section 6: Save Results

In [None]:
# Save results
RESULTS_DIR = Path("../outputs/results")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

FIGURES_DIR = Path("../outputs/figures")
FIGURES_DIR.mkdir(parents=True, exist_ok=True)

MODEL_DIR = Path("../models")
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Save per-video results
df_results.to_csv(RESULTS_DIR / "form_classification_results.csv", index=False)
print(f"Saved: {RESULTS_DIR / 'form_classification_results.csv'} ({len(df_results)} rows)")

# Save comparison summary
summary = pd.DataFrame({
    "Method": methods,
    "Mean_Accuracy": means,
    "Std_Accuracy": stds,
    "Per_Fold": [baseline_fold_accs, r3d_fold_accs, stgcn_fold_accs],
})
summary.to_csv(RESULTS_DIR / "form_classification_summary.csv", index=False)
print(f"Saved: {RESULTS_DIR / 'form_classification_summary.csv'}")

# Save best ST-GCN model
if stgcn_results["best_state"] is not None:
    torch.save(stgcn_results["best_state"], MODEL_DIR / "stgcn_best.pt")
    print(f"Saved best ST-GCN model: {MODEL_DIR / 'stgcn_best.pt'}")

print("\n=== Summary ===")
for method, acc, std in zip(methods, means, stds):
    print(f"  {method:15s}: {acc:.1%} +/- {std:.1%}")
print(f"\nFigures saved to: {FIGURES_DIR}")