[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hammad-ali1/vad_clip_colab_notebook/blob/main/vad_clip_notebook.ipynb)


In [3]:
from IPython.display import clear_output

!pip install ftfy
!pip install comet_ml

clear_output()

In [2]:
#@title Pull Latest Repo
from google.colab import userdata
from pathlib import Path

try:
  gh_token = userdata.get('vad_clip_gh_token')
except Exception:
  gh_token = None

REPO_NAME = "vad_clip_colab_notebook"
REPO_PATH = Path("/content") / REPO_NAME

if gh_token:
  GITHUB_URL = f"https://{gh_token}@github.com/hammad-ali1/{REPO_NAME}.git"
else:
  GITHUB_URL = f'https://github.com/hammad-ali1/{REPO_NAME}'

!rm -rf {REPO_PATH}
!git clone {GITHUB_URL}

!cp -r {REPO_PATH}/* /content

Cloning into 'vad_clip_colab_notebook'...
remote: Enumerating objects: 79, done.[K
remote: Counting objects: 100% (79/79), done.[K
remote: Compressing objects: 100% (66/66), done.[K
remote: Total 79 (delta 24), reused 52 (delta 11), pack-reused 0 (from 0)[K
Receiving objects: 100% (79/79), 1.68 MiB | 12.36 MiB/s, done.
Resolving deltas: 100% (24/24), done.


In [4]:
#@title Import Modules
import os
from pathlib import Path
from collections import OrderedDict
import os
from pathlib import Path
import random

import comet_ml
import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from sklearn.metrics import average_precision_score, roc_auc_score

from clip import clip
from utils.dataset import UCFDataset
from utils.layers import GraphConvolution, DistanceAdj
from utils.tools import get_batch_mask, get_prompt_text, get_batch_label
from utils.ucf_detectionMAP import getDetectionMAP as dmAP

In [5]:
DRIVE_DATASET_DIR = Path('/content/drive/MyDrive/UCFClipFeatures')

if os.path.ismount('/content/drive') and DRIVE_DATASET_DIR.exists():
  DATASET_DIR = DRIVE_DATASET_DIR
else:
  !gdown 1DOkXHObwEGqmU3c9-J1gCIyVXiTMFfC5
  !unzip /content/UCFClipFeatures.zip -d /content
  !rm /content/UCFClipFeatures.zip
  DATASET_DIR = Path('/content/UCFClipFeatures')

clear_output()

In [None]:
class Args:
    seed = 234

    embed_dim = 512
    visual_length = 256
    visual_width = 512
    visual_head = 1
    visual_layers = 2
    attn_window = 8
    prompt_prefix = 10
    prompt_postfix = 10
    classes_num = 14

    max_epoch = 10
    model_path = "model_cur.pth"
    use_checkpoint = False
    checkpoint_path = "checkpoint.pth"
    batch_size = 64
    train_list = 'list/ucf_CLIP_rgb.csv'
    test_list = 'list/ucf_CLIP_rgbtest.csv'
    gt_path = 'list/gt_ucf.npy'
    gt_segment_path = 'list/gt_segment_ucf.npy'
    gt_label_path = 'list/gt_label_ucf.npy'

    lr = 2e-5
    scheduler_rate = 0.1
    scheduler_milestones = [4, 8]


In [None]:
def update_lists(file_path, limit_list=None):
    df = pd.read_csv(file_path)

    # Extract video ID (removing part suffix)
    df['video_id'] = df['path'].apply(lambda x: x.rsplit('__', 1)[0])

    # Keep only complete videos (with all 10 parts)
    grouped = df.groupby('video_id')
    complete_videos = grouped.filter(lambda x: len(x) == 10)

    if limit_list is not None:
        total_videos = limit_list // 10

        # Create video-level metadata
        video_meta = (
            complete_videos.groupby('video_id')
            .first()
            .reset_index()[['video_id', 'label']]
        )

        # Separate Normal and Other labels
        normal_videos = video_meta[video_meta['label'] == 'Normal']
        other_videos = video_meta[video_meta['label'] != 'Normal']
        other_labels = other_videos['label'].unique()

        num_normal_videos = total_videos // 2
        num_other_videos = total_videos - num_normal_videos
        per_other_label = max(1, num_other_videos // len(other_labels))

        # Sample Normal videos
        selected_videos = normal_videos.sample(
            n=min(num_normal_videos, len(normal_videos)), random_state=42
        )

        # Sample from other labels
        for label in other_labels:
            vids = other_videos[other_videos['label'] == label]
            sampled = vids.sample(n=min(per_other_label, len(vids)), random_state=42)
            selected_videos = pd.concat([selected_videos, sampled], ignore_index=True)

        # Get all 10 parts of selected videos
        final_df = complete_videos[complete_videos['video_id'].isin(selected_videos['video_id'])]

        # Sort by video and clip number
        final_df['part'] = final_df['path'].apply(lambda x: int(Path(x).stem.split('__')[-1]))
        final_df = final_df.sort_values(by=['label', 'video_id', 'part']).drop(columns='part')
    else:
        final_df = df

    # Fix paths relative to DATASET_DIR
    old_root = Path('/content/drive/MyDrive/UCFClipFeatures')
    final_df['path'] = final_df['path'].apply(lambda p: str(DATASET_DIR / Path(p).relative_to(old_root)))

    # Save
    final_df.drop(columns='video_id').to_csv(file_path, index=False)


update_lists(Args.test_list)
update_lists(Args.train_list, limit_list=None)

In [None]:
#@title Model
class LayerNorm(nn.LayerNorm):

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor, padding_mask: torch.Tensor):
        padding_mask = padding_mask.to(dtype=bool, device=x.device) if padding_mask is not None else None
        self.attn_mask = self.attn_mask.to(device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, key_padding_mask=padding_mask, attn_mask=self.attn_mask)[0]

    def forward(self, x):
        x, padding_mask = x
        x = x + self.attention(self.ln_1(x), padding_mask)
        x = x + self.mlp(self.ln_2(x))
        return (x, padding_mask)


# class Transformer(nn.Module):
#     def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
#         super().__init__()
#         self.width = width
#         self.layers = layers
#         self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

#     def forward(self, x: torch.Tensor):
#         return self.resblocks(x)

# Full Attention Transformer
class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[
            ResidualAttentionBlock(width, heads, attn_mask=None) for _ in range(layers)
        ])

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)


class CLIPVAD(nn.Module):
    def __init__(self,
                 num_class: int,
                 embed_dim: int,
                 visual_length: int,
                 visual_width: int,
                 visual_head: int,
                 visual_layers: int,
                 attn_window: int,
                 prompt_prefix: int,
                 prompt_postfix: int,
                 device):
        super().__init__()

        self.num_class = num_class
        self.visual_length = visual_length
        self.visual_width = visual_width
        self.embed_dim = embed_dim
        self.attn_window = attn_window
        self.prompt_prefix = prompt_prefix
        self.prompt_postfix = prompt_postfix
        self.device = device

        self.temporal = Transformer(
            width=visual_width,
            layers=visual_layers,
            heads=visual_head,
            # attn_mask=self.build_attention_mask(self.attn_window)
        )

        width = int(visual_width / 2)
        # self.gc1 = GraphConvolution(visual_width, width, residual=True)
        # self.gc2 = GraphConvolution(width, width, residual=True)
        # self.gc3 = GraphConvolution(visual_width, width, residual=True)
        # self.gc4 = GraphConvolution(width, width, residual=True)
        # self.disAdj = DistanceAdj()
        self.linear = nn.Linear(visual_width, visual_width)
        self.gelu = QuickGELU()

        self.mlp1 = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(visual_width, visual_width * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(visual_width * 4, visual_width))
        ]))
        self.mlp2 = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(visual_width, visual_width * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(visual_width * 4, visual_width))
        ]))
        self.classifier = nn.Linear(visual_width, 1)

        self.clipmodel, _ = clip.load("ViT-B/16", device)
        for clip_param in self.clipmodel.parameters():
            clip_param.requires_grad = False

        self.frame_position_embeddings = nn.Embedding(visual_length, visual_width)
        self.text_prompt_embeddings = nn.Embedding(77, self.embed_dim)

        self.initialize_parameters()

    def initialize_parameters(self):
        nn.init.normal_(self.text_prompt_embeddings.weight, std=0.01)
        nn.init.normal_(self.frame_position_embeddings.weight, std=0.01)

    def build_attention_mask(self, attn_window):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.visual_length, self.visual_length)
        mask.fill_(float('-inf'))
        for i in range(int(self.visual_length / attn_window)):
            if (i + 1) * attn_window < self.visual_length:
                mask[i * attn_window: (i + 1) * attn_window, i * attn_window: (i + 1) * attn_window] = 0
            else:
                mask[i * attn_window: self.visual_length, i * attn_window: self.visual_length] = 0

        return mask

    def adj4(self, x, seq_len):
        soft = nn.Softmax(1)
        x2 = x.matmul(x.permute(0, 2, 1)) # B*T*T
        x_norm = torch.norm(x, p=2, dim=2, keepdim=True)  # B*T*1
        x_norm_x = x_norm.matmul(x_norm.permute(0, 2, 1))
        x2 = x2/(x_norm_x+1e-20)
        output = torch.zeros_like(x2)
        if seq_len is None:
            for i in range(x.shape[0]):
                tmp = x2[i]
                adj2 = tmp
                adj2 = F.threshold(adj2, 0.7, 0)
                adj2 = soft(adj2)
                output[i] = adj2
        else:
            for i in range(len(seq_len)):
                tmp = x2[i, :seq_len[i], :seq_len[i]]
                adj2 = tmp
                adj2 = F.threshold(adj2, 0.7, 0)
                adj2 = soft(adj2)
                output[i, :seq_len[i], :seq_len[i]] = adj2

        return output

    # def encode_video(self, images, padding_mask, lengths):
    #     images = images.to(torch.float)
    #     position_ids = torch.arange(self.visual_length, device=self.device)
    #     position_ids = position_ids.unsqueeze(0).expand(images.shape[0], -1)
    #     frame_position_embeddings = self.frame_position_embeddings(position_ids)
    #     frame_position_embeddings = frame_position_embeddings.permute(1, 0, 2)
    #     images = images.permute(1, 0, 2) + frame_position_embeddings

    #     x, _ = self.temporal((images, None))
    #     x = x.permute(1, 0, 2)

    #     adj = self.adj4(x, lengths)
    #     disadj = self.disAdj(x.shape[0], x.shape[1])
    #     x1_h = self.gelu(self.gc1(x, adj))
    #     x2_h = self.gelu(self.gc3(x, disadj))

    #     x1 = self.gelu(self.gc2(x1_h, adj))
    #     x2 = self.gelu(self.gc4(x2_h, disadj))

    #     x = torch.cat((x1, x2), 2)
    #     x = self.linear(x)

    #     return x

    def encode_video(self, images, padding_mask, lengths):
      images = images.to(torch.float)

      position_ids = torch.arange(self.visual_length, device=self.device)
      position_ids = position_ids.unsqueeze(0).expand(images.shape[0], -1)
      frame_position_embeddings = self.frame_position_embeddings(position_ids)
      frame_position_embeddings = frame_position_embeddings.permute(1, 0, 2)

      # Add positional encoding
      images = images.permute(1, 0, 2) + frame_position_embeddings  # Shape: [T, B, C]

      # Global attention transformer over entire sequence
      x, _ = self.temporal((images, None))  # Output shape: [T, B, C]
      x = x.permute(1, 0, 2)  # [B, T, C]

      return x  # No GCN applied

    def encode_textprompt(self, text):
        word_tokens = clip.tokenize(text).to(self.device)
        word_embedding = self.clipmodel.encode_token(word_tokens)
        text_embeddings = self.text_prompt_embeddings(torch.arange(77).to(self.device)).unsqueeze(0).repeat([len(text), 1, 1])
        text_tokens = torch.zeros(len(text), 77).to(self.device)

        for i in range(len(text)):
            ind = torch.argmax(word_tokens[i], -1)
            text_embeddings[i, 0] = word_embedding[i, 0]
            text_embeddings[i, self.prompt_prefix + 1: self.prompt_prefix + ind] = word_embedding[i, 1: ind]
            text_embeddings[i, self.prompt_prefix + ind + self.prompt_postfix] = word_embedding[i, ind]
            text_tokens[i, self.prompt_prefix + ind + self.prompt_postfix] = word_tokens[i, ind]

        text_features = self.clipmodel.encode_text(text_embeddings, text_tokens)

        return text_features

    def forward(self, visual, padding_mask, text, lengths):
        visual_features = self.encode_video(visual, padding_mask, lengths)
        logits1 = self.classifier(visual_features + self.mlp2(visual_features))

        text_features_ori = self.encode_textprompt(text)

        text_features = text_features_ori
        logits_attn = logits1.permute(0, 2, 1)
        visual_attn = logits_attn @ visual_features
        visual_attn = visual_attn / visual_attn.norm(dim=-1, keepdim=True)
        visual_attn = visual_attn.expand(visual_attn.shape[0], text_features_ori.shape[0], visual_attn.shape[2])
        text_features = text_features_ori.unsqueeze(0)
        text_features = text_features.expand(visual_attn.shape[0], text_features.shape[1], text_features.shape[2])
        text_features = text_features + visual_attn
        text_features = text_features + self.mlp1(text_features)

        visual_features_norm = visual_features / visual_features.norm(dim=-1, keepdim=True)
        text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
        text_features_norm = text_features_norm.permute(0, 2, 1)
        logits2 = visual_features_norm @ text_features_norm.type(visual_features_norm.dtype) / 0.07

        return text_features_ori, logits1, logits2


In [None]:
#@title Test Function
def test(model, testdataloader, maxlen, prompt_text, gt, gtsegments, gtlabels, device):

    model.to(device)
    model.eval()

    element_logits2_stack = []

    with torch.no_grad():
        for i, item in enumerate(testdataloader):
            visual = item[0].squeeze(0)
            length = item[2]

            length = int(length)
            len_cur = length
            if len_cur < maxlen:
                visual = visual.unsqueeze(0)

            visual = visual.to(device)

            lengths = torch.zeros(int(length / maxlen) + 1)
            for j in range(int(length / maxlen) + 1):
                if j == 0 and length < maxlen:
                    lengths[j] = length
                elif j == 0 and length > maxlen:
                    lengths[j] = maxlen
                    length -= maxlen
                elif length > maxlen:
                    lengths[j] = maxlen
                    length -= maxlen
                else:
                    lengths[j] = length
            lengths = lengths.to(int)
            padding_mask = get_batch_mask(lengths, maxlen).to(device)
            _, logits1, logits2 = model(visual, padding_mask, prompt_text, lengths)
            logits1 = logits1.reshape(logits1.shape[0] * logits1.shape[1], logits1.shape[2])
            logits2 = logits2.reshape(logits2.shape[0] * logits2.shape[1], logits2.shape[2])
            prob2 = (1 - logits2[0:len_cur].softmax(dim=-1)[:, 0].squeeze(-1))
            prob1 = torch.sigmoid(logits1[0:len_cur].squeeze(-1))

            if i == 0:
                ap1 = prob1
                ap2 = prob2
                #ap3 = prob3
            else:
                ap1 = torch.cat([ap1, prob1], dim=0)
                ap2 = torch.cat([ap2, prob2], dim=0)

            element_logits2 = logits2[0:len_cur].softmax(dim=-1).detach().cpu().numpy()
            element_logits2 = np.repeat(element_logits2, 16, 0)
            element_logits2_stack.append(element_logits2)

    ap1 = ap1.cpu().numpy()
    ap2 = ap2.cpu().numpy()
    ap1 = ap1.tolist()
    ap2 = ap2.tolist()

    ROC1 = roc_auc_score(gt, np.repeat(ap1, 16))
    AP1 = average_precision_score(gt, np.repeat(ap1, 16))
    ROC2 = roc_auc_score(gt, np.repeat(ap2, 16))
    AP2 = average_precision_score(gt, np.repeat(ap2, 16))

    print("AUC1: ", ROC1, " AP1: ", AP1)
    print("AUC2: ", ROC2, " AP2:", AP2)

    dmap, iou = dmAP(element_logits2_stack, gtsegments, gtlabels, excludeNormal=False)
    averageMAP = 0
    for i in range(5):
        print('mAP@{0:.1f} ={1:.2f}%'.format(iou[i], dmap[i]))
        averageMAP += dmap[i]
    averageMAP = averageMAP/(i+1)
    print('average MAP: {:.2f}'.format(averageMAP))

    return ROC1, AP1


def run_test():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    args = Args()

    label_map = dict({'Normal': 'Normal', 'Abuse': 'Abuse', 'Arrest': 'Arrest', 'Arson': 'Arson', 'Assault': 'Assault', 'Burglary': 'Burglary', 'Explosion': 'Explosion', 'Fighting': 'Fighting', 'RoadAccidents': 'RoadAccidents', 'Robbery': 'Robbery', 'Shooting': 'Shooting', 'Shoplifting': 'Shoplifting', 'Stealing': 'Stealing', 'Vandalism': 'Vandalism'})

    testdataset = UCFDataset(args.visual_length, args.test_list, True, label_map)
    testdataloader = DataLoader(testdataset, batch_size=1, shuffle=False)

    prompt_text = get_prompt_text(label_map)
    gt = np.load(args.gt_path)
    gtsegments = np.load(args.gt_segment_path, allow_pickle=True)
    gtlabels = np.load(args.gt_label_path, allow_pickle=True)

    model = CLIPVAD(args.classes_num, args.embed_dim, args.visual_length, args.visual_width, args.visual_head, args.visual_layers, args.attn_window, args.prompt_prefix, args.prompt_postfix, device)
    model_param = torch.load(args.model_path)
    model.load_state_dict(model_param)

    test(model, testdataloader, args.visual_length, prompt_text, gt, gtsegments, gtlabels, device)

In [None]:
#@title Train Function

def CLASM(logits, labels, lengths, device):
    instance_logits = torch.zeros(0).to(device)
    labels = labels / torch.sum(labels, dim=1, keepdim=True)
    labels = labels.to(device)

    for i in range(logits.shape[0]):
        tmp, _ = torch.topk(logits[i, 0:lengths[i]], k=int(lengths[i] / 16 + 1), largest=True, dim=0)
        instance_logits = torch.cat([instance_logits, torch.mean(tmp, 0, keepdim=True)], dim=0)

    milloss = -torch.mean(torch.sum(labels * F.log_softmax(instance_logits, dim=1), dim=1), dim=0)
    return milloss

def CLAS2(logits, labels, lengths, device):
    instance_logits = torch.zeros(0).to(device)
    labels = 1 - labels[:, 0].reshape(labels.shape[0])
    labels = labels.to(device)
    logits = torch.sigmoid(logits).reshape(logits.shape[0], logits.shape[1])

    for i in range(logits.shape[0]):
        tmp, _ = torch.topk(logits[i, 0:lengths[i]], k=int(lengths[i] / 16 + 1), largest=True)
        tmp = torch.mean(tmp).view(1)
        instance_logits = torch.cat([instance_logits, tmp], dim=0)

    clsloss = F.binary_cross_entropy(instance_logits, labels)
    return clsloss


def train(model, normal_loader, anomaly_loader, testloader, args, label_map, device, experiment):
    model.to(device)
    gt = np.load(args.gt_path)
    gtsegments = np.load(args.gt_segment_path, allow_pickle=True)
    gtlabels = np.load(args.gt_label_path, allow_pickle=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = MultiStepLR(optimizer, args.scheduler_milestones, args.scheduler_rate)
    prompt_text = get_prompt_text(label_map)
    auc_best = 0
    epoch = 0
    global_step = 0

    if args.use_checkpoint and Path(args.checkpoint_path).exists():
        checkpoint = torch.load(args.checkpoint_path, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        auc_best = checkpoint['auc']
        print("checkpoint info:")
        print("epoch:", epoch+1, " auc:", auc_best)

    for e in range(args.max_epoch):
        experiment.log_current_epoch(e)
        model.train()
        loss_total1 = 0
        loss_total2 = 0
        normal_iter = iter(normal_loader)
        anomaly_iter = iter(anomaly_loader)

        for i in range(min(len(normal_loader), len(anomaly_loader))):
            step = i * normal_loader.batch_size * 2
            global_step = step * (e+1)

            normal_features, normal_label, normal_lengths = next(normal_iter)
            anomaly_features, anomaly_label, anomaly_lengths = next(anomaly_iter)

            visual_features = torch.cat([normal_features, anomaly_features], dim=0).to(device)
            text_labels = list(normal_label) + list(anomaly_label)
            feat_lengths = torch.cat([normal_lengths, anomaly_lengths], dim=0).to(device)
            text_labels = get_batch_label(text_labels, prompt_text, label_map).to(device)

            text_features, logits1, logits2 = model(visual_features, None, prompt_text, feat_lengths)

            loss1 = CLAS2(logits1, text_labels, feat_lengths, device)
            loss_total1 += loss1.item()

            loss2 = CLASM(logits2, text_labels, feat_lengths, device)
            loss_total2 += loss2.item()

            loss3 = torch.zeros(1).to(device)
            text_feature_normal = text_features[0] / text_features[0].norm(dim=-1, keepdim=True)
            for j in range(1, text_features.shape[0]):
                text_feature_abr = text_features[j] / text_features[j].norm(dim=-1, keepdim=True)
                loss3 += torch.abs(text_feature_normal @ text_feature_abr)
            loss3 = loss3 / 13 * 1e-1

            loss = loss1 + loss2 + loss3

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            avg_loss1 = loss_total1 / (i + 1)
            avg_loss2 = loss_total2 / (i + 1)
            loss3_val = loss3.item()

            # Comet logging (per step)
            experiment.log_metric("loss1", avg_loss1, step=global_step)
            experiment.log_metric("loss2", avg_loss2, step=global_step)
            experiment.log_metric("loss3", loss3_val, step=global_step)

            if step % 1280 == 0 and step != 0:
                print(f'epoch: {e+1} | step: {step} | loss1: {avg_loss1:.4f} | loss2: {avg_loss2:.4f} | loss3: {loss3_val:.4f}')

                # Evaluate
                AUC, AP = test(model, testloader, args.visual_length, prompt_text, gt, gtsegments, gtlabels, device)
                experiment.log_metric("AUC", AUC, step=global_step)

                if AUC > auc_best:
                    auc_best = AUC
                    experiment.log_metric("Best_AUC", auc_best, step=global_step)

                    checkpoint = {
                        'epoch': e,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'auc': auc_best
                    }
                    torch.save(checkpoint, args.checkpoint_path)

                    # Log checkpoint to Comet
                    experiment.log_asset(file_data=args.checkpoint_path, file_name="best_checkpoint.pth", overwrite=True)

        scheduler.step()

        experiment.log_metric("epoch", e, step=global_step)

        # Save current model weights separately
        checkpoint_dir = os.path.dirname(args.checkpoint_path)
        save_path = os.path.join(checkpoint_dir, 'model_cur.pth')
        torch.save(model.state_dict(), save_path)

        # Reload best checkpoint before next epoch
        checkpoint = torch.load(args.checkpoint_path, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])

    # Save and log final model weights
    checkpoint = torch.load(args.checkpoint_path, weights_only=False)
    torch.save(checkpoint['model_state_dict'], args.model_path)
    experiment.log_model("final_model", args.model_path)


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    #torch.backends.cudnn.deterministic = True

def run_train(experiment):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    args = Args()
    setup_seed(args.seed)

    label_map = dict({'Normal': 'normal', 'Abuse': 'abuse', 'Arrest': 'arrest', 'Arson': 'arson', 'Assault': 'assault', 'Burglary': 'burglary', 'Explosion': 'explosion', 'Fighting': 'fighting', 'RoadAccidents': 'roadAccidents', 'Robbery': 'robbery', 'Shooting': 'shooting', 'Shoplifting': 'shoplifting', 'Stealing': 'stealing', 'Vandalism': 'vandalism'})

    normal_dataset = UCFDataset(args.visual_length, args.train_list, False, label_map, True)
    normal_loader = DataLoader(normal_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
    anomaly_dataset = UCFDataset(args.visual_length, args.train_list, False, label_map, False)
    anomaly_loader = DataLoader(anomaly_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)

    test_dataset = UCFDataset(args.visual_length, args.test_list, True, label_map)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    model = CLIPVAD(args.classes_num, args.embed_dim, args.visual_length, args.visual_width, args.visual_head, args.visual_layers, args.attn_window, args.prompt_prefix, args.prompt_postfix, device)

    train(model, normal_loader, anomaly_loader, test_loader, args, label_map, device, experiment)



In [None]:
COMET_API_KEY = userdata.get('COMET_API_KEY')
experiment = comet_ml.Experiment(
                  api_key=COMET_API_KEY,
                  project_name="vad_clip_notebook")

with experiment.train():
  run_train(experiment)

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : simple_thrush_2165
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/hammad-ali/vad-clip-notebook/ad5243acee4f4b29b08e6b17333dc11f
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INFO:[0m     notebook_url : https://colab.research.google.com/notebook#fileId=https%3A%2F%2Fgithub.com%2Fhammad-ali1%2Fvad_clip_colab_notebook%2Fblob%2Fmain%2Fvad_clip_notebook.ipynb
[1;38;5;39mCOMET INFO:[0m   Uploads:
[1;38;5;39mCOMET INFO:[0m     environment details : 1
[1;38;5;39mCOMET INFO:[0m     filename            : 1
[1;38;5;39m

epoch: 1 | step: 1280 | loss1: 0.5671 | loss2: 2.4711 | loss3: 0.0980
AUC1:  0.7868346428338847  AP1:  0.17884024646622185
AUC2:  0.7446548214199078  AP2: 0.1755740481650378
mAP@0.1 =2.69%
mAP@0.2 =2.03%
mAP@0.3 =1.17%
mAP@0.4 =0.79%
mAP@0.5 =0.19%
average MAP: 1.37
epoch: 1 | step: 2560 | loss1: 0.4895 | loss2: 2.2704 | loss3: 0.0945
AUC1:  0.794575799736916  AP1:  0.18613478740241668
AUC2:  0.7943967722021964  AP2: 0.1909005789987792
mAP@0.1 =1.65%
mAP@0.2 =1.21%
mAP@0.3 =0.91%
mAP@0.4 =0.39%
mAP@0.5 =0.16%
average MAP: 0.86
epoch: 1 | step: 3840 | loss1: 0.4368 | loss2: 2.0856 | loss3: 0.0780
AUC1:  0.8033252779522663  AP1:  0.19199043662060675
AUC2:  0.7987365114071406  AP2: 0.17942550302258417
mAP@0.1 =3.24%
mAP@0.2 =2.65%
mAP@0.3 =2.11%
mAP@0.4 =0.38%
mAP@0.5 =0.21%
average MAP: 1.72
epoch: 1 | step: 5120 | loss1: 0.3981 | loss2: 1.9399 | loss3: 0.0663
AUC1:  0.8055854052525966  AP1:  0.19861231576427862
AUC2:  0.8078892104246138  AP2: 0.19331347642961827
mAP@0.1 =4.56%
mAP@0.2 =