# Sequential Two-Stage Contextual Bridge Distillation — Sports-in-the-Wild

This notebook now implements a **sequential 2-stage knowledge bridge** instead of training all three models jointly. The idea:

```
Teacher (Base)  ──►  Stage 1: Train Small  ──►  Stage 2: Small (frozen) teaches Tiny
        (global rich knowledge)              (domain-adapted bridge)         (final efficient student)
```

Why a contextual bridge?
- The direct gap Teacher(Base) → Student(Tiny) can be large (capacity + representation mismatch).
- First adapting an intermediate (Small) yields a *domain-specialized assistant*.
- Then the Tiny model learns from both: retained high-level signal (Teacher) + distilled, compressed domain signal (Assistant).

Outcome: Better stability + improved tiny accuracy vs single-hop distillation.


## Training Flow Overview

We run **two independent training jobs** inside one notebook:

### Stage 1: Teacher → Assistant (Train ViT-Small)
Goal: Produce a strong, domain-adapted SMALL model that will act as a frozen assistant in Stage 2.

Configuration principles:
- Student model = `videomae-small` (becomes the assistant later)
- Assistant influence weights = 0 (no assistant yet)
- Only Teacher → Student logits (and optionally features) KD

### Stage 2: Assistant → Student (Train ViT-Tiny)
Goal: Train the TINY model using BOTH the frozen Small (assistant) and the original Teacher.

Configuration principles:
- Assistant model path = checkpoint directory produced in Stage 1
- Teacher still provides a small stabilizing signal
- Assistant has higher logits weight (primary mentor)

### Advantages of This Design
- Reduces representational jump distance
- Lets the Assistant internalize domain specifics before mentoring Tiny
- Often yields +accuracy versus a direct Teacher→Tiny pipeline

Proceed through sections in order. Skip Stage 1 only if you already have a trained Small checkpoint you want to reuse.


In [1]:
# Environment / Common Imports
import os, json, torch
import torchvision, pytorchvideo, transformers
from huggingface_hub import HfFolder
from datetime import datetime
from transformers import TrainingArguments

from tri_model_distillation.config import TriModelConfig
from tri_model_distillation.models import TriModelDistillationFramework
from tri_model_distillation.trainer import TriModelDistillationTrainer, compute_video_classification_metrics
from tri_model_distillation.utils import (
    setup_logging, load_label_mappings, create_data_loaders,
)

token = os.getenv("HUGGINGFACE_TOKEN")

if token:
    HfFolder.save_token(token)
    print("Hugging Face token successfully loaded from HUGGINGFACE_TOKEN environment variable.")
else:
    print("HUGGINGFACE_TOKEN environment variable not set. If you want to push models to the Hub, please set this variable before starting Jupyter Lab.")

print(torch.__version__)
print('CUDA available:', torch.cuda.is_available())
print("torch:", torch.__version__, "cuda:", torch.version.cuda)
print("torchvision:", torchvision.__version__)
print("pytorchvideo:", pytorchvideo.__version__)
print("has functional_tensor:", hasattr(__import__('torchvision.transforms', fromlist=['']), 'functional_tensor'))



Hugging Face token successfully loaded from HUGGINGFACE_TOKEN environment variable.
2.1.0+cu118
CUDA available: True
torch: 2.1.0+cu118 cuda: 11.8
torchvision: 0.16.0+cu118
pytorchvideo: 0.1.5
has functional_tensor: True


In [2]:
# Dataset + Label Mapping
DATASET_ROOT = 'processed_dataset'
BASE_RUN_DIR = f"./contextual_bridge_runs_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

STAGE1_OUTPUT_DIR = f"{BASE_RUN_DIR}/stage1_teacher_to_small"
STAGE2_OUTPUT_DIR = f"{BASE_RUN_DIR}/stage2_small_to_tiny"
os.makedirs(STAGE1_OUTPUT_DIR, exist_ok=True)
os.makedirs(STAGE2_OUTPUT_DIR, exist_ok=True)
print(BASE_RUN_DIR);

label2id, id2label = load_label_mappings(dataset_root=DATASET_ROOT, train_csv='train.csv', classification_type='multiclass')
num_labels = len(label2id)
print(f"Detected {num_labels} classes")
print(list(label2id.keys())[:30], '...')

./contextual_bridge_runs_20250901_065642
Detected 30 classes
['archery', 'baseball', 'basketball', 'bmx', 'bowling', 'boxing', 'cheerleading', 'discusthrow', 'diving', 'football', 'golf', 'gymnastics', 'hammerthrow', 'highjump', 'hockey', 'hurdling', 'javelin', 'longjump', 'polevault', 'rowing', 'running', 'shotput', 'skating', 'skiing', 'soccer', 'swimming', 'tennis', 'volleyball', 'weight', 'wrestling'] ...


In [3]:
# Stage 1 Skip Logic — Detect existing checkpoint to optionally skip Stage 1 training
import os, glob
stage1_checkpoint_exists = any(
    os.path.exists(os.path.join(STAGE1_OUTPUT_DIR, fname))
    for fname in ['pytorch_model.bin', 'model.safetensors', 'config.json']
)
if not stage1_checkpoint_exists:
    for sub in glob.glob(os.path.join(STAGE1_OUTPUT_DIR, 'checkpoint-*')):
        if os.path.exists(os.path.join(sub, 'pytorch_model.bin')) or os.path.exists(os.path.join(sub, 'model.safetensors')):
            stage1_checkpoint_exists = True
            break

SKIP_STAGE1 = stage1_checkpoint_exists
if SKIP_STAGE1:
    print(f"Existing Stage 1 checkpoint detected in {STAGE1_OUTPUT_DIR}. Stage 1 training will be skipped.")
    print("(Delete or rename the directory to force retraining.)")
else:
    print("No existing Stage 1 checkpoint found. Stage 1 training will run.")

No existing Stage 1 checkpoint found. Stage 1 training will run.


In [4]:
# Stage 1 Configuration (Teacher + Pretrained Assistant → Train Small Student)
# Dual supervision with ONLY logits KD (features/attentions disabled to save memory)
if SKIP_STAGE1:
    print('Stage 1 skipped: loading assistant from existing checkpoint for Stage 2.')
else:
    pretrained_small_ckpt = 'mitegvg/videomae-base-finetuned-ucf101-finetuned-sports-videos-in-the-wild'
    
    stage1_config = TriModelConfig(
        classification_type='multiclass',
        num_labels=num_labels,
        teacher_model_name='mitegvg/videomae-base-finetuned-kinetics-finetuned-sports-videos-in-the-wild',
        assistant_model_name=pretrained_small_ckpt,
        student_model_name=pretrained_small_ckpt,
        temperature=4.0,
        logits_temperature=4.0,
        teacher_logits_weight=1.0,
        assistant_logits_weight=0.5,
        classification_loss_weight=1.0,
        logits_distillation_weight=0.35,
        hidden_layers_to_align=[],
        feature_distillation_weight=0.0,
        attention_distillation_weight=0.0,
        use_pretrained_student=True,
        num_frames=16,
        apply_defaults=False,  # NEW: prevent auto override adding hidden/attn needs
    )
    
    print("User logits weight:", stage1_config.logits_distillation_weight)
    # Memory safety knobs
    TOTAL_TRAIN_SAMPLES = 3364
    per_device_train_batch_size = 2 if torch.cuda.is_available() else 2  # reduced
    gradient_accumulation_steps = 16 if torch.cuda.is_available() else 8  # keep effective batch similar
    effective_batch = per_device_train_batch_size * gradient_accumulation_steps
    stage1_epochs = 12
    steps_per_epoch = TOTAL_TRAIN_SAMPLES // effective_batch
    stage1_total_steps = steps_per_epoch * stage1_epochs
    stage1_warmup = min(500, int(0.1 * stage1_total_steps))
    
    print('Stage 1 config ready (logits-only KD; features/attn disabled)')

User logits weight: 0.35
Stage 1 config ready (logits-only KD; features/attn disabled)


In [5]:
# Stage 1: Initialize Framework + Dataloaders
if SKIP_STAGE1:
    print('Stage 1 skipped: loading assistant from existing checkpoint for Stage 2.')
else:
    setup_logging()

    stage1_framework = TriModelDistillationFramework(
        config=stage1_config,
        num_labels=num_labels,
        label2id=label2id,
        id2label=id2label,
    )

    def _count_params(m):
        return sum(p.numel() for p in m.parameters() if p.requires_grad), sum(p.numel() for p in m.parameters() if not p.requires_grad)

    tr_s, fr_s = _count_params(stage1_framework.student_model)
    tr_t, fr_t = _count_params(stage1_framework.teacher_model)
    tr_a, fr_a = _count_params(stage1_framework.assistant_model)
    print(f"Teacher trainable {tr_t:,} frozen {fr_t:,}")
    print(f"Assistant trainable {tr_a:,} frozen {fr_a:,}")
    print(f"Student trainable {tr_s:,} frozen {fr_s:,}")
    print('Need hidden states:', getattr(stage1_framework,'_need_hidden',True), 'Need attentions:', getattr(stage1_framework,'_need_attn',True))

    train_loader, val_loader, test_loader = create_data_loaders(
        dataset_root=DATASET_ROOT,
        image_processor=stage1_framework.image_processor,
        label2id=label2id,
        batch_size=per_device_train_batch_size,
        num_frames=stage1_config.num_frames,
        num_workers=2,
    )
    print('Data loaders ready')

    # Dry-run memory probe
    if torch.cuda.is_available():
        import gc
        batch = next(iter(train_loader))
        batch = {k: v.to('cuda') if hasattr(v,'to') else v for k,v in batch.items()}
        torch.cuda.empty_cache(); gc.collect(); torch.cuda.reset_peak_memory_stats()
        out = stage1_framework(pixel_values=batch['pixel_values'], labels=batch['labels'], output_hidden_states=False, output_attentions=False)
        loss_probe = out['student'].logits.mean(); loss_probe.backward();
        peak = torch.cuda.max_memory_allocated()/1024/1024
        print(f"Dry-run peak MB: {peak:.1f}")
        del out, batch, loss_probe; torch.cuda.empty_cache(); gc.collect()

2025-09-01 06:57:25,745 - tri_model_distillation.config - INFO - [TriModelConfig] Active components => logits:True features:False attn:False attn_w=(0.556,0.444) apply_defaults=False
2025-09-01 06:57:25,755 - tri_model_distillation.models - INFO - Initializing Tri-Model Distillation Framework for multiclass classification...
2025-09-01 06:57:25,757 - tri_model_distillation.models - INFO - Loading teacher model...
2025-09-01 06:57:28,823 - tri_model_distillation.models - INFO - Loading assistant model...
2025-09-01 06:57:28,824 - tri_model_distillation.models - INFO - Loading assistant model from HuggingFace: mitegvg/videomae-base-finetuned-ucf101-finetuned-sports-videos-in-the-wild
2025-09-01 06:57:31,434 - tri_model_distillation.models - INFO - Loading student model with label alignment...
2025-09-01 06:57:31,565 - tri_model_distillation.models - INFO - All label mappings match perfectly!
2025-09-01 06:57:31,566 - tri_model_distillation.models - INFO - Label mappings actually match - 

Teacher trainable 0 frozen 86,250,270
Assistant trainable 0 frozen 86,250,270
Student trainable 86,250,270 frozen 0
Need hidden states: False Need attentions: True


2025-09-01 06:57:34,357 - tri_model_distillation.utils - INFO - Loaded 420 video paths from val.csv
2025-09-01 06:57:34,377 - tri_model_distillation.utils - INFO - Loaded 422 video paths from test.csv
2025-09-01 06:57:34,378 - tri_model_distillation.utils - INFO - Created data loaders:
2025-09-01 06:57:34,379 - tri_model_distillation.utils - INFO -   Train: 3364 samples, 1682 batches
2025-09-01 06:57:34,380 - tri_model_distillation.utils - INFO -   Val: 420 samples, 210 batches
2025-09-01 06:57:34,381 - tri_model_distillation.utils - INFO -   Test: 422 samples, 211 batches


Data loaders ready
Dry-run peak MB: 11529.2


In [6]:
# Stage 1: Train Small (Assistant-to-be)
if SKIP_STAGE1:
    print('Skipping Stage 1 training; will use existing checkpoint as assistant in Stage 2.')
    stage1_trainer = None
    stage1_framework = None
else:
    # Enable gradient checkpointing for student to reduce activation memory
    if hasattr(stage1_framework.student_model, 'gradient_checkpointing_enable'):
        stage1_framework.student_model.gradient_checkpointing_enable()

    stage1_args = stage1_config.to_training_args(
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_train_batch_size,
        num_train_epochs=stage1_epochs,
        warmup_steps=stage1_warmup,
        evaluation_strategy='epoch',
        logging_strategy='epoch',
        save_strategy='epoch',
        save_total_limit=20,
        output_dir=STAGE1_OUTPUT_DIR,
        overwrite_output_dir=True,
        gradient_accumulation_steps=gradient_accumulation_steps,
        fp16=torch.cuda.is_available(),
        remove_unused_columns=False,
        dataloader_pin_memory=True,
        dataloader_num_workers=0,
        metric_for_best_model='eval_accuracy',
        greater_is_better=True,
        load_best_model_at_end=True,
        report_to=['tensorboard'],
        logging_dir=f'{STAGE1_OUTPUT_DIR}/logs',
    )

    stage1_trainer = TriModelDistillationTrainer(
        framework=stage1_framework,
        distillation_config=stage1_config,
        args=stage1_args,
        train_dataset=train_loader.dataset,
        eval_dataset=val_loader.dataset,
        compute_metrics=lambda eval_pred, **kw: compute_video_classification_metrics(eval_pred, classification_type='multiclass'),
    )

    print('Starting Stage 1 training (logits-only KD, memory optimized)...')
    stage1_train_result = stage1_trainer.train()
    print('Stage 1 training complete')

    stage1_trainer.save_model(STAGE1_OUTPUT_DIR)
    stage1_val_metrics = stage1_trainer.evaluate(eval_dataset=val_loader.dataset)
    print('Stage 1 validation:', stage1_val_metrics)

2025-09-01 06:58:00,036 - tri_model_distillation.trainer - INFO - [Distillation Weights] cls=1.0 feat=0.0 attn=0.0 logits=0.35 T=4.0 logits_T=4.0


Starting Stage 1 training (logits-only KD, memory optimized)...


2025-09-01 06:58:01,772 - tri_model_distillation.trainer - INFO - Projection layers created and ready for training


Epoch,Training Loss,Validation Loss,Classification Loss,Feature Distillation Loss,Attention Distillation Loss,Logits Distillation Loss,Accuracy,Precision Macro,Precision Micro,Precision Weighted,Recall Macro,Recall Micro,Recall Weighted,F1 Macro,F1 Micro,F1 Weighted,Roc Auc Ovr,Roc Auc Ovo,Cohen Kappa,Balanced Accuracy,Top1 Accuracy,Precision,Recall,F1 Score,Class 0 Accuracy,Class 0 F1,Class 1 Accuracy,Class 1 F1,Class 2 Accuracy,Class 2 F1,Class 3 Accuracy,Class 3 F1,Class 4 Accuracy,Class 4 F1,Class 5 Accuracy,Class 5 F1,Class 6 Accuracy,Class 6 F1,Class 7 Accuracy,Class 7 F1,Class 8 Accuracy,Class 8 F1,Class 9 Accuracy,Class 9 F1,Class 10 Accuracy,Class 10 F1,Class 11 Accuracy,Class 11 F1,Class 12 Accuracy,Class 12 F1,Class 13 Accuracy,Class 13 F1,Class 14 Accuracy,Class 14 F1,Class 15 Accuracy,Class 15 F1,Class 16 Accuracy,Class 16 F1,Class 17 Accuracy,Class 17 F1,Class 18 Accuracy,Class 18 F1,Class 19 Accuracy,Class 19 F1,Avg Class Accuracy,Avg Class F1,Top5 Accuracy,Confusion Matrix Trace,Confusion Matrix Total,Num Samples,Num Classes,Classification Type
1,0.5855,1.3014,1.101509,0.0,0.0,0.571116,0.72381,0.710188,0.72381,0.731031,0.713703,0.72381,0.72381,0.70224,0.72381,0.718375,0.975981,0.97548,0.713349,0.713703,0.72381,0.710188,0.713703,0.70224,0.666667,0.64,0.352941,0.4,0.625,0.625,0.7,0.736842,0.941176,0.914286,0.5,0.6,0.684211,0.787879,0.666667,0.615385,0.9375,0.967742,0.6,0.72,0.777778,0.666667,0.941176,0.8,1.0,0.727273,0.6,0.6,0.777778,0.777778,0.6875,0.666667,0.666667,0.571429,0.538462,0.666667,0.916667,0.785714,0.666667,0.705882,0.712343,0.69876,0.935714,304.0,420.0,420,30,multiclass
2,0.3974,1.236174,1.053852,0.0,0.0,0.520918,0.735714,0.722282,0.735714,0.739351,0.722059,0.735714,0.735714,0.71385,0.735714,0.730558,0.978154,0.977587,0.725628,0.722059,0.735714,0.722282,0.722059,0.71385,0.666667,0.727273,0.411765,0.4375,0.625,0.714286,0.75,0.810811,0.882353,0.882353,0.583333,0.666667,0.684211,0.742857,0.666667,0.666667,0.9375,0.9375,0.6,0.692308,0.555556,0.588235,0.941176,0.842105,1.0,0.64,0.7,0.7,0.777778,0.777778,0.75,0.666667,0.666667,0.6,0.538462,0.608696,0.916667,0.846154,0.777778,0.736842,0.721579,0.714235,0.945238,309.0,420.0,420,30,multiclass
3,0.2479,1.292912,1.118959,0.0,0.0,0.49701,0.72619,0.71365,0.72619,0.7315,0.711887,0.72619,0.72619,0.700308,0.72619,0.718053,0.977109,0.976466,0.71568,0.711887,0.72619,0.71365,0.711887,0.700308,0.583333,0.636364,0.411765,0.451613,0.5,0.615385,0.7,0.756757,0.941176,0.941176,0.5,0.6,0.684211,0.764706,0.666667,0.666667,0.9375,0.909091,0.533333,0.666667,0.777778,0.666667,0.882353,0.833333,1.0,0.64,0.6,0.571429,0.888889,0.842105,0.6875,0.628571,0.666667,0.571429,0.384615,0.47619,0.916667,0.814815,0.666667,0.666667,0.696456,0.685982,0.935714,305.0,420.0,420,30,multiclass
4,0.1801,1.303494,1.134222,0.0,0.0,0.483635,0.745238,0.730417,0.745238,0.750309,0.732656,0.745238,0.745238,0.720759,0.745238,0.738382,0.976878,0.976254,0.735535,0.732656,0.745238,0.730417,0.732656,0.720759,0.583333,0.666667,0.411765,0.466667,0.625,0.714286,0.85,0.85,0.941176,0.914286,0.5,0.6,0.684211,0.764706,0.666667,0.64,0.9375,0.9375,0.6,0.72,0.777778,0.7,0.941176,0.842105,1.0,0.666667,0.7,0.636364,0.888889,0.8,0.75,0.685714,0.666667,0.571429,0.461538,0.545455,0.916667,0.846154,0.777778,0.736842,0.734007,0.715242,0.940476,313.0,420.0,420,30,multiclass
5,0.151,1.298142,1.131477,0.0,0.0,0.476186,0.740476,0.725587,0.740476,0.744956,0.730322,0.740476,0.740476,0.71861,0.740476,0.734668,0.977596,0.977093,0.730619,0.730322,0.740476,0.725587,0.730322,0.71861,0.583333,0.636364,0.411765,0.482759,0.625,0.714286,0.75,0.789474,0.941176,0.941176,0.5,0.6,0.684211,0.764706,0.666667,0.666667,0.9375,0.909091,0.6,0.692308,0.666667,0.631579,0.941176,0.864865,1.0,0.695652,0.8,0.64,0.888889,0.8,0.6875,0.647059,0.666667,0.6,0.538462,0.608696,0.916667,0.846154,0.777778,0.736842,0.729173,0.713384,0.938095,311.0,420.0,420,30,multiclass
6,0.1387,1.305276,1.143352,0.0,0.0,0.462639,0.730952,0.723275,0.730952,0.741634,0.720915,0.730952,0.730952,0.710134,0.730952,0.725594,0.977403,0.976825,0.72072,0.720915,0.730952,0.723275,0.720915,0.710134,0.583333,0.666667,0.411765,0.482759,0.625,0.714286,0.75,0.769231,0.941176,0.941176,0.583333,0.666667,0.684211,0.764706,0.666667,0.64,0.9375,0.909091,0.6,0.72,0.777778,0.7,0.882353,0.810811,1.0,0.695652,0.7,0.56,0.888889,0.8,0.6875,0.611111,0.666667,0.6,0.461538,0.571429,0.916667,0.846154,0.777778,0.736842,0.727108,0.710329,0.933333,307.0,420.0,420,30,multiclass
7,0.1315,1.30741,1.148136,0.0,0.0,0.455071,0.738095,0.726614,0.738095,0.743297,0.724517,0.738095,0.738095,0.715558,0.738095,0.731474,0.977331,0.97673,0.728058,0.724517,0.738095,0.726614,0.724517,0.715558,0.583333,0.666667,0.411765,0.482759,0.625,0.714286,0.8,0.8,0.941176,0.941176,0.583333,0.666667,0.684211,0.764706,0.666667,0.64,0.9375,0.9375,0.666667,0.740741,0.666667,0.631579,0.882353,0.789474,1.0,0.727273,0.7,0.583333,0.888889,0.8,0.75,0.648649,0.666667,0.631579,0.461538,0.545455,0.916667,0.846154,0.777778,0.777778,0.73051,0.716789,0.930952,310.0,420.0,420,30,multiclass
8,0.1282,1.311103,1.151739,0.0,0.0,0.455324,0.733333,0.722533,0.733333,0.739572,0.721575,0.733333,0.733333,0.711829,0.733333,0.727112,0.977386,0.976786,0.723146,0.721575,0.733333,0.722533,0.721575,0.711829,0.583333,0.666667,0.411765,0.482759,0.625,0.714286,0.75,0.789474,0.941176,0.941176,0.583333,0.666667,0.684211,0.742857,0.666667,0.64,0.9375,0.9375,0.6,0.692308,0.666667,0.631579,0.882353,0.789474,1.0,0.727273,0.7,0.583333,0.888889,0.8,0.6875,0.611111,0.666667,0.6,0.461538,0.545455,0.916667,0.846154,0.777778,0.777778,0.721552,0.709292,0.930952,308.0,420.0,420,30,multiclass
9,0.1254,1.319571,1.161316,0.0,0.0,0.452159,0.72381,0.714004,0.72381,0.732532,0.710397,0.72381,0.72381,0.702385,0.72381,0.719397,0.977221,0.9766,0.713268,0.710397,0.72381,0.714004,0.710397,0.702385,0.583333,0.666667,0.411765,0.466667,0.625,0.714286,0.7,0.736842,0.941176,0.941176,0.583333,0.666667,0.684211,0.764706,0.666667,0.64,0.9375,0.9375,0.6,0.692308,0.555556,0.555556,0.882353,0.789474,1.0,0.695652,0.7,0.56,0.888889,0.8,0.6875,0.611111,0.666667,0.6,0.461538,0.48,0.833333,0.833333,0.777778,0.777778,0.70933,0.696486,0.930952,304.0,420.0,420,30,multiclass
10,0.1237,1.321066,1.163064,0.0,0.0,0.451435,0.72381,0.71323,0.72381,0.729963,0.710397,0.72381,0.72381,0.702682,0.72381,0.718588,0.977215,0.976574,0.713238,0.710397,0.72381,0.71323,0.710397,0.702682,0.583333,0.666667,0.411765,0.466667,0.625,0.714286,0.7,0.736842,0.941176,0.941176,0.583333,0.666667,0.684211,0.742857,0.666667,0.64,0.9375,0.9375,0.6,0.692308,0.555556,0.555556,0.882353,0.789474,1.0,0.727273,0.7,0.583333,0.888889,0.8,0.6875,0.611111,0.666667,0.6,0.461538,0.48,0.833333,0.833333,0.777778,0.777778,0.70933,0.698141,0.930952,304.0,420.0,420,30,multiclass


2025-09-01 07:23:28,275 - tri_model_distillation.models - INFO - Student model saved to ./contextual_bridge_runs_20250901_065642/stage1_teacher_to_small\checkpoint-106
2025-09-01 07:23:28,277 - tri_model_distillation.trainer - INFO - Tri-model distillation framework saved to ./contextual_bridge_runs_20250901_065642/stage1_teacher_to_small\checkpoint-106
2025-09-01 07:49:04,874 - tri_model_distillation.models - INFO - Student model saved to ./contextual_bridge_runs_20250901_065642/stage1_teacher_to_small\checkpoint-212
2025-09-01 07:49:04,875 - tri_model_distillation.trainer - INFO - Tri-model distillation framework saved to ./contextual_bridge_runs_20250901_065642/stage1_teacher_to_small\checkpoint-212
2025-09-01 08:14:13,429 - tri_model_distillation.models - INFO - Student model saved to ./contextual_bridge_runs_20250901_065642/stage1_teacher_to_small\checkpoint-318
2025-09-01 08:14:13,431 - tri_model_distillation.trainer - INFO - Tri-model distillation framework saved to ./contextual

Stage 1 training complete


2025-09-01 11:51:39,997 - tri_model_distillation.models - INFO - Student model saved to ./contextual_bridge_runs_20250901_065642/stage1_teacher_to_small
2025-09-01 11:51:39,998 - tri_model_distillation.trainer - INFO - Tri-model distillation framework saved to ./contextual_bridge_runs_20250901_065642/stage1_teacher_to_small


Stage 1 validation: {'eval_loss': 1.3041460293034712, 'eval_classification_loss': 1.1342222193446927, 'eval_feature_distillation_loss': 0.0, 'eval_attention_distillation_loss': 0.0, 'eval_logits_distillation_loss': 0.4854966187406154, 'eval_accuracy': 0.7452380952380953, 'eval_precision_macro': 0.7304171713601538, 'eval_precision_micro': 0.7452380952380953, 'eval_precision_weighted': 0.7503088751960932, 'eval_recall_macro': 0.7326563720745867, 'eval_recall_micro': 0.7452380952380953, 'eval_recall_weighted': 0.7452380952380953, 'eval_f1_macro': 0.7207590301376926, 'eval_f1_micro': 0.7452380952380953, 'eval_f1_weighted': 0.7383824491967869, 'eval_roc_auc_ovr': 0.9768781388594466, 'eval_roc_auc_ovo': 0.9762542086524343, 'eval_cohen_kappa': 0.7355350501388823, 'eval_balanced_accuracy': 0.7326563720745867, 'eval_top1_accuracy': 0.7452380952380953, 'eval_precision': 0.7304171713601538, 'eval_recall': 0.7326563720745867, 'eval_f1_score': 0.7207590301376926, 'eval_class_0_accuracy': 0.58333333

### Stage 2 Rationale
The trained Small model now serves as a **domain-adapted bridge**. In Stage 2 we:
- Freeze the Small checkpoint (loaded via its output directory)
- Keep a light stabilizing signal from the original Base teacher (lower weight)
- Emphasize logits distillation from the Assistant (higher weight)

Tuning tips:
- If Tiny underfits early: increase `assistant_logits_weight` or `logits_temperature`
- If overfitting: reduce `classification_loss_weight` slightly or add light feature distillation
- If training unstable: raise `teacher_logits_weight` to 0.4–0.5 for extra regularization


In [7]:
# Stage 2 Configuration (Assistant → Tiny)
# If Stage 1 skipped, we still expect STAGE1_OUTPUT_DIR to already contain a trained small model
import math
if SKIP_STAGE1:
    assert os.path.exists(os.path.join(STAGE1_OUTPUT_DIR, 'config.json')), 'Expected existing Stage 1 checkpoint missing.'

# Stage 2 Configuration (Assistant → Tiny) — full logits + feature + attention KD
stage2_config = TriModelConfig(
    classification_type='multiclass',
    num_labels=num_labels,
    teacher_model_name='mitegvg/videomae-base-finetuned-kinetics-finetuned-sports-videos-in-the-wild',
    assistant_model_name=STAGE1_OUTPUT_DIR,
    student_model_name='mitegvg/videomae-tiny-12-finetuned-kinetics-finetuned-sports-videos-in-the-wild',
    use_pretrained_student=True,
    use_tiny_student=False,

 # --- Logits KD (now uses both teacher & assistant) ---
    temperature=4.0,
    logits_temperature=4.0,
    teacher_logits_weight=0.1,
    assistant_logits_weight=0.8,
    logits_distillation_weight=0.20,  

    # --- Feature & Attention KD (turn features on; keep attention moderate) ---
    feature_distillation_weight=0.12,  
    attention_distillation_weight=0.3,
    hidden_layers_to_align=[-8, -4, -1],  # mid + late
    teacher_feature_weight=0.1,
    assistant_feature_weight=0.7,
    teacher_attention_weight=0.1,
    assistant_attention_weight=0.7,

    # Head squeeze stays on (assistant often has more heads than tiny)
    attention_head_squeeze=True,
    attention_head_squeeze_mode="learned",
    head_squeeze_ortho_weight=1e-3,

    # --- Temporal KD (motion cues) ---
    temporal_delta_distillation_weight=0.05,  # ↑ from 0.00
    temporal_delta_layers=[-4, -1],

    # --- Fusion off for now (optional later once stable) ---
    enable_layer_fusion=False,

    # --- CE weight slightly reduced to let KD act ---
    classification_loss_weight=0.50, 

    # Runtime
    num_frames=16,
    require_hidden_states=True,
    require_attentions=True,

    apply_defaults=False,
)
for k in sorted(stage2_config.__dataclass_fields__.keys()):
    print(f"{k}: {getattr(stage2_config, k)!r}")
stage2_epochs = 20
TOTAL_TRAIN_SAMPLES = 3364
per_device_train_batch_size = 2
gradient_accumulation_steps = 16 if torch.cuda.is_available() else 8
effective_batch = per_device_train_batch_size * gradient_accumulation_steps  # 32 if CUDA
steps_per_epoch = math.ceil(TOTAL_TRAIN_SAMPLES / effective_batch)          # ≈105, not 3364
stage2_total_steps = steps_per_epoch * stage2_epochs                        # ≈2100
stage2_warmup = min(100, int(0.05 * stage2_total_steps))                    # shorter warmup (≤100)
print(f"Stage2: steps/epoch={steps_per_epoch}, total_steps={stage2_total_steps}, warmup={stage2_warmup}")
print('Stage 2 config ready (full KD: logits + features + attention)')

2025-09-01 12:22:53,719 - tri_model_distillation.config - INFO - [TriModelConfig] Active components => logits:True features:True attn:True attn_w=(0.100,0.700) apply_defaults=False


align_attention_maps: True
align_hidden_states: True
apply_defaults: False
assistant_attention_weight: 0.7
assistant_feature_weight: 0.7
assistant_logits_weight: 0.8
assistant_model_name: './contextual_bridge_runs_20250901_065642/stage1_teacher_to_small'
assistant_model_path: None
attention_distillation_weight: 0.3
attention_head_squeeze: True
attention_head_squeeze_mode: 'learned'
classification_loss_weight: 0.5
classification_type: 'multiclass'
dataset_root: 'processed_dataset'
enable_layer_fusion: False
eval_strategy: None
evaluation_strategy: 'epoch'
feature_distillation_weight: 0.12
fusion_assistant_weight: 0.5
fusion_projection_dim: None
fusion_source: 'teacher'
fusion_source_weighting: 'learned'
fusion_teacher_weight: 0.5
head_squeeze_ortho_weight: 0.001
hidden_layers_to_align: [-8, -4, -1]
image_size: 224
kd_conf_threshold: 0.5
kd_dynamic_lower: True
kd_min_keep_ratio: 0.25
layer_fusion_assistant_layers: []
layer_fusion_mode: 'attention'
layer_fusion_teacher_layers: []
logging_

In [8]:
# Stage 2: Initialize + Train Tiny with Frozen Assistant
# Clean up Stage 1 objects to free GPU memory before constructing Stage 2 framework

from tri_model_distillation import make_metrics_fn
if torch.cuda.is_available():
    import gc
    if not SKIP_STAGE1:
        del stage1_trainer, stage1_framework
    torch.cuda.empty_cache(); gc.collect()
    print('CUDA memory cache cleared before Stage 2 initialization.')

stage2_framework = TriModelDistillationFramework(
    config=stage2_config,
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
)

per_device_train_batch_size = 2
gradient_accumulation_steps = 16 if torch.cuda.is_available() else 8  # keep effective batch similar
effective_batch = per_device_train_batch_size * gradient_accumulation_steps
train_loader2, val_loader2, test_loader2 = create_data_loaders(
    dataset_root=DATASET_ROOT,
    image_processor=stage2_framework.image_processor,
    label2id=label2id,
    batch_size=per_device_train_batch_size,
    num_frames=stage2_config.num_frames,
    num_workers=2,
)

stage2_args = stage2_config.to_training_args(
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_train_batch_size,
    num_train_epochs=stage2_epochs,
    warmup_steps=stage2_warmup,
    learning_rate=1e-4,   
    eval_strategy='epoch',
    logging_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=20,
    output_dir=STAGE2_OUTPUT_DIR,
    overwrite_output_dir=True,
    gradient_accumulation_steps=gradient_accumulation_steps,
    fp16=torch.cuda.is_available(),
    remove_unused_columns=False,
    dataloader_pin_memory=True,
    dataloader_num_workers=0,
    metric_for_best_model='eval_accuracy',
    greater_is_better=True,
    load_best_model_at_end=True,
    report_to=['tensorboard', 'wandb'],
    logging_dir=f'{STAGE2_OUTPUT_DIR}/logs',
)
metrics_fn = make_metrics_fn("multiclass")

stage2_trainer = TriModelDistillationTrainer(
    framework=stage2_framework,
    distillation_config=stage2_config,
    args=stage2_args,
    train_dataset=train_loader2.dataset,
    eval_dataset=val_loader2.dataset, 
    compute_metrics=metrics_fn,
)

print('Starting Stage 2 training...')
stage2_train_result = stage2_trainer.train()
print('Stage 2 training complete')

stage2_trainer.save_model(STAGE2_OUTPUT_DIR)
stage2_val_metrics = stage2_trainer.evaluate(eval_dataset=val_loader2.dataset)
print('Stage 2 validation:', stage2_val_metrics)

# For downstream evaluation cells
OUTPUT_DIR = STAGE2_OUTPUT_DIR
framework = stage2_framework
val_loader = val_loader2
test_loader = test_loader2
print('OUTPUT_DIR set to final student:', OUTPUT_DIR)

2025-09-01 12:23:07,475 - tri_model_distillation.config - INFO - [TriModelConfig] Active components => logits:True features:True attn:True attn_w=(0.100,0.700) apply_defaults=False
2025-09-01 12:23:07,476 - tri_model_distillation.models - INFO - Initializing Tri-Model Distillation Framework for multiclass classification...
2025-09-01 12:23:07,476 - tri_model_distillation.models - INFO - Loading teacher model...


CUDA memory cache cleared before Stage 2 initialization.


2025-09-01 12:23:10,024 - tri_model_distillation.models - INFO - Loading assistant model...
2025-09-01 12:23:10,025 - tri_model_distillation.models - INFO - Loading assistant model from HuggingFace: ./contextual_bridge_runs_20250901_065642/stage1_teacher_to_small
2025-09-01 12:23:11,650 - tri_model_distillation.models - INFO - Loading student model with label alignment...
2025-09-01 12:23:12,137 - tri_model_distillation.models - INFO - All label mappings match perfectly!
2025-09-01 12:23:12,138 - tri_model_distillation.models - INFO - Label mappings actually match - loading pretrained model directly
2025-09-01 12:23:13,147 - tri_model_distillation.utils - INFO - Loaded 3364 video paths from train.csv
2025-09-01 12:23:13,177 - tri_model_distillation.utils - INFO - Loaded 420 video paths from val.csv
2025-09-01 12:23:13,197 - tri_model_distillation.utils - INFO - Loaded 422 video paths from test.csv
2025-09-01 12:23:13,198 - tri_model_distillation.utils - INFO - Created data loaders:
202

Starting Stage 2 training...


2025-09-01 12:23:19,355 - tri_model_distillation.trainer - INFO - Projection layers created and ready for training
2025-09-01 12:23:19,358 - tri_model_distillation.trainer - INFO - Adding 3 projection parameters to optimizer
wandb: Currently logged in as: mite_gvg (mitegvg) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


Epoch,Training Loss,Validation Loss,Classification Loss,Feature Distillation Loss,Attention Distillation Loss,Logits Distillation Loss,Accuracy,Precision,Recall,F1,Num Classes,Macro Precision,Macro Recall,Macro F1,Weighted Precision,Weighted Recall,Weighted F1
1,5.7775,3.982114,1.998168,0.927477,0.278194,1.255882,0.407143,0.410689,0.397559,0.389603,30,0.410689,0.397559,0.389603,0.415928,0.407143,0.398665
2,2.975,2.838579,2.003227,0.845385,0.226984,1.245669,0.42619,0.425178,0.41083,0.40299,30,0.425178,0.41083,0.40299,0.437394,0.42619,0.41721
3,2.3322,2.485517,1.964822,0.806331,0.206436,1.195152,0.447619,0.448147,0.434412,0.422762,30,0.448147,0.434412,0.422762,0.461906,0.447619,0.438309
4,2.0358,2.277722,1.923853,0.787017,0.192784,1.153426,0.457143,0.466436,0.440411,0.432965,30,0.466436,0.440411,0.432965,0.484178,0.457143,0.450915
5,1.8361,2.158631,1.934674,0.773372,0.187151,1.110267,0.461905,0.452365,0.436109,0.429925,30,0.452365,0.436109,0.429925,0.467158,0.461905,0.44924
6,1.6772,2.058522,1.901449,0.763389,0.184924,1.105419,0.457143,0.451413,0.440632,0.434543,30,0.451413,0.440632,0.434543,0.47093,0.457143,0.453889
7,1.5492,1.986145,1.9005,0.757892,0.180599,1.068101,0.457143,0.454981,0.440467,0.435891,30,0.454981,0.440467,0.435891,0.464411,0.457143,0.448582
8,1.4464,1.946484,1.928643,0.750219,0.179111,1.053794,0.466667,0.47349,0.45361,0.453366,30,0.47349,0.45361,0.453366,0.480912,0.466667,0.463782
9,1.3483,1.93397,1.991104,0.741436,0.174957,1.055611,0.459524,0.465014,0.451968,0.443603,30,0.465014,0.451968,0.443603,0.478601,0.459524,0.455486
10,1.2678,1.893058,1.976049,0.733905,0.174435,1.054556,0.457143,0.462886,0.445439,0.440216,30,0.462886,0.445439,0.440216,0.479724,0.457143,0.456012


2025-09-01 12:43:55,318 - tri_model_distillation.models - INFO - Student model saved to ./contextual_bridge_runs_20250901_065642/stage2_small_to_tiny\checkpoint-106
2025-09-01 12:43:55,319 - tri_model_distillation.trainer - INFO - Tri-model distillation framework saved to ./contextual_bridge_runs_20250901_065642/stage2_small_to_tiny\checkpoint-106
2025-09-01 13:04:32,354 - tri_model_distillation.models - INFO - Student model saved to ./contextual_bridge_runs_20250901_065642/stage2_small_to_tiny\checkpoint-212
2025-09-01 13:04:32,356 - tri_model_distillation.trainer - INFO - Tri-model distillation framework saved to ./contextual_bridge_runs_20250901_065642/stage2_small_to_tiny\checkpoint-212
2025-09-01 13:25:19,300 - tri_model_distillation.models - INFO - Student model saved to ./contextual_bridge_runs_20250901_065642/stage2_small_to_tiny\checkpoint-318
2025-09-01 13:25:19,302 - tri_model_distillation.trainer - INFO - Tri-model distillation framework saved to ./contextual_bridge_runs_20

Stage 2 training complete


2025-09-01 19:14:57,887 - tri_model_distillation.models - INFO - Student model saved to ./contextual_bridge_runs_20250901_065642/stage2_small_to_tiny
2025-09-01 19:14:57,889 - tri_model_distillation.trainer - INFO - Tri-model distillation framework saved to ./contextual_bridge_runs_20250901_065642/stage2_small_to_tiny


Stage 2 validation: {'eval_loss': 1.7921031492097037, 'eval_classification_loss': 1.981936843090114, 'eval_feature_distillation_loss': 0.7165044177146185, 'eval_attention_distillation_loss': 0.17290812801747096, 'eval_logits_distillation_loss': 1.0398140084175838, 'eval_accuracy': 0.4738095238095238, 'eval_precision': 0.4792325782805659, 'eval_recall': 0.4645112283561716, 'eval_f1': 0.4603697425047173, 'eval_num_classes': 30, 'eval_macro_precision': 0.4792325782805659, 'eval_macro_recall': 0.4645112283561716, 'eval_macro_f1': 0.4603697425047173, 'eval_weighted_precision': 0.4872105208786986, 'eval_weighted_recall': 0.4738095238095238, 'eval_weighted_f1': 0.4695475831518839, 'epoch': 19.81807372175981}
OUTPUT_DIR set to final student: ./contextual_bridge_runs_20250901_065642/stage2_small_to_tiny


In [9]:
# Replace the problematic last cell with this corrected version:

import os
import torch
import time
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
import cv2
from tqdm import tqdm

print("Starting evaluation on the full test set...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model directly (not using pipeline)
local_model_directory = OUTPUT_DIR
print(OUTPUT_DIR)
student_model = VideoMAEForVideoClassification.from_pretrained(local_model_directory)
processor = VideoMAEImageProcessor.from_pretrained(local_model_directory)
student_model.to(device)
student_model.eval()

def process_video_for_inference(video_path, processor, num_frames=16):
    """Process video exactly like the training pipeline"""
    try:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            return None
        
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if frame_count < num_frames:
            cap.release()
            return None
            
        # Sample frames uniformly (same as training)
        frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
        frames = []
        
        for idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
        
        cap.release()
        
        if len(frames) != num_frames:
            return None
            
        # Process frames using the same processor
        inputs = processor(frames, return_tensors="pt")
        return inputs
        
    except Exception as e:
        print(f"Error processing video {video_path}: {e}")
        return None

# Load test data from CSV
def load_test_data_from_csv(csv_file_path, data_root_path):
    test_samples = []
    if not os.path.exists(csv_file_path):
        print(f"ERROR: Test CSV file not found at {csv_file_path}")
        return test_samples

    with open(csv_file_path, "r") as f:
        for line_num, line in enumerate(f.readlines(), 1):
            parts = line.strip().split()
            if len(parts) >= 2:
                relative_video_path = parts[0]
                true_label_str = parts[1]
                full_video_path = os.path.normpath(os.path.join(data_root_path, relative_video_path))
                test_samples.append((full_video_path, true_label_str))
            elif line.strip():
                print(f"Warning: Malformed line {line_num} in {csv_file_path}: '{line.strip()}'")
                
    print(f"Loaded {len(test_samples)} samples from {csv_file_path}")
    return test_samples

# Load test data
dataset_root_path = "processed_dataset"
test_csv_path = os.path.join(dataset_root_path, "test.csv")
test_data = load_test_data_from_csv(test_csv_path, dataset_root_path)

if test_data:
    total_videos_processed = 0
    videos_skipped = 0
    true_labels = []
    predicted_labels = []
    top1_correct_predictions = 0
    top5_correct_predictions = 0
    inference_times = []

    print(f"\nStarting inference on {len(test_data)} test videos...")
    
    with torch.no_grad():
        for i, (video_path, true_label) in enumerate(tqdm(test_data, desc="Processing videos")):
            if not os.path.exists(video_path):
                videos_skipped += 1
                continue

            # Process video using the same pipeline as training
            inputs = process_video_for_inference(video_path, processor)
            if inputs is None:
                videos_skipped += 1
                continue

            try:
                start_time = time.time()
                
                # Move inputs to device
                inputs = {k: v.to(device) for k, v in inputs.items()}
                
                # Get model predictions
                outputs = student_model(**inputs)
                logits = outputs.logits[0]  # Remove batch dimension
                probs = torch.softmax(logits, dim=0)
                
                # Get top-5 predictions
                top5_probs, top5_indices = torch.topk(probs, 5)
                
                end_time = time.time()
                inference_times.append(end_time - start_time)
                total_videos_processed += 1

                # Convert indices to labels using id2label
                predicted_labels_top5 = [id2label[idx.item()] for idx in top5_indices]
                predicted_label_top1 = predicted_labels_top5[0]

                # Store for metrics calculation
                predicted_labels.append(predicted_label_top1)
                true_labels.append(true_label)

                # Calculate top-k accuracy
                if predicted_label_top1 == true_label:
                    top1_correct_predictions += 1
                if true_label in predicted_labels_top5:
                    top5_correct_predictions += 1

            except Exception as e:
                print(f"Error during inference for {video_path}: {e}")
                videos_skipped += 1

    # Print results
    if total_videos_processed > 0:
        top1_accuracy = (top1_correct_predictions / total_videos_processed) * 100
        top5_accuracy = (top5_correct_predictions / total_videos_processed) * 100
        avg_inference_time = sum(inference_times) / len(inference_times)
        fps = 1.0 / avg_inference_time if avg_inference_time > 0 else float('inf')

        print("\n--- Evaluation Complete ---")
        print(f"Total videos in test set: {len(test_data)}")
        print(f"Videos successfully processed: {total_videos_processed}")
        print(f"Videos skipped (missing/corrupt): {videos_skipped}")
        print(f"Top-1 Correct Predictions: {top1_correct_predictions}")
        print(f"Top-5 Correct Predictions: {top5_correct_predictions}")
        print(f"Top-1 Accuracy: {top1_accuracy:.2f}%")
        print(f"Top-5 Accuracy: {top5_accuracy:.2f}%")
        print(f"Average inference time per video: {avg_inference_time:.3f} seconds ({fps:.2f} videos/sec)")

        # Classification report
        if len(predicted_labels) == len(true_labels) and len(true_labels) > 0:
            print("\nDetailed Classification Report:")
            print(classification_report(true_labels, predicted_labels, labels=list(label2id.keys()), zero_division=0))
    else:
        print("No videos were processed successfully.")
else:
    print("No test data loaded.")

Starting evaluation on the full test set...
./contextual_bridge_runs_20250830_111437/stage2_small_to_tiny
Loaded 422 samples from processed_dataset\test.csv

Starting inference on 422 test videos...


Processing videos: 100%|█████████████████████████████████████████| 422/422 [01:08<00:00,  6.19it/s]



--- Evaluation Complete ---
Total videos in test set: 422
Videos successfully processed: 422
Videos skipped (missing/corrupt): 0
Top-1 Correct Predictions: 208
Top-5 Correct Predictions: 347
Top-1 Accuracy: 49.29%
Top-5 Accuracy: 82.23%
Average inference time per video: 0.014 seconds (69.98 videos/sec)

Detailed Classification Report:
              precision    recall  f1-score   support

     archery       0.62      0.62      0.62        13
    baseball       0.57      0.67      0.62        18
  basketball       0.42      0.42      0.42        12
         bmx       0.43      0.27      0.33        11
     bowling       0.38      0.50      0.43        10
      boxing       0.25      0.09      0.13        11
cheerleading       0.44      0.21      0.29        19
 discusthrow       0.25      0.50      0.33         4
      diving       0.50      0.73      0.59        11
    football       0.61      0.70      0.65        20
        golf       0.50      0.55      0.52        11
  gymnastics 