# Libraries

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoImageProcessor, TimesformerModel, get_scheduler

import numpy as np
import pandas as pd
import os
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score

2024-06-22 15:24:05.998756: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-22 15:24:05.998850: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-22 15:24:06.130103: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Get Data

In [2]:
test_data = """
Abuse/Abuse028_x264.mp4
Abuse/Abuse030_x264.mp4
Arrest/Arrest001_x264.mp4
Arrest/Arrest007_x264.mp4
Arrest/Arrest024_x264.mp4
Arrest/Arrest030_x264.mp4
Arrest/Arrest039_x264.mp4
Arson/Arson007_x264.mp4
Arson/Arson009_x264.mp4
Arson/Arson010_x264.mp4
Arson/Arson011_x264.mp4
Arson/Arson016_x264.mp4
Arson/Arson018_x264.mp4
Arson/Arson022_x264.mp4
Arson/Arson035_x264.mp4
Arson/Arson041_x264.mp4
Assault/Assault006_x264.mp4
Assault/Assault010_x264.mp4
Assault/Assault011_x264.mp4
Burglary/Burglary005_x264.mp4
Burglary/Burglary017_x264.mp4
Burglary/Burglary018_x264.mp4
Burglary/Burglary021_x264.mp4
Burglary/Burglary024_x264.mp4
Burglary/Burglary032_x264.mp4
Burglary/Burglary033_x264.mp4
Burglary/Burglary035_x264.mp4
Burglary/Burglary037_x264.mp4
Burglary/Burglary061_x264.mp4
Burglary/Burglary076_x264.mp4
Burglary/Burglary079_x264.mp4
Burglary/Burglary092_x264.mp4
Explosion/Explosion002_x264.mp4
Explosion/Explosion004_x264.mp4
Explosion/Explosion007_x264.mp4
Explosion/Explosion008_x264.mp4
Explosion/Explosion010_x264.mp4
Explosion/Explosion011_x264.mp4
Explosion/Explosion013_x264.mp4
Explosion/Explosion016_x264.mp4
Explosion/Explosion017_x264.mp4
Explosion/Explosion020_x264.mp4
Explosion/Explosion021_x264.mp4
Explosion/Explosion022_x264.mp4
Explosion/Explosion025_x264.mp4
Explosion/Explosion027_x264.mp4
Explosion/Explosion028_x264.mp4
Explosion/Explosion029_x264.mp4
Explosion/Explosion033_x264.mp4
Explosion/Explosion035_x264.mp4
Explosion/Explosion036_x264.mp4
Explosion/Explosion039_x264.mp4
Explosion/Explosion043_x264.mp4
Fighting/Fighting003_x264.mp4
Fighting/Fighting018_x264.mp4
Fighting/Fighting033_x264.mp4
Fighting/Fighting042_x264.mp4
Fighting/Fighting047_x264.mp4
RoadAccidents/RoadAccidents001_x264.mp4
RoadAccidents/RoadAccidents002_x264.mp4
RoadAccidents/RoadAccidents004_x264.mp4
RoadAccidents/RoadAccidents009_x264.mp4
RoadAccidents/RoadAccidents010_x264.mp4
RoadAccidents/RoadAccidents011_x264.mp4
RoadAccidents/RoadAccidents012_x264.mp4
RoadAccidents/RoadAccidents016_x264.mp4
RoadAccidents/RoadAccidents017_x264.mp4
RoadAccidents/RoadAccidents019_x264.mp4
RoadAccidents/RoadAccidents020_x264.mp4
RoadAccidents/RoadAccidents021_x264.mp4
RoadAccidents/RoadAccidents022_x264.mp4
RoadAccidents/RoadAccidents121_x264.mp4
RoadAccidents/RoadAccidents122_x264.mp4
RoadAccidents/RoadAccidents123_x264.mp4
RoadAccidents/RoadAccidents124_x264.mp4
RoadAccidents/RoadAccidents125_x264.mp4
RoadAccidents/RoadAccidents127_x264.mp4
RoadAccidents/RoadAccidents128_x264.mp4
RoadAccidents/RoadAccidents131_x264.mp4
RoadAccidents/RoadAccidents132_x264.mp4
RoadAccidents/RoadAccidents133_x264.mp4
Robbery/Robbery048_x264.mp4
Robbery/Robbery050_x264.mp4
Robbery/Robbery102_x264.mp4
Robbery/Robbery106_x264.mp4
Robbery/Robbery137_x264.mp4
Shooting/Shooting002_x264.mp4
Shooting/Shooting004_x264.mp4
Shooting/Shooting007_x264.mp4
Shooting/Shooting008_x264.mp4
Shooting/Shooting010_x264.mp4
Shooting/Shooting011_x264.mp4
Shooting/Shooting013_x264.mp4
Shooting/Shooting015_x264.mp4
Shooting/Shooting018_x264.mp4
Shooting/Shooting019_x264.mp4
Shooting/Shooting021_x264.mp4
Shooting/Shooting022_x264.mp4
Shooting/Shooting024_x264.mp4
Shooting/Shooting026_x264.mp4
Shooting/Shooting028_x264.mp4
Shooting/Shooting032_x264.mp4
Shooting/Shooting033_x264.mp4
Shooting/Shooting034_x264.mp4
Shooting/Shooting037_x264.mp4
Shooting/Shooting043_x264.mp4
Shooting/Shooting046_x264.mp4
Shooting/Shooting047_x264.mp4
Shooting/Shooting048_x264.mp4
Shoplifting/Shoplifting001_x264.mp4
Shoplifting/Shoplifting004_x264.mp4
Shoplifting/Shoplifting005_x264.mp4
Shoplifting/Shoplifting007_x264.mp4
Shoplifting/Shoplifting010_x264.mp4
Shoplifting/Shoplifting015_x264.mp4
Shoplifting/Shoplifting016_x264.mp4
Shoplifting/Shoplifting017_x264.mp4
Shoplifting/Shoplifting020_x264.mp4
Shoplifting/Shoplifting021_x264.mp4
Shoplifting/Shoplifting022_x264.mp4
Shoplifting/Shoplifting027_x264.mp4
Shoplifting/Shoplifting028_x264.mp4
Shoplifting/Shoplifting029_x264.mp4
Shoplifting/Shoplifting031_x264.mp4
Shoplifting/Shoplifting033_x264.mp4
Shoplifting/Shoplifting034_x264.mp4
Shoplifting/Shoplifting037_x264.mp4
Shoplifting/Shoplifting039_x264.mp4
Shoplifting/Shoplifting044_x264.mp4
Shoplifting/Shoplifting049_x264.mp4
Stealing/Stealing019_x264.mp4
Stealing/Stealing036_x264.mp4
Stealing/Stealing058_x264.mp4
Stealing/Stealing062_x264.mp4
Stealing/Stealing079_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_003_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_006_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_010_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_014_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_015_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_018_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_019_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_024_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_025_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_027_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_033_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_034_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_041_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_042_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_048_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_050_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_051_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_056_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_059_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_063_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_067_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_070_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_100_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_129_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_150_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_168_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_175_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_182_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_189_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_196_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_203_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_210_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_217_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_224_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_246_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_247_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_248_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_251_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_289_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_310_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_312_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_317_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_345_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_352_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_360_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_365_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_401_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_417_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_439_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_452_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_453_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_478_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_576_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_597_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_603_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_606_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_621_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_634_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_641_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_656_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_686_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_696_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_702_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_704_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_710_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_717_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_722_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_725_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_745_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_758_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_778_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_780_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_781_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_782_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_783_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_798_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_801_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_828_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_831_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_866_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_867_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_868_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_869_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_870_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_871_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_872_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_873_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_874_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_875_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_876_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_877_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_878_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_879_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_880_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_881_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_882_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_883_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_884_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_885_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_886_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_887_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_888_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_889_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_890_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_891_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_892_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_893_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_894_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_895_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_896_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_897_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_898_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_899_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_900_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_901_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_902_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_903_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_904_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_905_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_906_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_907_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_908_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_909_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_910_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_911_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_912_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_913_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_914_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_915_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_923_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_924_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_925_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_926_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_927_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_928_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_929_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_930_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_931_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_932_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_933_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_934_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_935_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_936_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_937_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_938_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_939_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_940_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_941_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_943_x264.mp4
Testing_Normal_Videos_Anomaly/Normal_Videos_944_x264.mp4
Vandalism/Vandalism007_x264.mp4
Vandalism/Vandalism015_x264.mp4
Vandalism/Vandalism017_x264.mp4
Vandalism/Vandalism028_x264.mp4
Vandalism/Vandalism036_x264.mp4
"""

In [3]:
file_names = test_data.splitlines()
file_names = [name for name in file_names if name.strip()]  

test_list = []

for file_name in tqdm(file_names):
    file_name_pt = file_name.split("/")[-1][:-3] + "pt"
    test_list.append(file_name_pt)

  0%|          | 0/290 [00:00<?, ?it/s]

In [4]:
train_features = []
train_labels = []
test_features = []
test_labels = []

rgb_anomaly_dir = ['/kaggle/input/ucf-crime-two-stream-tokenized/RGB Anomaly-20240621T095401Z-001/RGB Anomaly', '/kaggle/input/ucf-crime-two-stream-tokenized/RGB Anomaly-20240621T095401Z-002/RGB Anomaly']
flow_anomaly_dir = ['/kaggle/input/ucf-crime-two-stream-tokenized/Flow Anomaly-20240621T101905Z-001/Flow Anomaly', '/kaggle/input/ucf-crime-two-stream-tokenized/Flow Anomaly-20240621T101905Z-002/Flow Anomaly']

rgb_normal_dir = ['/kaggle/input/ucf-crime-two-stream-tokenized/RGB Normal-20240621T103144Z-001/RGB Normal', '/kaggle/input/ucf-crime-two-stream-tokenized/RGB Normal-20240621T103144Z-002/RGB Normal']
flow_normal_dir = ['/kaggle/input/ucf-crime-two-stream-tokenized/Flow Normal-20240621T102819Z-001/Flow Normal']

rgb_test_dir = ['/kaggle/input/ucf-crime-two-stream-tokenized/RGB Test/RGB Test']
flow_test_dir = ['/kaggle/input/ucf-crime-two-stream-tokenized/Flow Test-20240621T100912Z-001/Flow Test']

anomaly_dir = (rgb_anomaly_dir, flow_anomaly_dir)
normal_dir = (rgb_normal_dir, flow_normal_dir)
test_dir = (rgb_test_dir, flow_test_dir)
all_dirs = [anomaly_dir, normal_dir, test_dir]

In [5]:
def get_filename(filepath):
    return os.path.basename(filepath)

In [6]:
for dirs in all_dirs:
    rgb_dirs, flow_dirs = dirs

    rgb_all_files = []
    for rgb_dir in rgb_dirs:
        rgb_files = sorted(os.listdir(rgb_dir))
        rgb_files = [os.path.join(rgb_dir, f) for f in rgb_files]
        rgb_all_files.extend(rgb_files)

    flow_all_files = []
    for flow_dir in flow_dirs:
        flow_files = sorted(os.listdir(flow_dir))
        flow_files = [os.path.join(flow_dir, f) for f in flow_files]
        flow_all_files.extend(flow_files)

    rgb_all_files = sorted(rgb_all_files, key=get_filename)
    flow_all_files = sorted(flow_all_files, key=get_filename)

    for rgb_file, flow_file in zip(rgb_all_files, flow_all_files):
        rgb_filename = os.path.basename(rgb_file)
        flow_filename = os.path.basename(flow_file)
        
        if rgb_filename != flow_filename:
            print(f"Filename mismatch: {rgb_filename} and {flow_filename}")
            continue

        file_path = (rgb_file, flow_file)
        filename = rgb_filename
        
        if filename not in test_list:
            train_features.append(file_path)
            if 'Normal' in filename:
                train_labels.append(0)
            else:
                train_labels.append(1)
        else:
            test_features.append(file_path)
            if 'Normal' in filename:
                test_labels.append(0)
            else:
                test_labels.append(1)

# Dataset Preparation

In [7]:
class TwoStreamUCFDataset(Dataset):
    def __init__(self,paths,labels):
        self.paths = paths
        self.labels = labels
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self,idx):
        rgb_path,flow_path = self.paths[idx]
        rgb_tokens = torch.load(rgb_path)
        flow_tokens = torch.load(flow_path)
        label = torch.tensor(self.labels[idx],dtype=torch.float32)
        return rgb_tokens['pixel_values'].squeeze(),flow_tokens['pixel_values'].squeeze(),label

In [8]:
class RGBStreamUCFDataset(Dataset):
    def __init__(self,paths,labels):
        self.paths = paths
        self.labels = labels
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self,idx):
        rgb_path,flow_path = self.paths[idx]
        rgb_tokens = torch.load(rgb_path)
        label = torch.tensor(self.labels[idx],dtype=torch.float32)
        return rgb_tokens['pixel_values'].squeeze(),label

In [9]:
train_dataset = RGBStreamUCFDataset(train_features,train_labels)
test_dataset = RGBStreamUCFDataset(test_features, test_labels)

In [10]:
# Data loading
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=True)

# Modelling

In [11]:
# Define pretrained model
model_rgb = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400")
model_flow = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400")

config.json:   0%|          | 0.00/22.7k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/486M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


In [12]:
class TwoStreamModel(torch.nn.Module):
    def __init__(self, model_rgb, model_flow, input_dim, num_classes):
        super(TwoStreamModel, self).__init__()
        self.model_rgb = model_rgb
        self.model_flow = model_flow
        self.classifier = torch.nn.Linear(input_dim * 2, num_classes)
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, x_rgb, x_flow):
        out_rgb = self.model_rgb(pixel_values=x_rgb)
        out_flow = self.model_flow(pixel_values=x_flow)
        combined = torch.cat((out_rgb.last_hidden_state[:, 0], out_flow.last_hidden_state[:, 0]), dim=1)
        out = self.sigmoid(self.classifier(combined))
        return out

In [13]:
class RGBStreamModel(torch.nn.Module):
    def __init__(self, model_rgb, input_dim, num_classes):
        super(RGBStreamModel, self).__init__()
        self.model_rgb = model_rgb
        self.classifier = torch.nn.Linear(input_dim , num_classes)
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, x_rgb):
        out_rgb = self.model_rgb(pixel_values=x_rgb).last_hidden_state[:, 0]
        out = self.sigmoid(self.classifier(out_rgb))
        return out

In [14]:
# Define model
input_dim = 768
num_classes = 1

model = RGBStreamModel(model_rgb,input_dim,num_classes)

# Training

In [15]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
print("Device : " + str(device))

Device : cuda


In [16]:
# Loss function
loss_fn = torch.nn.BCELoss()

# Optimizer
learning_rate = 5e-5
optimizer = AdamW(model.parameters(), lr=learning_rate)

num_epochs = 3
num_training_steps = num_epochs * len(train_loader)

lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [17]:
progress_bar = tqdm(range(num_training_steps))

model.train()
best_auc = 0

for epoch in range(num_epochs):
    for batch in train_loader:
        x_rgb, labels = batch
        x_rgb = x_rgb.to(device)
        labels = labels.to(device)
        
        outputs = model(x_rgb)
        loss = loss_fn(outputs.view(-1),labels)
        loss.backward()
        
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        
    model.eval()
    all_labels = []
    all_probs = []
    all_loss = 0

    with torch.inference_mode():
        for batch in test_loader:
            x_rgb,labels = batch
            x_rgb = x_rgb.to(device)
            labels = labels.to(device)
            
            outputs = model(x_rgb)
            loss = loss_fn(outputs.view(-1),labels)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(outputs.view(-1).cpu().numpy())
            all_loss += loss

    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    auc = roc_auc_score(all_labels, all_probs)

    if auc > best_auc:
        best_auc = auc
        torch.save(model.state_dict(),'best_model.pth')

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.4f}, Test AUC: {auc:.4f}")

    model.train()

  0%|          | 0/2415 [00:00<?, ?it/s]

Epoch [1/3], Loss: 0.1299, Test AUC: 0.9150
Epoch [2/3], Loss: 0.2205, Test AUC: 0.9336
Epoch [3/3], Loss: 0.0008, Test AUC: 0.9422
