In [None]:
import os
import sys
import torch
import torch.backends.cudnn as cudnn
from os import path, mkdir
import logging
from torch.utils.tensorboard import SummaryWriter

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

from feature_extractor import FeaturesWriter, get_features_loader
from utils.utils import register_logger, get_torch_device
from utils.load_model import load_feature_extractor
from features_loader import FeaturesLoader
from network.TorchUtils import TorchModel
from network.anomaly_detector_model import (
    AnomalyDetector,
    custom_objective,
    RegularizedLoss,
)
from utils.callbacks import DefaultModelCallback, TensorBoardCallback

# Definitions

## Global definitions

In [None]:
log_every = 50  # log the writing of clips every n steps
log_file = None  # set logging file
num_workers = 32  # define the number of workers used for loading the videos

cudnn.benchmark = True
register_logger(log_file=log_file)

device = get_torch_device()  # will use GPU if available, CPU otherwise

## Definitions of features extraction

In [None]:
dataset_path = ""  # path to the video dataset
clip_length = 16  # define the length of each input sample
frame_interval = 1  # define the sampling interval between frames
features_dir = ""  # set directory for the features
batch_size = 4
model_type = "c3d"
pretrained_3d = ""  # set the path of the 3d feature extractor

## Definitions of training

In [None]:
annotation_path = ""  # path to train annotation
exps_dir = ""  # path to the directory where models and tensorboard would be saved
feature_dim = 4096
save_every = 1  # epochs interval for saving the model checkpoints
lr_base = 0.01  # learning rate
iterations_per_epoch = 20000  # number of training iterations
epochs = 2  # number of training epochs


models_dir = path.join(exps_dir, "models")
tb_dir = path.join(exps_dir, "tensorboard")

os.makedirs(exps_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
os.makedirs(tb_dir, exist_ok=True)

# Features Extraction

## Create model and dataset

In [None]:
data_loader, data_iter = get_features_loader(
    args.dataset_path, clip_length, frame_interval, batch_size, num_workers, model_type
)

network = load_feature_extractor(args.model_type, args.pretrained_3d, device).eval()

features_writer = FeaturesWriter(num_videos=data_loader.video_count)

In [None]:
if not path.exists(features_dir):
    mkdir(features_dir)

In [None]:
loop_i = 0
with torch.no_grad():
    for data, clip_idxs, dirs, vid_names in data_iter:
        outputs = network(data.to(device)).detach().cpu().numpy()

        for i, (dir, vid_name, clip_idx) in enumerate(zip(dirs, vid_names, clip_idxs)):
            if loop_i == 0:
                logging.info(
                    f"Video {features_writer.dump_count} / {features_writer.num_videos} : Writing clip {clip_idx} of video {vid_name}"
                )

            loop_i += 1
            loop_i %= args.log_every

            dir = path.join(args.save_dir, dir)
            features_writer.write(
                feature=outputs[i],
                video_name=vid_name,
                idx=clip_idx,
                dir=dir,
            )

features_writer.dump()

# Train the Anomaly Detection Model Using the Extracted Features

## Create model, dataset, optimizer and loss function

In [None]:
train_loader = FeaturesLoader(
    features_path=features_dir,
    annotation_path=annotation_path,
    iterations=iterations_per_epoch,
)

network = AnomalyDetector(args.feature_dim)
model = TorchModel(network).to(device).train()

# Callbacks
model.register_callback(DefaultModelCallback(visualization_dir=args.exps_dir))
model.register_callback(TensorBoardCallback(tb_writer=SummaryWriter(log_dir=tb_dir)))

# Training parameters
"""
In the original paper:
    lr = 0.01
    epsilon = 1e-8
"""
optimizer = torch.optim.Adadelta(model.parameters(), lr=args.lr_base, eps=1e-8)

criterion = RegularizedLoss(network, custom_objective).to(device)

## Train the model

In [None]:
model.fit(
    train_iter=train_loader,
    criterion=criterion,
    optimizer=optimizer,
    epochs=epochs,
    network_model_path_base=models_dir,
    save_every=save_every,
)