<a href="https://colab.research.google.com/github/noahdrakes/mldl-final/blob/main/mm_violence_det_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-Modal Violence Detection Network

original src code: https://github.com/Roc-Ng/XDVioDet.git

### Copying Training and Testing Data

The folders are pretty large (~40/50GB) so it takes a while to copy all of the data over.


In [1]:
from google.colab import drive
drive.mount('/mydrive')

Mounted at /mydrive


In [2]:
%cd /mydrive/MyDrive

/mydrive/MyDrive


In [5]:
!unzip final_dl.zip -d /content/


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/final_dl/dl_files/i3d-features/Flow/v=vsI_pKNcgeQ__#00-03-00_00-06-00_label_A__3.npy  
  inflating: /content/final_dl/dl_files/i3d-features/Flow/v=vsI_pKNcgeQ__#00-03-00_00-06-00_label_A__1.npy  
  inflating: /content/final_dl/dl_files/i3d-features/Flow/v=vsI_pKNcgeQ__#00-03-00_00-06-00_label_A__4.npy  
  inflating: /content/final_dl/dl_files/i3d-features/Flow/v=vsI_pKNcgeQ__#00-06-00_00-09-00_label_A__2.npy  
  inflating: /content/final_dl/dl_files/i3d-features/Flow/v=vsI_pKNcgeQ__#00-06-00_00-09-00_label_A__4.npy  
  inflating: /content/final_dl/dl_files/i3d-features/Flow/v=vsI_pKNcgeQ__#00-06-00_00-09-00_label_A__1.npy  
  inflating: /content/final_dl/dl_files/i3d-features/Flow/v=vsI_pKNcgeQ__#00-06-00_00-09-00_label_A__0.npy  
  inflating: /content/final_dl/dl_files/i3d-features/Flow/v=vsI_pKNcgeQ__#00-06-00_00-09-00_label_A__3.npy  
  inflating: /content/final_dl/dl_files/i3d-features/Flow/v=vsI

may need to change directory depending on where you upload the data to google drive.

In [None]:
# !cp -r /mydrive/MyDrive/final_dl ./

## 1. Methods


### A) Test


In [None]:
from sklearn.metrics import auc, precision_recall_curve
import numpy as np
import torch

def test(dataloader, model, device, gt):
    with torch.no_grad():
        model.eval()
        pred = torch.zeros(0).to(device)
        pred2 = torch.zeros(0).to(device)
        for i, input in enumerate(dataloader):
            input = input.to(device)
            logits, logits2 = model(inputs=input, seq_len=None)
            logits = torch.squeeze(logits)
            sig = torch.sigmoid(logits)
            sig = torch.mean(sig, 0)
            pred = torch.cat((pred, sig))
            '''
            online detection
            '''
            logits2 = torch.squeeze(logits2)
            sig2 = torch.sigmoid(logits2)
            sig2 = torch.mean(sig2, 0)

            sig2 = torch.unsqueeze(sig2, 1) ##for audio
            pred2 = torch.cat((pred2, sig2))

            # print("pred:, ", pred)
            # print("pred2:, ", pred2)

        pred = list(pred.cpu().detach().numpy())
        pred2 = list(pred2.cpu().detach().numpy())




        precision, recall, th = precision_recall_curve(list(gt), np.repeat(pred, 16))
        pr_auc = auc(recall, precision)
        precision, recall, th = precision_recall_curve(list(gt), np.repeat(pred2, 16))
        pr_auc2 = auc(recall, precision)
        return pr_auc, pr_auc2




### B) Utils

In [None]:
# -*- coding: utf-8 -*-

import numpy as np


def random_extract(feat, t_max):
   r = np.random.randint(len(feat)-t_max)
   return feat[r:r+t_max]

def uniform_extract(feat, t_max):
   r = np.linspace(0, len(feat)-1, t_max, dtype=np.uint16)
   return feat[r, :]

def pad(feat, min_len):
    if np.shape(feat)[0] <= min_len:
       return np.pad(feat, ((0, min_len-np.shape(feat)[0]), (0, 0)), mode='constant', constant_values=0)
    else:
       return feat

def process_feat(feat, length, is_random=True):
    if len(feat) > length:
        if is_random:
            return random_extract(feat, length)
        else:
            return uniform_extract(feat, length)
    else:
        return pad(feat, length)



### C) Dataset

In [None]:
import torch.utils.data as data
import numpy as np

# from utils import process_feat

class Dataset(data.Dataset):
    def __init__(self, args, transform=None, test_mode=False):
        self.modality = args.modality

        if test_mode:
            self.rgb_list_file = args.test_rgb_list
            self.flow_list_file = args.test_flow_list
            self.audio_list_file = args.test_audio_list
        else:
            self.rgb_list_file = args.rgb_list
            self.flow_list_file = args.flow_list
            self.audio_list_file = args.audio_list
        self.max_seqlen = args.max_seqlen
        self.tranform = transform
        self.test_mode = test_mode
        self.normal_flag = '_label_A'
        self._parse_list()

    def _parse_list(self):
        if self.modality == 'AUDIO':
            self.list = list(open(self.audio_list_file))
        elif self.modality == 'RGB':
            self.list = list(open(self.rgb_list_file))
            print("here")
            # print(self.list)
        elif self.modality == 'FLOW':
            self.list = list(open(self.flow_list_file))
        elif self.modality == 'MIX':
            self.list = list(open(self.rgb_list_file))
            self.flow_list = list(open(self.flow_list_file))
        elif self.modality == 'MIX2':
            self.list = list(open(self.rgb_list_file))
            self.audio_list = list(open(self.audio_list_file))
        elif self.modality == 'MIX3':
            self.list = list(open(self.flow_list_file))
            self.audio_list = list(open(self.audio_list_file))
        elif self.modality == 'MIX_ALL':
            self.list = list(open(self.rgb_list_file))
            self.flow_list = list(open(self.flow_list_file))
            self.audio_list = list(open(self.audio_list_file))
        else:
            assert 1 > 2, 'Modality is wrong!'

    def __getitem__(self, index):
        if self.normal_flag in self.list[index]:
            label = 0.0
        else:
            label = 1.0

        if self.modality == 'AUDIO':
            features = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
        elif self.modality == 'RGB':
            features = np.array(np.load(self.list[index].strip('\n')),dtype=np.float32)
        elif self.modality == 'FLOW':
            features = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
        elif self.modality == 'MIX':
            features1 = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
            features2 = np.array(np.load(self.flow_list[index].strip('\n')), dtype=np.float32)
            if features1.shape[0] == features2.shape[0]:
                features = np.concatenate((features1, features2),axis=1)
            else:# because the frames of flow is one less than that of rgb
                features = np.concatenate((features1[:-1], features2), axis=1)
        elif self.modality == 'MIX2':
            features1 = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
            features2 = np.array(np.load(self.audio_list[index//5].strip('\n')), dtype=np.float32)
            if features1.shape[0] == features2.shape[0]:
                features = np.concatenate((features1, features2),axis=1)
            else:# because the frames of flow is one less than that of rgb
                features = np.concatenate((features1[:-1], features2), axis=1)
        elif self.modality == 'MIX3':
            features1 = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
            features2 = np.array(np.load(self.audio_list[index//5].strip('\n')), dtype=np.float32)
            if features1.shape[0] == features2.shape[0]:
                features = np.concatenate((features1, features2),axis=1)
            else:# because the frames of flow is one less than that of rgb
                features = np.concatenate((features1[:-1], features2), axis=1)
        elif self.modality == 'MIX_ALL':
            features1 = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
            features2 = np.array(np.load(self.flow_list[index].strip('\n')), dtype=np.float32)
            features3 = np.array(np.load(self.audio_list[index//5].strip('\n')), dtype=np.float32)
            if features1.shape[0] == features2.shape[0]:
                features = np.concatenate((features1, features2, features3),axis=1)
            else:# because the frames of flow is one less than that of rgb
                features = np.concatenate((features1[:-1], features2, features3[:-1]), axis=1)
        else:
            assert 1>2, 'Modality is wrong!'
        if self.tranform is not None:
            features = self.tranform(features)
        if self.test_mode:
            return features

        else:
            features = process_feat(features, self.max_seqlen, is_random=False)
            return features, label

    def __len__(self):
        return len(self.list)

### D) Layers

In [None]:
from math import sqrt
from torch import FloatTensor
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.spatial.distance import pdist, squareform

class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(in_features, out_features).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
        self.a = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(2*out_features, 1).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj):
        h = torch.mm(input, self.W)
        N = h.size()[0]

        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

class linear(nn.Module):
    def __init__(self, in_features, out_features):
        super(linear, self).__init__()
        self.weight = Parameter(FloatTensor(in_features, out_features))
        self.register_parameter('bias', None)
        stdv = 1. / sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
    def forward(self, x):
        x = x.matmul(self.weight)
        return x

class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=False, residual=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(FloatTensor(in_features, out_features))

        if bias:
            self.bias = Parameter(FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        if not residual:
            self.residual = lambda x: 0
        elif (in_features == out_features):
            self.residual = lambda x: x
        else:
            # self.residual = linear(in_features, out_features)
            self.residual = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=5, padding=2)
    def reset_parameters(self):
        # stdv = 1. / sqrt(self.weight.size(1))
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            self.bias.data.fill_(0.1)

    def forward(self, input, adj):
        # To support batch operations
        support = input.matmul(self.weight)
        output = adj.matmul(support)

        if self.bias is not None:
            output = output + self.bias
        if self.in_features != self.out_features and self.residual:
            input = input.permute(0,2,1)
            res = self.residual(input)
            res = res.permute(0,2,1)
            output = output + res
        else:
            output = output + self.residual(input)

        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

######################################################

class SimilarityAdj(Module):

    def __init__(self, in_features, out_features):
        super(SimilarityAdj, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.weight0 = Parameter(FloatTensor(in_features, out_features))
        self.weight1 = Parameter(FloatTensor(in_features, out_features))
        self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        # stdv = 1. / sqrt(self.weight0.size(1))
        nn.init.xavier_uniform_(self.weight0)
        nn.init.xavier_uniform_(self.weight1)

    def forward(self, input, seq_len):
        # To support batch operations
        soft = nn.Softmax(1)
        theta = torch.matmul(input, self.weight0)
        phi = torch.matmul(input, self.weight0)
        phi2 = phi.permute(0, 2, 1)
        sim_graph = torch.matmul(theta, phi2)

        theta_norm = torch.norm(theta, p=2, dim=2, keepdim=True)  # B*T*1
        phi_norm = torch.norm(phi, p=2, dim=2, keepdim=True)  # B*T*1
        x_norm_x = theta_norm.matmul(phi_norm.permute(0, 2, 1))
        sim_graph = sim_graph / (x_norm_x + 1e-20)

        output = torch.zeros_like(sim_graph)
        if seq_len is None:
            for i in range(sim_graph.shape[0]):
                tmp = sim_graph[i]
                adj2 = tmp
                adj2 = F.threshold(adj2, 0.7, 0)
                adj2 = soft(adj2)
                output[i] = adj2
        else:
            for i in range(len(seq_len)):
                tmp = sim_graph[i, :seq_len[i], :seq_len[i]]
                adj2 = tmp
                adj2 = F.threshold(adj2, 0.7, 0)
                adj2 = soft(adj2)
                output[i, :seq_len[i], :seq_len[i]] = adj2

        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

class DistanceAdj(Module):

    def __init__(self):
        super(DistanceAdj, self).__init__()
        self.sigma = Parameter(FloatTensor(1))
        self.sigma.data.fill_(0.1)

    def forward(self, batch_size, max_seqlen):
        # To support batch operations
        self.arith = np.arange(max_seqlen).reshape(-1, 1)
        dist = pdist(self.arith, metric='cityblock').astype(np.float32)
        self.dist = torch.from_numpy(squareform(dist)).to('cuda')
        self.dist = torch.exp(-self.dist / torch.exp(torch.tensor(1.)))
        self.dist = torch.unsqueeze(self.dist, 0).repeat(batch_size, 1, 1).to('cuda')
        return self.dist

### E) Model

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as torch_init
import os
# from layers import GraphConvolution, SimilarityAdj, DistanceAdj


def weight_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        torch_init.xavier_uniform_(m.weight)
        # m.bias.data.fill_(0.1)

class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()

        n_features = args.feature_size
        n_class = args.num_classes

        self.conv1d1 = nn.Conv1d(in_channels=n_features, out_channels=512, kernel_size=1, padding=0)
        self.conv1d2 = nn.Conv1d(in_channels=512, out_channels=128, kernel_size=1, padding=0)
        self.conv1d3 = nn.Conv1d(in_channels=128, out_channels=32, kernel_size=5, padding=2)
        self.conv1d4 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
        # Graph Convolution
        self.gc1 = GraphConvolution(128, 32, residual=True)  # nn.Linear(128, 32)
        self.gc2 = GraphConvolution(32, 32, residual=True)
        self.gc3 = GraphConvolution(128, 32, residual=True)  # nn.Linear(128, 32)
        self.gc4 = GraphConvolution(32, 32, residual=True)
        self.gc5 = GraphConvolution(128, 32, residual=True)  # nn.Linear(128, 32)
        self.gc6 = GraphConvolution(32, 32, residual=True)
        self.simAdj = SimilarityAdj(n_features, 32)
        self.disAdj = DistanceAdj()

        self.classifier = nn.Linear(32*3, n_class)
        self.approximator = nn.Sequential(nn.Conv1d(128, 64, 1, padding=0), nn.ReLU(),
                                          nn.Conv1d(64, 32, 1, padding=0), nn.ReLU())
        self.conv1d_approximator = nn.Conv1d(32, 1, 5, padding=0)
        self.dropout = nn.Dropout(0.6)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.apply(weight_init)



    def forward(self, inputs, seq_len):
        x = inputs.permute(0, 2, 1)  # for conv1d
        x = self.relu(self.conv1d1(x))
        x = self.dropout(x)
        x = self.relu(self.conv1d2(x))
        x = self.dropout(x)

        logits = self.approximator(x)
        logits = F.pad(logits, (4, 0))
        logits = self.conv1d_approximator(logits)
        logits = logits.permute(0, 2, 1)
        x = x.permute(0, 2, 1)  # b*t*c

        ## gcn
        scoadj = self.sadj(logits.detach(), seq_len)
        adj = self.adj(inputs, seq_len)
        disadj = self.disAdj(x.shape[0], x.shape[1])
        x1_h = self.relu(self.gc1(x, adj))
        x1_h = self.dropout(x1_h)
        x2_h = self.relu(self.gc3(x, disadj))
        x2_h = self.dropout(x2_h)
        x3_h = self.relu(self.gc5(x, scoadj))
        x3_h = self.dropout(x3_h)
        x1 = self.relu(self.gc2(x1_h, adj))
        x1 = self.dropout(x1)
        x2 = self.relu(self.gc4(x2_h, disadj))
        x2 = self.dropout(x2)
        x3 = self.relu(self.gc6(x3_h, scoadj))
        x3 = self.dropout(x3)
        x = torch.cat((x1, x2, x3), 2)
        x = self.classifier(x)
        return x, logits

    def sadj(self, logits, seq_len):
        lens = logits.shape[1]
        soft = nn.Softmax(1)
        logits2 = self.sigmoid(logits).repeat(1, 1, lens)
        tmp = logits2.permute(0, 2, 1)
        adj = 1. - torch.abs(logits2 - tmp)
        self.sig = lambda x:1/(1+torch.exp(-((x-0.5))/0.1))
        adj = self.sig(adj)
        output = torch.zeros_like(adj)
        if seq_len is None:
            for i in range(logits.shape[0]):
                tmp = adj[i]
                adj2 = soft(tmp)
                output[i] = adj2
        else:
            for i in range(len(seq_len)):
                tmp = adj[i, :seq_len[i], :seq_len[i]]
                adj2 = soft(tmp)
                output[i, :seq_len[i], :seq_len[i]] = adj2
        return output


    def adj(self, x, seq_len):
        soft = nn.Softmax(1)
        x2 = x.matmul(x.permute(0,2,1)) # B*T*T
        x_norm = torch.norm(x, p=2, dim=2, keepdim=True)  # B*T*1
        x_norm_x = x_norm.matmul(x_norm.permute(0,2,1))
        x2 = x2/(x_norm_x+1e-20)
        output = torch.zeros_like(x2)
        if seq_len is None:
            for i in range(x.shape[0]):
                tmp = x2[i]
                adj2 = tmp
                adj2 = F.threshold(adj2, 0.7, 0)
                adj2 = soft(adj2)
                output[i] = adj2
        else:
            for i in range(len(seq_len)):
                tmp = x2[i, :seq_len[i], :seq_len[i]]
                adj2 = tmp
                adj2 = F.threshold(adj2, 0.7, 0)
                adj2 = soft(adj2)
                output[i, :seq_len[i], :seq_len[i]] = adj2

        return output



In [None]:
import os
print("Current directory:", os.getcwd())

Current directory: /mydrive/MyDrive


### E) Make List

In [None]:
import os
import glob

root_path = '/content/final_dl/list/xx/test'    ## the path of features
files = sorted(glob.glob(os.path.join(root_path, "*.npy")))

# print(files)

# print(files)
violents = []
normal = []
with open('audio_test.list', 'w+') as f:  ## the name of feature list
    for file in files:
        if '_label_A' in file:
            normal.append(file)
            # print(file)
        else:
            newline = file+'\n'
            f.write(newline)
    for file in normal:
        newline = file+'\n'
        f.write(newline)

print(normal)

## F) Updating the Lists files

The list files are used to specify the paths of all features (audio and visual) for testing and training. The features are just the latent video and audio information for training and testing (VGGish model for audio, i3d model for video)

In [None]:
# Define the input .list file containing the original file paths
input_list_file = "/content/final_dl/list/rgb_test.list"

# Define the directory to update the paths to
new_directory = "/content/final_dl/dl_files/i3d-features/RGBTest"

# Define the output .list file for the updated file paths
output_list_file = "/content/final_dl/list/rgb_test.list"

# Read the original file paths from the .list file
with open(input_list_file, "r") as file:
    original_paths = file.readlines()

# Process and update each file path
updated_paths = []
for path in original_paths:
    path = path.strip()  # Remove any leading/trailing whitespace or newlines
    if path:  # Ensure the path is not empty
        # Extract the filename from the original path and create a new path
        filename = path.split("/")[-1]
        updated_path = f"{new_directory}/{filename}"
        updated_paths.append(updated_path)

# Write the updated paths to the output .list file
with open(output_list_file, "w") as file:
    file.write("\n".join(updated_paths))

print(f"Updated paths have been written to {output_list_file}")


Updated paths have been written to /content/final_dl/list/rgb_test.list


In [None]:
# Define the input .list file containing the original file paths
input_list_file = "/content/final_dl/list/audio_test.list"

# Define the directory to update the paths to
new_directory = "/content/final_dl/list/xx/test/"

# Define the output .list file for the updated file paths
output_list_file = "/content/final_dl/list/audio_test.list"

# Read the original file paths from the .list file
with open(input_list_file, "r") as file:
    original_paths = file.readlines()

# Process and update each file path
updated_paths = []
for path in original_paths:
    path = path.strip()  # Remove any leading/trailing whitespace or newlines
    if path:  # Ensure the path is not empty
        # Extract the filename from the original path and create a new path
        filename = path.split("/")[-1]
        updated_path = f"{new_directory}/{filename}"
        updated_paths.append(updated_path)

# Write the updated paths to the output .list file
with open(output_list_file, "w") as file:
    file.write("\n".join(updated_paths))

print(f"Updated paths have been written to {output_list_file}")


Updated paths have been written to /content/final_dl/list/audio_test.list


## Args

Here are the default args that were obtained via cmd line arg parser. I just created a class 'Args' that holds the default config for the model.

I think the most important args:

*`Modality`*: Determines whether we want to use either audio alone, video alone, both audio and video, audio, video, and flow, etc. for training

*`List`*: point to the list containing filenames for all training and testing data.

*`workers`*: I believe this is the number of individual threads/processes running during training or testing. In ther model it was set to 4 by defualt but that spit out an error so it lowered it to 1. Prob a sign that we need to do heavy downsampling to compensate for lack of parallel processing.

In [None]:
class Args:
  def __init__(self):
      self.modality = 'MIX2'
      self.rgb_list = '/content/final_dl/list/rgb.list'
      self.flow_list = '/content/final_dl/list/flow.list'
      self.audio_list = '/content/final_dl/list/audio.list'
      self.test_rgb_list = '/content/final_dl/list/rgb_test.list'
      self.test_flow_list = '/content/final_dl/list/flow_test.list'
      self.test_audio_list = '/content/final_dl/list/audio_test.list'
      self.gt = '/content/final_dl/list/gt.npy'
      self.gpus = 1
      self.lr = 0.0001
      self.batch_size = 128
      self.workers = 1
      self.model_name = 'wsanodet'
      self.pretrained_ckpt = None
      self.feature_size = 1152  # 1024 + 128
      self.num_classes = 1
      self.dataset_name = 'XD-Violence'
      self.max_seqlen = 200
      self.max_epoch = 50

  # Create an instance of the Args class
args = Args()

## Potential Val Split

In [7]:
import os
import random
import numpy as np

def create_validation_split(train_list_file, val_ratio=0.2, seed=42):
    """
    Creates training and validation splits from a training list file.

    Args:
        train_list_file (str): Path to original training list file
        val_ratio (float): Ratio of validation split (default: 0.2)
        seed (int): Random seed for reproducibility

    Returns:
        train_paths (list): List of training file paths
        val_paths (list): List of validation file paths
    """
    # Set random seed
    random.seed(seed)

    # Read original training list
    with open(train_list_file, 'r') as f:
        paths = f.readlines()
    paths = [p.strip() for p in paths]

    # Group paths by normal/abnormal
    normal_paths = [p for p in paths if '_label_A' in p]
    abnormal_paths = [p for p in paths if '_label_A' not in p]

    # Calculate split sizes for each group
    n_normal_val = int(len(normal_paths) * val_ratio)
    n_abnormal_val = int(len(abnormal_paths) * val_ratio)

    # Randomly sample validation paths
    val_normal = random.sample(normal_paths, n_normal_val)
    val_abnormal = random.sample(abnormal_paths, n_abnormal_val)

    # Create train paths by removing validation paths
    train_normal = [p for p in normal_paths if p not in val_normal]
    train_abnormal = [p for p in abnormal_paths if p not in val_abnormal]

    # Combine normal and abnormal paths
    train_paths = train_normal + train_abnormal
    val_paths = val_normal + val_abnormal

    return train_paths, val_paths

def save_split_files(train_paths, val_paths, original_list_file):
    """
    Saves the train and validation splits to new files.

    Args:
        train_paths (list): List of training file paths
        val_paths (list): List of validation file paths
        original_list_file (str): Path to original list file
    """
    # Create filenames based on original
    base_dir = os.path.dirname(original_list_file)
    base_name = os.path.basename(original_list_file).split('.')[0]

    train_file = os.path.join(base_dir, f'{base_name}_train.list')
    val_file = os.path.join(base_dir, f'{base_name}_val.list')

    # Save train split
    with open(train_file, 'w') as f:
        f.write('\n'.join(train_paths))

    # Save validation split
    with open(val_file, 'w') as f:
        f.write('\n'.join(val_paths))

    return train_file, val_file

class Dataset(data.Dataset):
    def __init__(self, args, transform=None, mode='train'):
        """
        Args:
            args: Arguments containing dataset paths and configuration
            transform: Optional transforms to apply
            mode: One of ['train', 'val', 'test'] to specify the dataset split
        """
        self.modality = args.modality

        # Get appropriate list files based on mode
        if mode == 'test':
            self.rgb_list_file = args.test_rgb_list
            self.flow_list_file = args.test_flow_list
            self.audio_list_file = args.test_audio_list
        elif mode == 'val':
            self.rgb_list_file = args.val_rgb_list
            self.flow_list_file = args.val_flow_list
            self.audio_list_file = args.val_audio_list
        else: # train
            self.rgb_list_file = args.train_rgb_list
            self.flow_list_file = args.train_flow_list
            self.audio_list_file = args.train_audio_list

        self.max_seqlen = args.max_seqlen
        self.transform = transform
        self.test_mode = (mode == 'test')
        self.normal_flag = '_label_A'
        self._parse_list()

    def _parse_list(self):
        if self.modality == 'AUDIO':
            self.list = list(open(self.audio_list_file))
        elif self.modality == 'RGB':
            self.list = list(open(self.rgb_list_file))
            print("here")
            # print(self.list)
        elif self.modality == 'FLOW':
            self.list = list(open(self.flow_list_file))
        elif self.modality == 'MIX':
            self.list = list(open(self.rgb_list_file))
            self.flow_list = list(open(self.flow_list_file))
        elif self.modality == 'MIX2':
            self.list = list(open(self.rgb_list_file))
            self.audio_list = list(open(self.audio_list_file))
        elif self.modality == 'MIX3':
            self.list = list(open(self.flow_list_file))
            self.audio_list = list(open(self.audio_list_file))
        elif self.modality == 'MIX_ALL':
            self.list = list(open(self.rgb_list_file))
            self.flow_list = list(open(self.flow_list_file))
            self.audio_list = list(open(self.audio_list_file))
        else:
            assert 1 > 2, 'Modality is wrong!'

    def __getitem__(self, index):
        if self.normal_flag in self.list[index]:
            label = 0.0
        else:
            label = 1.0

        if self.modality == 'AUDIO':
            features = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
        elif self.modality == 'RGB':
            features = np.array(np.load(self.list[index].strip('\n')),dtype=np.float32)
        elif self.modality == 'FLOW':
            features = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
        elif self.modality == 'MIX':
            features1 = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
            features2 = np.array(np.load(self.flow_list[index].strip('\n')), dtype=np.float32)
            if features1.shape[0] == features2.shape[0]:
                features = np.concatenate((features1, features2),axis=1)
            else:# because the frames of flow is one less than that of rgb
                features = np.concatenate((features1[:-1], features2), axis=1)
        elif self.modality == 'MIX2':
            features1 = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
            features2 = np.array(np.load(self.audio_list[index//5].strip('\n')), dtype=np.float32)
            if features1.shape[0] == features2.shape[0]:
                features = np.concatenate((features1, features2),axis=1)
            else:# because the frames of flow is one less than that of rgb
                features = np.concatenate((features1[:-1], features2), axis=1)
        elif self.modality == 'MIX3':
            features1 = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
            features2 = np.array(np.load(self.audio_list[index//5].strip('\n')), dtype=np.float32)
            if features1.shape[0] == features2.shape[0]:
                features = np.concatenate((features1, features2),axis=1)
            else:# because the frames of flow is one less than that of rgb
                features = np.concatenate((features1[:-1], features2), axis=1)
        elif self.modality == 'MIX_ALL':
            features1 = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
            features2 = np.array(np.load(self.flow_list[index].strip('\n')), dtype=np.float32)
            features3 = np.array(np.load(self.audio_list[index//5].strip('\n')), dtype=np.float32)
            if features1.shape[0] == features2.shape[0]:
                features = np.concatenate((features1, features2, features3),axis=1)
            else:# because the frames of flow is one less than that of rgb
                features = np.concatenate((features1[:-1], features2, features3[:-1]), axis=1)
        else:
            assert 1>2, 'Modality is wrong!'
        if self.tranform is not None:
            features = self.tranform(features)
        if self.test_mode:
            return features

        else:
            features = process_feat(features, self.max_seqlen, is_random=False)
            return features, label

    def __len__(self):
        return len(self.list)

class Args:
    def __init__(self):
        # Existing attributes
        self.rgb_list = '/content/final_dl/list/rgb.list'
        self.flow_list = '/content/final_dl/list/flow.list'
        self.audio_list = '/content/final_dl/list/audio.list'

        # Add validation list paths
        self.train_rgb_list = '/content/final_dl/list/rgb_train.list'
        self.train_flow_list = '/content/final_dl/list/flow_train.list'
        self.train_audio_list = '/content/final_dl/list/audio_train.list'

        self.val_rgb_list = '/content/final_dl/list/rgb_val.list'
        self.val_flow_list = '/content/final_dl/list/flow_val.list'
        self.val_audio_list = '/content/final_dl/list/audio_val.list'
        self.gt = '/content/final_dl/list/gt.npy'
        self.gpus = 1
        self.lr = 0.0001
        self.batch_size = 128
        self.workers = 1
        self.model_name = 'wsanodet'
        self.pretrained_ckpt = None
        self.feature_size = 1152  # 1024 + 128
        self.num_classes = 1
        self.dataset_name = 'XD-Violence'
        self.max_seqlen = 200
        self.max_epoch = 50

  # Create an instance of the Args class
args = Args()


for modality in ['rgb', 'flow', 'audio']:
    # Get original training list file
    list_file = getattr(args, f'{modality}_list')

    # Create splits
    train_paths, val_paths = create_validation_split(list_file)

    # Save splits to new files
    train_file, val_file = save_split_files(train_paths, val_paths, list_file)

    # Update args with new file paths
    setattr(args, f'train_{modality}_list', train_file)
    setattr(args, f'val_{modality}_list', val_file)


train_loader = DataLoader(Dataset(args, mode='train'),
                         batch_size=args.batch_size,
                         shuffle=True,
                         num_workers=args.workers,
                         pin_memory=True)

val_loader = DataLoader(Dataset(args, mode='val'),
                       batch_size=args.batch_size,
                       shuffle=False,
                       num_workers=args.workers,
                       pin_memory=True)

test_loader = DataLoader(Dataset(args, mode='test'),
                        batch_size=5,
                        shuffle=False,
                        num_workers=args.workers,
                        pin_memory=True)

## Testing PreTrained Model

In [None]:
from torch.utils.data import DataLoader
import torch
import numpy as np
# from model import Model
# from dataset import Dataset
# from test import test
# import option
import time

if __name__ == '__main__':

  device = torch.device("cuda")

  test_loader = DataLoader(Dataset(args, test_mode=True),
                            batch_size=5, shuffle=False,
                            num_workers=args.workers, pin_memory=True)
  model = Model(args)
  model = model.to(device)

  model_dict = model.load_state_dict(
      {k.replace('module.', ''): v for k, v in torch.load('final_dl/wsanodet_mix2.pkl').items()})

  gt = np.load(args.gt)
  st = time.time()
  pr_auc, pr_auc_online = test(test_loader, model, device, gt)
  print('Time:{}'.format(time.time()-st))
  print('offline pr_auc:{0:.4}; online pr_auc:{1:.4}\n'.format(pr_auc, pr_auc_online))


  {k.replace('module.', ''): v for k, v in torch.load('final_dl/wsanodet_mix2.pkl').items()})


Time:12.470504760742188
offline pr_auc:0.79; online pr_auc:0.7433



how to save a model for the future.

In [None]:
# torch.save(model.state_dict(), "/content/test.pkl")

# Training HLNET


In [None]:
import torch


def CLAS(logits, label, seq_len, criterion, device, is_topk=True):
    logits = logits.squeeze()
    instance_logits = torch.zeros(0).to(device)  # tensor([])
    for i in range(logits.shape[0]):
        if is_topk:
            tmp, _ = torch.topk(logits[i][:seq_len[i]], k=int(seq_len[i]//16+1), largest=True)
            tmp = torch.mean(tmp).view(1)
        else:
            tmp = torch.mean(logits[i, :seq_len[i]]).view(1)
        instance_logits = torch.cat((instance_logits, tmp))

    instance_logits = torch.sigmoid(instance_logits)

    clsloss = criterion(instance_logits, label)
    return clsloss


def CENTROPY(logits, logits2, seq_len, device):
    instance_logits = torch.tensor(0).to(device)  # tensor([])
    for i in range(logits.shape[0]):
        tmp1 = torch.sigmoid(logits[i, :seq_len[i]]).squeeze()
        tmp2 = torch.sigmoid(logits2[i, :seq_len[i]]).squeeze()
        loss = torch.mean(-tmp1.detach() * torch.log(tmp2))
        instance_logits = instance_logits + loss
    instance_logits = instance_logits/logits.shape[0]
    return instance_logits


def train(dataloader, model, optimizer, criterion, device, is_topk):
    with torch.set_grad_enabled(True):
        model.train()
        for i, (input, label) in enumerate(dataloader):
            seq_len = torch.sum(torch.max(torch.abs(input), dim=2)[0]>0, 1)
            input = input[:, :torch.max(seq_len), :]
            input, label = input.float().to(device), label.float().to(device)
            logits, logits2 = model(input, seq_len)
            clsloss = CLAS(logits, label, seq_len, criterion, device, is_topk)
            clsloss2 = CLAS(logits2, label, seq_len, criterion, device, is_topk)
            croloss = CENTROPY(logits, logits2, seq_len, device)

            total_loss = clsloss + clsloss2 + 5*croloss
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

In [None]:
# Define the input .list file containing the original file paths
input_list_file = "/content/final_dl/list/audio.list"

# Define the directory to update the paths to
new_directory = "/content/final_dl/list/xx/train"

# Define the output .list file for the updated file paths
output_list_file = "/content/final_dl/list/audio.list"

# Read the original file paths from the .list file
with open(input_list_file, "r") as file:
    original_paths = file.readlines()

# Process and update each file path
updated_paths = []
for path in original_paths:
    path = path.strip()  # Remove any leading/trailing whitespace or newlines
    if path:  # Ensure the path is not empty
        # Extract the filename from the original path and create a new path
        filename = path.split("/")[-1]
        updated_path = f"{new_directory}/{filename}"
        updated_paths.append(updated_path)

# Write the updated paths to the output .list file
with open(output_list_file, "w") as file:
    file.write("\n".join(updated_paths))

print(f"Updated paths have been written to {output_list_file}")

Updated paths have been written to /content/final_dl/list/audio.list


In [None]:
# Define the input .list file containing the original file paths
input_list_file = "/content/final_dl/list/rgb.list"

# Define the directory to update the paths to
new_directory = "/content/final_dl/dl_files/i3d-features/RGB"

# Define the output .list file for the updated file paths
output_list_file = "/content/final_dl/list/rgb.list"

# Read the original file paths from the .list file
with open(input_list_file, "r") as file:
    original_paths = file.readlines()

# Process and update each file path
updated_paths = []
for path in original_paths:
    path = path.strip()  # Remove any leading/trailing whitespace or newlines
    if path:  # Ensure the path is not empty
        # Extract the filename from the original path and create a new path
        filename = path.split("/")[-1]
        updated_path = f"{new_directory}/{filename}"
        updated_paths.append(updated_path)

# Write the updated paths to the output .list file
with open(output_list_file, "w") as file:
    file.write("\n".join(updated_paths))

print(f"Updated paths have been written to {output_list_file}")

Updated paths have been written to /content/final_dl/list/rgb.list


## test (ignore)

In [None]:
class Args:
  def __init__(self):
      self.modality = 'MIX2'
      self.rgb_list = '/content/final_dl/list/rgb.list'
      self.flow_list = '/content/final_dl/list/flow.list'
      self.audio_list = '/content/final_dl/list/audio.list'
      self.test_rgb_list = '/content/final_dl/list/rgb_test.list'
      self.test_flow_list = '/content/final_dl/list/flow_test.list'
      self.test_audio_list = '/content/final_dl/list/audio_test.list'
      self.gt = '/content/final_dl/list/gt.npy'
      self.gpus = 1
      self.lr = 0.0001
      self.batch_size = 128
      self.workers = 1
      self.model_name = 'wsanodet'
      self.pretrained_ckpt = None
      self.feature_size = 1152  # 1024 + 128
      self.num_classes = 1
      self.dataset_name = 'XD-Violence'
      self.max_seqlen = 200
      self.max_epoch = 50

  # Create an instance of the Args class
args = Args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(args)
model = model.cuda()

test_loader = DataLoader(Dataset(args, test_mode=False),
                          batch_size=5, shuffle=True,
                          num_workers=args.workers, pin_memory=True)

with torch.no_grad():
  for i, (input,label) in enumerate(test_loader):
    input = input.to(device)

    print(input.shape)
    ############
    ### NOTE: ## setting seq_len to None pads training data in the sequence dim to 200
    ############
    logits, logits2 = model(inputs=input, seq_len=None)
    # print(logits, logits2)
    if i == 2:
      break

torch.Size([5, 200, 1152])
torch.Size([5, 200, 1152])
torch.Size([5, 200, 1152])


## Training HL NET

In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim
import torch
import time
import numpy as np
import random
import os
# import option


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


# torch.multiprocessing.set_start_method('spawn')
# setup_seed(2333)
# args = option.parser.parse_args()

!export TORCH_USE_CUDA_DSA=ON
device = torch.device("cuda")
train_loader = DataLoader(Dataset(args, test_mode=False),
                          batch_size=args.batch_size, shuffle=True,
                          num_workers=args.workers, pin_memory=True)
test_loader = DataLoader(Dataset(args, test_mode=True),
                          batch_size=5, shuffle=False,
                          num_workers=args.workers, pin_memory=True)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(args)
model = model.cuda()

for name, value in model.named_parameters():
    print(name)
approximator_param = list(map(id, model.approximator.parameters()))
approximator_param += list(map(id, model.conv1d_approximator.parameters()))
base_param = filter(lambda p: id(p) not in approximator_param, model.parameters())

if not os.path.exists('./ckpt'):
    os.makedirs('./ckpt')
optimizer = optim.Adam([{'params': base_param},
                        {'params': model.approximator.parameters(), 'lr': args.lr / 2},
                        {'params': model.conv1d_approximator.parameters(), 'lr': args.lr / 2},
                        ],
                        lr=args.lr, weight_decay=0.000)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1)
criterion = torch.nn.BCELoss()

is_topk = True
gt = np.load(args.gt)
pr_auc, pr_auc_online = test(test_loader, model, device, gt)
print('Random initalization: offline pr_auc:{0:.4}; online pr_auc:{1:.4}\n'.format(pr_auc, pr_auc_online))
for epoch in range(args.max_epoch):
    scheduler.step()
    st = time.time()
    train(train_loader, model, optimizer, criterion, device, is_topk)
    if epoch % 2 == 0 and not epoch == 0:
        torch.save(model.state_dict(), './ckpt/'+args.model_name+'{}.pkl'.format(epoch))

    pr_auc, pr_auc_online = test(test_loader, model, device, gt)
    print('Epoch {0}/{1}: offline pr_auc:{2:.4}; online pr_auc:{3:.4}\n'.format(epoch, args.max_epoch, pr_auc, pr_auc_online))
torch.save(model.state_dict(), './ckpt/' + args.model_name + '.pkl')

# Training VAE



## VAE MODEL

In [None]:
import torch
from torch import nn

class Sampling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, z_means, z_log_vars):
        epsilon = torch.randn_like(z_means, dtype=torch.float32)
        return z_means + torch.exp(0.5 * z_log_vars) * epsilon


class Encoder(nn.Module):
    def __init__(self, latent_dim, input_dim=1024, seq_len=200):
        super().__init__()
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Conv1d(input_dim, 512, kernel_size=3, stride=2, padding=1),  # (B, 1024, 200) -> (B, 512, 100)
            nn.ReLU(True),
            nn.Conv1d(512, 256, kernel_size=3, stride=2, padding=1),         # (B, 512, 100) -> (B, 256, 50)
            nn.ReLU(True),
            nn.Conv1d(256, 128, kernel_size=3, stride=2, padding=1),         # (B, 256, 50) -> (B, 128, 25)
            nn.ReLU(True),
            nn.Flatten()  # Flatten for fully connected layers
        )

        flattened_dim = 25 * 128  # Calculate flattened dimension
        self.lin_mean = nn.Linear(flattened_dim, latent_dim)
        self.lin_log_var = nn.Linear(flattened_dim, latent_dim)
        self.sampling = Sampling()

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.encoder(x)
        z_means = self.lin_mean(x)
        z_log_vars = self.lin_log_var(x)
        z = self.sampling(z_means, z_log_vars)
        return z, z_means, z_log_vars


class Decoder(nn.Module):
    def __init__(self, latent_dim, input_dim=1024, seq_len=200):
        super().__init__()
        self.seq_len = seq_len
        flattened_dim = 25 * 128  # Must match Encoder's flattened_dim (200)

        self.decoder_fc = nn.Sequential(
            nn.Linear(latent_dim, flattened_dim),
            nn.ReLU(True)
        )

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose1d(128, 256, kernel_size=3, stride=2, padding=1, output_padding=1),  # (B, 128, 25) -> (B, 256, 50)
            nn.ReLU(True),
            nn.ConvTranspose1d(256, 512, kernel_size=3, stride=2, padding=1, output_padding=1),  # (B, 256, 50) -> (B, 512, 100)
            nn.ReLU(True),
            nn.ConvTranspose1d(512, input_dim, kernel_size=3, stride=2, padding=1, output_padding=1),  # (B, 512, 100) -> (B, 1024, 200)
            nn.Sigmoid()
        )

    def forward(self, x):

        x = self.decoder_fc(x)  # Fully connected layer
        x = x.view(-1, 128, 25)  # Reshape to match ConvTranspose1D input
        x = self.decoder_conv(x)
        return x


class VAE(nn.Module):
    def __init__(self, latent_dim, input_dim=1024, seq_len=200):
        super().__init__()
        self.encoder = Encoder(latent_dim, input_dim, seq_len)
        self.decoder = Decoder(latent_dim, input_dim, seq_len)

    def forward(self, x):
        z, z_means, z_log_vars = self.encoder(x)
        x_reconstructed = self.decoder(z)
        x_reconstructed = x_reconstructed.view(-1, 200, 1024)
        return x_reconstructed, z_means, z_log_vars

In [None]:
class Args:
  def __init__(self):
      self.modality = 'RGB'
      self.rgb_list = '/content/final_dl/list/rgb.list'
      self.flow_list = '/content/final_dl/list/flow.list'
      self.audio_list = '/content/final_dl/list/audio.list'
      self.test_rgb_list = '/content/final_dl/list/rgb_test.list'
      self.test_flow_list = '/content/final_dl/list/flow_test.list'
      self.test_audio_list = '/content/final_dl/list/audio_test.list'
      self.gt = '/content/final_dl/list/gt.npy'
      self.gpus = 1
      self.lr = 0.0001
      self.batch_size = 1
      self.workers = 1
      self.model_name = 'wsanodet'
      self.pretrained_ckpt = None
      # self.feature_size = 1152  # 1024 + 128
      self.feature_size = 1024
      self.num_classes = 1
      self.dataset_name = 'XD-Violence'
      self.max_seqlen = 200
      self.max_epoch = 50

  # Create an instance of the Args class
args = Args()



## VAE Train FN

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = DataLoader(Dataset(args, test_mode=False),
                          batch_size=args.batch_size, shuffle=True,
                          num_workers=args.workers, pin_memory=True)


# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(args)
model = model.cuda()

with torch.set_grad_enabled(True):
  model.train()
  for i, (input, label) in enumerate(train_loader):

    input, label = input.float().to(torch.float32), label.float().to(device)

    # input and label from train dataset
    print("shape of input: ", input.shape)
    print("shape of label: ", label.shape)

    input = input.to(device)

    # size of one sample from batch
    test_input = input[:1][::][::]

    # permuting input shape to fit vae cnn layers
    # test_input = test_input.permute(0, 2, 1)
    # print("test_input shape: ", test_input.shape)

    # passing through vae
    vae = VAE(latent_dim=128, input_dim=1024, seq_len=200)
    vae = vae.to(device)

    x_reconstructed, z_means, z_log_vars = vae(test_input)

    print("reconstructed input shape (after vae): ", x_reconstructed.shape)

    ## break statement just to read out this input/output shapes in the first iteration
    if i == 0:
      break

here
shape of input:  torch.Size([1, 200, 1024])
shape of label:  torch.Size([1])
reconstructed input shape (after vae):  torch.Size([1, 200, 1024])


## VAE Training func (Gus)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

def vae_train_loop(vae, train_loader, optimizer, device, num_epochs=50):
    vae.train()
    for epoch in range(num_epochs):
        total_loss = 0
        recon_loss = 0
        kl_loss = 0

        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()
            recon_data, mu, logvar = vae(data)
            recon_criterion = nn.MSELoss(reduction='sum')
            r_loss = recon_criterion(recon_data, data)
            kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss = r_loss + kl
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            recon_loss += r_loss.item()
            kl_loss += kl.item()

        avg_loss = total_loss / len(train_loader.dataset)
        avg_recon = recon_loss / len(train_loader.dataset)
        avg_kl = kl_loss / len(train_loader.dataset)

        print(f'Epoch {epoch}: Loss = {avg_loss:.4f}, Recon = {avg_recon:.4f}, KL = {avg_kl:.4f}')

def extract_vae_features(vae, data_loader, device):
    """
    Extract latent features using trained VAE
    """
    vae.eval()
    features = []
    labels = []

    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            # Get latent representation
            z, _, _ = vae.encoder(data)
            features.append(z.cpu())
            labels.append(label)

    return torch.cat(features), torch.cat(labels)

class FeatureDataset(torch.utils.data.Dataset):
    """Dataset for VAE extracted features that maintains original binary labels"""
    def __init__(self, features, labels):
        self.features = features  # VAE latent features
        self.labels = labels      # Original binary labels (0.0 or 1.0)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        feature = self.features[idx]
        label = torch.tensor(float(self.labels[idx]))
        return feature, label

## Integrated Training pipeline

In [None]:


def train_pipeline(vae, hlnet, train_loader, args, device):
    """
    Integrate the hlnet and vae trainings to form a pipeline
    the vae is trained first, and then features are extracted and used to create a new dataset to train the hlnet
    """

    # Train VAE
    vae_optimizer = optim.Adam(vae.parameters(), lr=1e-3)
    print("Training VAE...")
    vae_train_loop(vae, train_loader, vae_optimizer, device)

    # Extract features using trained VAE
    print("Extracting VAE features...")
    features, labels = extract_vae_features(vae, train_loader, device)

    # Create new dataset with VAE features
    feature_dataset = FeatureDataset(features, labels)
    feature_loader = DataLoader(
        feature_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers
    )

    # Train HL-Net
    print("Training HL-Net...")
    hlnet_optimizer = optim.Adam([
        {'params': hlnet.parameters()},
    ], lr=args.lr)

    criterion = nn.BCELoss()

    for epoch in range(args.max_epoch):
        train(feature_loader, hlnet, hlnet_optimizer, criterion, device, True)

        # Validation could be added here

        if epoch % 5 == 0:
            print(f"Completed epoch {epoch}")

    return vae, hlnet


## example of run

In [None]:
vae = VAE(latent_dim=128, input_dim=1024, seq_len=200).to(device)
hlnet = Model(args).to(device)
trained_vae, trained_hlnet = train_pipeline(vae, hlnet, train_loader, args, device)