# Driver State Anomaly Detection With Temporal Autoencoders

[https://dagshub.com/matejfric/driver-state](https://dagshub.com/matejfric/driver-state)

TODO: setup a pre-commit hook for https://github.com/mwouts/jupytext and Ruff

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import datetime
import os
import re
from collections.abc import Mapping
from pathlib import Path
from pprint import pprint
from typing import Literal

import dagshub
import matplotlib.pyplot as plt
import mlflow
import mlflow.pytorch
import numpy as np
import pytorch_lightning as L
import torch
from pytorch_lightning.callbacks import TQDMProgressBar

# Pytorch Lightning EarlyStopping callback does not recover the best weights as in Keras!
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# https://github.com/Lightning-AI/pytorch-lightning/discussions/10399,
# https://pytorch-lightning.readthedocs.io/en/1.5.10/extensions/generated/pytorch_lightning.callbacks.ModelCheckpoint.html
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torch.utils.data import DataLoader, random_split

from model import (
    AutoencoderModel,
)
from model.ae import (
    LSTMDecoder,
    LSTMEncoder,
    TemporalAutoencoderModel,
    summarize_model,
)
from model.common import Anomalies, BatchSizeDict
from model.dataset import TemporalAutoencoderDataset
from model.git import get_commit_id, get_current_branch
from model.plot import (
    plot_learning_curves,
    show_examples,  # noqa F401 TODO
    show_random,  # noqa F401 TODO
)

## Configuration

In [None]:
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'font.size': 14})

# Experiment logging
REPO_NAME = 'driver-state'
USER_NAME = 'matejfric'
dagshub.init(REPO_NAME, USER_NAME, mlflow=True)  # type: ignore

# Reproducibility
# https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
SEED = 42
L.seed_everything(SEED, workers=True)

print(
    f'torch: {torch.__version__}, cuda: {torch.cuda.is_available()}, lightning: {L.__version__}'  # type: ignore
)

In [7]:
# HYPERPARAMETERS
# ----------------------------------------
MAX_EPOCHS = 10
MONITOR = 'valid_loss'
PATIENCE = 2
IMAGE_SIZE = 256  # 224
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
LOSS_FUNCTION = 'mse'  # 'mae'
TRAIN_NOISE_STD_INPUT = 0.0  # 0.025
TRAIN_NOISE_STD_LATENT = 0.0  # 0.025
TIME_STEPS = 4
TRAIN_SET_RATIO = 0.9

# LOGGING
# ----------------------------------------
LOG_DIR = Path('logs')
EXPERIMENT_NAME = f'{datetime.datetime.now().strftime("%Y-%m-%d-%H%M%S")}-tae'
VERSION = 0
EXPERIMENT_DIR = LOG_DIR / EXPERIMENT_NAME / f'version_{VERSION}'
DATASET_NAME = '2024-10-28-driver-all-frames/2021_08_31_geordi_enyaq'

MLFLOW_ARTIFACT_DIR = 'outputs'
METRICS_CSV_NAME = 'metrics.csv'
LEARNING_CURVES_PDF_NAME = 'learning_curves.pdf'
PREDICTIONS_PNG_NAME = 'predictions.png'
TRAIN_TRANSFORMS_JSON_NAME = 'train_transforms.json'
NOTEBOOK_NAME = 'anomaly_detection.ipynb'

# DATASET
# ----------------------------------------
DATASET_DIR = Path().home() / f'source/driver-dataset/{DATASET_NAME}'

NORMAL_MEMORY_MAP = DATASET_DIR / 'normal' / 'memory_maps' / 'depth_256.dat'
ANOMAL_MEMORY_MAP = DATASET_DIR / 'anomal' / 'memory_maps' / 'depth_256.dat'
ANOMALIES_FILE = DATASET_DIR / 'anomal' / 'labels.txt'

assert (
    NORMAL_MEMORY_MAP.exists()
), f'Normal memory map does not exist: {NORMAL_MEMORY_MAP}'
assert (
    ANOMAL_MEMORY_MAP.exists()
), f'Anomal memory map does not exist: {ANOMAL_MEMORY_MAP}'
assert ANOMALIES_FILE.exists(), f'Anomalies file does not exist: {ANOMALIES_FILE}'

In [None]:
# Simple test case for forward pass
# ---------------------------------

encoder = LSTMEncoder(n_time_steps=TIME_STEPS)
decoder = LSTMDecoder(n_time_steps=TIME_STEPS, n_image_channels=1)

# Test input tensor of size (batch_size, time_steps, channels, height, width)
x = torch.randn(BATCH_SIZE, TIME_STEPS, 1, IMAGE_SIZE, IMAGE_SIZE)

# Forward pass through the encoder and decoder
encoded = encoder(x)
decoded = decoder(encoded)

# Check the shapes
print(f'Input shape: {x.shape}')
print(f'Latent shape: {encoded.shape}')
print(f'Decoded shape: {decoded.shape}')

assert x.shape == decoded.shape, 'Input and output shapes do not match!'

In [None]:
from torchsummary import summary

from model.ae import ConvEncoder

summary(ConvEncoder(), (1, IMAGE_SIZE, IMAGE_SIZE), device='cpu')

In [None]:
print(summarize_model(ConvEncoder()))

In [None]:
print(summarize_model([encoder, decoder]))

## Loaders

In [16]:
batch_size_dict = BatchSizeDict(
    {'train': BATCH_SIZE, 'valid': BATCH_SIZE, 'test': BATCH_SIZE}
)

train_val_dataset = TemporalAutoencoderDataset(
    memory_map_file=NORMAL_MEMORY_MAP,
    memory_map_image_shape=(IMAGE_SIZE, IMAGE_SIZE),
)
test_dataset = TemporalAutoencoderDataset(
    memory_map_file=ANOMAL_MEMORY_MAP,
    memory_map_image_shape=(IMAGE_SIZE, IMAGE_SIZE),
)

# Train validation split
train_size = int(TRAIN_SET_RATIO * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size
train_dataset, valid_dataset = random_split(train_val_dataset, [train_size, val_size])

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size_dict['train'],
    shuffle=True,
    num_workers=int(os.cpu_count()),  # type: ignore
    drop_last=True,
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=batch_size_dict['valid'],
    shuffle=False,
    num_workers=int(os.cpu_count()),  # type: ignore
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size_dict['test'],
    shuffle=False,
    num_workers=int(os.cpu_count()),  # type: ignore
)

## Training

In [12]:
model = TemporalAutoencoderModel(
    encoder=encoder,
    decoder=decoder,
    batch_size_dict=batch_size_dict,
    learning_rate=LEARNING_RATE,
    loss_function=LOSS_FUNCTION,
    train_noise_std_input=TRAIN_NOISE_STD_INPUT,
    train_noise_std_latent=TRAIN_NOISE_STD_LATENT,
)

In [13]:
csv_logger = CSVLogger(LOG_DIR, name=EXPERIMENT_NAME, version=VERSION)
early_stopping = EarlyStopping(
    monitor=MONITOR,
    mode='min',
    patience=PATIENCE,
)
model_checkpoint = ModelCheckpoint(
    dirpath=EXPERIMENT_DIR,
    filename='{epoch}-{val_loss:3f}',
    monitor=MONITOR,
    save_top_k=1,  # save only the best model
    mode='min',
)
progress_bar = TQDMProgressBar(refresh_rate=10)

In [None]:
trainer = L.Trainer(
    accelerator='gpu',
    logger=csv_logger,
    callbacks=[model_checkpoint, early_stopping, progress_bar],
    max_epochs=MAX_EPOCHS,
    log_every_n_steps=1,  # log every batch
    # https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
    deterministic=True,
)

In [15]:
# https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
torch.set_float32_matmul_precision('medium')

In [None]:
trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=valid_dataloader,
)

## Validation

In [17]:
if False:
    # Load from MLflow
    model_name = 'pytorch-2024-10-14-220152-anomaly-detection-conv-ae'
    model_version = 1
    model_uri = f'models:/{model_name}/{model_version}'
    model_ = mlflow.pytorch.load_model(model_uri)
else:
    model_checkpoint_path = list(EXPERIMENT_DIR.glob('*.ckpt'))[0]
    model_ = AutoencoderModel.load_from_checkpoint(
        model_checkpoint_path, encoder=encoder, decoder=decoder
    )

In [None]:
trainer_ = L.Trainer(logger=False)  # no need to log anything for validation and testing

In [None]:
valid_metrics = trainer_.validate(model_, dataloaders=valid_dataloader, verbose=False)[
    0
]
pprint(valid_metrics)

In [None]:
test_metrics = trainer_.test(model_, dataloaders=test_dataloader, verbose=False)[0]
pprint(test_metrics)

In [36]:
anomalies = Anomalies.from_file(ANOMALIES_FILE)
n_test_frames = len(list((ANOMALIES_FILE.parent / 'images').glob('*.jpg')))
y_true = anomalies.to_ground_truth(n_test_frames)

In [29]:
temporal_slice = test_dataset[anomalies[0].middle() // TIME_STEPS]['image']

In [21]:
# TODO:
# plot_autoencoder_reconstruction(
#     model_,
#     test_dataloader,
#     dataset_path=DATASET_DIR,
#     save_path=EXPERIMENT_DIR / PREDICTIONS_PNG_NAME,
#     limit=15,
#     random_shuffle=False,
# )

In [None]:
plot_learning_curves(
    EXPERIMENT_DIR / METRICS_CSV_NAME,
    save_path=EXPERIMENT_DIR / LEARNING_CURVES_PDF_NAME,
    metrics={
        'mse': 'Mean Squared Error',
        'fro': 'Frobenius Norm',
        'mae': 'Mean Absolute Error',
    },
    figsize=(22, 5),
)

## Logging

In [23]:
def get_early_stopping_epoch() -> int | None:
    checkpoint = list(EXPERIMENT_DIR.glob('*.ckpt'))[0].stem
    pattern = r'epoch=(\d+)'
    match = re.search(pattern, checkpoint)
    if match:
        return int(match.group(1))
    else:
        return None

In [24]:
def log_dict_to_mlflow(
    dictionary: Mapping[str, int | float], type: Literal['metric', 'param']
) -> None:
    for k, v in dictionary.items():
        if type == 'metric':
            mlflow.log_metric(k, v)
        elif type == 'param':
            mlflow.log_param(k, v)

In [25]:
def get_submodule_param_count(model: L.LightningModule) -> dict[str, int]:
    """Get the number of parameters for each submodule in the model."""
    param_counts = {}
    for name, submodule in model.named_children():
        num_params = sum(p.numel() for p in submodule.parameters())
        if num_params > 0:
            param_counts[f'{name}_parameters'] = num_params
    param_counts['total_parameters'] = sum([x for x in param_counts.values()])
    return param_counts

In [None]:
with mlflow.start_run(run_name=f'{EXPERIMENT_NAME}') as run:
    try:
        mlflow.set_tag('Branch', get_current_branch())
        mlflow.set_tag('Commit ID', get_commit_id())
        mlflow.set_tag('Dataset', DATASET_NAME)
    except Exception as e:
        print(e)

    log_dict_to_mlflow(valid_metrics, 'metric')
    log_dict_to_mlflow(test_metrics, 'metric')
    log_dict_to_mlflow(get_submodule_param_count(model_), 'param')

    mlflow.log_param('encoder', str(encoder))
    mlflow.log_param('decoder', str(decoder))
    mlflow.log_param('batch_size', BATCH_SIZE)
    mlflow.log_param('max_epochs', MAX_EPOCHS)
    mlflow.log_param('early_stopping', get_early_stopping_epoch())
    mlflow.log_param('monitor', MONITOR)
    mlflow.log_param('patience', PATIENCE)
    mlflow.log_param('image_size', IMAGE_SIZE)
    mlflow.log_param('learning_rate', LEARNING_RATE)
    mlflow.log_param('loss_function', LOSS_FUNCTION)
    mlflow.log_param('train_noise_std_input', TRAIN_NOISE_STD_INPUT)
    mlflow.log_param('train_noise_std_latent', TRAIN_NOISE_STD_LATENT)
    mlflow.log_param('seed', SEED)
    mlflow.log_param('time_steps', TIME_STEPS)
    mlflow.log_param('train_set_ratio', TRAIN_SET_RATIO)

    mlflow.log_artifact(str(EXPERIMENT_DIR / METRICS_CSV_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(
        str(EXPERIMENT_DIR / LEARNING_CURVES_PDF_NAME), MLFLOW_ARTIFACT_DIR
    )
    # TODO: mlflow.log_artifact(str(EXPERIMENT_DIR / PREDICTIONS_PNG_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(NOTEBOOK_NAME, MLFLOW_ARTIFACT_DIR)

    input = np.random.random((BATCH_SIZE, 1, IMAGE_SIZE, IMAGE_SIZE))
    signature = mlflow.models.infer_signature(input, input, dict(training=False))  # type: ignore

    # Models are versioned by default
    mlflow.pytorch.log_model(
        pytorch_model=model_,
        artifact_path='model',
        registered_model_name=f'pytorch-{EXPERIMENT_NAME}',
        signature=signature,
    )