<a href="https://colab.research.google.com/github/noahdrakes/mldl-final/blob/main/mm_violence_det_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-Modal Violence Detection Network

original src code: https://github.com/Roc-Ng/XDVioDet.git

Data can be downloaded here: https://roc-ng.github.io/XD-Violence/

### Copying Training and Testing Data

The folders are pretty large (~40/50GB) so it takes a while to copy all of the data over.

Data can be downloaded here: https://roc-ng.github.io/XD-Violence/ under V1.0 Features. Then upload the data to your Google Drive.

In [1]:
from google.colab import drive
drive.mount('/mydrive', force_remount=True)

Mounted at /mydrive


In [2]:
%cd /mydrive/MyDrive

/mydrive/MyDrive


In [None]:
!unzip final_dl.zip -d /content/


may need to change directory depending on where you upload the data to google drive.

In [None]:
# !cp -r /mydrive/MyDrive/final_dl ./

# Pre-Flight

Some utils, classes, helpers, data splitters to run before selecting a workflow

### Utils

In [13]:
# -*- coding: utf-8 -*-

import numpy as np


def random_extract(feat, t_max):
   r = np.random.randint(len(feat)-t_max)
   return feat[r:r+t_max]

def uniform_extract(feat, t_max):
   r = np.linspace(0, len(feat)-1, t_max, dtype=np.uint16)
   return feat[r, :]

def pad(feat, min_len):
    if np.shape(feat)[0] <= min_len:
       return np.pad(feat, ((0, min_len-np.shape(feat)[0]), (0, 0)), mode='constant', constant_values=0)
    else:
       return feat

def process_feat(feat, length, is_random=True):
    if len(feat) > length:
        if is_random:
            return random_extract(feat, length)
        else:
            return uniform_extract(feat, length)
    else:
        return pad(feat, length)



## Args

Here are the default args that were obtained via cmd line arg parser. I just created a class 'Args' that holds the default config for the model.

I think the most important args:

*`Modality`*: Determines whether we want to use either audio alone, video alone, both audio and video, audio, video, and flow, etc. for training

*`List`*: point to the list containing filenames for all training and testing data.

*`workers`*: I believe this is the number of individual threads/processes running during training or testing. In ther model it was set to 4 by defualt but that spit out an error so it lowered it to 1. Prob a sign that we need to do heavy downsampling to compensate for lack of parallel processing.

In [22]:
class Args:
    def __init__(self):
        self.modality = 'MIX2'
        # Original paths
        self.rgb_list = '/content/final_dl/list/rgb.list'
        self.flow_list = '/content/final_dl/list/flow.list'
        self.audio_list = '/content/final_dl/list/audio.list'

        # Train paths
        self.train_rgb_list = '/content/final_dl/list/rgb_train.list'
        self.train_flow_list = '/content/final_dl/list/flow_train.list'
        self.train_audio_list = '/content/final_dl/list/audio_train.list'

        # Val paths
        self.val_rgb_list = '/content/final_dl/list/rgb_val.list'
        self.val_flow_list = '/content/final_dl/list/flow_val.list'
        self.val_audio_list = '/content/final_dl/list/audio_val.list'

        # Test paths
        self.test_rgb_list = '/content/final_dl/list/rgb_test.list'
        self.test_flow_list = '/content/final_dl/list/flow_test.list'
        self.test_audio_list = '/content/final_dl/list/audio_test.list'

        self.gt = '/content/final_dl/list/gt.npy'
        self.gpus = 1
        self.lr = 0.0001
        self.batch_size = 128
        self.workers = 1  # Reduced from 4 to avoid memory issues
        self.model_name = 'wsanodet'
        self.pretrained_ckpt = None
        self.feature_size = 1152  # 1024 + 128
        self.num_classes = 1
        self.dataset_name = 'XD-Violence'
        self.max_seqlen = 200
        self.max_epoch = 50

args = Args()

## Dataset

class for building a dataset

In [39]:
import torch.utils.data as data
import numpy as np
class Dataset(data.Dataset):
    def __init__(self, args, transform=None, mode='train'):
        self.modality = args.modality
        self.max_seqlen = args.max_seqlen
        self.transform = transform
        self.test_mode = (mode == 'test')

        # Set appropriate file lists based on mode
        if mode == 'test':
            self.rgb_list_file = args.test_rgb_list
            self.flow_list_file = args.test_flow_list
            self.audio_list_file = args.test_audio_list
        elif mode == 'val':
            self.rgb_list_file = args.val_rgb_list
            self.flow_list_file = args.val_flow_list
            self.audio_list_file = args.val_audio_list
        else:  # train
            self.rgb_list_file = args.train_rgb_list
            self.flow_list_file = args.train_flow_list
            self.audio_list_file = args.train_audio_list

        self._parse_list()

    def _parse_list(self):
        """Parse file lists - assumes lists are already properly aligned"""
        if self.modality == 'MIX2':
            self.list = [line.strip() for line in open(self.rgb_list_file)]
            self.audio_list = [line.strip() for line in open(self.audio_list_file)]
        elif self.modality == 'AUDIO':
            self.list = [line.strip() for line in open(self.audio_list_file)]
        elif self.modality == 'RGB':
            self.list = [line.strip() for line in open(self.rgb_list_file)]
        elif self.modality == 'FLOW':
            self.list = [line.strip() for line in open(self.flow_list_file)]

    def __getitem__(self, index):
        if self.modality in ['RGB', 'FLOW', 'AUDIO']:
            file_path = self.list[index].strip()
            features = np.array(np.load(file_path), dtype=np.float32)
            label = 0.0 if '_label_A' in file_path else 1.0
        elif self.modality == 'MIX2':
            # Load RGB features
            file_path1 = self.list[index].strip()
            features1 = np.array(np.load(file_path1), dtype=np.float32)
            label = 0.0 if '_label_A' in file_path1 else 1.0

            # Load corresponding audio features
            audio_index = index // 5
            file_path2 = self.audio_list[audio_index].strip()
            features2 = np.array(np.load(file_path2), dtype=np.float32)

            features = np.concatenate((features1, features2), axis=1)

        if self.transform is not None:
            features = self.transform(features)

        features = process_feat(features, self.max_seqlen, is_random=not self.test_mode)
        return features, label

    def __len__(self):
        return len(self.list)


## Data Splitting (Test, Train, Val)

paths are hard-coded and can be adjusted as needed

In [40]:
import os
import glob
import random
from pathlib import Path
import numpy as np
def get_video_id(filepath):
    """Extract consistent video ID for both RGB and audio files"""
    filename = os.path.basename(filepath)

    # Handle RGB files (end with __[0-4].npy)
    if filename.endswith(('.npy')):
        base = filename.rsplit('__', 1)[0]  # Split from the right on __ to remove the number or vggish
        return base

    # Fallback - just remove extension
    return filename.split('.')[0]

def create_splits(aligned_files, train_ratio=0.8, seed=42):
    """Split the video IDs first, then we'll expand to files in write_list_files"""
    random.seed(seed)
    video_ids = list(aligned_files.keys())
    train_size = int(len(video_ids) * train_ratio)
    train_ids = random.sample(video_ids, train_size)
    val_ids = [vid for vid in video_ids if vid not in train_ids]
    return {
        'train': train_ids,
        'val': val_ids
    }
def write_list_files(split_data, aligned_files, outputdir="/content/final_dl/list"):
    """Write list files with audio files repeated to match RGB structure"""
    #os.makedirs(outputdir, exist_ok=True)
    for split_name, video_ids in split_data.items():
        # RGB list - one entry per frame
        rgb_path = os.path.join(outputdir, f'rgb_{split_name}.list')
        with open(rgb_path, 'w') as f:
            for vid_id in video_ids:
                for rgb_file in aligned_files[vid_id]['rgb']:
                    f.write(f"{rgb_file}\n")
        # Audio list - one entry per video (not per frame)
        audio_path = os.path.join(outputdir, f'audio_{split_name}.list')
        with open(audio_path, 'w') as f:
            for vid_id in video_ids:
                audio_file = aligned_files[vid_id]['audio']
                f.write(f"{audio_file}\n")

def find_matching_files():
    """
    Find and align RGB and audio feature files.
    Only includes files that actually exist in the filesystem.
    Returns dict mapping video IDs to their RGB and audio paths
    """
    rgb_path = "/content/final_dl/dl_files/i3d-features/RGB"
    audio_path = "/content/final_dl/list/xx/train"

    # Get all files
    rgb_files = [f for f in glob.glob(os.path.join(rgb_path, "*.npy")) if os.path.exists(f)]
    audio_files = [f for f in glob.glob(os.path.join(audio_path, "*.npy")) if os.path.exists(f)]

    # Create mappings that preserve the 5:1 ratio
    rgb_map = {}
    for f in rgb_files:
        vid_id = get_video_id(f)
        if vid_id not in rgb_map:
            rgb_map[vid_id] = []
        rgb_map[vid_id].append(f)
    audio_map = {}
    for f in audio_files:
        if os.path.exists(f):
            vid_id = get_video_id(f)
            audio_map[vid_id] = f

    # Find common video IDs and verify each RGB group has exactly 5 files
    common_ids = set(rgb_map.keys()) & set(audio_map.keys())
    complete_ids = {vid_id for vid_id in common_ids if len(rgb_map[vid_id]) == 5}

    # Create aligned mapping only for complete groups
    aligned_files = {
        vid_id: {
            'rgb': sorted(rgb_map[vid_id]),  # Sort to maintain consistent ordering
            'audio': audio_map[vid_id],
            'is_normal': '_label_A' in rgb_map[vid_id][0]  # Check first RGB file for label
        }
        for vid_id in complete_ids
    }

    return aligned_files

In [41]:
# Run this:
aligned_files = find_matching_files()

# Create train/val splits
split_data = create_splits(aligned_files)

# Write list files
write_list_files(split_data, aligned_files)

### Fix Test Files

There are already {mode}_test.list files in the list dir. We need to update these paths because they are hardcoded to a destination on the original authors' machine. Run these next few lines just once:

In [42]:
%cd /content/final_dl/list
# Uncommnet these and run ONLY ONCE
#!sed -i 's|/media/peng/Samsung_T5|/content/final_dl/dl_files|g' rgb_test.list
#!sed -i 's|/media/peng/Samsung_T5/vggish-features|/content/final_dl/list/xx|g' audio_test.list
# Verify paths
!head -n 5 /content/final_dl/list/rgb_test.list
!head -n 5 /content/final_dl/list/audio_test.list
%cd /mydrive/MyDrive/

/content/final_dl/list
/content/final_dl/dl_files/i3d-features/RGBTest/Bad.Boys.1995__#01-11-55_01-12-40_label_G-B2-B6__0.npy
/content/final_dl/dl_files/i3d-features/RGBTest/Bad.Boys.1995__#01-11-55_01-12-40_label_G-B2-B6__1.npy
/content/final_dl/dl_files/i3d-features/RGBTest/Bad.Boys.1995__#01-11-55_01-12-40_label_G-B2-B6__2.npy
/content/final_dl/dl_files/i3d-features/RGBTest/Bad.Boys.1995__#01-11-55_01-12-40_label_G-B2-B6__3.npy
/content/final_dl/dl_files/i3d-features/RGBTest/Bad.Boys.1995__#01-11-55_01-12-40_label_G-B2-B6__4.npy
/content/final_dl/list/xx/test/Bad.Boys.1995__#01-11-55_01-12-40_label_G-B2-B6__vggish.npy
/content/final_dl/list/xx/test/Bad.Boys.1995__#01-33-51_01-34-37_label_B2-0-0__vggish.npy
/content/final_dl/list/xx/test/Bad.Boys.II.2003__#00-06-42_00-10-00_label_B2-G-0__vggish.npy
/content/final_dl/list/xx/test/Black.Hawk.Down.2001__#01-13-59_01-14-49_label_B2-0-0__vggish.npy
/content/final_dl/list/xx/test/Black.Hawk.Down.2001__#01-32-40_01-34-00_label_B4-0-0__vggis

### Create Dataloaders for multimodal

In [43]:
from torch.utils.data import DataLoader
def create_data_loaders(args):
    """
    Create train, validation and test data loaders
    """
    print("Creating data loaders...")

    # Create train loader
    train_dataset = Dataset(args, mode='train')
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Train loader created with {len(train_dataset)} samples")

    # Create validation loader
    val_dataset = Dataset(args, mode='val')
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,  # No need to shuffle validation data
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Validation loader created with {len(val_dataset)} samples")

    # Create test loader with smaller batch size as per original code
    test_dataset = Dataset(args, mode='test')
    test_loader = DataLoader(
        test_dataset,
        batch_size=5,  # Using smaller batch size for testing
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Test loader created with {len(test_dataset)} samples")

    return train_loader, val_loader, test_loader

args = Args()
train_loader, val_loader, test_loader = create_data_loaders(args)

Creating data loaders...
Train loader created with 15815 samples
Validation loader created with 3955 samples
Test loader created with 4000 samples


### VAL SPLIT FOR SINGLE MODALITY

In [44]:
def create_single_modality_data_loaders(args, modality='AUDIO'):
    """
    Create train, validation and test data loaders for a single modality
    """
    print(f"Creating {modality} data loaders...")

    # Create new args with only needed attributes
    args_new = Args()
    args_new.modality = modality

    # List files needed for train/val/test splits
    if modality == 'AUDIO':
        args_new.train_audio_list = args.train_audio_list
        args_new.val_audio_list = args.val_audio_list
        args_new.test_audio_list = args.test_audio_list
    elif modality == 'RGB':
        args_new.train_rgb_list = args.train_rgb_list
        args_new.val_rgb_list = args.val_rgb_list
        args_new.test_rgb_list = args.test_rgb_list
    elif modality == 'FLOW':
        args_new.train_flow_list = args.train_flow_list
        args_new.val_flow_list = args.val_flow_list
        args_new.test_flow_list = args.test_flow_list

    # Create data loaders
    train_dataset = Dataset(args_new, mode='train')
    train_loader = DataLoader(
        train_dataset,
        batch_size=args_new.batch_size,
        shuffle=True,
        num_workers=args_new.workers,
        pin_memory=True
    )
    print(f"Train loader created with {len(train_dataset)} samples")

    val_dataset = Dataset(args_new, mode='val')
    val_loader = DataLoader(
        val_dataset,
        batch_size=args_new.batch_size,
        shuffle=False,
        num_workers=args_new.workers,
        pin_memory=True
    )
    print(f"Validation loader created with {len(val_dataset)} samples")

    test_dataset = Dataset(args_new, mode='test')
    test_loader = DataLoader(
        test_dataset,
        batch_size=5,
        shuffle=False,
        num_workers=args_new.workers,
        pin_memory=True
    )
    print(f"Test loader created with {len(test_dataset)} samples")

    return train_loader, val_loader, test_loader

In [None]:
# For audio only
#train_loader, val_loader, test_loader = create_single_modality_data_loaders(args, modality='AUDIO')

# For RGB only
#train_loader, val_loader, test_loader = create_single_modality_data_loaders(args, modality='RGB')

# For flow only
## CURRENTLY, FLOW IS NOT SUPPORTED, BUT IMPLEMENTING IT WOULD NOT BE THAT CHALLENGING. YOU WOULD SIMPLY HAVE TO
## ADJUST THE CODE A FEW CELLS ABOVE SO TO WRITE AN EQUIVALENT OF "find_matching_files" FOR FLOW DATA
#train_loader, val_loader, test_loader = create_single_modality_data_loaders(args, modality='FLOW')

# Training VAE
Trains the multimodal VAE and single-modality VAEs


## VAE MODEL

In [51]:
import torch
from torch import nn

class Sampling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, z_means, z_log_vars):
        epsilon = torch.randn_like(z_means, dtype=torch.float32)
        return z_means + torch.exp(0.5 * z_log_vars) * epsilon

class Encoder(nn.Module):
    def __init__(self, latent_dim, input_dim=1152, seq_len=200):
        super().__init__()
        self.latent_dim = latent_dim

        # Reduced number of feature maps in encoder
        self.encoder = nn.Sequential(
            nn.Conv1d(input_dim, 256, kernel_size=3, stride=2, padding=1),  # Reduced from 576
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Conv1d(256, 128, kernel_size=3, stride=2, padding=1),  # Reduced from 288
            nn.BatchNorm1d(128),
            nn.ReLU(True),
            nn.Conv1d(128, 64, kernel_size=3, stride=2, padding=1),  # Reduced from 144
            nn.BatchNorm1d(64),
            nn.ReLU(True),
            nn.Flatten()
        )

        flattened_dim = 64 * 25  # Updated based on reduced features

        self.lin_mean = nn.Sequential(
            nn.Linear(flattened_dim, latent_dim),
            nn.BatchNorm1d(latent_dim)
        )

        self.lin_log_var = nn.Sequential(
            nn.Linear(flattened_dim, latent_dim),
            nn.BatchNorm1d(latent_dim)
        )

        self.sampling = Sampling()

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.encoder(x)
        z_means = self.lin_mean(x)
        z_log_vars = self.lin_log_var(x)
        z = self.sampling(z_means, z_log_vars)
        return z, z_means, z_log_vars

class Decoder(nn.Module):
    def __init__(self, latent_dim, input_dim=1152, seq_len=200):
        super().__init__()
        self.seq_len = seq_len
        flattened_dim = 64 * 25  # Updated based on reduced features

        self.decoder_fc = nn.Sequential(
            nn.Linear(latent_dim, flattened_dim),
            nn.BatchNorm1d(flattened_dim),
            nn.ReLU(True)
        )

        # Reduced number of feature maps in decoder
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose1d(64, 128, kernel_size=3, stride=2, padding=1, output_padding=1),  # Reduced from 144->288
            nn.BatchNorm1d(128),
            nn.ReLU(True),
            nn.ConvTranspose1d(128, 256, kernel_size=3, stride=2, padding=1, output_padding=1),  # Reduced from 288->576
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.ConvTranspose1d(256, input_dim, kernel_size=3, stride=2, padding=1, output_padding=1),  # Reduced from 576->input_dim
            nn.BatchNorm1d(input_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.decoder_fc(x)
        x = x.view(-1, 64, 25)  # Updated based on reduced features
        x = self.decoder_conv(x)
        x = x.permute(0, 2, 1)
        return x

class VAE(nn.Module):
    def __init__(self, latent_dim, input_dim=1152, seq_len=200):
        super().__init__()
        self.encoder = Encoder(latent_dim, input_dim, seq_len)
        self.decoder = Decoder(latent_dim, input_dim, seq_len)

    def forward(self, x):
        z, z_means, z_log_vars = self.encoder(x)
        x_reconstructed = self.decoder(z)
        return x_reconstructed, z_means, z_log_vars

## Reimplement Dataset Class for Normal Samples

The VAE is pre-trained on only normal (non-anomalous) samples

In [52]:
import torch.utils.data as data
import numpy as np
import os
import glob
import random
from pathlib import Path

class NormalDataset(data.Dataset):
    def __init__(self, args, transform=None, mode='train'):
        self.modality = args.modality
        self.normal_flag = '_label_A'
        self.max_seqlen = args.max_seqlen
        self.transform = transform
        self.test_mode = (mode == 'test')

        # Set appropriate file lists based on mode
        if mode == 'test':
            self.rgb_list_file = args.test_rgb_list
            self.flow_list_file = args.test_flow_list
            self.audio_list_file = args.test_audio_list
        elif mode == 'val':
            self.rgb_list_file = args.val_rgb_list
            self.flow_list_file = args.val_flow_list
            self.audio_list_file = args.val_audio_list
        else:  # train
            self.rgb_list_file = args.train_rgb_list
            self.flow_list_file = args.train_flow_list
            self.audio_list_file = args.train_audio_list

        self._parse_list()

    def _parse_list(self):
        """Parse file lists and filter for normal samples only"""
        def filter_normal_samples(file_list):
            return [f for f in file_list if self.normal_flag in f]

        if self.modality == 'AUDIO':
            self.list = filter_normal_samples(list(open(self.audio_list_file)))
        elif self.modality == 'RGB':
            self.list = filter_normal_samples(list(open(self.rgb_list_file)))
        elif self.modality == 'FLOW':
            self.list = filter_normal_samples(list(open(self.flow_list_file)))
        elif self.modality == 'MIX2':
            # For MIX2, we need to handle the 5:1 ratio between RGB and audio
            self.list = filter_normal_samples(list(open(self.rgb_list_file)))
            # Filter audio list and ensure alignment
            all_audio = list(open(self.audio_list_file))
            self.audio_list = [f for f in all_audio if self.normal_flag in f]

            # Ensure RGB and audio lists are aligned (5:1 ratio)
            rgb_video_ids = set([self._get_video_id(f) for f in self.list])
            audio_video_ids = set([self._get_video_id(f) for f in self.audio_list])
            common_ids = rgb_video_ids & audio_video_ids

            # Filter lists to only include common videos
            self.list = [f for f in self.list if self._get_video_id(f) in common_ids]
            self.audio_list = [f for f in self.audio_list if self._get_video_id(f) in common_ids]

    def _get_video_id(self, filepath):
        """Extract video ID from filepath"""
        filename = os.path.basename(filepath.strip('\n'))
        return filename.split('_label')[0]

    def __getitem__(self, index):
        if self.modality in ['RGB', 'FLOW', 'AUDIO']:
            features = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
        elif self.modality == 'MIX2':
            # Load RGB features
            features1 = np.array(np.load(self.list[index].strip('\n')), dtype=np.float32)
            # Load corresponding audio features (accounting for 5:1 ratio)
            audio_index = index // 5
            features2 = np.array(np.load(self.audio_list[audio_index].strip('\n')), dtype=np.float32)

            # Handle potential dimension mismatch
            if features1.shape[0] > features2.shape[0]:
                features1 = features1[:features2.shape[0]]
            features = np.concatenate((features1, features2), axis=1)

        if self.transform is not None:
            features = self.transform(features)

        features = process_feat(features, self.max_seqlen, is_random=not self.test_mode)

        # Always return label 0 since these are normal samples
        return features, 0.0

    def __len__(self):
        return len(self.list)

def create_normal_data_loaders(args):
    """Create data loaders for normal samples only"""
    print("Creating normal-only data loaders...")

    # Create train loader
    train_dataset = NormalDataset(args, mode='train')
    train_loader = data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Normal train loader created with {len(train_dataset)} samples")

    # Create validation loader
    val_dataset = NormalDataset(args, mode='val')
    val_loader = data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Normal validation loader created with {len(val_dataset)} samples")

    # Create test loader
    test_dataset = NormalDataset(args, mode='test')
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Normal test loader created with {len(test_dataset)} samples")

    return train_loader, val_loader, test_loader

def process_feat(feat, length, is_random=True):
    """Process features to have consistent length"""
    if len(feat) > length:
        if is_random:
            r = np.random.randint(len(feat) - length)
            return feat[r:r + length]
        else:
            r = np.linspace(0, len(feat) - 1, length, dtype=np.uint16)
            return feat[r, :]
    else:
        return np.pad(feat, ((0, length - len(feat)), (0, 0)), mode='constant', constant_values=0)


## VAE Pretraining


In [53]:
import torch
from datetime import datetime
import os

class EarlyStopping:
    """Early stopping to prevent overfitting"""
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
        return self.early_stop

def validate_vae(vae, val_loader, device):
    """Run validation loop and return average loss"""
    vae.eval()
    total_loss = 0
    total_recon_loss = 0
    total_kl_loss = 0
    n_samples = 0

    with torch.no_grad():
        for data, labels in val_loader:
            # Only process normal samples (label == 0)
            normal_mask = (labels == 0.0)
            if not normal_mask.any():
                continue

            data = data[normal_mask].to(device)
            recon_data, mu, logvar = vae(data)

            # Reconstruction loss
            recon_criterion = torch.nn.MSELoss(reduction='sum')
            recon_loss = recon_criterion(recon_data, data)

            # KL divergence loss
            kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

            # Total loss
            loss = recon_loss + kl_loss

            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_kl_loss += kl_loss.item()
            n_samples += data.size(0)

    # Calculate averages
    if n_samples > 0:
        avg_loss = total_loss / n_samples
        avg_recon = total_recon_loss / n_samples
        avg_kl = total_kl_loss / n_samples
    else:
        avg_loss = float('inf')
        avg_recon = float('inf')
        avg_kl = float('inf')

    vae.train()
    return avg_loss, avg_recon, avg_kl

def train_vae(vae, train_loader, val_loader, args, save_dir='vae_checkpoints'):
    """Main training loop for VAE"""

    # Create directory for saving checkpoints
    os.makedirs(save_dir, exist_ok=True)

    # Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae = vae.to(device)
    optimizer = torch.optim.Adam(vae.parameters(), lr=args.lr)
    early_stopping = EarlyStopping(patience=5)

    history = {
        'train_loss': [],
        'train_recon': [],
        'train_kl': [],
        'val_loss': [],
        'val_recon': [],
        'val_kl': []
    }

    # Training loop
    best_val_loss = float('inf')
    for epoch in range(args.max_epoch):
        # Training
        vae.train()
        train_loss = 0
        train_recon = 0
        train_kl = 0
        n_samples = 0

        for batch_idx, (data, labels) in enumerate(train_loader):
            # Only process normal samples (label == 0)
            normal_mask = (labels == 0.0)
            if not normal_mask.any():
                continue

            data = data[normal_mask].to(device)
            optimizer.zero_grad()

            # Forward pass
            recon_data, mu, logvar = vae(data)

            # Losses
            recon_criterion = torch.nn.MSELoss(reduction='sum')
            recon_loss = recon_criterion(recon_data, data)
            kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss = recon_loss + kl_loss

            # Backward pass
            loss.backward()
            optimizer.step()

            # Record losses
            train_loss += loss.item()
            train_recon += recon_loss.item()
            train_kl += kl_loss.item()
            n_samples += data.size(0)

        # Calculate average training losses
        if n_samples > 0:
            avg_train_loss = train_loss / n_samples
            avg_train_recon = train_recon / n_samples
            avg_train_kl = train_kl / n_samples
        else:
            print("Warning: No normal samples in training batch")
            continue

        # Validation
        val_loss, val_recon, val_kl = validate_vae(vae, val_loader, device)

        # Print progress
        print(f'Epoch {epoch+1}/{args.max_epoch}:')
        print(f'Training - Loss: {avg_train_loss:.4f}, Recon: {avg_train_recon:.4f}, KL: {avg_train_kl:.4f}')
        print(f'Validation - Loss: {val_loss:.4f}, Recon: {val_recon:.4f}, KL: {val_kl:.4f}\n')

        history['train_loss'].append(avg_train_loss)
        history['train_recon'].append(avg_train_recon)
        history['train_kl'].append(avg_train_kl)
        history['val_loss'].append(val_loss)
        history['val_recon'].append(val_recon)
        history['val_kl'].append(val_kl)

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            save_path = os.path.join(save_dir, f'vae_{args.modality}_best_{timestamp}.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': vae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': val_loss,
                'history': history
            }, save_path)
            print(f'Saved best model to {save_path}')
            torch.save(vae.state_dict(), f"/content/best_trained_vae_{args.modality}.pkl")


        # Early stopping
        if early_stopping(val_loss):
            print("Early stopping triggered")
            break

    return vae

### Run for pretraining multimodal VAE

In [None]:
# adjust args needed for VAE
args_vae = Args()
args_vae.feature_size = 1152  # 1024 (RGB) + 128 (audio)
args_vae.batch_size = 64
args_vae.modality = 'MIX2'
args_vae.max_epoch = 500
args_vae.lr = 0.0005

# initialize VAE with correct input dimension
vae = VAE(latent_dim=64, input_dim=args_vae.feature_size, seq_len=200)

# Create normal-only dataloaders
normal_train_loader, normal_val_loader, normal_test_loader = create_normal_data_loaders(args_vae)

# Train the VAE
trained_vae = train_vae(vae, normal_train_loader, normal_val_loader, args_vae)
# Save the model (adjust path name for modality)
#torch.save(trained_vae.state_dict(), "/content/last_trained_vae.pkl")

In [None]:
#%cp '/content/best_trained_vae.pkl' ./vae_checkpoints/

### Run for pretraining VAE on RGB single modality data

In [None]:
# adjust args needed for VAE
args_vae = Args()
args_vae.feature_size = 1024  # 1024 (RGB) + 128 (audio)
args_vae.batch_size = 64
args_vae.modality = 'RGB'
args_vae.max_epoch = 500
args_vae.lr = 0.0005

# initialize VAE with correct input dimension
vae = VAE(latent_dim=64, input_dim=args_vae.feature_size, seq_len=200)

# Create normal-only dataloaders
normal_train_loader, normal_val_loader, normal_test_loader = create_normal_data_loaders(args_vae)

# Train the VAE
trained_vae = train_vae(vae, normal_train_loader, normal_val_loader, args_vae)
# Save the model
#torch.save(trained_vae.state_dict(), "./last_trained_vae.pkl")

### Run for pretraining VAE on RGB single modality data

In [None]:
# adjust args needed for VAE
args_vae = Args()
args_vae.feature_size = 128  # 1024 (RGB) + 128 (audio)
args_vae.batch_size = 64
args_vae.modality = 'AUDIO'
args_vae.max_epoch = 500
args_vae.lr = 0.0005

# initialize VAE with correct input dimension
vae = VAE(latent_dim=64, input_dim=args_vae.feature_size, seq_len=200)

# Create normal-only dataloaders
normal_train_loader, normal_val_loader, normal_test_loader = create_normal_data_loaders(args_vae)

# Train the VAE
trained_vae = train_vae(vae, normal_train_loader, normal_val_loader, args_vae)
# Save the model
#torch.save(trained_vae.state_dict(), "./last_trained_vae.pkl")

# Training HL-Net and HL-Net + VAE

Training just the HL-Net on different modalities and then we also will load the pre-trained VAE and training the HL-net receiving reconstruction signal from the pre-trained VAE

## Pre-flight

Classes, model, layers, utils for running the HL-Net

In [71]:
from math import sqrt
from torch import FloatTensor
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.spatial.distance import pdist, squareform
import torch.nn.init as torch_init
import os

def weight_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        torch_init.xavier_uniform_(m.weight)
        # m.bias.data.fill_(0.1)

class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()

        n_features = args.feature_size
        n_class = args.num_classes

        self.conv1d1 = nn.Conv1d(in_channels=n_features, out_channels=512, kernel_size=1, padding=0)
        self.conv1d2 = nn.Conv1d(in_channels=512, out_channels=128, kernel_size=1, padding=0)
        self.conv1d3 = nn.Conv1d(in_channels=128, out_channels=32, kernel_size=5, padding=2)
        self.conv1d4 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
        # Graph Convolution
        self.gc1 = GraphConvolution(128, 32, residual=True)  # nn.Linear(128, 32)
        self.gc2 = GraphConvolution(32, 32, residual=True)
        self.gc3 = GraphConvolution(128, 32, residual=True)  # nn.Linear(128, 32)
        self.gc4 = GraphConvolution(32, 32, residual=True)
        self.gc5 = GraphConvolution(128, 32, residual=True)  # nn.Linear(128, 32)
        self.gc6 = GraphConvolution(32, 32, residual=True)
        self.simAdj = SimilarityAdj(n_features, 32)
        self.disAdj = DistanceAdj()

        self.classifier = nn.Linear(32*3, n_class)
        self.approximator = nn.Sequential(nn.Conv1d(128, 64, 1, padding=0), nn.ReLU(),
                                          nn.Conv1d(64, 32, 1, padding=0), nn.ReLU())
        self.conv1d_approximator = nn.Conv1d(32, 1, 5, padding=0)
        self.dropout = nn.Dropout(0.6)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.apply(weight_init)



    def forward(self, inputs, seq_len):
        x = inputs.permute(0, 2, 1)  # for conv1d
        x = self.relu(self.conv1d1(x))
        x = self.dropout(x)
        x = self.relu(self.conv1d2(x))
        x = self.dropout(x)

        logits = self.approximator(x)
        logits = F.pad(logits, (4, 0))
        logits = self.conv1d_approximator(logits)
        logits = logits.permute(0, 2, 1)
        x = x.permute(0, 2, 1)  # b*t*c

        ## gcn
        scoadj = self.sadj(logits.detach(), seq_len)
        adj = self.adj(inputs, seq_len)
        disadj = self.disAdj(x.shape[0], x.shape[1])
        x1_h = self.relu(self.gc1(x, adj))
        x1_h = self.dropout(x1_h)
        x2_h = self.relu(self.gc3(x, disadj))
        x2_h = self.dropout(x2_h)
        x3_h = self.relu(self.gc5(x, scoadj))
        x3_h = self.dropout(x3_h)
        x1 = self.relu(self.gc2(x1_h, adj))
        x1 = self.dropout(x1)
        x2 = self.relu(self.gc4(x2_h, disadj))
        x2 = self.dropout(x2)
        x3 = self.relu(self.gc6(x3_h, scoadj))
        x3 = self.dropout(x3)
        x = torch.cat((x1, x2, x3), 2)
        x = self.classifier(x)
        return x, logits

    def sadj(self, logits, seq_len):
        lens = logits.shape[1]
        soft = nn.Softmax(1)
        logits2 = self.sigmoid(logits).repeat(1, 1, lens)
        tmp = logits2.permute(0, 2, 1)
        adj = 1. - torch.abs(logits2 - tmp)
        self.sig = lambda x:1/(1+torch.exp(-((x-0.5))/0.1))
        adj = self.sig(adj)
        output = torch.zeros_like(adj)
        if seq_len is None:
            for i in range(logits.shape[0]):
                tmp = adj[i]
                adj2 = soft(tmp)
                output[i] = adj2
        else:
            for i in range(len(seq_len)):
                tmp = adj[i, :seq_len[i], :seq_len[i]]
                adj2 = soft(tmp)
                output[i, :seq_len[i], :seq_len[i]] = adj2
        return output


    def adj(self, x, seq_len):
        soft = nn.Softmax(1)
        x2 = x.matmul(x.permute(0,2,1)) # B*T*T
        x_norm = torch.norm(x, p=2, dim=2, keepdim=True)  # B*T*1
        x_norm_x = x_norm.matmul(x_norm.permute(0,2,1))
        x2 = x2/(x_norm_x+1e-20)
        output = torch.zeros_like(x2)
        if seq_len is None:
            for i in range(x.shape[0]):
                tmp = x2[i]
                adj2 = tmp
                adj2 = F.threshold(adj2, 0.7, 0)
                adj2 = soft(adj2)
                output[i] = adj2
        else:
            for i in range(len(seq_len)):
                tmp = x2[i, :seq_len[i], :seq_len[i]]
                adj2 = tmp
                adj2 = F.threshold(adj2, 0.7, 0)
                adj2 = soft(adj2)
                output[i, :seq_len[i], :seq_len[i]] = adj2

        return output

class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(in_features, out_features).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
        self.a = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(2*out_features, 1).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj):
        h = torch.mm(input, self.W)
        N = h.size()[0]

        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

class linear(nn.Module):
    def __init__(self, in_features, out_features):
        super(linear, self).__init__()
        self.weight = Parameter(FloatTensor(in_features, out_features))
        self.register_parameter('bias', None)
        stdv = 1. / sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
    def forward(self, x):
        x = x.matmul(self.weight)
        return x

class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=False, residual=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(FloatTensor(in_features, out_features))

        if bias:
            self.bias = Parameter(FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        if not residual:
            self.residual = lambda x: 0
        elif (in_features == out_features):
            self.residual = lambda x: x
        else:
            # self.residual = linear(in_features, out_features)
            self.residual = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=5, padding=2)
    def reset_parameters(self):
        # stdv = 1. / sqrt(self.weight.size(1))
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            self.bias.data.fill_(0.1)

    def forward(self, input, adj):
        # To support batch operations
        support = input.matmul(self.weight)
        output = adj.matmul(support)

        if self.bias is not None:
            output = output + self.bias
        if self.in_features != self.out_features and self.residual:
            input = input.permute(0,2,1)
            res = self.residual(input)
            res = res.permute(0,2,1)
            output = output + res
        else:
            output = output + self.residual(input)

        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

######################################################

class SimilarityAdj(Module):

    def __init__(self, in_features, out_features):
        super(SimilarityAdj, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.weight0 = Parameter(FloatTensor(in_features, out_features))
        self.weight1 = Parameter(FloatTensor(in_features, out_features))
        self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        # stdv = 1. / sqrt(self.weight0.size(1))
        nn.init.xavier_uniform_(self.weight0)
        nn.init.xavier_uniform_(self.weight1)

    def forward(self, input, seq_len):
        # To support batch operations
        soft = nn.Softmax(1)
        theta = torch.matmul(input, self.weight0)
        phi = torch.matmul(input, self.weight0)
        phi2 = phi.permute(0, 2, 1)
        sim_graph = torch.matmul(theta, phi2)

        theta_norm = torch.norm(theta, p=2, dim=2, keepdim=True)  # B*T*1
        phi_norm = torch.norm(phi, p=2, dim=2, keepdim=True)  # B*T*1
        x_norm_x = theta_norm.matmul(phi_norm.permute(0, 2, 1))
        sim_graph = sim_graph / (x_norm_x + 1e-20)

        output = torch.zeros_like(sim_graph)
        if seq_len is None:
            for i in range(sim_graph.shape[0]):
                tmp = sim_graph[i]
                adj2 = tmp
                adj2 = F.threshold(adj2, 0.7, 0)
                adj2 = soft(adj2)
                output[i] = adj2
        else:
            for i in range(len(seq_len)):
                tmp = sim_graph[i, :seq_len[i], :seq_len[i]]
                adj2 = tmp
                adj2 = F.threshold(adj2, 0.7, 0)
                adj2 = soft(adj2)
                output[i, :seq_len[i], :seq_len[i]] = adj2

        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

class DistanceAdj(Module):

    def __init__(self):
        super(DistanceAdj, self).__init__()
        self.sigma = Parameter(FloatTensor(1))
        self.sigma.data.fill_(0.1)

    def forward(self, batch_size, max_seqlen):
        # To support batch operations
        self.arith = np.arange(max_seqlen).reshape(-1, 1)
        dist = pdist(self.arith, metric='cityblock').astype(np.float32)
        self.dist = torch.from_numpy(squareform(dist)).to('cuda')
        self.dist = torch.exp(-self.dist / torch.exp(torch.tensor(1.)))
        self.dist = torch.unsqueeze(self.dist, 0).repeat(batch_size, 1, 1).to('cuda')
        return self.dist

def CLAS(logits, label, seq_len, criterion, device, is_topk=True):
    logits = logits.squeeze()
    instance_logits = torch.zeros(0).to(device)  # tensor([])
    for i in range(logits.shape[0]):
        if is_topk:
            tmp, _ = torch.topk(logits[i][:seq_len[i]], k=int(seq_len[i]//16+1), largest=True)
            tmp = torch.mean(tmp).view(1)
        else:
            tmp = torch.mean(logits[i, :seq_len[i]]).view(1)
        instance_logits = torch.cat((instance_logits, tmp))

    instance_logits = torch.sigmoid(instance_logits)

    clsloss = criterion(instance_logits, label)
    return clsloss


def CENTROPY(logits, logits2, seq_len, device):
    instance_logits = torch.tensor(0).to(device)  # tensor([])
    for i in range(logits.shape[0]):
        tmp1 = torch.sigmoid(logits[i, :seq_len[i]]).squeeze()
        tmp2 = torch.sigmoid(logits2[i, :seq_len[i]]).squeeze()
        loss = torch.mean(-tmp1.detach() * torch.log(tmp2))
        instance_logits = instance_logits + loss
    instance_logits = instance_logits/logits.shape[0]
    return instance_logits

def create_data_loaders(args):
    """
    Create train, validation and test data loaders
    """
    print("Creating data loaders...")

    # Create train loader
    train_dataset = Dataset(args, mode='train')
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Train loader created with {len(train_dataset)} samples")

    # Create validation loader
    val_dataset = Dataset(args, mode='val')
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,  # No need to shuffle validation data
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Validation loader created with {len(val_dataset)} samples")

    # Create test loader with smaller batch size as per original code
    test_dataset = Dataset(args, mode='test')
    test_loader = DataLoader(
        test_dataset,
        batch_size=5,  # Using smaller batch size for testing
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Test loader created with {len(test_dataset)} samples")

    return train_loader, val_loader, test_loader



### Plotting
this is called a lot down below, you can adjust the plot commands here

In [None]:
def plot_training_history(train_losses, val_losses, save_path='hl-vae-training_history.png'):
    """
    Plot training and validation losses over epochs.

    Args:
        train_losses (list): List of training losses per epoch
        val_losses (list): List of validation losses per epoch
        save_path (str): Path to save the plot
    """
    plt.figure(figsize=(12, 8))
    epochs = range(len(train_losses))

    plt.plot(epochs, train_losses, 'b-o', label='Training Loss')
    plt.plot(epochs, val_losses, 'r-o', label='Validation Loss')

    plt.title('')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.legend()

    # Adjust layout to prevent label cutoff
    plt.tight_layout()

    # Save the plot
    plt.savefig(save_path)
    plt.close()

## HL-Net Training Multimodal (with no VAE)

In [58]:
import torch
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import os
import glob
import random
from pathlib import Path
from torch import nn
from datetime import datetime
from sklearn.metrics import auc, precision_recall_curve
import time
import argparse
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

class EarlyStopping:
    """Early stopping to prevent overfitting"""
    def __init__(self, patience=7, min_delta=0, verbose=False):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_state_dict = None

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_state_dict = model.state_dict()
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_state_dict = model.state_dict()
            self.counter = 0

    def get_best_model_state(self):
        return self.best_state_dict

def validate_epoch(val_loader, model, criterion, device, is_topk):
    """Run validation for one epoch"""
    model.eval()
    total_loss = 0.0
    batch_count = 0

    with torch.no_grad():
        for input, label in val_loader:
            seq_len = torch.sum(torch.max(torch.abs(input), dim=2)[0]>0, 1)
            input = input[:, :torch.max(seq_len), :]
            input, label = input.float().to(device), label.float().to(device)

            logits, logits2 = model(input, seq_len)
            clsloss = CLAS(logits, label, seq_len, criterion, device, is_topk)
            clsloss2 = CLAS(logits2, label, seq_len, criterion, device, is_topk)
            croloss = CENTROPY(logits, logits2, seq_len, device)

            total_loss += (clsloss + clsloss2 + 5*croloss).item()
            batch_count += 1

    return total_loss / batch_count

def test_hlnet(dataloader, model, device):
    """Test function that evaluates video-level predictions"""
    print("Starting test...")
    with torch.no_grad():
        model.eval()
        pred = []
        pred2 = []
        gt = []

        for i, (input, label) in enumerate(dataloader):
            gt.append(label[0].item())
            input = input.to(device)
            logits, logits2 = model(inputs=input, seq_len=None)

            # Get predictions for offline model
            logits = torch.squeeze(logits)
            sig = torch.sigmoid(logits)
            batch_pred = torch.mean(sig).item()
            pred.append(batch_pred)

            # Get predictions for online model
            logits2 = torch.squeeze(logits2)
            sig2 = torch.sigmoid(logits2)
            batch_pred2 = torch.mean(sig2).item()
            pred2.append(batch_pred2)

        # Convert to numpy arrays
        gt = np.array(gt)
        pred = np.array(pred)
        pred2 = np.array(pred2)
        precision, recall, _ = precision_recall_curve(gt, pred)
        pr_auc = auc(recall, precision)

        precision2, recall2, _ = precision_recall_curve(gt, pred2)
        pr_auc2 = auc(recall2, precision2)

        return pr_auc, pr_auc2

def train_hlnet(train_loader, model, optimizer, scheduler, criterion,
                device, is_topk, val_loader=None):
    """Training function with loss tracking"""
    model.train()
    epoch_loss = 0.0
    batch_count = 0

    for i, (input, label) in enumerate(train_loader):
        seq_len = torch.sum(torch.max(torch.abs(input), dim=2)[0]>0, 1)
        input = input[:, :torch.max(seq_len), :]
        input, label = input.float().to(device), label.float().to(device)

        logits, logits2 = model(input, seq_len)
        clsloss = CLAS(logits, label, seq_len, criterion, device, is_topk)
        clsloss2 = CLAS(logits2, label, seq_len, criterion, device, is_topk)
        croloss = CENTROPY(logits, logits2, seq_len, device)

        total_loss = clsloss + clsloss2 + 5*croloss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        epoch_loss += total_loss.item()
        batch_count += 1

        if (i + 1) % 100 == 0:
            print(f"Step {i+1}: Training Loss: {total_loss.item():.4f}")

    avg_epoch_loss = epoch_loss / batch_count

    # Calculate validation loss if provided
    val_epoch_loss = None
    if val_loader is not None:
        val_epoch_loss = validate_epoch(val_loader, model, criterion, device, is_topk)

    return model, avg_epoch_loss, val_epoch_loss


In [None]:
# Initialize settings
args = Args()
args.max_epoch = 200
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create model and move to device
model = Model(args).to(device)

# Setup optimizer
approximator_param = list(map(id, model.approximator.parameters()))
approximator_param += list(map(id, model.conv1d_approximator.parameters()))
base_param = filter(lambda p: id(p) not in approximator_param, model.parameters())

optimizer = optim.Adam([
    {'params': base_param},
    {'params': model.approximator.parameters(), 'lr': args.lr / 2},
    {'params': model.conv1d_approximator.parameters(), 'lr': args.lr / 2},
], lr=args.lr, weight_decay=0.000)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1)
criterion = torch.nn.BCELoss()
is_topk = True

# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(args)

# Initialize early stopping
early_stopping = EarlyStopping(patience=7, verbose=True)

# Training tracking
train_losses = []
val_losses = []
best_pr_auc = 0
best_epoch = 0

print(f"Starting training on {device}")

for epoch in range(args.max_epoch):
    print(f"\nEpoch {epoch+1}/{args.max_epoch}")
    st = time.time()

    # Training step
    model, train_loss, val_loss = train_hlnet(
        train_loader=train_loader,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        device=device,
        is_topk=is_topk,
        val_loader=val_loader
    )

    # Store losses
    train_losses.append(train_loss)
    if val_loss is not None:
        val_losses.append(val_loss)

    pr_auc, pr_auc_online = test_hlnet(test_loader, model, device)

    if pr_auc > best_pr_auc:
        best_pr_auc = pr_auc
        best_epoch = epoch

    print(f'Epoch {epoch+1}/{args.max_epoch}:')
    print(f'Train Loss: {train_loss:.4f}')
    if val_loss is not None:
        print(f'Validation Loss: {val_loss:.4f}')
    print(f'PR-AUC (Offline/Online): {pr_auc:.4f}/{pr_auc_online:.4f}')
    print(f'Best PR-AUC: {best_pr_auc:.4f} (Epoch {best_epoch+1})')
    print(f'Epoch time: {time.time() - st:.2f}s')

    # Early stopping check
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        model.load_state_dict(early_stopping.get_best_model_state())
        break

    scheduler.step()

    # Save checkpoint every 5 epochs
    if epoch % 5 == 0 and epoch > 0:
        save_dir = './only_hlnet_saves_mm/checkpoints'
        os.makedirs(save_dir, exist_ok=True)
        torch.save(model.state_dict(), f'{save_dir}/{args.model_name}_epoch{epoch}.pth')

# Save final model
save_dir = './only_hlnet_saves_mm/final'
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(), f'{save_dir}/{args.model_name}_final.pth')

# Plot and save training history
plot_training_history(train_losses, val_losses, save_path='only_hlnet_training_history.png')

# Final evaluation
final_pr_auc, final_pr_auc_online = test_hlnet(test_loader, model, device)
print("\nTraining completed!")
print(f'Final PR-AUC (Offline/Online): {final_pr_auc:.4f}/{final_pr_auc_online:.4f}')
print(f'Best PR-AUC: {best_pr_auc:.4f} (Epoch {best_epoch+1})')


## HLNet Training RGB (single modality, no VAE)

In [None]:
import torch
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import os
import glob
import random
from pathlib import Path
from torch import nn
from datetime import datetime
from sklearn.metrics import auc, precision_recall_curve
import time
import argparse
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

class EarlyStopping:
    """Early stopping to prevent overfitting"""
    def __init__(self, patience=7, min_delta=0, verbose=False):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_state_dict = None

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_state_dict = model.state_dict()
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_state_dict = model.state_dict()
            self.counter = 0

    def get_best_model_state(self):
        return self.best_state_dict

def validate_epoch(val_loader, model, criterion, device, is_topk):
    """Run validation for one epoch"""
    model.eval()
    total_loss = 0.0
    batch_count = 0

    with torch.no_grad():
        for input, label in val_loader:
            seq_len = torch.sum(torch.max(torch.abs(input), dim=2)[0]>0, 1)
            input = input[:, :torch.max(seq_len), :]
            input, label = input.float().to(device), label.float().to(device)

            logits, logits2 = model(input, seq_len)
            clsloss = CLAS(logits, label, seq_len, criterion, device, is_topk)
            clsloss2 = CLAS(logits2, label, seq_len, criterion, device, is_topk)
            croloss = CENTROPY(logits, logits2, seq_len, device)

            total_loss += (clsloss + clsloss2 + 5*croloss).item()
            batch_count += 1

    return total_loss / batch_count

def test_hlnet_single(dataloader, model, device):
    """Test function that evaluates video-level predictions"""
    print("Starting test...")
    with torch.no_grad():
        model.eval()
        pred = []
        pred2 = []
        gt = []

        for i, (input, label) in enumerate(dataloader):
            gt.append(label[0].item())  # Take first label as they're all same for a video

            input = input.to(device)
            logits, logits2 = model(inputs=input, seq_len=None)

            # Get predictions for offline model
            logits = torch.squeeze(logits)
            sig = torch.sigmoid(logits)
            batch_pred = torch.mean(sig).item()
            pred.append(batch_pred)

            # Get predictions for online model
            logits2 = torch.squeeze(logits2)
            sig2 = torch.sigmoid(logits2)
            batch_pred2 = torch.mean(sig2).item()
            pred2.append(batch_pred2)

        # Convert to numpy arrays
        gt = np.array(gt)
        pred = np.array(pred)
        pred2 = np.array(pred2)

        # Calculate metrics
        precision, recall, _ = precision_recall_curve(gt, pred)
        pr_auc = auc(recall, precision)

        precision2, recall2, _ = precision_recall_curve(gt, pred2)
        pr_auc2 = auc(recall2, precision2)

        return pr_auc, pr_auc2

def train_hlnet_single(train_loader, model, optimizer, scheduler, criterion,
                      device, is_topk, val_loader=None):
    """Training function with loss tracking"""
    model.train()
    epoch_loss = 0.0
    batch_count = 0

    for i, (input, label) in enumerate(train_loader):
        seq_len = torch.sum(torch.max(torch.abs(input), dim=2)[0]>0, 1)
        input = input[:, :torch.max(seq_len), :]
        input, label = input.float().to(device), label.float().to(device)

        logits, logits2 = model(input, seq_len)
        clsloss = CLAS(logits, label, seq_len, criterion, device, is_topk)
        clsloss2 = CLAS(logits2, label, seq_len, criterion, device, is_topk)
        croloss = CENTROPY(logits, logits2, seq_len, device)

        total_loss = clsloss + clsloss2 + 5*croloss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        epoch_loss += total_loss.item()
        batch_count += 1

        if (i + 1) % 100 == 0:
            print(f"Step {i+1}: Training Loss: {total_loss.item():.4f}")

    avg_epoch_loss = epoch_loss / batch_count

    # Calculate validation loss if provided
    val_epoch_loss = None
    if val_loader is not None:
        val_epoch_loss = validate_epoch(val_loader, model, criterion, device, is_topk)

    return model, avg_epoch_loss, val_epoch_loss

In [None]:
args = Args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.modality = 'RGB'
args.max_epoch = 200
# Set feature size based on modality
if args.modality == 'RGB':
    args.feature_size = 1024
elif args.modality == 'AUDIO':
    args.feature_size = 128
else:
    raise ValueError(f"Unsupported modality: {args.modality}")

# Create model and move to device
model = Model(args).to(device)

# Setup optimizer
approximator_param = list(map(id, model.approximator.parameters()))
approximator_param += list(map(id, model.conv1d_approximator.parameters()))
base_param = filter(lambda p: id(p) not in approximator_param, model.parameters())

optimizer = optim.Adam([
    {'params': base_param},
    {'params': model.approximator.parameters(), 'lr': args.lr / 2},
    {'params': model.conv1d_approximator.parameters(), 'lr': args.lr / 2},
], lr=args.lr, weight_decay=0.000)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1)
criterion = torch.nn.BCELoss()
is_topk = True

# Create data loaders using your existing function
train_loader, val_loader, test_loader = create_single_modality_data_loaders(args)

# Initialize early stopping
early_stopping = EarlyStopping(patience=7, verbose=True)

# Training tracking
train_losses = []
val_losses = []
best_pr_auc = 0
best_epoch = 0

print(f"Starting training on {device} for {args.modality} modality")

for epoch in range(args.max_epoch):
    print(f"\nEpoch {epoch+1}/{args.max_epoch}")
    st = time.time()

    # Training step
    model, train_loss, val_loss = train_hlnet_single(
        train_loader=train_loader,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        device=device,
        is_topk=is_topk,
        val_loader=val_loader
    )

    # Store losses
    train_losses.append(train_loss)
    if val_loss is not None:
        val_losses.append(val_loss)

    # Calculate PR-AUC for monitoring
    pr_auc, pr_auc_online = test_hlnet_single(test_loader, model, device)

    # Update best PR-AUC
    if pr_auc > best_pr_auc:
        best_pr_auc = pr_auc
        best_epoch = epoch

    print(f'Epoch {epoch+1}/{args.max_epoch}:')
    print(f'Train Loss: {train_loss:.4f}')
    if val_loss is not None:
        print(f'Validation Loss: {val_loss:.4f}')
    print(f'PR-AUC (Offline/Online): {pr_auc:.4f}/{pr_auc_online:.4f}')
    print(f'Best PR-AUC: {best_pr_auc:.4f} (Epoch {best_epoch+1})')
    print(f'Epoch time: {time.time() - st:.2f}s')

    # Early stopping check
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        model.load_state_dict(early_stopping.get_best_model_state())
        break

    scheduler.step()

    # Save checkpoint every 5 epochs
    if epoch % 5 == 0 and epoch > 0:
        save_dir = f'./only_hlnet_saves_{args.modality.lower()}/checkpoints'
        os.makedirs(save_dir, exist_ok=True)
        torch.save(model.state_dict(),
                  f'{save_dir}/{args.model_name}_epoch{epoch}.pth')

# Save final model
save_dir = f'./only_hlnet_saves_{args.modality.lower()}/final'
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(),
          f'{save_dir}/{args.model_name}_final.pth')

# Plot and save training history
plot_training_history(train_losses, val_losses, save_path='only_hlnet_rgb_training_history.png')

# Final evaluation
final_pr_auc, final_pr_auc_online = test_hlnet_single(test_loader, model, device)
print("\nTraining completed!")
print(f'Final PR-AUC (Offline/Online): {final_pr_auc:.4f}/{final_pr_auc_online:.4f}')
print(f'Best PR-AUC: {best_pr_auc:.4f} (Epoch {best_epoch+1})')


## HLNet Training Audio (single modality, no VAE)

In [None]:
args = Args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.modality = 'AUDIO'
args.max_epoch = 200
# Set feature size based on modality
if args.modality == 'RGB':
    args.feature_size = 1024
elif args.modality == 'AUDIO':
    args.feature_size = 128
else:
    raise ValueError(f"Unsupported modality: {args.modality}")

# Create model and move to device
model = Model(args).to(device)

# Setup optimizer
approximator_param = list(map(id, model.approximator.parameters()))
approximator_param += list(map(id, model.conv1d_approximator.parameters()))
base_param = filter(lambda p: id(p) not in approximator_param, model.parameters())

optimizer = optim.Adam([
    {'params': base_param},
    {'params': model.approximator.parameters(), 'lr': args.lr / 2},
    {'params': model.conv1d_approximator.parameters(), 'lr': args.lr / 2},
], lr=args.lr, weight_decay=0.000)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1)
criterion = torch.nn.BCELoss()
is_topk = True

# Create data loaders using your existing function
train_loader, val_loader, test_loader = create_single_modality_data_loaders(args)

# Initialize early stopping
early_stopping = EarlyStopping(patience=7, verbose=True)

# Training tracking
train_losses = []
val_losses = []
best_pr_auc = 0
best_epoch = 0

print(f"Starting training on {device} for {args.modality} modality")

for epoch in range(args.max_epoch):
    print(f"\nEpoch {epoch+1}/{args.max_epoch}")
    st = time.time()

    # Training step
    model, train_loss, val_loss = train_hlnet_single(
        train_loader=train_loader,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        device=device,
        is_topk=is_topk,
        val_loader=val_loader
    )

    # Store losses
    train_losses.append(train_loss)
    if val_loss is not None:
        val_losses.append(val_loss)

    # Calculate PR-AUC for monitoring
    pr_auc, pr_auc_online = test_hlnet_single(test_loader, model, device)

    # Update best PR-AUC
    if pr_auc > best_pr_auc:
        best_pr_auc = pr_auc
        best_epoch = epoch

    print(f'Epoch {epoch+1}/{args.max_epoch}:')
    print(f'Train Loss: {train_loss:.4f}')
    if val_loss is not None:
        print(f'Validation Loss: {val_loss:.4f}')
    print(f'PR-AUC (Offline/Online): {pr_auc:.4f}/{pr_auc_online:.4f}')
    print(f'Best PR-AUC: {best_pr_auc:.4f} (Epoch {best_epoch+1})')
    print(f'Epoch time: {time.time() - st:.2f}s')

    # Early stopping check
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        model.load_state_dict(early_stopping.get_best_model_state())
        break

    scheduler.step()

    # Save checkpoint every 5 epochs
    if epoch % 5 == 0 and epoch > 0:
        save_dir = f'./only_hlnet_saves_{args.modality.lower()}/checkpoints'
        os.makedirs(save_dir, exist_ok=True)
        torch.save(model.state_dict(),
                  f'{save_dir}/{args.model_name}_epoch{epoch}.pth')

# Save final model
save_dir = f'./only_hlnet_saves_{args.modality.lower()}/final'
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(),
          f'{save_dir}/{args.model_name}_final.pth')

# Plot and save training history
plot_training_history(train_losses, val_losses, save_path='only_hlnet_audio_training_history.png'))

# Final evaluation
final_pr_auc, final_pr_auc_online = test_hlnet_single(test_loader, model, device)
print("\nTraining completed!")
print(f'Final PR-AUC (Offline/Online): {final_pr_auc:.4f}/{final_pr_auc_online:.4f}')
print(f'Best PR-AUC: {best_pr_auc:.4f} (Epoch {best_epoch+1})')


## Train HL-Net + VAE (Multimodal: Audio + RGB)

In [65]:
import torch
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import os
import glob
import random
from pathlib import Path
from torch import nn
from datetime import datetime
from sklearn.metrics import auc, precision_recall_curve
import time
import argparse
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

class EarlyStopping:
    """Early stopping to prevent overfitting"""
    def __init__(self, patience=7, min_delta=0, verbose=False):
        """
        Args:
            patience (int): Number of epochs to wait before stopping if no improvement
            min_delta (float): Minimum change in monitored value to qualify as an improvement
            verbose (bool): If True, prints out early stopping information
        """
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_state_dict = None

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_state_dict = model.state_dict()
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_state_dict = model.state_dict()
            self.counter = 0

    def get_best_model_state(self):
        return self.best_state_dict


def validate_epoch(val_loader, hlnet, vae, criterion, device, is_topk,
                  HLNET_LOSS_WEIGHT, RECON_LOSS_WEIGHT):
    """Run validation for one epoch"""
    hlnet.eval()
    vae.eval()
    total_loss = 0.0
    batch_count = 0

    with torch.no_grad():
        for input, label in val_loader:
            inputcpy = input.float().to(device)
            seq_len = torch.sum(torch.max(torch.abs(input), dim=2)[0]>0, 1)
            input = input[:, :torch.max(seq_len), :]
            input, label = input.float().to(device), label.float().to(device)

            logits, logits2 = hlnet(input, seq_len)
            clsloss = CLAS(logits, label, seq_len, criterion, device, is_topk)
            clsloss2 = CLAS(logits2, label, seq_len, criterion, device, is_topk)
            croloss = CENTROPY(logits, logits2, seq_len, device)

            recon_data, mu, logvar = vae(inputcpy)
            recon_criterion = torch.nn.MSELoss(reduction='mean')
            recon_loss = recon_criterion(recon_data, inputcpy)

            total_loss += (HLNET_LOSS_WEIGHT * (clsloss + clsloss2 + 5*croloss) +
                         RECON_LOSS_WEIGHT * recon_loss).item()
            batch_count += 1

    return total_loss / batch_count

def test_hl_vae(dataloader, model, device):
    """
    Test function for HL-Net that evaluates video-level predictions
    Returns PR-AUC scores for both offline and online predictions
    """
    model.eval()
    video_gt = []
    video_pred = []
    video_pred2 = []

    current_video_preds = []
    current_video_preds2 = []

    with torch.no_grad():
        for i, (input, label) in enumerate(dataloader):
            # Process input
            input = input.to(device)
            logits, logits2 = model(inputs=input, seq_len=None)

            # Process predictions - average across time dimension first
            logits = torch.squeeze(logits)  # Remove batch dim if batch_size=1
            sig = torch.sigmoid(logits)
            # Take mean across time dimension for each sample
            sample_preds = torch.mean(sig, dim=1).cpu().numpy()
            current_video_preds.extend(sample_preds)

            logits2 = torch.squeeze(logits2)
            sig2 = torch.sigmoid(logits2)
            sample_preds2 = torch.mean(sig2, dim=1).cpu().numpy()
            current_video_preds2.extend(sample_preds2)

            # Every 5 frames, compute video-level prediction
            if (i + 1) % 5 == 0:
                # Take mean of the 5 frame predictions for this video
                video_pred.append(np.mean(current_video_preds[-5:]))
                video_pred2.append(np.mean(current_video_preds2[-5:]))
                # Only take one label per video (they're all the same)
                video_gt.append(label[0].item())
                # Reset for next video
                current_video_preds = []
                current_video_preds2 = []

    # Convert to numpy arrays
    video_gt = np.array(video_gt)
    video_pred = np.array(video_pred)
    video_pred2 = np.array(video_pred2)

    precision, recall, _ = precision_recall_curve(video_gt, video_pred)
    pr_auc = auc(recall, precision)

    precision2, recall2, _ = precision_recall_curve(video_gt, video_pred2)
    pr_auc2 = auc(recall2, precision2)

    return pr_auc, pr_auc2

def train_hlnet_vae(train_loader, hlnet, vae, optimizer, scheduler, criterion,
                    device, is_topk, HLNET_LOSS_WEIGHT, RECON_LOSS_WEIGHT, val_loader=None):
    """Training function with loss tracking"""
    hlnet.train()
    vae.eval()
    epoch_loss = 0.0
    batch_count = 0

    for i, (input, label) in enumerate(train_loader):
        inputcpy = input.float().to(device)
        seq_len = torch.sum(torch.max(torch.abs(input), dim=2)[0]>0, 1)
        input = input[:, :torch.max(seq_len), :]
        input, label = input.float().to(device), label.float().to(device)

        logits, logits2 = hlnet(input, seq_len)
        clsloss = CLAS(logits, label, seq_len, criterion, device, is_topk)
        clsloss2 = CLAS(logits2, label, seq_len, criterion, device, is_topk)
        croloss = CENTROPY(logits, logits2, seq_len, device)

        with torch.inference_mode():
            recon_data, mu, logvar = vae(inputcpy)
            recon_criterion = torch.nn.MSELoss(reduction='mean')
            recon_loss = recon_criterion(recon_data, inputcpy)

        total_loss = HLNET_LOSS_WEIGHT * (clsloss + clsloss2 + 5*croloss) + RECON_LOSS_WEIGHT * recon_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        epoch_loss += total_loss.item()
        batch_count += 1

        if (i + 1) % 100 == 0:
            print(f"Step {i+1}: Training Loss: {total_loss.item():.4f}")

    avg_epoch_loss = epoch_loss / batch_count

    # Calculate validation loss if provided
    val_epoch_loss = None
    if val_loader is not None:
        val_epoch_loss = validate_epoch(val_loader, hlnet, vae, criterion,
                                      device, is_topk, HLNET_LOSS_WEIGHT, RECON_LOSS_WEIGHT)

    return hlnet, avg_epoch_loss, val_epoch_loss


#### Set the VAE model to load
change line 9 to match the path of the VAE file you are loading

In [66]:
args_vae = Args()
args_vae.feature_size = 1152
args_vae.batch_size = 64
args_vae.modality = 'MIX2'
args_vae.max_epoch = 500
args_vae.lr = 0.0005

vae_model = VAE(latent_dim=64, input_dim=args_vae.feature_size, seq_len=200)
dir = "/mydrive/MyDrive/vae_checkpoints/best_trained_vae.pkl" # Change this
vae_model.load_state_dict(torch.load(dir))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae_model = vae_model.cuda()
vae_model.eval()

HL_NET_LOSS_weight = .8
RECON_LOSS_weight = .2

args = Args()
args.feature_size = 1152
args.batch_size = 128
args.modality = 'MIX2'
args.max_seqlen = 200
args.workers = 1
args.max_epoch = 200

train_loader, val_loader, test_loader = create_data_loaders(args)
model = Model(args).to(device)

approximator_param = list(map(id, model.approximator.parameters()))
approximator_param += list(map(id, model.conv1d_approximator.parameters()))
base_param = filter(lambda p: id(p) not in approximator_param, model.parameters())

optimizer = optim.Adam([
    {'params': base_param},
    {'params': model.approximator.parameters(), 'lr': args.lr / 2},
    {'params': model.conv1d_approximator.parameters(), 'lr': args.lr / 2},
], lr=args.lr, weight_decay=0.000)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1)
criterion = torch.nn.BCELoss()
is_topk = True

# Lists to store training history
train_losses = []
val_losses = []
best_pr_auc = 0
best_epoch = 0
early_stopping = EarlyStopping(patience=7, verbose=True)

for epoch in range(args.max_epoch):
    print(f"\nEpoch {epoch+1}/{args.max_epoch}")
    st = time.time()

    # Training step
    model, train_loss, val_loss = train_hlnet_vae(
        train_loader=train_loader,
        hlnet=model,
        vae=vae_model,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        device=device,
        is_topk=is_topk,
        HLNET_LOSS_WEIGHT=HL_NET_LOSS_weight,
        RECON_LOSS_WEIGHT=RECON_LOSS_weight,
        val_loader=val_loader
    )

    # Store losses
    train_losses.append(train_loss)
    if val_loss is not None:
        val_losses.append(val_loss)

    # Calculate PR-AUC for monitoring
    pr_auc, pr_auc_online = test_hl_vae(test_loader, model, device)

    # Update best PR-AUC
    if pr_auc > best_pr_auc:
        best_pr_auc = pr_auc
        best_epoch = epoch

    print(f'Epoch {epoch+1}/{args.max_epoch}:')
    print(f'Train Loss: {train_loss:.4f}')
    if val_loss is not None:
        print(f'Validation Loss: {val_loss:.4f}')
    print(f'PR-AUC (Offline/Online): {pr_auc:.4f}/{pr_auc_online:.4f}')
    print(f'Best PR-AUC: {best_pr_auc:.4f} (Epoch {best_epoch+1})')
    print(f'Epoch time: {time.time() - st:.2f}s')

    # Early stopping check
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        # Load the best model state
        model.load_state_dict(early_stopping.get_best_model_state())
        break

    scheduler.step()

    if epoch % 5 == 0 and epoch > 0:
        save_dir = './hlnet_saves_mm'
        os.makedirs(save_dir, exist_ok=True)
        torch.save(model.state_dict(), f'{save_dir}/{args.model_name}{epoch}.pth')

# Save final model
save_dir = './hlnet_saves_mm/final'
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(), f'{save_dir}/{args.model_name}{epoch}.pth')

# Plot and save training history
plot_training_history(train_losses, val_losses, save_path='hlnet_vae_training_history.png')


  vae_model.load_state_dict(torch.load(dir))


Creating data loaders...
Train loader created with 15815 samples
Validation loader created with 3955 samples
Test loader created with 4000 samples

Epoch 1/200
Step 100: Training Loss: 1.2349
Epoch 1/200:
Train Loss: 1.5289
Validation Loss: 1.7576
PR-AUC (Offline/Online): 0.8281/0.9180
Best PR-AUC: 0.8281 (Epoch 1)
Epoch time: 281.91s

Epoch 2/200


KeyboardInterrupt: 

## Train HL-Net + VAE (RGB Only)

In [73]:
import torch
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import os
import glob
import random
from pathlib import Path
from torch import nn
from datetime import datetime
from sklearn.metrics import auc, precision_recall_curve
import time
import argparse
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

class EarlyStopping:
    """Early stopping to prevent overfitting"""
    def __init__(self, patience=7, min_delta=0, verbose=False):
        """
        Args:
            patience (int): Number of epochs to wait before stopping if no improvement
            min_delta (float): Minimum change in monitored value to qualify as an improvement
            verbose (bool): If True, prints out early stopping information
        """
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_state_dict = None

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_state_dict = model.state_dict()
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_state_dict = model.state_dict()
            self.counter = 0

    def get_best_model_state(self):
        return self.best_state_dict

def validate_epoch(val_loader, hlnet, vae, criterion, device, is_topk,
                  HLNET_LOSS_WEIGHT, RECON_LOSS_WEIGHT):
    hlnet.eval()
    vae.eval()
    total_loss = 0.0
    batch_count = 0

    with torch.no_grad():
        for input, label in val_loader:
            inputcpy = input.float().to(device)
            seq_len = torch.sum(torch.max(torch.abs(input), dim=2)[0]>0, 1)
            input = input[:, :torch.max(seq_len), :]
            input, label = input.float().to(device), label.float().to(device)

            logits, logits2 = hlnet(input, seq_len)
            clsloss = CLAS(logits, label, seq_len, criterion, device, is_topk)
            clsloss2 = CLAS(logits2, label, seq_len, criterion, device, is_topk)
            croloss = CENTROPY(logits, logits2, seq_len, device)

            recon_data, mu, logvar = vae(inputcpy)
            recon_criterion = torch.nn.MSELoss(reduction='mean')
            recon_loss = recon_criterion(recon_data, inputcpy)

            total_loss += (HLNET_LOSS_WEIGHT * (clsloss + clsloss2 + 5*croloss) +
                         RECON_LOSS_WEIGHT * recon_loss).item()
            batch_count += 1

    return total_loss / batch_count

def test_hl_vae(dataloader, model, device):
    print("Starting test...")
    with torch.no_grad():
        model.eval()
        pred = []
        pred2 = []
        gt = []

        for i, (input, label) in enumerate(dataloader):
            if i == 0:  # Print shapes for first batch
                print(f"Input shape: {input.shape}")
                print(f"Label shape: {label.shape}")

            # For ground truth, just take the label of first frame in batch
            # (they should all be same for a video segment)
            gt.append(label[0].item())

            input = input.to(device)
            logits, logits2 = model(inputs=input, seq_len=None)
            if i == 0:
                print(f"Logits shape: {logits.shape}")

            # Get one prediction per batch/video segment
            logits = torch.squeeze(logits)
            sig = torch.sigmoid(logits)
            if i == 0:
                print(f"Sig shape before mean: {sig.shape}")

            # Average over both frames and sequence length to get one score per video
            batch_pred = torch.mean(sig).item()
            pred.append(batch_pred)

            # Same for online predictions
            logits2 = torch.squeeze(logits2)
            sig2 = torch.sigmoid(logits2)
            batch_pred2 = torch.mean(sig2).item()
            pred2.append(batch_pred2)

        # Convert to numpy arrays
        gt = np.array(gt)
        pred = np.array(pred)
        pred2 = np.array(pred2)

        print(f"Final shapes - GT: {gt.shape}, Pred: {pred.shape}, Pred2: {pred2.shape}")

        precision, recall, th = precision_recall_curve(gt, pred)
        pr_auc = auc(recall, precision)
        precision, recall, th = precision_recall_curve(gt, pred2)
        pr_auc2 = auc(recall, precision)
        return pr_auc, pr_auc2

def train_hlnet_vae(train_loader, hlnet, vae, optimizer, scheduler, criterion,
                    device, is_topk, HLNET_LOSS_WEIGHT, RECON_LOSS_WEIGHT, val_loader=None):
    hlnet.train()
    vae.eval()
    epoch_loss = 0.0
    batch_count = 0

    for i, (input, label) in enumerate(train_loader):
        inputcpy = input.float().to(device)
        seq_len = torch.sum(torch.max(torch.abs(input), dim=2)[0]>0, 1)
        input = input[:, :torch.max(seq_len), :]
        input, label = input.float().to(device), label.float().to(device)

        logits, logits2 = hlnet(input, seq_len)
        clsloss = CLAS(logits, label, seq_len, criterion, device, is_topk)
        clsloss2 = CLAS(logits2, label, seq_len, criterion, device, is_topk)
        croloss = CENTROPY(logits, logits2, seq_len, device)

        with torch.inference_mode():
            recon_data, mu, logvar = vae(inputcpy)
            recon_criterion = torch.nn.MSELoss(reduction='mean')
            recon_loss = recon_criterion(recon_data, inputcpy)

        total_loss = HLNET_LOSS_WEIGHT * (clsloss + clsloss2 + 5*croloss) + RECON_LOSS_WEIGHT * recon_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        epoch_loss += total_loss.item()
        batch_count += 1

        if (i + 1) % 100 == 0:
            print(f"Step {i+1}: Training Loss: {total_loss.item():.4f}")

    avg_epoch_loss = epoch_loss / batch_count

    val_epoch_loss = None
    if val_loader is not None:
        val_epoch_loss = validate_epoch(val_loader, hlnet, vae, criterion,
                                      device, is_topk, HLNET_LOSS_WEIGHT, RECON_LOSS_WEIGHT)

    return hlnet, avg_epoch_loss, val_epoch_loss

class RgbDataset(data.Dataset):
    def __init__(self, args, transform=None, mode='train'):
        self.modality = args.modality
        self.max_seqlen = args.max_seqlen
        self.transform = transform
        self.test_mode = (mode == 'test')

        # Set appropriate file list based on mode
        if mode == 'test':
            self.rgb_list_file = args.test_rgb_list
        elif mode == 'val':
            self.rgb_list_file = args.val_rgb_list
        else:  # train
            self.rgb_list_file = args.train_rgb_list

        self._parse_list()

    def _parse_list(self):
        self.list = [line.strip() for line in open(self.rgb_list_file)]

    def __getitem__(self, index):
        file_path = self.list[index].strip()
        features = np.array(np.load(file_path), dtype=np.float32)
        label = 0.0 if '_label_A' in file_path else 1.0

        if self.transform is not None:
            features = self.transform(features)

        features = process_feat(features, self.max_seqlen, is_random=not self.test_mode)
        return features, label

    def __len__(self):
        return len(self.list)

def rgbcreate_data_loaders(args):
    """
    Create train, validation and test data loaders
    """
    print("Creating data loaders...")

    # Create train loader
    train_dataset = RgbDataset(args, mode='train')
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Train loader created with {len(train_dataset)} samples")

    # Create validation loader
    val_dataset = RgbDataset(args, mode='val')
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,  # No need to shuffle validation data
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Validation loader created with {len(val_dataset)} samples")

    # Create test loader with smaller batch size as per original code
    test_dataset = RgbDataset(args, mode='test')
    test_loader = DataLoader(
        test_dataset,
        batch_size=5,  # Using smaller batch size for testing
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Test loader created with {len(test_dataset)} samples")

    return train_loader, val_loader, test_loader

#### Set the VAE model to load
change line 9 to match the path of the VAE file you are loading

In [None]:
args_vae = Args()
args_vae.feature_size = 1024  # RGB feature size
args_vae.batch_size = 64
args_vae.modality = 'RGB'
args_vae.max_epoch = 200
args_vae.lr = 0.0005

vae_model = VAE(latent_dim=64, input_dim=args_vae.feature_size, seq_len=200)
vae_dir = "/mydrive/MyDrive/single_modality/vae_checkpoints/best_trained_vae_RGB.pkl" # CHANGE THIS
vae_model.load_state_dict(torch.load(vae_dir))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae_model = vae_model.to(device)
vae_model.eval()

HL_NET_LOSS_weight = 0.8
RECON_LOSS_weight = 0.2

# Initialize HL-Net
args = Args()
args.feature_size = 1024 # RGB Feature size
args.max_epoch = 200
model = Model(args).to(device)

# Setup optimizer
approximator_param = list(map(id, model.approximator.parameters()))
approximator_param += list(map(id, model.conv1d_approximator.parameters()))
base_param = filter(lambda p: id(p) not in approximator_param, model.parameters())

optimizer = optim.Adam([
    {'params': base_param},
    {'params': model.approximator.parameters(), 'lr': args.lr / 2},
    {'params': model.conv1d_approximator.parameters(), 'lr': args.lr / 2},
], lr=args.lr, weight_decay=0.000)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1)
criterion = torch.nn.BCELoss()
is_topk = True

# Create datasets and dataloaders
train_loader, val_loader, test_loader = rgbcreate_data_loaders(args)

# Track losses
train_losses = []
val_losses = []
best_pr_auc = 0
best_epoch = 0

early_stopping = EarlyStopping(patience=7, verbose=True)
for epoch in range(args.max_epoch):
    print(f"\nEpoch {epoch+1}/{args.max_epoch}")
    st = time.time()

    model, train_loss, val_loss = train_hlnet_vae(
        train_loader=train_loader,
        hlnet=model,
        vae=vae_model,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        device=device,
        is_topk=is_topk,
        HLNET_LOSS_WEIGHT=HL_NET_LOSS_weight,
        RECON_LOSS_WEIGHT=RECON_LOSS_weight,
        val_loader=val_loader
    )
    train_losses.append(train_loss)

    if val_loss is not None:
        val_losses.append(val_loss)

    pr_auc, pr_auc_online = test_hl_vae(test_loader, model, device)

    if pr_auc > best_pr_auc:
        best_pr_auc = pr_auc
        best_epoch = epoch

    print(f'Epoch {epoch+1}/{args.max_epoch}:')
    print(f'Train Loss: {train_loss:.4f}')
    if val_loss is not None:
        print(f'Validation Loss: {val_loss:.4f}')
    print(f'PR-AUC (Offline/Online): {pr_auc:.4f}/{pr_auc_online:.4f}')
    print(f'Best PR-AUC: {best_pr_auc:.4f} (Epoch {best_epoch+1})')
    print(f'Epoch time: {time.time() - st:.2f}s')

    # Early stopping check
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        # Load the best model state
        model.load_state_dict(early_stopping.get_best_model_state())
        break

    scheduler.step()

    if epoch % 5 == 0 and epoch > 0:
        save_dir = './hlnet_saves_rgb'
        os.makedirs(save_dir, exist_ok=True)
        torch.save(model.state_dict(), f'{save_dir}/{args.model_name}{epoch}.pth')

save_dir = './hlnet_saves_rgb/final'
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(), f'{save_dir}/rgb_{args.model_name}{epoch}.pth')
plot_training_history(train_losses, val_losses, save_path='hlnet_vae_rgb_training_history.png')


## Train HL-NET + VAE (AUDIO Only)

In [75]:
class AudDataset(data.Dataset):
    def __init__(self, args, transform=None, mode='train'):
        self.modality = args.modality
        self.max_seqlen = args.max_seqlen
        self.transform = transform
        self.test_mode = (mode == 'test')

        # Set appropriate file list based on mode
        if mode == 'test':
            self.audio_list_file = args.test_audio_list
        elif mode == 'val':
            self.audio_list_file = args.val_audio_list
        else:  # train
            self.audio_list_file = args.train_audio_list

        self._parse_list()

    def _parse_list(self):
        self.list = [line.strip() for line in open(self.audio_list_file)]

    def __getitem__(self, index):
        file_path = self.list[index].strip()
        features = np.array(np.load(file_path), dtype=np.float32)
        label = 0.0 if '_label_A' in file_path else 1.0

        if self.transform is not None:
            features = self.transform(features)

        features = process_feat(features, self.max_seqlen, is_random=not self.test_mode)
        return features, label

    def __len__(self):
        return len(self.list)

def audcreate_data_loaders(args):
    """
    Create train, validation and test data loaders
    """
    print("Creating data loaders...")

    # Create train loader
    train_dataset = AudDataset(args, mode='train')
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Train loader created with {len(train_dataset)} samples")

    # Create validation loader
    val_dataset = AudDataset(args, mode='val')
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,  # No need to shuffle validation data
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Validation loader created with {len(val_dataset)} samples")

    # Create test loader with smaller batch size as per original code
    test_dataset = AudDataset(args, mode='test')
    test_loader = DataLoader(
        test_dataset,
        batch_size=5,  # Using smaller batch size for testing
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True
    )
    print(f"Test loader created with {len(test_dataset)} samples")

    return train_loader, val_loader, test_loader

In [76]:
args_vae = Args()
args_vae.feature_size = 128  # AUDIO feature size
args_vae.batch_size = 64
args_vae.modality = 'AUDIO'
args_vae.max_epoch = 200
args_vae.lr = 0.0005

vae_model = VAE(latent_dim=64, input_dim=args_vae.feature_size, seq_len=200)
vae_dir = "/mydrive/MyDrive/single_modality/vae_checkpoints/best_trained_vae_AUDIO.pkl" # CHANGE THIS
vae_model.load_state_dict(torch.load(vae_dir))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae_model = vae_model.to(device)
vae_model.eval()

HL_NET_LOSS_weight = 0.8
RECON_LOSS_weight = 0.2

# Initialize HL-Net
args = Args()
args.feature_size = 128 # AUDIO Feature size
args.max_epoch = 200
model = Model(args).to(device)

# Setup optimizer
approximator_param = list(map(id, model.approximator.parameters()))
approximator_param += list(map(id, model.conv1d_approximator.parameters()))
base_param = filter(lambda p: id(p) not in approximator_param, model.parameters())

optimizer = optim.Adam([
    {'params': base_param},
    {'params': model.approximator.parameters(), 'lr': args.lr / 2},
    {'params': model.conv1d_approximator.parameters(), 'lr': args.lr / 2},
], lr=args.lr, weight_decay=0.000)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1)
criterion = torch.nn.BCELoss()
is_topk = True

# Create datasets and dataloaders
train_loader, val_loader, test_loader = audcreate_data_loaders(args)

# Track losses
train_losses = []
val_losses = []
best_pr_auc = 0
best_epoch = 0

early_stopping = EarlyStopping(patience=7, verbose=True)
for epoch in range(args.max_epoch):
    print(f"\nEpoch {epoch+1}/{args.max_epoch}")
    st = time.time()

    model, train_loss, val_loss = train_hlnet_vae(
        train_loader=train_loader,
        hlnet=model,
        vae=vae_model,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        device=device,
        is_topk=is_topk,
        HLNET_LOSS_WEIGHT=HL_NET_LOSS_weight,
        RECON_LOSS_WEIGHT=RECON_LOSS_weight,
        val_loader=val_loader
    )
    train_losses.append(train_loss)

    if val_loss is not None:
        val_losses.append(val_loss)

    pr_auc, pr_auc_online = test_hl_vae(test_loader, model, device)

    if pr_auc > best_pr_auc:
        best_pr_auc = pr_auc
        best_epoch = epoch

    print(f'Epoch {epoch+1}/{args.max_epoch}:')
    print(f'Train Loss: {train_loss:.4f}')
    if val_loss is not None:
        print(f'Validation Loss: {val_loss:.4f}')
    print(f'PR-AUC (Offline/Online): {pr_auc:.4f}/{pr_auc_online:.4f}')
    print(f'Best PR-AUC: {best_pr_auc:.4f} (Epoch {best_epoch+1})')
    print(f'Epoch time: {time.time() - st:.2f}s')

    # Early stopping check
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        # Load the best model state
        model.load_state_dict(early_stopping.get_best_model_state())
        break

    scheduler.step()

    if epoch % 5 == 0 and epoch > 0:
        save_dir = './hlnet_saves_audio'
        os.makedirs(save_dir, exist_ok=True)
        torch.save(model.state_dict(), f'{save_dir}/{args.model_name}{epoch}.pth')

save_dir = './hlnet_saves_audio/final'
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(), f'{save_dir}/audio_{args.model_name}{epoch}.pth')
plot_training_history(train_losses, val_losses, save_path='hlnet_vae_audio_training_history.png')


Creating data loaders...
Train loader created with 3163 samples
Validation loader created with 791 samples
Test loader created with 800 samples

Epoch 1/200


  vae_model.load_state_dict(torch.load(vae_dir))


Starting test...
Input shape: torch.Size([5, 200, 128])
Label shape: torch.Size([5])
Logits shape: torch.Size([5, 200, 1])
Sig shape before mean: torch.Size([5, 200])
Final shapes - GT: (160,), Pred: (160,), Pred2: (160,)
Epoch 1/200:
Train Loss: 2.3285
Validation Loss: 1.6906
PR-AUC (Offline/Online): 0.8730/0.7986
Best PR-AUC: 0.8730 (Epoch 1)
Epoch time: 15.30s

Epoch 2/200
Starting test...
Input shape: torch.Size([5, 200, 128])
Label shape: torch.Size([5])
Logits shape: torch.Size([5, 200, 1])
Sig shape before mean: torch.Size([5, 200])
Final shapes - GT: (160,), Pred: (160,), Pred2: (160,)
Epoch 2/200:
Train Loss: 1.4376
Validation Loss: 1.6638
PR-AUC (Offline/Online): 0.8920/0.5679
Best PR-AUC: 0.8920 (Epoch 2)
Epoch time: 10.14s

Epoch 3/200
Starting test...
Input shape: torch.Size([5, 200, 128])
Label shape: torch.Size([5])
Logits shape: torch.Size([5, 200, 1])
Sig shape before mean: torch.Size([5, 200])
Final shapes - GT: (160,), Pred: (160,), Pred2: (160,)
Epoch 3/200:
Train L