[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hammad-ali1/vad_clip_colab_notebook/blob/main/vad_clip_notebook.ipynb)


In [1]:
from IPython.display import clear_output

!pip install ftfy
!pip install comet_ml
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0
!pip install causal-conv1d==1.4.0 && pip install mamba-ssm==2.2.2
!pip install fvcore

Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy
Successfully installed ftfy-6.3.1
Collecting comet_ml
  Downloading comet_ml-3.51.0-py3-none-any.whl.metadata (4.1 kB)
Collecting dulwich!=0.20.33,>=0.20.6 (from comet_ml)
  Downloading dulwich-0.24.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.2 kB)
Collecting everett<3.2.0,>=1.0.1 (from everett[ini]<3.2.0,>=1.0.1->comet_ml)
  Downloading everett-3.1.0-py2.py3-none-any.whl.metadata (17 kB)
Collecting python-box<7.0.0 (from comet_ml)
  Downloading python_box-6.1.0-py3-none-any.whl.metadata (7.8 kB)
Collecting configobj (from everett[ini]<3.2.0,>=1.0.1->comet_ml)
  Downloading configobj-5.0.9

In [19]:
#@title Pull Latest Repo
from google.colab import userdata
from pathlib import Path

try:
  gh_token = userdata.get('vad_clip_gh_token')
except Exception:
  gh_token = None

REPO_NAME = "vad_clip_colab_notebook"
REPO_PATH = Path("/content") / REPO_NAME

if gh_token:
  GITHUB_URL = f"https://{gh_token}@github.com/hammad-ali1/{REPO_NAME}.git"
else:
  GITHUB_URL = f'https://github.com/hammad-ali1/{REPO_NAME}'

!rm -rf {REPO_PATH}
!git clone {GITHUB_URL}

!cp -r {REPO_PATH}/* /content

Cloning into 'vad_clip_colab_notebook'...
remote: Enumerating objects: 177, done.[K
remote: Counting objects: 100% (177/177), done.[K
remote: Compressing objects: 100% (114/114), done.[K
remote: Total 177 (delta 94), reused 123 (delta 61), pack-reused 0 (from 0)[K
Receiving objects: 100% (177/177), 1.74 MiB | 13.09 MiB/s, done.
Resolving deltas: 100% (94/94), done.


In [3]:
#@title Import Modules
import os
from pathlib import Path
from collections import OrderedDict
import os
from pathlib import Path
import random
from enum import Enum

import comet_ml
import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from sklearn.metrics import average_precision_score, roc_auc_score

from clip import clip
from models import CLIP_VMamba_S
from utils.dataset import UCFDataset, XDDataset
from utils.layers import GraphConvolution, DistanceAdj
from utils.tools import get_batch_mask, get_prompt_text, get_batch_label
from utils.ucf_detectionMAP import getDetectionMAP as dmAP

  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd


In [4]:
class Datasets(Enum):
  XD = 'XD'
  UCF = 'UCF'

CURRENT_DATASET = Datasets.XD

In [7]:
def download_dataset(gdown_id, file_path, extract_to):
    !gdown --id {gdown_id} -O {file_path}
    !unzip -q {file_path} -d {extract_to}
    !rm {file_path}

In [None]:
if CURRENT_DATASET == Datasets.UCF:
  download_dataset('1DOkXHObwEGqmU3c9-J1gCIyVXiTMFfC5', '/content/UCFClipFeatures.zip', '/content')
  TRAIN_DIR = Path('/content/UCFClipFeatures')
  TEST_DIR = Path('/content/UCFClipFeatures')
elif CURRENT_DATASET == Datasets.XD:
  download_dataset('1bg2cyrsAaXjXM2MZv3j9x4gdqwl8Ntok', '/content/XDTrainClipFeatures.zip', '/content')
  download_dataset('1fhZZWSfWRl_9xQXByFOx5PM3VUwjIfp5', '/content/XDTestClipFeatures.zip', '/content')
  TRAIN_DIR = Path('/content/XDTrainClipFeatures')
  TEST_DIR = Path('/content/XDTestClipFeatures')

In [10]:
class UCFArgs:
    seed = 234

    embed_dim = 512
    visual_length = 256
    visual_width = 512
    visual_head = 1
    visual_layers = 2
    attn_window = 8
    prompt_prefix = 10
    prompt_postfix = 10
    classes_num = 14

    max_epoch = 10
    model_path = "model_cur.pth"
    use_checkpoint = False
    checkpoint_experiment_key = ''
    checkpoint_file = "best_checkpoint.pth"
    batch_size = 64
    train_list = 'list/ucf_CLIP_rgb.csv'
    test_list = 'list/ucf_CLIP_rgbtest.csv'
    gt_path = 'list/gt_ucf.npy'
    gt_segment_path = 'list/gt_segment_ucf.npy'
    gt_label_path = 'list/gt_label_ucf.npy'

    lr = 2e-5
    scheduler_rate = 0.1
    scheduler_milestones = [4, 8]



class XDArgs:
    seed = 234

    embed_dim = 512
    visual_length = 256
    visual_width = 512
    visual_head = 1
    visual_layers = 1
    attn_window = 64
    prompt_prefix = 10
    prompt_postfix = 10
    classes_num = 7

    max_epoch = 10
    model_path = "model_cur.pth"
    use_checkpoint = False
    checkpoint_experiment_key = ''
    checkpoint_file = "best_checkpoint.pth"
    batch_size = 96
    train_list = 'list/xd_CLIP_rgb.csv'
    test_list = 'list/xd_CLIP_rgbtest.csv'
    gt_path = 'list/gt.npy'
    gt_segment_path = 'list/gt_segment.npy'
    gt_label_path = 'list/gt_label.npy'

    lr = 1e-5
    scheduler_rate = 0.1
    scheduler_milestones = [3, 6, 10]


In [11]:
Args = XDArgs if CURRENT_DATASET == Datasets.XD else UCFArgs

In [12]:
COMET_API_KEY = userdata.get('COMET_API_KEY')
COMET_PROJECT_NAME = "vad_clip_notebook"
COMET_WORKSPACE = 'hammad-ali'

In [13]:
def load_clip_model():
  !curl -L -o VMamba_S_clip.pt "https://huggingface.co/weiquan/mamba-clip/resolve/main/VMamba_S_clip.pt?download=true"

  mamba_clip_model = CLIP_VMamba_S()

  ckpt_path = "VMamba_S_clip.pt"

  ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
  state_dict = OrderedDict()
  for k, v in ckpt['state_dict'].items():
      state_dict[k.replace('module.', '')] = v

  mamba_clip_model.load_state_dict(state_dict, strict=True)

  return mamba_clip_model

In [14]:
def download_commet_asset(file_name):
  comet_api = comet_ml.API(api_key=COMET_API_KEY)
  checkpoint_experiment = comet_api.get_experiment(COMET_WORKSPACE, COMET_PROJECT_NAME, Args.checkpoint_experiment_key)
  asset_link = [asset for asset in checkpoint_experiment.get_asset_list()
                   if asset['fileName'] == file_name][0]['s3Link']
  !curl -o {file_name} "{asset_link}"

In [15]:
def update_lists(file_path, old_root, dataset_dir, limit_list=None):
    df = pd.read_csv(file_path)

    # Extract video ID (removing part suffix)
    df['video_id'] = df['path'].apply(lambda x: x.rsplit('__', 1)[0])

    # Keep only complete videos (with all 10 parts)
    grouped = df.groupby('video_id')
    complete_videos = grouped.filter(lambda x: len(x) == 10)

    if limit_list is not None:
        total_videos = limit_list // 10

        # Create video-level metadata
        video_meta = (
            complete_videos.groupby('video_id')
            .first()
            .reset_index()[['video_id', 'label']]
        )

        # Separate Normal and Other labels
        normal_videos = video_meta[video_meta['label'] == 'Normal']
        other_videos = video_meta[video_meta['label'] != 'Normal']
        other_labels = other_videos['label'].unique()

        num_normal_videos = total_videos // 2
        num_other_videos = total_videos - num_normal_videos
        per_other_label = max(1, num_other_videos // len(other_labels))

        # Sample Normal videos
        selected_videos = normal_videos.sample(
            n=min(num_normal_videos, len(normal_videos)), random_state=42
        )

        # Sample from other labels
        for label in other_labels:
            vids = other_videos[other_videos['label'] == label]
            sampled = vids.sample(n=min(per_other_label, len(vids)), random_state=42)
            selected_videos = pd.concat([selected_videos, sampled], ignore_index=True)

        # Get all 10 parts of selected videos
        final_df = complete_videos[complete_videos['video_id'].isin(selected_videos['video_id'])]

        # Sort by video and clip number
        final_df['part'] = final_df['path'].apply(lambda x: int(Path(x).stem.split('__')[-1]))
        final_df = final_df.sort_values(by=['label', 'video_id', 'part']).drop(columns='part')
    else:
        final_df = df

    # Fix paths relative to DATASET_DIR
    final_df['path'] = final_df['path'].apply(lambda p: str(dataset_dir / Path(p).relative_to(old_root)))

    # Save
    final_df.drop(columns='video_id').to_csv(file_path, index=False)


In [20]:
if CURRENT_DATASET == Datasets.UCF:
  old_train_root = Path('/content/drive/MyDrive/UCFClipFeatures')
  old_test_root = Path('/content/drive/MyDrive/UCFClipFeatures')
elif CURRENT_DATASET == Datasets.XD:
  old_train_root = Path('/home/xbgydx/Desktop/XDTrainClipFeatures')
  old_test_root = Path('/home/xbgydx/Desktop/XDTestClipFeatures')

update_lists(Args.test_list, old_test_root, TEST_DIR)
update_lists(Args.train_list, old_train_root, TRAIN_DIR, limit_list=None)

In [21]:
mamba_clip_model = load_clip_model()

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1319  100  1319    0     0   6800      0 --:--:-- --:--:-- --:--:--  6834
100 1728M  100 1728M    0     0  68.5M      0  0:00:25  0:00:25 --:--:-- 77.3M
=> merge config from /content/vmamba/configs/vssm1/vssm_small_224.yaml


  @torch.cuda.amp.custom_fwd
  @torch.cuda.amp.custom_bwd
  @torch.cuda.amp.custom_fwd
  @torch.cuda.amp.custom_bwd
  @torch.cuda.amp.custom_fwd
  @torch.cuda.amp.custom_bwd


In [22]:
#@title Model
class LayerNorm(nn.LayerNorm):

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor, padding_mask: torch.Tensor):
        padding_mask = padding_mask.to(dtype=bool, device=x.device) if padding_mask is not None else None
        self.attn_mask = self.attn_mask.to(device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, key_padding_mask=padding_mask, attn_mask=self.attn_mask)[0]

    def forward(self, x):
        x, padding_mask = x
        x = x + self.attention(self.ln_1(x), padding_mask)
        x = x + self.mlp(self.ln_2(x))
        return (x, padding_mask)


class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)


class CLIPVAD(nn.Module):
    def __init__(self,
                 num_class: int,
                 embed_dim: int,
                 visual_length: int,
                 visual_width: int,
                 visual_head: int,
                 visual_layers: int,
                 attn_window: int,
                 prompt_prefix: int,
                 prompt_postfix: int,
                 device):
        super().__init__()

        self.num_class = num_class
        self.visual_length = visual_length
        self.visual_width = visual_width
        self.embed_dim = embed_dim
        self.attn_window = attn_window
        self.prompt_prefix = prompt_prefix
        self.prompt_postfix = prompt_postfix
        self.device = device

        self.temporal = Transformer(
            width=visual_width,
            layers=visual_layers,
            heads=visual_head,
            attn_mask=self.build_attention_mask(self.attn_window)
        )

        width = int(visual_width / 2)
        self.gc1 = GraphConvolution(visual_width, width, residual=True)
        self.gc2 = GraphConvolution(width, width, residual=True)
        self.gc3 = GraphConvolution(visual_width, width, residual=True)
        self.gc4 = GraphConvolution(width, width, residual=True)
        self.disAdj = DistanceAdj()
        self.linear = nn.Linear(visual_width, visual_width)
        self.gelu = QuickGELU()

        self.mlp1 = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(visual_width, visual_width * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(visual_width * 4, visual_width))
        ]))
        self.mlp2 = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(visual_width, visual_width * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(visual_width * 4, visual_width))
        ]))
        self.classifier = nn.Linear(visual_width, 1)

        self.clipmodel = mamba_clip_model
        for clip_param in self.clipmodel.parameters():
            clip_param.requires_grad = False

        self.frame_position_embeddings = nn.Embedding(visual_length, visual_width)
        self.text_prompt_embeddings = nn.Embedding(77, self.embed_dim)

        self.initialize_parameters()

    def initialize_parameters(self):
        nn.init.normal_(self.text_prompt_embeddings.weight, std=0.01)
        nn.init.normal_(self.frame_position_embeddings.weight, std=0.01)

    def build_attention_mask(self, attn_window):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.visual_length, self.visual_length)
        mask.fill_(float('-inf'))
        for i in range(int(self.visual_length / attn_window)):
            if (i + 1) * attn_window < self.visual_length:
                mask[i * attn_window: (i + 1) * attn_window, i * attn_window: (i + 1) * attn_window] = 0
            else:
                mask[i * attn_window: self.visual_length, i * attn_window: self.visual_length] = 0

        return mask

    def adj4(self, x, seq_len):
        soft = nn.Softmax(1)
        x2 = x.matmul(x.permute(0, 2, 1)) # B*T*T
        x_norm = torch.norm(x, p=2, dim=2, keepdim=True)  # B*T*1
        x_norm_x = x_norm.matmul(x_norm.permute(0, 2, 1))
        x2 = x2/(x_norm_x+1e-20)
        output = torch.zeros_like(x2)
        if seq_len is None:
            for i in range(x.shape[0]):
                tmp = x2[i]
                adj2 = tmp
                adj2 = F.threshold(adj2, 0.7, 0)
                adj2 = soft(adj2)
                output[i] = adj2
        else:
            for i in range(len(seq_len)):
                tmp = x2[i, :seq_len[i], :seq_len[i]]
                adj2 = tmp
                adj2 = F.threshold(adj2, 0.7, 0)
                adj2 = soft(adj2)
                output[i, :seq_len[i], :seq_len[i]] = adj2

        return output

    def encode_video(self, images, padding_mask, lengths):
        images = images.to(torch.float)
        position_ids = torch.arange(self.visual_length, device=self.device)
        position_ids = position_ids.unsqueeze(0).expand(images.shape[0], -1)
        frame_position_embeddings = self.frame_position_embeddings(position_ids)
        frame_position_embeddings = frame_position_embeddings.permute(1, 0, 2)
        images = images.permute(1, 0, 2) + frame_position_embeddings

        x, _ = self.temporal((images, None))
        x = x.permute(1, 0, 2)

        adj = self.adj4(x, lengths)
        disadj = self.disAdj(x.shape[0], x.shape[1])
        x1_h = self.gelu(self.gc1(x, adj))
        x2_h = self.gelu(self.gc3(x, disadj))

        x1 = self.gelu(self.gc2(x1_h, adj))
        x2 = self.gelu(self.gc4(x2_h, disadj))

        x = torch.cat((x1, x2), 2)
        x = self.linear(x)

        return x

    def encode_textprompt(self, text):
        word_tokens = clip.tokenize(text).to(self.device)
        word_embedding = self.clipmodel.token_embedding(word_tokens)
        text_embeddings = self.text_prompt_embeddings(torch.arange(77).to(self.device)).unsqueeze(0).repeat([len(text), 1, 1])
        text_tokens = torch.zeros(len(text), 77).to(self.device)

        for i in range(len(text)):
            ind = torch.argmax(word_tokens[i], -1)
            text_embeddings[i, 0] = word_embedding[i, 0]
            text_embeddings[i, self.prompt_prefix + 1: self.prompt_prefix + ind] = word_embedding[i, 1: ind]
            text_embeddings[i, self.prompt_prefix + ind + self.prompt_postfix] = word_embedding[i, ind]
            text_tokens[i, self.prompt_prefix + ind + self.prompt_postfix] = word_tokens[i, ind]

        text_features = self.clipmodel.encode_text(text_embeddings, text_tokens)

        return text_features

    def forward(self, visual, padding_mask, text, lengths):
        visual_features = self.encode_video(visual, padding_mask, lengths)
        logits1 = self.classifier(visual_features + self.mlp2(visual_features))

        text_features_ori = self.encode_textprompt(text)

        text_features = text_features_ori
        logits_attn = logits1.permute(0, 2, 1)
        visual_attn = logits_attn @ visual_features
        visual_attn = visual_attn / visual_attn.norm(dim=-1, keepdim=True)
        visual_attn = visual_attn.expand(visual_attn.shape[0], text_features_ori.shape[0], visual_attn.shape[2])
        text_features = text_features_ori.unsqueeze(0)
        text_features = text_features.expand(visual_attn.shape[0], text_features.shape[1], text_features.shape[2])
        text_features = text_features + visual_attn
        text_features = text_features + self.mlp1(text_features)

        visual_features_norm = visual_features / visual_features.norm(dim=-1, keepdim=True)
        text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
        text_features_norm = text_features_norm.permute(0, 2, 1)
        logits2 = visual_features_norm @ text_features_norm.type(visual_features_norm.dtype) / 0.07

        return text_features_ori, logits1, logits2, visual_features


In [28]:
#@title Test Function
def test(model, testdataloader, maxlen, prompt_text, gt, gtsegments, gtlabels, device):

    model.to(device)
    model.eval()

    element_logits2_stack = []

    with torch.no_grad():
        for i, item in enumerate(testdataloader):
            visual = item[0].squeeze(0)
            length = item[2]

            length = int(length)
            len_cur = length
            if len_cur < maxlen:
                visual = visual.unsqueeze(0)

            visual = visual.to(device)

            lengths = torch.zeros(int(length / maxlen) + 1)
            for j in range(int(length / maxlen) + 1):
                if j == 0 and length < maxlen:
                    lengths[j] = length
                elif j == 0 and length > maxlen:
                    lengths[j] = maxlen
                    length -= maxlen
                elif length > maxlen:
                    lengths[j] = maxlen
                    length -= maxlen
                else:
                    lengths[j] = length
            lengths = lengths.to(int)
            padding_mask = get_batch_mask(lengths, maxlen).to(device)
            _, logits1, logits2, _ = model(visual, padding_mask, prompt_text, lengths)
            logits1 = logits1.reshape(logits1.shape[0] * logits1.shape[1], logits1.shape[2])
            logits2 = logits2.reshape(logits2.shape[0] * logits2.shape[1], logits2.shape[2])
            prob2 = (1 - logits2[0:len_cur].softmax(dim=-1)[:, 0].squeeze(-1))
            prob1 = torch.sigmoid(logits1[0:len_cur].squeeze(-1))

            if i == 0:
                ap1 = prob1
                ap2 = prob2
                #ap3 = prob3
            else:
                ap1 = torch.cat([ap1, prob1], dim=0)
                ap2 = torch.cat([ap2, prob2], dim=0)

            element_logits2 = logits2[0:len_cur].softmax(dim=-1).detach().cpu().numpy()
            element_logits2 = np.repeat(element_logits2, 16, 0)
            element_logits2_stack.append(element_logits2)

    ap1 = ap1.cpu().numpy()
    ap2 = ap2.cpu().numpy()
    ap1 = ap1.tolist()
    ap2 = ap2.tolist()

    ROC1 = roc_auc_score(gt, np.repeat(ap1, 16))
    AP1 = average_precision_score(gt, np.repeat(ap1, 16))
    ROC2 = roc_auc_score(gt, np.repeat(ap2, 16))
    AP2 = average_precision_score(gt, np.repeat(ap2, 16))

    print("AUC1: ", ROC1, " AP1: ", AP1)
    print("AUC2: ", ROC2, " AP2:", AP2)

    if CURRENT_DATASET == Datasets.XD:
      return ROC1, AP2, 0

    dmap, iou = dmAP(element_logits2_stack, gtsegments, gtlabels, excludeNormal=False)
    averageMAP = 0
    for i in range(5):
        print('mAP@{0:.1f} ={1:.2f}%'.format(iou[i], dmap[i]))
        averageMAP += dmap[i]
    averageMAP = averageMAP/(i+1)
    print('average MAP: {:.2f}'.format(averageMAP))

    return ROC1, AP1, averageMAP


def run_test():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    args = Args()

    label_map = dict({'Normal': 'Normal', 'Abuse': 'Abuse', 'Arrest': 'Arrest', 'Arson': 'Arson', 'Assault': 'Assault', 'Burglary': 'Burglary', 'Explosion': 'Explosion', 'Fighting': 'Fighting', 'RoadAccidents': 'RoadAccidents', 'Robbery': 'Robbery', 'Shooting': 'Shooting', 'Shoplifting': 'Shoplifting', 'Stealing': 'Stealing', 'Vandalism': 'Vandalism'})

    testdataset = UCFDataset(args.visual_length, args.test_list, True, label_map)
    testdataloader = DataLoader(testdataset, batch_size=1, shuffle=False)

    prompt_text = get_prompt_text(label_map)
    gt = np.load(args.gt_path)
    gtsegments = np.load(args.gt_segment_path, allow_pickle=True)
    gtlabels = np.load(args.gt_label_path, allow_pickle=True)

    model = CLIPVAD(args.classes_num, args.embed_dim, args.visual_length, args.visual_width, args.visual_head, args.visual_layers, args.attn_window, args.prompt_prefix, args.prompt_postfix, device)
    model_param = torch.load(args.model_path)
    model.load_state_dict(model_param)

    test(model, testdataloader, args.visual_length, prompt_text, gt, gtsegments, gtlabels, device)

In [25]:
class TrainIterator:
    def __init__(self, current_dataset, normal_loader=None, anomaly_loader=None, train_loader=None, device="cpu", label_map=None, prompt_text=None):
        self.current_dataset = current_dataset
        self.normal_loader = normal_loader
        self.anomaly_loader = anomaly_loader
        self.train_loader = train_loader
        self.device = device
        self.label_map = label_map
        self.prompt_text = prompt_text

    def __iter__(self):
        if self.current_dataset == Datasets.UCF:
            normal_iter = iter(self.normal_loader)
            anomaly_iter = iter(self.anomaly_loader)

            for _ in range(min(len(self.normal_loader), len(self.anomaly_loader))):
                normal_features, normal_label, normal_lengths = next(normal_iter)
                anomaly_features, anomaly_label, anomaly_lengths = next(anomaly_iter)

                # Concat features + lengths
                visual_features = torch.cat([normal_features, anomaly_features], dim=0).to(self.device)
                feat_lengths = torch.cat([normal_lengths, anomaly_lengths], dim=0).to(self.device)

                # Merge labels
                text_labels = list(normal_label) + list(anomaly_label)
                text_labels = get_batch_label(text_labels, self.prompt_text, self.label_map).to(self.device)

                yield visual_features, text_labels, feat_lengths

        elif self.current_dataset == Datasets.XD:
            for visual_features, text_labels, feat_lengths in self.train_loader:
                visual_features = visual_features.to(self.device)
                feat_lengths = feat_lengths.to(self.device)
                text_labels = get_batch_label(text_labels, self.prompt_text, self.label_map).to(self.device)

                yield visual_features, text_labels, feat_lengths

        else:
            raise ValueError(f"Unknown dataset {self.current_dataset}")


In [26]:
#@title Train Function

class TripletLoss(nn.Module):
    def __init__(self):
        super(TripletLoss, self).__init__()

    def distance(self, x, y):
        return torch.cdist(x, y, p=2)

    def forward(self, feats, margin=100.0):
        bs = feats.size(0)
        n_feats = feats[:bs // 2]
        a_feats = feats[bs // 2:]

        # Shape: (N, N) and (N, A)
        n_d = self.distance(n_feats, n_feats)
        a_d = self.distance(n_feats, a_feats)

        n_d_max, _ = torch.max(n_d, dim=0)         # shape: (N,)
        a_d_min, _ = torch.min(a_d, dim=0)         # shape: (A,)

        a_d_min = margin - a_d_min
        a_d_min = torch.max(torch.zeros_like(a_d_min), a_d_min)  # element-wise clamp to >= 0

        return torch.mean(n_d_max) + torch.mean(a_d_min)



def CLASM(logits, labels, lengths, device):
    instance_logits = torch.zeros(0).to(device)
    labels = labels / torch.sum(labels, dim=1, keepdim=True)
    labels = labels.to(device)

    for i in range(logits.shape[0]):
        tmp, _ = torch.topk(logits[i, 0:lengths[i]], k=int(lengths[i] / 16 + 1), largest=True, dim=0)
        instance_logits = torch.cat([instance_logits, torch.mean(tmp, 0, keepdim=True)], dim=0)

    milloss = -torch.mean(torch.sum(labels * F.log_softmax(instance_logits, dim=1), dim=1), dim=0)
    return milloss

def CLAS2(logits, labels, lengths, device):
    instance_logits = torch.zeros(0).to(device)
    labels = 1 - labels[:, 0].reshape(labels.shape[0])
    labels = labels.to(device)
    logits = torch.sigmoid(logits).reshape(logits.shape[0], logits.shape[1])

    for i in range(logits.shape[0]):
        tmp, _ = torch.topk(logits[i, 0:lengths[i]], k=int(lengths[i] / 16 + 1), largest=True)
        tmp = torch.mean(tmp).view(1)
        instance_logits = torch.cat([instance_logits, tmp], dim=0)

    clsloss = F.binary_cross_entropy(instance_logits, labels)
    return clsloss


def train(model, train_iter, testloader, args, label_map, device, experiment):
    model.to(device)
    gt = np.load(args.gt_path)
    gtsegments = np.load(args.gt_segment_path, allow_pickle=True)
    gtlabels = np.load(args.gt_label_path, allow_pickle=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = MultiStepLR(optimizer, args.scheduler_milestones, args.scheduler_rate)
    triplet_loss_fn = TripletLoss().to(device)
    prompt_text = get_prompt_text(label_map)
    best_metric = 0
    metric_to_optimize = 'ap' if CURRENT_DATASET == Datasets.XD else 'auc'
    epoch = 0
    global_step = 0

    if args.use_checkpoint:
        download_commet_asset(args.checkpoint_file)
        checkpoint = torch.load(args.checkpoint_file, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        best_metric = checkpoint[metric_to_optimize]
        print("checkpoint info:")
        print("epoch:", epoch+1, f" {metric_to_optimize}:", best_metric)

    for e in range(args.max_epoch):
        curr_epoch = e + 1
        experiment.log_current_epoch(curr_epoch)
        model.train()
        loss_total1 = 0
        loss_total2 = 0

        for i, (visual_features, text_labels, feat_lengths) in enumerate(train_iter):
            step = i * Args.batch_size * 2
            global_step += 1

            text_features, logits1, logits2, visual_feats = model(visual_features, None, prompt_text, feat_lengths)

            loss1 = CLAS2(logits1, text_labels, feat_lengths, device)
            loss_total1 += loss1.item()

            loss2 = CLASM(logits2, text_labels, feat_lengths, device)
            loss_total2 += loss2.item()

            triplet_loss_val = triplet_loss_fn(visual_feats)

            loss3 = torch.zeros(1).to(device)
            text_feature_normal = text_features[0] / text_features[0].norm(dim=-1, keepdim=True)
            for j in range(1, text_features.shape[0]):
                text_feature_abr = text_features[j] / text_features[j].norm(dim=-1, keepdim=True)
                loss3 += torch.abs(text_feature_normal @ text_feature_abr)
            loss3 = loss3 / 13 * 1e-1

            loss = loss1 + loss2 + loss3 + 0.01 * triplet_loss_val

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            avg_loss1 = loss_total1 / (i + 1)
            avg_loss2 = loss_total2 / (i + 1)
            loss3_val = loss3.item()

            # Comet logging (per step)
            experiment.log_metric("loss1", avg_loss1, step=global_step)
            experiment.log_metric("loss2", avg_loss2, step=global_step)
            experiment.log_metric("loss3", loss3_val, step=global_step)
            experiment.log_metric("loss_triplet", triplet_loss_val.item(), step=global_step)


            if step % 1280 == 0 and step != 0:
                print(f'epoch: {curr_epoch} | step: {step} | loss1: {avg_loss1:.4f} | loss2: {avg_loss2:.4f} | loss3: {loss3_val:.4f}')

                # Evaluate
                AUC, AP, averageMAP = test(model, testloader, args.visual_length, prompt_text, gt, gtsegments, gtlabels, device)
                experiment.log_metric("AP", AP, step=global_step)
                experiment.log_metric("AUC", AUC, step=global_step)
                experiment.log_metric("Average_MAP", averageMAP, step=global_step)

                current_metric = AP if metric_to_optimize == 'ap' else AUC

                if current_metric > best_metric:
                    best_metric = current_metric
                    experiment.log_metric(f"Best_{metric_to_optimize.upper()}", best_metric, step=global_step)

                    checkpoint = {
                        'epoch': curr_epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                         metric_to_optimize: best_metric
                    }
                    torch.save(checkpoint, args.checkpoint_file)

                    # Log checkpoint to Comet
                    experiment.log_asset(file_data=args.checkpoint_file, file_name=args.checkpoint_file, overwrite=True)

        scheduler.step()

        experiment.log_metric("epoch", curr_epoch, step=global_step)

        # Reload best checkpoint before next epoch
        checkpoint = torch.load(args.checkpoint_file, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])

    # Save and log final model weights
    checkpoint = torch.load(args.checkpoint_file, weights_only=False)
    torch.save(checkpoint['model_state_dict'], args.model_path)
    experiment.log_model("final_model", args.model_path)


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    #torch.backends.cudnn.deterministic = True

def run_train(experiment):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    args = Args()
    setup_seed(args.seed)

    if CURRENT_DATASET == Datasets.UCF:
      label_map = dict({'Normal': 'normal', 'Abuse': 'abuse', 'Arrest': 'arrest', 'Arson': 'arson', 'Assault': 'assault', 'Burglary': 'burglary', 'Explosion': 'explosion', 'Fighting': 'fighting', 'RoadAccidents': 'roadAccidents', 'Robbery': 'robbery', 'Shooting': 'shooting', 'Shoplifting': 'shoplifting', 'Stealing': 'stealing', 'Vandalism': 'vandalism'})

      normal_dataset = UCFDataset(args.visual_length, args.train_list, False, label_map, True)
      normal_loader = DataLoader(normal_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)

      anomaly_dataset = UCFDataset(args.visual_length, args.train_list, False, label_map, False)
      anomaly_loader = DataLoader(anomaly_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)

      test_dataset = UCFDataset(args.visual_length, args.test_list, True, label_map)
      test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    elif CURRENT_DATASET == Datasets.XD:
      label_map = dict({'A': 'normal', 'B1': 'fighting', 'B2': 'shooting', 'B4': 'riot', 'B5': 'abuse', 'B6': 'car accident', 'G': 'explosion'})

      train_dataset = XDDataset(args.visual_length, args.train_list, False, label_map)
      train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)

      test_dataset = XDDataset(args.visual_length, args.test_list, True, label_map)
      test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)


    prompt_text = get_prompt_text(label_map)

    train_iter = TrainIterator(
        current_dataset=CURRENT_DATASET,
        normal_loader=normal_loader if CURRENT_DATASET == Datasets.UCF else None,
        anomaly_loader=anomaly_loader if CURRENT_DATASET == Datasets.UCF else None,
        train_loader=train_loader if CURRENT_DATASET == Datasets.XD else None,
        device=device,
        label_map=label_map,
        prompt_text=prompt_text,
      )

    model = CLIPVAD(args.classes_num, args.embed_dim, args.visual_length, args.visual_width, args.visual_head, args.visual_layers, args.attn_window, args.prompt_prefix, args.prompt_postfix, device)

    train(model, train_iter, test_loader, args, label_map, device, experiment)



In [None]:
experiment = comet_ml.Experiment(
                  api_key=COMET_API_KEY,
                  project_name=COMET_PROJECT_NAME)

with experiment.train():
  run_train(experiment)

experiment.end()

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : flat_dessert_6344
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/hammad-ali/vad-clip-notebook/d5008c842b8a4e9d91e5b12ad7cd5e27
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     curr_epoch              : 1
[1;38;5;39mCOMET INFO:[0m     train_loss [3]          : (3.688724994659424, 4.881041526794434)
[1;38;5;39mCOMET INFO:[0m     train_loss1 [21]        : (0.6341090557121095, 1.2367403507232666)
[1;38;5;39mCOMET INFO:[0m     train_loss2 [21]        : (1.8316247633525304, 1.909945368

epoch: 1 | step: 3840 | loss1: 0.6341 | loss2: 1.8316 | loss3: 0.0449
AUC1:  0.8118602589274566  AP1:  0.49492551791127704
AUC2:  0.5911856816790272  AP2: 0.2828839954196463
epoch: 1 | step: 7680 | loss1: 0.5219 | loss2: 1.7386 | loss3: 0.0438
AUC1:  0.869501882858055  AP1:  0.6244935460977947
AUC2:  0.6941002307608997  AP2: 0.3900166307098992
epoch: 1 | step: 11520 | loss1: 0.4657 | loss2: 1.6432 | loss3: 0.0418
AUC1:  0.8951022694269679  AP1:  0.6811501609584534
AUC2:  0.789479305779756  AP2: 0.4919985574462133
epoch: 1 | step: 15360 | loss1: 0.4274 | loss2: 1.5582 | loss3: 0.0390
AUC1:  0.9099322663412064  AP1:  0.7147525447227142
AUC2:  0.8539781070010186  AP2: 0.5301988082254321
epoch: 1 | step: 19200 | loss1: 0.3943 | loss2: 1.4747 | loss3: 0.0358
AUC1:  0.9166402212094118  AP1:  0.7297119004740841
AUC2:  0.8678692455108434  AP2: 0.5386584836237396
epoch: 1 | step: 23040 | loss1: 0.3686 | loss2: 1.3895 | loss3: 0.0326
AUC1:  0.9213295630669529  AP1:  0.7394200197332585
AUC2:  0.8