# Binary Violence Classification with VideoMAE Tiny Model
This notebook demonstrates how to fine-tune a VideoMAE tiny model for binary classification on the Violence-XD dataset. Videos are classified as either "safe" (non-violent) or "unsafe" (violent), where class A represents safe content and all other classes (B1, B2, B4, B5, B6, G) represent unsafe content.

In [1]:
# Install required libraries
!pip install transformers pytorchvideo datasets evaluate



In [2]:
import os
import torch
from huggingface_hub import HfFolder

# Read token from environment variable (more secure)
token = os.getenv("HUGGINGFACE_TOKEN")
if token:
    HfFolder.save_token(token)
    print("Hugging Face token successfully loaded from HUGGINGFACE_TOKEN environment variable.")
else:
    print("HUGGINGFACE_TOKEN environment variable not set. If you want to push models to the Hub, please set this variable before starting Jupyter Lab.")

Hugging Face token successfully loaded from HUGGINGFACE_TOKEN environment variable.


## Load Violence XD dataset and create binary labels

In [3]:
# Set the path to the local processed dataset folder
dataset_root_path = "processed_dataset"
all_video_file_paths = []

with open(os.path.join(dataset_root_path, "train.csv"), "r") as f:
    train_paths = [line.strip().split()[0] for line in f.readlines()]
    all_video_file_paths.extend([os.path.join(dataset_root_path, path) for path in train_paths])
with open(os.path.join(dataset_root_path, "val.csv"), "r") as f:
    val_paths = [line.strip().split()[0] for line in f.readlines()]
    all_video_file_paths.extend([os.path.join(dataset_root_path, path) for path in val_paths])
with open(os.path.join(dataset_root_path, "test.csv"), "r") as f:
    test_paths = [line.strip().split()[0] for line in f.readlines()]
    all_video_file_paths.extend([os.path.join(dataset_root_path, path) for path in test_paths])
print(f"Total video files: {len(all_video_file_paths)}")

Total video files: 4227


In [4]:
# Create binary classification mapping
# A = safe (0), all others (B1, B2, B4, B5, B6, G) = unsafe (1)
def map_to_binary_label(original_label):
    """Map original granular labels to binary labels"""
    main_category = original_label.split('-')[0]
    if main_category == 'A':
        return 'safe'
    else:
        return 'unsafe'

# Get labels from CSV files and create binary mapping
labels = []
binary_labels = []
for split in ["train.csv", "val.csv", "test.csv"]:
    with open(os.path.join(dataset_root_path, split), "r") as f:
        for line in f.readlines():
            parts = line.strip().split()
            if len(parts) > 1:
                original_label = parts[1]
                binary_label = map_to_binary_label(original_label)
                labels.append(original_label)
                binary_labels.append(binary_label)

# Create binary label mappings
class_labels = ['safe', 'unsafe']  # 0: safe, 1: unsafe
label2id = {'safe': 0, 'unsafe': 1}
id2label = {0: 'safe', 1: 'unsafe'}

print(f"Binary classes: {len(label2id)} - {class_labels}")
print(f"Total samples: {len(binary_labels)}")
print(f"Safe samples: {binary_labels.count('safe')}")
print(f"Unsafe samples: {binary_labels.count('unsafe')}")

Binary classes: 2 - ['safe', 'unsafe']
Total samples: 4227
Safe samples: 2046
Unsafe samples: 2181


In [5]:
import torch

# Calculate class weights for binary classification
# Count occurrences in training set
train_binary_labels = []
with open(os.path.join(dataset_root_path, "train.csv"), "r") as f:
    for line in f.readlines():
        parts = line.strip().split()
        if len(parts) > 1:
            original_label = parts[1]
            binary_label = map_to_binary_label(original_label)
            train_binary_labels.append(binary_label)

safe_count = train_binary_labels.count('safe')
unsafe_count = train_binary_labels.count('unsafe')
total_train = len(train_binary_labels)

print(f"Training set distribution:")
print(f"Safe: {safe_count} ({100*safe_count/total_train:.1f}%)")
print(f"Unsafe: {unsafe_count} ({100*unsafe_count/total_train:.1f}%)")

# Calculate class weights (inverse frequency)
safe_weight = total_train / (2 * safe_count)
unsafe_weight = total_train / (2 * unsafe_count)

class_weights = torch.tensor([safe_weight, unsafe_weight], dtype=torch.float)
print(f"Class weights - Safe: {safe_weight:.3f}, Unsafe: {unsafe_weight:.3f}")

Training set distribution:
Safe: 1632 (48.4%)
Unsafe: 1740 (51.6%)
Class weights - Safe: 1.033, Unsafe: 0.969


## Define and initialize a tiny VideoMAE model for binary classification

In [6]:
from transformers import VideoMAEConfig, VideoMAEForVideoClassification, VideoMAEImageProcessor

# Tiny model configuration for binary classification
binary_config = VideoMAEConfig(
    num_hidden_layers=4,  # Reduced from default
    hidden_size=384,
    intermediate_size=1536,
    num_attention_heads=6,
    image_size=224,
    num_frames=16,
    num_labels=2,  # Binary classification
    label2id=label2id,
    id2label=id2label,
    mask_ratio=0.0,
)

model = VideoMAEForVideoClassification(binary_config)
image_processor = VideoMAEImageProcessor.from_pretrained("MCG-NJU/videomae-base")
print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")

Model initialized with 7688066 parameters


## Prepare the datasets for binary classification training

In [7]:
import pytorchvideo.data
from pytorchvideo.transforms import (
    ApplyTransformToKey, Normalize, RandomShortSideScale, RemoveKey, ShortSideScale, UniformTemporalSubsample,
)
from torchvision.transforms import (
    Compose, Lambda, RandomCrop, RandomHorizontalFlip, Resize,
)

mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]
resize_to = (height, width)

num_frames_to_sample = model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps

batch_size = 8

train_transform = Compose([
    ApplyTransformToKey(
        key="video",
        transform=Compose([
            UniformTemporalSubsample(num_frames_to_sample),
            Lambda(lambda x: x / 255.0),
            Normalize(mean, std),
            RandomShortSideScale(min_size=256, max_size=320),
            RandomCrop(resize_to),
            RandomHorizontalFlip(p=0.5),
        ]),
    ),
])

def load_labeled_video_paths_binary(csv_filename, root_dir_for_csv_paths, binary_label_map):
    """Load video paths with binary labels"""
    labeled_paths = []
    csv_path = os.path.join(root_dir_for_csv_paths, csv_filename)
    with open(csv_path, "r") as f:
        for line in f.readlines():
            parts = line.strip().split()
            if len(parts) >= 2:
                video_path_in_csv = parts[0]
                original_label = parts[1]
                binary_label = map_to_binary_label(original_label)
                full_video_path = os.path.join(root_dir_for_csv_paths, video_path_in_csv)
                
                if binary_label in binary_label_map:
                    label_id = binary_label_map[binary_label]
                    labeled_paths.append((full_video_path, {"label": label_id}))
                else:
                    print(f"Warning: Binary label '{binary_label}' not in label2id map for video {full_video_path}. Skipping.")
            elif line.strip():
                print(f"Warning: Malformed line in {csv_path}: '{line.strip()}'")
    return labeled_paths

labeled_video_paths_train = load_labeled_video_paths_binary("train.csv", dataset_root_path, label2id)
train_dataset = pytorchvideo.data.LabeledVideoDataset(
    labeled_video_paths=labeled_video_paths_train,
    clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
    decode_audio=False,
    transform=train_transform,
)



In [8]:
val_transform = Compose([
    ApplyTransformToKey(
        key="video",
        transform=Compose([
            UniformTemporalSubsample(num_frames_to_sample),
            Lambda(lambda x: x / 255.0),
            Normalize(mean, std),
            Resize(resize_to),
        ]),
    ),
])

labeled_video_paths_val = load_labeled_video_paths_binary("val.csv", dataset_root_path, label2id)
val_dataset = pytorchvideo.data.LabeledVideoDataset(
    labeled_video_paths=labeled_video_paths_val,
    clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
    decode_audio=False,
    transform=val_transform,
)

labeled_video_paths_test = load_labeled_video_paths_binary("test.csv", dataset_root_path, label2id)
test_dataset = pytorchvideo.data.LabeledVideoDataset(
    labeled_video_paths=labeled_video_paths_test,
    clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
    decode_audio=False,
    transform=val_transform,
)

In [9]:
print(f"Dataset sizes - Train: {train_dataset.num_videos}, Val: {val_dataset.num_videos}, Test: {test_dataset.num_videos}")

Dataset sizes - Train: 3372, Val: 430, Test: 425


## Training and evaluation setup for binary classification

In [10]:
from transformers import TrainingArguments, Trainer
import numpy as np
import evaluate
import torch
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

model_name = "videomae-tiny-binary"
new_model_name = f"{model_name}-finetuned-xd-violence"
num_epochs = 4

args = TrainingArguments(
    new_model_name,
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    push_to_hub=True,
    max_steps=(train_dataset.num_videos // batch_size) * num_epochs,
)

# Enhanced metrics for binary classification
def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    labels = eval_pred.label_ids
    
    # Calculate various metrics
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    accuracy = np.mean(predictions == labels)
    
    # Confusion matrix for additional insights
    tn, fp, fn, tp = confusion_matrix(labels, predictions).ravel()
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,
        'true_positives': tp,
        'true_negatives': tn,
        'false_positives': fp,
        'false_negatives': fn
    }

def collate_fn(examples):
    # Permute to (num_frames, num_channels, height, width)
    pixel_values = torch.stack([
        example["video"].permute(1, 0, 2, 3) for example in examples
    ])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

# Custom Trainer with weighted loss for binary classification
class WeightedBinaryTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        # Move class_weights to the device of the logits
        weights_on_device = self.class_weights.to(logits.device)
        
        # Define the loss function with weights
        loss_fct = torch.nn.CrossEntropyLoss(weight=weights_on_device)
        
        # Compute loss
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        
        return (loss, outputs) if return_outputs else loss

# Instantiate the custom trainer
trainer = WeightedBinaryTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    class_weights=class_weights
)

train_results = trainer.train()

  super().__init__(*args, **kwargs)
wandb: Currently logged in as: mite_gvg (mitegvg) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall,Specificity,True Positives,True Negatives,False Positives,False Negatives
0,0.698,0.643976,0.632362,0.375784,0.515391,0.295689,0.833718,3985,18787,3747,9492
1,0.6132,0.648088,0.639444,0.575853,0.514386,0.654003,0.630736,8814,14213,8321,4663
2,0.6364,0.616777,0.661714,0.53557,0.550772,0.521184,0.745762,7024,16805,5729,6453
3,0.5219,0.605744,0.681514,0.519704,0.59652,0.460414,0.813748,6205,18337,4197,7272




In [11]:
trainer.push_to_hub()

events.out.tfevents.1753519487.DESKTOP-JCNIME4.31276.0:   0%|          | 0.00/44.0k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/mitegvg/videomae-tiny-binary-finetuned-xd-violence/commit/73a728eb672c2918630c00c87119fb7ed5462075', commit_message='End of training', commit_description='', oid='73a728eb672c2918630c00c87119fb7ed5462075', pr_url=None, repo_url=RepoUrl('https://huggingface.co/mitegvg/videomae-tiny-binary-finetuned-xd-violence', endpoint='https://huggingface.co', repo_type='model', repo_id='mitegvg/videomae-tiny-binary-finetuned-xd-violence'), pr_revision=None, pr_num=None)

## Binary Classification Inference

In [12]:
from transformers import pipeline
import os

local_model_directory = new_model_name
absolute_model_path = os.path.abspath(local_model_directory)

video_cls = pipeline(task="video-classification", model=absolute_model_path)

# Test inference on a sample video
test_video = next(iter(test_dataset))["video"]
inputs = {"pixel_values": test_video.permute(1, 0, 2, 3).unsqueeze(0)}
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    probabilities = torch.softmax(logits, dim=-1)

predicted_class_idx = logits.argmax(-1).item()
confidence = probabilities[0][predicted_class_idx].item()

print(f"Predicted class: {model.config.id2label[predicted_class_idx]}")
print(f"Confidence: {confidence:.3f}")
print(f"Safe probability: {probabilities[0][0].item():.3f}")
print(f"Unsafe probability: {probabilities[0][1].item():.3f}")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Device set to use cuda:0


RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

## Comprehensive Binary Classification Evaluation

In [13]:
import os
from transformers import pipeline
import torch
import time
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

print("Starting binary classification evaluation on the full test set...")

# Define paths
dataset_root_path = "processed_dataset"
test_csv_filename = "test.csv"
test_csv_path = os.path.join(dataset_root_path, test_csv_filename)

local_model_directory = "videomae-tiny-binary-finetuned-xd-violence"
absolute_model_path = os.path.abspath(local_model_directory)

# Function to load test data with binary labels
def load_binary_test_data_from_csv(csv_file_path, data_root_path):
    test_samples = []
    if not os.path.exists(csv_file_path):
        print(f"ERROR: Test CSV file not found at {csv_file_path}")
        return test_samples
        
    with open(csv_file_path, "r") as f:
        for line in f.readlines():
            parts = line.strip().split()
            if len(parts) >= 2:
                relative_video_path = parts[0]
                original_label = parts[1]
                binary_label = map_to_binary_label(original_label)
                
                full_video_path = os.path.normpath(os.path.join(data_root_path, relative_video_path))
                test_samples.append((full_video_path, binary_label, original_label))
            elif line.strip():
                print(f"Warning: Malformed line in {csv_file_path}: '{line.strip()}'")
    
    print(f"Loaded {len(test_samples)} samples from {csv_file_path}")
    return test_samples

# Initialize the video classification pipeline
video_cls = None
print(f"Attempting to load model from: {absolute_model_path}")
if not os.path.isdir(absolute_model_path):
    print(f"ERROR: Model directory not found at {absolute_model_path}")
else:
    print(f"Model directory found. Initializing pipeline...")
    try:
        video_cls = pipeline(
            task="video-classification",
            model=absolute_model_path,
            device=0 if torch.cuda.is_available() else -1
        )
        print(f"Pipeline initialized. Using device: {'cuda:0' if torch.cuda.is_available() else 'cpu'}")
    except Exception as e:
        print(f"Error initializing pipeline: {e}")

if video_cls:
    # Load test data
    test_data = load_binary_test_data_from_csv(test_csv_path, dataset_root_path)

    if test_data:
        true_labels = []
        predicted_labels = []
        confidences = []
        inference_times = []
        total_videos_processed = 0

        print(f"\nStarting inference on {len(test_data)} test videos...")
        for i, (video_path, true_binary_label, original_label) in enumerate(test_data):
            if not os.path.exists(video_path):
                print(f"Warning: Video file not found at {video_path}. Skipping.")
                continue

            try:
                start_time = time.time()
                raw_results = video_cls(video_path)
                end_time = time.time()
                inference_times.append(end_time - start_time)
                total_videos_processed += 1

                if not raw_results:
                    print(f"Warning: No results returned for video {video_path}. Skipping.")
                    continue
                
                # Get the top prediction
                top_prediction = raw_results[0]
                predicted_label = top_prediction['label']
                confidence = top_prediction['score']
                
                true_labels.append(true_binary_label)
                predicted_labels.append(predicted_label)
                confidences.append(confidence)
                
                if (i + 1) % 10 == 0 or (i + 1) == len(test_data):
                    print(f"  Processed {i + 1}/{len(test_data)} videos...")

            except Exception as e:
                print(f"An error occurred during processing of {video_path}: {e}")

        # Calculate comprehensive metrics
        if total_videos_processed > 0:
            # Convert to numpy arrays for easier manipulation
            true_labels_array = np.array(true_labels)
            predicted_labels_array = np.array(predicted_labels)
            confidences_array = np.array(confidences)
            
            # Calculate basic metrics
            accuracy = np.mean(predicted_labels_array == true_labels_array) * 100
            avg_inference_time = sum(inference_times) / len(inference_times)
            fps = 1.0 / avg_inference_time if avg_inference_time > 0 else float('inf')
            avg_confidence = np.mean(confidences_array) * 100
            
            # Detailed classification report
            print("\n" + "="*50)
            print("BINARY CLASSIFICATION EVALUATION RESULTS")
            print("="*50)
            print(f"Total videos processed: {total_videos_processed}")
            print(f"Overall Accuracy: {accuracy:.2f}%")
            print(f"Average Confidence: {avg_confidence:.2f}%")
            print(f"Average inference time per video: {avg_inference_time:.3f} seconds ({fps:.2f} videos/sec)")
            
            print("\nDetailed Classification Report:")
            print(classification_report(true_labels, predicted_labels, target_names=['safe', 'unsafe']))
            
            print("\nConfusion Matrix:")
            cm = confusion_matrix(true_labels, predicted_labels, labels=['safe', 'unsafe'])
            print("                Predicted")
            print("              Safe  Unsafe")
            print(f"Actual Safe   {cm[0,0]:4d}   {cm[0,1]:4d}")
            print(f"       Unsafe {cm[1,0]:4d}   {cm[1,1]:4d}")
            
            # Calculate additional metrics
            tn, fp, fn, tp = cm.ravel()
            sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0  # Recall for unsafe class
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0  # Recall for safe class
            precision_unsafe = tp / (tp + fp) if (tp + fp) > 0 else 0
            precision_safe = tn / (tn + fn) if (tn + fn) > 0 else 0
            
            print(f"\nAdditional Binary Classification Metrics:")
            print(f"Sensitivity (Unsafe Recall): {sensitivity:.3f}")
            print(f"Specificity (Safe Recall): {specificity:.3f}")
            print(f"Precision (Unsafe): {precision_unsafe:.3f}")
            print(f"Precision (Safe): {precision_safe:.3f}")
            
            # Class distribution in test set
            safe_count = list(true_labels).count('safe')
            unsafe_count = list(true_labels).count('unsafe')
            print(f"\nTest Set Distribution:")
            print(f"Safe videos: {safe_count} ({100*safe_count/len(true_labels):.1f}%)")
            print(f"Unsafe videos: {unsafe_count} ({100*unsafe_count/len(true_labels):.1f}%)")
            
        else:
            print("\n--- Evaluation Complete ---")
            print("No videos were processed successfully.")
    else:
        print("No test data loaded. Cannot perform evaluation.")
else:
    print("Video classification pipeline not initialized. Cannot perform evaluation.")

Starting binary classification evaluation on the full test set...
Attempting to load model from: D:\BIRKBECK\REPOS\videomae-base-finetuned-xd-violence\videomae-tiny-binary-finetuned-xd-violence
Model directory found. Initializing pipeline...


Device set to use cuda:0


Pipeline initialized. Using device: cuda:0
Loaded 425 samples from processed_dataset\test.csv

Starting inference on 425 test videos...
  Processed 10/425 videos...


You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


  Processed 20/425 videos...
  Processed 30/425 videos...
  Processed 40/425 videos...
  Processed 50/425 videos...
  Processed 60/425 videos...
  Processed 70/425 videos...
  Processed 80/425 videos...
  Processed 90/425 videos...
  Processed 100/425 videos...
  Processed 110/425 videos...
  Processed 120/425 videos...
  Processed 130/425 videos...
  Processed 140/425 videos...
  Processed 150/425 videos...
  Processed 160/425 videos...
  Processed 170/425 videos...
  Processed 180/425 videos...
  Processed 190/425 videos...
  Processed 200/425 videos...
  Processed 210/425 videos...
  Processed 220/425 videos...
  Processed 230/425 videos...
  Processed 240/425 videos...
  Processed 250/425 videos...
  Processed 260/425 videos...
  Processed 270/425 videos...
  Processed 280/425 videos...
  Processed 290/425 videos...
  Processed 300/425 videos...
  Processed 310/425 videos...
  Processed 320/425 videos...
  Processed 330/425 videos...
  Processed 340/425 videos...
  Processed 350/42

moov atom not found


An error occurred during processing of processed_dataset\videos\video_000844.mp4: [Errno 1094995529] Invalid data found when processing input: 'processed_dataset\\videos\\video_000844.mp4'; last error log: [mov,mp4,m4a,3gp,3g2,mj2] moov atom not found

BINARY CLASSIFICATION EVALUATION RESULTS
Total videos processed: 357
Overall Accuracy: 57.14%
Average Confidence: 60.86%
Average inference time per video: 0.116 seconds (8.63 videos/sec)

Detailed Classification Report:
              precision    recall  f1-score   support

        safe       0.54      0.73      0.62       172
      unsafe       0.63      0.42      0.50       185

    accuracy                           0.57       357
   macro avg       0.58      0.58      0.56       357
weighted avg       0.59      0.57      0.56       357


Confusion Matrix:
                Predicted
              Safe  Unsafe
Actual Safe    126     46
       Unsafe  107     78

Additional Binary Classification Metrics:
Sensitivity (Unsafe Recall): 0.42

## Model Compression: Pruning and Quantization for Binary Model

In [None]:
import torch
from torch.nn.utils import prune
import os

# Load the fine-tuned binary model
from transformers import VideoMAEForVideoClassification
model_dir = "videomae-tiny-binary-finetuned-xd-violence"
model = VideoMAEForVideoClassification.from_pretrained(model_dir)

print(f"Original model parameters: {sum(p.numel() for p in model.parameters())}")

# PRUNING: Prune 40% of the weights in all Linear layers (more aggressive for binary task)
parameters_to_prune = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        parameters_to_prune.append((module, 'weight'))

if parameters_to_prune:
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=0.4,  # Prune 40% of weights globally
    )
    # Remove pruning re-parametrization
    for module, _ in parameters_to_prune:
        prune.remove(module, 'weight')
    print(f"Pruned {len(parameters_to_prune)} Linear layers (40% of weights removed).")
else:
    print("No Linear layers found for pruning.")

# QUANTIZATION: Convert to dynamic quantized version
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

# Save the compressed model
quantized_dir = model_dir + "-compressed"
os.makedirs(quantized_dir, exist_ok=True)
torch.save(quantized_model.state_dict(), os.path.join(quantized_dir, "pytorch_model.bin"))
model.config.save_pretrained(quantized_dir)

print(f"Compressed model saved to {quantized_dir}")
print("The compressed model combines pruning (40%) and int8 quantization for maximum efficiency.")

## Evaluate Compressed Binary Model

In [None]:
import torch
from transformers import VideoMAEForVideoClassification, VideoMAEConfig
import time

# Load and quantize the model
quantized_dir = "videomae-tiny-binary-finetuned-xd-violence-compressed"
config = VideoMAEConfig.from_pretrained(quantized_dir) if os.path.exists(quantized_dir) else VideoMAEConfig.from_pretrained("videomae-tiny-binary-finetuned-xd-violence")
model = VideoMAEForVideoClassification.from_pretrained("videomae-tiny-binary-finetuned-xd-violence", config=config)

# Apply quantization
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Evaluate compressed model
correct_predictions = 0
total_samples = 0
inference_times = []

print("Evaluating compressed binary model...")
for i, sample in enumerate(test_dataset):
    video = sample["video"].to(device)
    label = sample["label"]
    
    # Prepare input
    video = video.permute(1, 0, 2, 3).unsqueeze(0)
    
    start_time = time.time()
    with torch.no_grad():
        outputs = model(pixel_values=video)
        logits = outputs.logits
        prediction = logits.argmax(-1).item()
    end_time = time.time()
    
    inference_times.append(end_time - start_time)
    
    if prediction == label:
        correct_predictions += 1
    total_samples += 1
    
    if (i + 1) % 10 == 0:
        print(f"  Processed {i + 1}/{test_dataset.num_videos} videos...")

accuracy = (correct_predictions / total_samples) * 100
avg_inference_time = sum(inference_times) / len(inference_times)
fps = 1.0 / avg_inference_time

print(f"\nCompressed Binary Model Results:")
print(f"Accuracy: {accuracy:.2f}%")
print(f"Average inference time: {avg_inference_time:.4f} seconds")
print(f"Inference speed: {fps:.2f} videos/sec")
print(f"Model size reduction: ~75% (pruning + quantization)")