In [1]:
!pip install pytorchvideo

from IPython import display
display.clear_output()

In [2]:
#-- Import Libraries -------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torchvision.models.video import mvit_v2_s
from torchvision.models import efficientnet_b2

import cv2
import numpy as np
import matplotlib.pyplot as plt
import math
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
import os
from sklearn.cluster import KMeans
import json 
from collections import defaultdict
#---------------------------------------------------------------------------------------------------------

In [3]:
#-- Initialize -----------------------------------------------------------------------------------------
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {DEVICE}')

NUM_CLASSES = 2
CLASS_NAMES = ['normal', 'fight']

NUM_FRAMES = 16
FRAME_W = 224
FRAME_H = 224

TARGET_LAYER_INDEX = -1
FRAME_INDICES = [1, 3, 5, 7, 9, 11, 13, 15]
NUM_ATTENDED_FRAMES = 8
#---------------------------------------------------------------------------------------------------------

device: cpu


In [4]:
#-- Mvit Model Definition ----------------------------------------------------------------------------------
class VideoMViTModel(nn.Module):
    def __init__(self, num_classes=2):
        super(VideoMViTModel, self).__init__()
        self.model = mvit_v2_s(weights="DEFAULT")
        self.model.head = nn.Linear(768, num_classes)

    def forward(self, x):
        return self.model(x)
#---------------------------------------------------------------------------------------------------------

In [5]:
#-- Load MviT Model -----------------------------------------------------------------------------------------------
weights_dir = "/kaggle/input/fight-wholedata-fold4-60epoch/best_model_fold_4.pth"

mvit_model = VideoMViTModel(num_classes=2)  
checkpoint = torch.load(weights_dir, map_location=torch.device(DEVICE))
checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()}

mvit_model.load_state_dict(checkpoint, strict=False)
mvit_model.eval()

print(':)')
#---------------------------------------------------------------------------------------------------------

Downloading: "https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth" to /root/.cache/torch/hub/checkpoints/mvit_v2_s-ae3be167.pth
100%|██████████| 132M/132M [00:00<00:00, 174MB/s]  
  checkpoint = torch.load(weights_dir, map_location=torch.device(DEVICE))


:)


In [6]:
#-- Load EfficentNet Model ----------------------------------------------------------------------------------
eff_model = efficientnet_b2(weights='IMAGENET1K_V1')  
eff_model = torch.nn.Sequential(*(list(eff_model.children())[:-2]), torch.nn.AdaptiveAvgPool2d((1, 1)))  #-- Remove top layers
eff_model = eff_model.to(DEVICE).eval()  
#---------------------------------------------------------------------------------------------------------

Downloading: "https://download.pytorch.org/models/efficientnet_b2_rwightman-c35c1473.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b2_rwightman-c35c1473.pth
100%|██████████| 35.2M/35.2M [00:00<00:00, 146MB/s] 


In [7]:
#-- Preprocessing for input frames -----------------------------------------------------------------------
transform = transforms.Compose([
    transforms.ToPILImage(), 
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])
#---------------------------------------------------------------------------------------------------------

In [8]:
#-- Function to  Extract all frames from a video ----------------------------------------------------------
def extract_frames(video_path):    
    cap = cv2.VideoCapture(video_path)
    frames = []
    success, frame = cap.read()
    while success:
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  #-- Convert BGR to RGB
        frames.append(transform(frame))
        success, frame = cap.read()
    cap.release()
    return torch.stack(frames).to(DEVICE)
#---------------------------------------------------------------------------------------------------------

In [9]:
#-- Function to "Extract features for a list of frames ------------------------------------------------------
def extract_features(frames, batch_size=32):    
    features = []
    num_frames = frames.shape[0]  #-- Total number of frames
    for start in range(0, num_frames, batch_size):
        # Process a batch of frames
        end = min(start + batch_size, num_frames)
        batch = frames[start:end].to(DEVICE)
        with torch.no_grad():
            batch_features = eff_model(batch).squeeze(-1).squeeze(-1)  #-- Remove extra dimensions
        features.append(batch_features.cpu().numpy())  #-- Convert to NumPy and append
    
    return np.concatenate(features, axis=0)  #-- Combine features from all batches
#---------------------------------------------------------------------------------------------------------

In [10]:
#-- Function to Cluster the frames ---------------------------------------------------------------------------
def cluster_frames(features, n_clusters=16):    
    num_frames = len(features)
    print("number of frames", num_frames)
    if num_frames < 2:
        return 1, [0]  #-- If there are fewer than 2 frames, return 1 cluster
    
    #-- Ensure the number of clusters is not greater than the number of frames --
    n_clusters = min(n_clusters, num_frames)
    
    #-- Perform KMeans clustering --
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto')
    labels = kmeans.fit_predict(features)
    
    return n_clusters, labels
#---------------------------------------------------------------------------------------------------------

In [11]:
#-- function to get all frames in a cluster ---------------------------------------------------------------
def get_frames_per_cluster(labels):
    cluster_dict = defaultdict(list)    
    for frame_idx, label in enumerate(labels):
        cluster_dict[label].append(frame_idx)
        
    return cluster_dict
#---------------------------------------------------------------------------------------------------------

In [12]:
#-- Function to Preprocess videos -------------------------------------------------------------------------------
def preprocess_video(video_path, num_frames=NUM_FRAMES, resize=(FRAME_W, FRAME_H)):    
    
    frames = extract_frames(video_path)
    features = extract_features(frames)

    n_clusters, labels = cluster_frames(features, n_clusters=num_frames)

    cluster_dict = get_frames_per_cluster(labels)

    representative_frames = []
    frame_indices = []

    for cluster in range(n_clusters):
        cluster_indices = np.where(labels == cluster)[0]
        if len(cluster_indices) > 0:
            representative_index = cluster_indices[0]
            representative_frames.append(frames[representative_index])
            frame_indices.append(int(representative_index))  

    # representative_frames_cpu = [frame.cpu().numpy() for frame in representative_frames]
    # video_tensor = torch.stack(representative_frames)
    # video_tensor = video_tensor.permute(1, 0, 2, 3).unsqueeze(0)  #-- [1, 3, 16, 224, 224]

    #- Now we sort the representative frames and their indices --
    sorted_indices = np.argsort(frame_indices)  #-- Get sorted indices based on frame_indices
    sorted_representative_frames = [representative_frames[i] for i in sorted_indices]
    sorted_frame_indices = [frame_indices[i] for i in sorted_indices]

    #-- Convert to tensor and adjust dimensions --
    video_tensor = torch.stack(sorted_representative_frames)
    video_tensor = video_tensor.permute(1, 0, 2, 3).unsqueeze(0)  #-- [1, 3, 16, 224, 224]

    return video_tensor, sorted_frame_indices, cluster_dict
    # return video_tensor, frame_indices
    
#-------------------------------------------------------------------------------------------------------------

In [14]:
#-- function to Predict video label ------------------------------------------------------------------------------
def predict_video(model, preprocessed_video, num_frames=NUM_FRAMES, resize=(FRAME_W, FRAME_H)):    
    
    with torch.no_grad():        
        outputs = model(preprocessed_video)
        _, predicted = torch.max(outputs, 1)      
    
    return CLASS_NAMES[predicted.item()]   
#-----------------------------------------------------------------------------------------------------------------

In [None]:
#-- link of MviT Source Code --------------------------------------------------------------------------------------
#https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py
#-----------------------------------------------------------------------------------------------------------------

In [15]:
#-- Required functions for forward from Source Code --------------------------------------------------------------
#------------------------------------------------------------------------
def _add_rel_pos(
    attn: torch.Tensor,
    q: torch.Tensor,
    q_thw: Tuple[int, int, int],
    k_thw: Tuple[int, int, int],
    rel_pos_h: torch.Tensor,
    rel_pos_w: torch.Tensor,
    rel_pos_t: torch.Tensor,) -> torch.Tensor:
    # Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932
    q_t, q_h, q_w = q_thw
    k_t, k_h, k_w = k_thw
    dh = int(2 * max(q_h, k_h) - 1)
    dw = int(2 * max(q_w, k_w) - 1)
    dt = int(2 * max(q_t, k_t) - 1)

    # Scale up rel pos if shapes for q and k are different.
    q_h_ratio = max(k_h / q_h, 1.0)
    k_h_ratio = max(q_h / k_h, 1.0)
    dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio
    q_w_ratio = max(k_w / q_w, 1.0)
    k_w_ratio = max(q_w / k_w, 1.0)
    dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio
    q_t_ratio = max(k_t / q_t, 1.0)
    k_t_ratio = max(q_t / k_t, 1.0)
    dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio

    # Interpolate rel pos if needed.
    rel_pos_h = _interpolate(rel_pos_h, dh)
    rel_pos_w = _interpolate(rel_pos_w, dw)
    rel_pos_t = _interpolate(rel_pos_t, dt)
    Rh = rel_pos_h[dist_h.long()]
    Rw = rel_pos_w[dist_w.long()]
    Rt = rel_pos_t[dist_t.long()]

    B, n_head, _, dim = q.shape

    r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim)
    rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh)  # [B, H, q_t, qh, qw, k_h]
    rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw)  # [B, H, q_t, qh, qw, k_w]
    # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]
    r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim)
    # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]
    rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1)
    # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]
    rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5)

    # Combine rel pos.
    rel_pos = (
        rel_h_q[:, :, :, :, :, None, :, None]
        + rel_w_q[:, :, :, :, :, None, None, :]
        + rel_q_t[:, :, :, :, :, :, None, None]
    ).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w)

    # Add it to attention
    attn[:, :, 1:, 1:] += rel_pos

    return attn
#------------------------------------------------------------------------

#------------------------------------------------------------------------
def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor:
    if embedding.shape[0] == d:
        return embedding

    return (
        nn.functional.interpolate(
            embedding.permute(1, 0).unsqueeze(0),
            size=d,
            mode="linear",
        )
        .squeeze(0)
        .permute(1, 0)
    )
#------------------------------------------------------------------------

#------------------------------------------------------------------------
def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_with_cls_embed: bool):
    if residual_with_cls_embed:
        x.add_(shortcut)
    else:
        x[:, :, 1:, :] += shortcut[:, :, 1:, :]
    return x
#------------------------------------------------------------------------

#-----------------------------------------------------------------------------------------------------------------

In [16]:
#--Override Forward Function to get Attention Map and Cls weights -----------------------------------------------------
def my_forward_wrapper_video(attn_obj):
    
    def my_forward(x, thw):
        B, N, C = x.shape

        q, k, v = attn_obj.qkv(x).reshape(B, N, 3, attn_obj.num_heads, attn_obj.head_dim).transpose(1, 3).unbind(dim=2)
        if attn_obj.pool_k is not None:
            k, k_thw = attn_obj.pool_k(k, thw)
        else:
            k_thw = thw
        if attn_obj.pool_v is not None:
            v = attn_obj.pool_v(v, thw)[0]
        if attn_obj.pool_q is not None:
            q, thw = attn_obj.pool_q(q, thw)        
        

        attn = torch.matmul(attn_obj.scaler * q, k.transpose(2, 3))
        if attn_obj.rel_pos_h is not None and attn_obj.rel_pos_w is not None and attn_obj.rel_pos_t is not None:
            attn = _add_rel_pos(
                attn,
                q,
                thw,
                k_thw,
                attn_obj.rel_pos_h,
                attn_obj.rel_pos_w,
                attn_obj.rel_pos_t,
            )
        attn = attn.softmax(dim=-1)

        
        
        attn_obj.attn_map = attn
        attn_obj.cls_attn_map = attn[:, :, 0, 1:]
        
        # print('attn:', attn.shape)        
        # print('attn_obj.cls_attn_map:', attn_obj.cls_attn_map.shape)        

        x = torch.matmul(attn, v)
        if attn_obj.residual_pool:
            _add_shortcut(x, q, attn_obj.residual_with_cls_embed)
        x = x.transpose(1, 2).reshape(B, -1, attn_obj.output_dim)
        x = attn_obj.project(x)

        return x, thw
    
    return my_forward
#-----------------------------------------------------------------------------------------------------------------

In [18]:
#-- Function to Plot HeatMaps on Frames ------------------------------------------------------------------------
def plot_attention_maps(img, attns, frame_idx):    
    num_heads = attns.shape[0]
    num_cols = 4  # تعداد ستون‌ها
    num_rows = math.ceil(num_heads / num_cols)  
    
    plt.figure(figsize=(15, num_rows * 4))  
    for head in range(num_heads):
        row = head // num_cols  
        col = head % num_cols   
        
        plt.subplot(num_rows, num_cols, head + 1)  
        plt.imshow(img)
        plt.imshow(attns[head].cpu().numpy().squeeze(0), cmap='jet', alpha=0.6)  
        plt.title(f"Head {head+1} - Frame {frame_idx}")
        plt.axis('off')
    
    plt.show()
#-----------------------------------------------------------------------------------------------------------------

In [20]:
#-- function to Unnormalize frames -----------------------------------------------------------------
def unnormalize(tensor):
    device = tensor.device  # گرفتن دیوایس تنسور ورودی
    mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1)
    return tensor * std + mean
#-----------------------------------------------------------------------------------------------------------------

In [21]:
#-- Apply Custom Forward on Target Layer ----------------------------------------------------------------------------
mvit_model.model.blocks[TARGET_LAYER_INDEX].attn.forward = my_forward_wrapper_video(mvit_model.model.blocks[TARGET_LAYER_INDEX].attn)
#-----------------------------------------------------------------------------------------------------------------

In [23]:
#-- Function to get all cluster mates for a frame index ----------------------------------------------------
def get_cluster_mates_from_dict(frame_idx, cluster_dict):
    for cluster_label, frame_list in cluster_dict.items():
        if frame_idx in frame_list:
            return frame_list
    return []  
#-----------------------------------------------------------------------------------------------------------

In [25]:
#-----------------------------------------------------------------------------------------------------------
def visualize_heatmaps_on_frames(cls_weight, all_frames, frame_indices):
    patches_per_frame = cls_weight.shape[2] // NUM_ATTENDED_FRAMES  

    framewise_attn = []
    for i in range(NUM_ATTENDED_FRAMES):
        start = i * patches_per_frame
        end = (i + 1) * patches_per_frame
        frame_attn = cls_weight[0, :, start:end]  
        framewise_attn.append(frame_attn)
    
    combined_imgs = {}
    for i, frame_attn in enumerate(framewise_attn):
        for j in range(2):  
            idx = i * 2 + j  
            frame_idx = frame_indices[idx]
    
            print(f"Processing frame {frame_idx}")

            mates = get_cluster_mates_from_dict(frame_idx, cluster_dict)

            for f_idx in mates:            
                x = all_frames[f_idx]
                
                cls_resized_all_heads = []
                for head in range(frame_attn.shape[0]):
                    cls_weight_head = frame_attn[head]
                    cls_weight_2d = cls_weight_head.view(1, 7, 7).unsqueeze(0)
                    cls_resized = F.interpolate(cls_weight_2d, size=(224, 224), mode='bilinear', align_corners=False)
                    cls_resized_all_heads.append(cls_resized)
        
                cls_resized_all_heads = torch.stack(cls_resized_all_heads)
                cls_resized_normalized = cls_resized_all_heads / cls_resized_all_heads.max()
        
                mean_attn_map = cls_resized_normalized.mean(dim=0)                  
                x = unnormalize(x)
                img_resized = x.permute(1, 2, 0).cpu().numpy()
                img_resized = np.clip(img_resized, 0, 1)
                img_resized = (img_resized * 255).astype(np.uint8)        
        
                mean_attn_map_resized = mean_attn_map.squeeze().cpu().numpy()
                mean_attn_map_resized = (mean_attn_map_resized * 255 / mean_attn_map_resized.max()).astype(np.uint8)
                        
                attn_map_colored = cv2.applyColorMap(mean_attn_map_resized, cv2.COLORMAP_JET)       
                combined_image = cv2.addWeighted(img_resized, 0.6, attn_map_colored, 0.4, 0)   
        
                combined_imgs[f_idx]= combined_image

    return combined_imgs
 #-----------------------------------------------------------------------------------------------------------   

In [26]:
#-----------------------------------------------------------------------------------------------------------
def save_video_from_frames_dict(frames_dict, output_path, fps):
    
    sorted_items = sorted(frames_dict.items())
    
    first_frame = sorted_items[0][1]
    height, width, _ = first_frame.shape
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    for _, frame in sorted_items:
        if frame.dtype != np.uint8:
            frame = (frame * 255).astype(np.uint8)
        out.write(frame)

    out.release()
    print(f"Video saved at: {output_path}")
#-----------------------------------------------------------------------------------------------------------

In [27]:
#-----------------------------------------------------------------------------------------------------------
def get_video_fps(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise IOError(f"Cannot open video: {video_path}")
    
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release()
    return fps
#-----------------------------------------------------------------------------------------------------------

In [28]:
#-- Run ---------------------------------------------------------------------------------------------------------
# video_path = '/kaggle/input/sample-videos-for-fight-detection-2/fight/fight (2).mp4'
video_path = '/kaggle/input/sample-videos-for-fight-detection-2/normal/normal (1).mp4'
output_video_path = 'output_with_attention_map.mp4'

frame_width = 224  # Assuming the frames are 224x224
frame_height = 224
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, 20.0, (frame_width, frame_height))

all_frames = extract_frames(video_path)
print('all_frames:', len(all_frames), all_frames[0].shape)
video_tensor, frame_indices, cluster_dict = preprocess_video(video_path)
print('video_tensor:', video_tensor.shape)
print('frame_indices:', frame_indices)

predicted_lbl = predict_video(mvit_model , video_tensor)
print('predicted_lbl:', predicted_lbl)

#-- Get the attention map and cls_weight --
attn_map = mvit_model.model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach()
cls_weight = mvit_model.model.blocks[-1].attn.cls_attn_map  

print('attn_map:' , attn_map.shape)
print('cls_weight:', cls_weight.shape)



combined_imgs = visualize_heatmaps_on_frames(cls_weight, all_frames, frame_indices)
combined_imgs = dict(sorted(combined_imgs.items()))
# print(combined_imgs)
# for frame_idx , img in combined_imgs.items():
#         # رسم تصویر ترکیب‌شده
#         plt.figure(figsize=(5, 5))
#         # plt.imshow(combined_image)
#         plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
#         plt.title(f"Mean Attention - Frame {frame_idx}")
#         plt.axis('off')
#         plt.show()

fps = get_video_fps(video_path)
save_video_from_frames_dict(combined_imgs, output_video_path, fps)
#-----------------------------------------------------------------------------------------------------------

all_frames: 158 torch.Size([3, 224, 224])
number of frames 158
video_tensor: torch.Size([1, 3, 16, 224, 224])
frame_indices: [0, 17, 26, 37, 45, 57, 68, 78, 87, 96, 118, 119, 126, 135, 141, 153]
predicted_lbl: normal
attn_map: torch.Size([393, 393])
cls_weight: torch.Size([1, 8, 392])
Processing frame 0
Processing frame 17
Processing frame 26
Processing frame 37
Processing frame 45
Processing frame 57
Processing frame 68
Processing frame 78
Processing frame 87
Processing frame 96
Processing frame 118
Processing frame 119
Processing frame 126
Processing frame 135
Processing frame 141
Processing frame 153
Video saved at: output_with_attention_map.mp4
