# Simple Teacher-to-Student Distillation with VideoMAE
This notebook demonstrates a minimal workflow for distilling a VideoMAE teacher model into a student model using **feature (hidden state) alignment only**. Logit alignment is not used, as the teacher and student have different class spaces. The workflow includes model loading, a processor function, a single training step example, and model saving.

In [43]:
# Install required libraries
!pip install transformers torch

import os
from huggingface_hub import HfFolder

# Read token from environment variable (more secure)
# You can set this environment variable before running the notebook
# export HUGGINGFACE_TOKEN=your_token_here (Linux/Mac)
# set HUGGINGFACE_TOKEN=your_token_here (Windows)
token = os.getenv("HUGGINGFACE_TOKEN")

if token:
    HfFolder.save_token(token)
    print("Hugging Face token successfully loaded from HUGGINGFACE_TOKEN environment variable.")
else:
    print("HUGGINGFACE_TOKEN environment variable not set. If you want to push models to the Hub, please set this variable before starting Jupyter Lab.")

# Commenting out other options to keep the cell clean
# Option 1: Set token directly in code (not recommended for shared notebooks)
# HfFolder.save_token("your_token_here")

# Option 3: Load token from a file (more secure)
# token_path = "path/to/token.txt"
# if os.path.exists(token_path):
#     with open(token_path, "r") as f:
#         token = f.read().strip()
#     HfFolder.save_token(token)
#     print("Hugging Face token successfully loaded from file.")
# else:
#     print(f"Token file not found at {token_path}")

Hugging Face token successfully loaded from HUGGINGFACE_TOKEN environment variable.


In [44]:
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
import torch
import torch.nn as nn

# Load teacher model (pretrained, frozen)
teacher_model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base").eval()
for param in teacher_model.parameters():
    param.requires_grad = False

# Load student model (same architecture or smaller, can be randomly initialized or fine-tuned)
student_model = VideoMAEForVideoClassification.from_pretrained("mitegvg/videomae-base-finetuned-xd-violence")
student_model.config.mask_ratio = 0.0  # full visibility for student

# Load processor for preprocessing videos
image_processor = VideoMAEImageProcessor.from_pretrained("MCG-NJU/videomae-base")

Some weights of VideoMAEForVideoClassification were not initialized from the model checkpoint at MCG-NJU/videomae-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [45]:
# Define distillation loss: feature (hidden state) alignment only
mse_loss = nn.MSELoss()
loss_fn = nn.CrossEntropyLoss()

def distillation_loss(student_outputs, teacher_outputs):
    """
    Compute distillation loss as MSE between student and teacher hidden states (last layer).
    Logit alignment is not used due to different class spaces.
    """
    student_hidden = student_outputs.hidden_states[-1]
    teacher_hidden = teacher_outputs.hidden_states[-1]
    hidden_loss = mse_loss(student_hidden, teacher_hidden)
    return hidden_loss

In [46]:
import os
import numpy as np
import cv2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load label mapping from the main dataset
label2id = {}
id2label = {}
with open(os.path.join("processed_dataset", "train.csv"), "r") as f:
    for line in f.readlines():
        parts = line.strip().split()
        if len(parts) > 1:
            label = parts[1].split("-")[0]  # Only use the main class (e.g., 'A', 'B1', ...)
            if label not in label2id:
                idx = len(label2id)
                label2id[label] = idx
                id2label[idx] = label

# Pick a video and its label from the training set
with open(os.path.join("processed_dataset", "train.csv"), "r") as f:
    first_line = f.readline().strip().split()
    video_rel_path = first_line[0]
    label_str = first_line[1].split("-")[0]  # Only use the main class
    label_idx = label2id[label_str]

video_path = os.path.join("processed_dataset", video_rel_path)

def read_video_frames(video_path, num_frames=16):
    cap = cv2.VideoCapture(video_path)
    frames = []
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    step = max(total_frames // num_frames, 1)
    for i in range(num_frames):
        cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (224, 224))
        frame = frame.astype(np.float32) / 255.0  # normalize to [0, 1]
        frames.append(frame)
    cap.release()
    return frames

video_frames = read_video_frames(video_path, num_frames=16)
videos = [video_frames]  # batch of 1 video
labels = torch.tensor([label_idx])  # Use correct integer label

# Move models and labels to device
teacher_model = teacher_model.to(device)
student_model = student_model.to(device)
labels = labels.to(device)

In [47]:
# Set device and define processor function before using them



def processor(videos, return_tensors="pt"):
    """
    Preprocess a list of videos (as numpy arrays or tensors) for VideoMAE.
    Set do_rescale=False if input is already in [0, 1].
    """
    return image_processor(videos, return_tensors=return_tensors, do_rescale=False)

In [48]:
# Example training step (single batch)
inputs = processor(videos, return_tensors="pt").to(device)

# Forward pass for teacher (no grad)
with torch.no_grad():
    teacher_outputs = teacher_model(**inputs, output_hidden_states=True)

# Forward pass for student
student_outputs = student_model(**inputs, output_hidden_states=True)

# Compute classification loss
classification_loss = loss_fn(student_outputs.logits, labels)

# Compute distillation loss
kd_loss = distillation_loss(student_outputs, teacher_outputs)

# Combine losses
loss = classification_loss + kd_loss

# Backpropagation and optimizer step
optimizer = torch.optim.Adam(student_model.parameters(), lr=5e-5)
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"Distillation step complete. Total loss: {loss.item():.4f}")

Distillation step complete. Total loss: 1569.2889


In [49]:
# Save the distilled student model (as safetensors if available)
student_model.save_pretrained("distilled_videomae_student", safe_serialization=True)
print("Student model saved to 'distilled_videomae_student'.")

Student model saved to 'distilled_videomae_student'.
