# Driver State Anomaly Detection With Temporal Autoencoders

[https://dagshub.com/matejfric/driver-state](https://dagshub.com/matejfric/driver-state)

TODO: setup a pre-commit hook for https://github.com/mwouts/jupytext and Ruff

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import json
import os
import sys
from pathlib import Path
from typing import Literal, cast

repo_root = str(Path.cwd().parent.parent)
if repo_root not in sys.path:
    sys.path.append(repo_root)

import dagshub
import matplotlib.pyplot as plt
import mlflow
import mlflow.pytorch
import numpy as np
import pytorch_lightning as L
import torch
import torchview
from mlflow.models.signature import infer_signature
from pytorch_lightning.callbacks import TQDMProgressBar

# Pytorch Lightning EarlyStopping callback does not recover the best weights as in Keras!
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# https://github.com/Lightning-AI/pytorch-lightning/discussions/10399,
# https://pytorch-lightning.readthedocs.io/en/1.5.10/extensions/generated/pytorch_lightning.callbacks.ModelCheckpoint.html
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torch.utils.data import DataLoader, random_split

from model import dmd
from model.ae import (  # noqa F401
    LSTMDecoder,
    LSTMEncoder,
    TemporalAutoencoderModel,
    evaluate_model_parallel,
    summarize_model,
)
from model.ae.iscv2023 import (  # noqa F401
    EfficientNetEncoder,
    ISVC23DecoderV1,
    ISVC23DecoderV2,
    ISVC23DecoderV3,
    ISVC23DecoderV4,
    ISVC23EncoderV1,
    ISVC23EncoderV2,
    ISVC23EncoderV3,
    ISVC23EncoderV4,
)
from model.ae.temporal_3d import (
    Conv3dDecoder,
    Conv3dEncoder,
)
from model.common import Anomalies, BatchSizeDict
from model.dataset import TemporalAutoencoderDataset, TemporalAutoencoderDatasetDMD
from model.eval import compute_best_roc_auc
from model.fonts import set_cmu_serif_font
from model.git import get_commit_id, get_current_branch
from model.logging import (
    get_early_stopping_epoch,
    get_experiment_id,
    get_submodule_param_count,
    log_dict_to_mlflow,
)
from model.plot import (
    plot_error_and_anomalies,
    plot_learning_curves,
    plot_pr_chart,
    plot_roc_chart,
    plot_temporal_autoencoder_reconstruction,
    show_examples,  # noqa F401 TODO
    show_random,  # noqa F401 TODO
)

## Configuration

In [None]:
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'font.size': 24})
set_cmu_serif_font()

# Experiment logging
REPO_NAME = 'driver-tae'
USER_NAME = 'matejfric'
dagshub.init(REPO_NAME, USER_NAME, mlflow=True)  # type: ignore

# Reproducibility
# https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
SEED = 42
L.seed_everything(SEED, workers=True)

print(
    f'torch: {torch.__version__}, cuda: {torch.cuda.is_available()}, lightning: {L.__version__}'  # type: ignore
)

In [None]:
driver = None
source_type = None
image_size = None
latent_dim = None
dataset = None
batch_size = None

In [None]:
# HYPERPARAMETERS
# ----------------------------------------
MAX_EPOCHS = 100
MIN_EPOCHS = 15
MONITOR = 'valid_loss'
PATIENCE = 10
# Run memory map script to use a different image size (`run_memory_map_conversion.py`)
IMAGE_SIZE: Literal[64, 128, 224, 256] = image_size or 64
BATCH_SIZE = batch_size or 128
SEQUENCE_LENGTH = 2
TIME_STEP = 1
LEARNING_RATE = 0.0005  # 1e-4
LOSS_FUNCTION = 'mse'  # 'mae'
TRAIN_SET_RATIO = 0.9
USE_MASK = True
MODEL_NAME: Literal['tae', 'tae3d'] = 'tae'
LATENT_DIM = latent_dim or 2 * IMAGE_SIZE
SOURCE_TYPE: Literal[
    'depth',
    'video_depth_anything',
    'depth_realsense',
    'masks',
    'images',  # alias for 'rgb'
    'rgb',
    'rgbd',
    'rgbdm',
] = source_type or 'depth'
if SOURCE_TYPE == 'images' or SOURCE_TYPE == 'rgb':
    CHANNELS = 3
elif SOURCE_TYPE == 'rgbd' or SOURCE_TYPE == 'rgb_source_depth':
    CHANNELS = 4
elif SOURCE_TYPE == 'rgbdm':
    CHANNELS = 5
else:
    CHANNELS = 1

# LOGGING
# ----------------------------------------
DRIVER_MAP = {
    'geordi': '2021_08_31_geordi_enyaq',
    'poli': '2021_09_06_poli_enyaq',
    'michal': '2021_11_05_michal_enyaq',
    'dans': '2021_11_18_dans_enyaq',
    'jakub': '2021_11_18_jakubh_enyaq',
    'radovan': '2024_07_02_radovan_enyaq',
}
NOW = datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')
DRIVER = driver or 'geordi'
EXPERIMENT_NAME = DRIVER if isinstance(DRIVER, str) else f'driver_{DRIVER}'
LOG_DIR = Path('logs')
RUN_NAME = (
    f'{MODEL_NAME}-{DRIVER}-{SOURCE_TYPE}-{IMAGE_SIZE}x{IMAGE_SIZE}-latent{LATENT_DIM}'
)
VERSION = 0
EXPERIMENT_DIR = LOG_DIR / f'{NOW}-{RUN_NAME}' / f'version_{VERSION}'


MLFLOW_ARTIFACT_DIR = 'outputs'
METRICS_CSV_NAME = 'metrics.csv'
LEARNING_CURVES_PDF_NAME = 'learning_curves.pdf'
PREDICTIONS_NAME = 'predictions.pdf'
PREDICTIONS_JSON_NAME = 'predictions.json'
TRAIN_TRANSFORMS_JSON_NAME = 'train_transforms.json'
NOTEBOOK_NAME = 'train.ipynb'
ARCHITECTURE_VISUALIZATION_NAME = 'architecture'
MODEL_SUMMARY_NAME = 'model_summary.txt'
ROC_CHART_NAME = 'roc_chart.pdf'
ERROR_CHART_NAME = 'error_chart.pdf'
PR_CHART_NAME = 'pr_chart.pdf'
TEST_SESSION = None
# MODEL_ONNX_NAME = 'model.onnx'

# DATASET
# ----------------------------------------
DATASET: Literal['mrl', 'dmd'] = dataset or 'mrl'  #'mrl'
DATASET_NAME = '2024-10-28-driver-all-frames' if DATASET == 'mrl' else 'dmd'
DATASET_DIR = Path().home() / f'source/driver-dataset/{DATASET_NAME}'
assert DATASET_DIR.exists(), f'Dataset directory does not exist: {DATASET_DIR}'

[DRIVER_MAP[driver] for driver in ['geordi', 'poli', 'michal', 'dans', 'jakub']]

if DATASET == 'mrl':
    if DRIVER == 'all':
        driver_dirs = [
            DRIVER_MAP[driver]
            for driver in ['geordi', 'poli', 'michal', 'dans', 'jakub']
        ]
    else:
        driver_dirs = [DRIVER_MAP[DRIVER]]

    TEST_SESSION = 'anomal'  # ~, 181149, 182201
    memory_map_filename = (
        f'{SOURCE_TYPE}_{IMAGE_SIZE}{"" if USE_MASK else "_no_mask"}.dat'
    )

    NORMAL_MEMORY_MAPS = []
    ANOMAL_MEMORY_MAPS = []
    ANOMALIES_FILES = []
    for driver_dir in driver_dirs:
        root_dir = DATASET_DIR / driver_dir
        normal_mem_map = root_dir / 'normal' / 'memory_maps' / memory_map_filename
        anomal_mem_map = root_dir / TEST_SESSION / 'memory_maps' / memory_map_filename
        anomalies_file = root_dir / TEST_SESSION / 'labels.txt'

        assert normal_mem_map.exists(), (
            f'Normal memory map does not exist: {normal_mem_map}'
        )
        assert anomal_mem_map.exists(), (
            f'Anomal memory map does not exist: {anomal_mem_map}'
        )
        assert anomalies_file.exists(), (
            f'Anomalies file does not exist: {anomalies_file}'
        )
        NORMAL_MEMORY_MAPS.append(normal_mem_map)
        ANOMAL_MEMORY_MAPS.append(anomal_mem_map)
        ANOMALIES_FILES.append(anomalies_file)

elif DATASET == 'dmd':
    TRAIN_SESSIONS = sorted(dmd.DRIVER_SESSION_MAPPING[DRIVER])
    # TEST_SESSIONS = copy.copy(TRAIN_SESSIONS)

    # Use session 's1' for testing.
    TEST_SESSIONS = [x for x in TRAIN_SESSIONS if 's1' in x]
    TRAIN_SESSIONS = [x for x in TRAIN_SESSIONS if 's1' not in x]

    # assert SESSION not in TEST_SESSIONS, 'Training session cannot be in the test set!'

    TRAIN_DATASETS = sorted(
        [DATASET_DIR / session / 'normal' for session in TRAIN_SESSIONS]
    )
    assert all(dataset.exists() for dataset in TRAIN_DATASETS), (
        'Training datasets do not exist!'
    )

    TEST_DATASETS = sorted(
        [DATASET_DIR / session / 'memory_maps' for session in TEST_SESSIONS]
    )
    assert all(dataset.exists() for dataset in TEST_DATASETS), (
        'Test datasets do not exist!'
    )

    ANOMALIES_FILES = [
        dataset.parent / f'{dataset.parent.name}.json'
        for dataset in TEST_DATASETS
        if dataset.is_dir()
    ]
    assert all(anom_file.exists() for anom_file in ANOMALIES_FILES), (
        f'Anomalies file does not exist: {ANOMALIES_FILES}'
    )

if MODEL_NAME == 'tae3d':
    assert SEQUENCE_LENGTH >= 16, (
        'Number of time steps must be at least 16 for the 3D convolution model'
    )

### LSTM

In [None]:
# Simple test case for forward pass, also used as model signature in Mlflow.

if MODEL_NAME == 'tae':
    # encoder = LSTMEncoder(n_time_steps=SEQUENCE_LENGTH, bidirectional=True)
    # decoder = LSTMDecoder(
    #     n_time_steps=SEQUENCE_LENGTH,
    #     n_image_channels=1,
    #     image_size=IMAGE_SIZE,
    #     bidirectional=True,
    # )

    # encoder = EfficientNetEncoder(
    #     n_time_steps=SEQUENCE_LENGTH,
    #     bidirectional=True,
    #     image_size=IMAGE_SIZE,
    #     latent_dim=LATENT_DIM,
    # )

    encoder = ISVC23EncoderV1(
        n_time_steps=SEQUENCE_LENGTH,
        bidirectional=True,
        image_size=IMAGE_SIZE,
        latent_dim=LATENT_DIM,
    )

    decoder = ISVC23DecoderV1(
        n_time_steps=SEQUENCE_LENGTH,
        bidirectional=True,
        image_size=IMAGE_SIZE,
        latent_dim=LATENT_DIM,
        n_image_channels=CHANNELS,
    )

    # encoder = ISVC23EncoderV2(
    #     n_time_steps=SEQUENCE_LENGTH, image_size=IMAGE_SIZE, latent_dim=LATENT_DIM
    # )
    # decoder = ISVC23DecoderV2(
    #     n_time_steps=SEQUENCE_LENGTH, image_size=IMAGE_SIZE, latent_dim=LATENT_DIM
    # )

    # encoder = ISVC23EncoderV3(
    #     n_time_steps=SEQUENCE_LENGTH, image_size=IMAGE_SIZE, latent_dim=LATENT_DIM
    # )
    # decoder = ISVC23DecoderV3(
    #     n_time_steps=SEQUENCE_LENGTH, image_size=IMAGE_SIZE, latent_dim=LATENT_DIM
    # )

    # encoder = ISVC23EncoderV4(
    #     n_time_steps=SEQUENCE_LENGTH, image_size=IMAGE_SIZE, latent_dim=LATENT_DIM
    # )
    # decoder = ISVC23DecoderV4(
    #     n_time_steps=SEQUENCE_LENGTH, image_size=IMAGE_SIZE, latent_dim=LATENT_DIM
    # )

    # Test input tensor of size (batch_size, time_steps, channels, height, width)
    INPUT_SAMPLE = torch.randn(
        BATCH_SIZE, SEQUENCE_LENGTH, CHANNELS, IMAGE_SIZE, IMAGE_SIZE
    )

    # Forward pass through the encoder and decoder
    encoded = encoder(INPUT_SAMPLE)
    decoded = decoder(encoded)

    print(f'Encoder: {encoder.__class__.__name__}')
    print(f'Decoder: {decoder.__class__.__name__}')

    # Check the shapes
    print(f'Input shape: {INPUT_SAMPLE.shape}')
    print(f'Latent shape: {encoded.shape}')
    print(f'Decoded shape: {decoded.shape}')

    assert INPUT_SAMPLE.shape == decoded.shape, 'Input and output shapes do not match!'

    print(summarize_model([encoder, decoder]))

    # torchinfo.summary(
    #     encoder,
    #     input_size=(BATCH_SIZE, SEQUENCE_LENGTH, 1, IMAGE_SIZE, IMAGE_SIZE),
    #     depth=4,
    # )

### Conv3D

In [None]:
# Simple test case for forward pass, also used as model signature in Mlflow.

if MODEL_NAME == 'tae3d':
    encoder = Conv3dEncoder()
    decoder = Conv3dDecoder()

    # Test input tensor of size (batch_size, channels, time_steps/depth, height, width)
    INPUT_SAMPLE = torch.randn(
        BATCH_SIZE, CHANNELS, SEQUENCE_LENGTH, IMAGE_SIZE, IMAGE_SIZE
    )

    # Forward pass through the encoder and decoder
    encoded = encoder(INPUT_SAMPLE)
    decoded = decoder(encoded)

    # Check the shapes
    print(f'Input shape: {INPUT_SAMPLE.shape}')
    print(f'Latent shape: {encoded.shape}')
    print(f'Decoded shape: {decoded.shape}')

    assert INPUT_SAMPLE.shape == decoded.shape, 'Input and output shapes do not match!'

    print(summarize_model([encoder, decoder]))

## Loaders

In [None]:
TIME_DIM_INDEX = 0 if MODEL_NAME == 'tae' else 1
DTYPE = np.float32 if SOURCE_TYPE == 'depth_realsense' else np.uint8

batch_size_dict = BatchSizeDict(
    {'train': BATCH_SIZE, 'valid': BATCH_SIZE, 'test': BATCH_SIZE}
)

if DATASET == 'dmd':
    train_val_dataset = TemporalAutoencoderDatasetDMD(
        dataset_directories=TRAIN_DATASETS,
        memory_map_image_shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS),
        window_size=SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        time_dim_index=TIME_DIM_INDEX,
        source_type=SOURCE_TYPE,
    )
    test_dataset = TemporalAutoencoderDatasetDMD(
        dataset_directories=TEST_DATASETS,
        memory_map_image_shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS),
        window_size=SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        time_dim_index=TIME_DIM_INDEX,
        source_type=SOURCE_TYPE,
    )
elif DATASET == 'mrl':
    train_val_dataset = TemporalAutoencoderDataset(
        memory_map_file=NORMAL_MEMORY_MAPS,
        memory_map_image_shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS),
        window_size=SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        time_dim_index=TIME_DIM_INDEX,
        dtype=DTYPE,
    )
    test_dataset = TemporalAutoencoderDataset(
        memory_map_file=ANOMAL_MEMORY_MAPS,
        memory_map_image_shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS),
        window_size=SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        time_dim_index=TIME_DIM_INDEX,
        dtype=DTYPE,
    )

# Train validation split
train_size = int(TRAIN_SET_RATIO * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size
train_dataset, valid_dataset = random_split(train_val_dataset, [train_size, val_size])

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size_dict['train'],
    shuffle=False,  # YOU DON'T WANT TO SHUFFLE TEMPORAL DATA!
    num_workers=int(os.cpu_count()),  # type: ignore
    drop_last=True,
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=batch_size_dict['valid'],
    shuffle=False,
    num_workers=int(os.cpu_count()),  # type: ignore
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,  # batch_size_dict['test'],
    shuffle=False,
    num_workers=0,  # int(os.cpu_count()),  # type: ignore
)

In [None]:
len(train_dataset), len(valid_dataset), len(test_dataset)

In [None]:
# Show an example
plt.imshow(test_dataloader.dataset[0]['image'][0].permute(1, 2, 0))
plt.axis('off')
plt.show()

## Training

In [None]:
model = TemporalAutoencoderModel(
    encoder=encoder,
    decoder=decoder,
    batch_size_dict=batch_size_dict,
    learning_rate=LEARNING_RATE,
    loss_function=LOSS_FUNCTION,
    time_dim_index=1 if MODEL_NAME == 'tae' else 2,
)

In [None]:
csv_logger = CSVLogger(LOG_DIR, name=f'{NOW}-{RUN_NAME}', version=VERSION)
early_stopping = EarlyStopping(
    monitor=MONITOR,
    mode='min',
    patience=PATIENCE,
)
model_checkpoint = ModelCheckpoint(
    dirpath=EXPERIMENT_DIR,
    filename='{epoch}-{valid_loss:3f}',
    monitor=MONITOR,
    save_top_k=1,  # save only the best model
    mode='min',
)
progress_bar = TQDMProgressBar(refresh_rate=10)

In [None]:
trainer = L.Trainer(
    accelerator='gpu',
    logger=csv_logger,
    callbacks=[model_checkpoint, early_stopping, progress_bar],
    max_epochs=MAX_EPOCHS,
    min_epochs=MIN_EPOCHS,
    log_every_n_steps=1,  # log every batch
    # https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
    deterministic=True,
)
torch.use_deterministic_algorithms(
    True, warn_only=True
)  # torch 2.5 does not implement deterministic `max_pool3d_with_indices_backward_cuda`

In [None]:
# https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
torch.set_float32_matmul_precision('medium')

In [None]:
training_time = datetime.datetime.now()

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=valid_dataloader,
)

training_time = datetime.datetime.now() - training_time
training_time_minutes = training_time.total_seconds() / 60
print(f'Training time: {training_time_minutes:.2f} minutes')

## Validation

In [None]:
if False:
    # Load from MLflow
    model_name = 'pytorch-2025-02-16-201321-tae-radovan-isvc23v1-L515-test'
    model_version = 1
    model_uri = f'models:/{model_name}/{model_version}'
    model_ = mlflow.pytorch.load_model(model_uri)
else:
    model_checkpoint_path = list(EXPERIMENT_DIR.glob('*.ckpt'))[0]
    print(f'Loading model from: {model_checkpoint_path}')
    model_ = TemporalAutoencoderModel.load_from_checkpoint(
        model_checkpoint_path, encoder=encoder, decoder=decoder
    )

In [None]:
trainer_ = L.Trainer(logger=False)  # no need to log anything for validation and testing

In [None]:
valid_metrics = trainer_.validate(model_, dataloaders=valid_dataloader, verbose=True)[0]

In [None]:
# Load ground truth
if DATASET == 'mrl':
    anomalies = Anomalies.from_file(ANOMALIES_FILES)
elif DATASET == 'dmd':
    video_lengths = [vid.length for vid in test_dataset.videos]  # type: ignore
    # Anomaly files are sorted so that they correspond to the video lengths.
    anomalies = Anomalies.from_json(ANOMALIES_FILES, video_lengths)

y_true = anomalies.to_ground_truth()

In [None]:
errors = evaluate_model_parallel(model_, test_dataloader)
test_metrics = {
    'mse': sum(errors['mse']) / len(errors['mse']),
    'mae': sum(errors['mae']) / len(errors['mae']),
    'fro': sum(errors['fro']) / len(errors['fro']),
}
for metric_name, value in test_metrics.items():
    print(f'{metric_name}: {value:.4f}')

In [None]:
res_dict = compute_best_roc_auc(y_true, errors)
best_metric = str(res_dict['best_metric'])
y_proba = cast(list[float], res_dict['y_proba'])
y_true = y_true[: len(y_proba)]

In [None]:
roc_auc, optimal_threshold = plot_roc_chart(
    y_true=y_true,
    y_pred_proba=y_proba,
    save_path=EXPERIMENT_DIR / ROC_CHART_NAME,
    cbar_text=f'Thresholds ({best_metric.upper()})',
)

In [None]:
# Save the predictions
predictions = {
    'y_true': y_true,
    'y_proba': y_proba,
    'errors': errors,
    'roc_auc': roc_auc,
    'best_metric': best_metric,
    'optimal_threshold': optimal_threshold,
}
with open(EXPERIMENT_DIR / PREDICTIONS_JSON_NAME, 'w') as f:
    json.dump(predictions, f)

In [None]:
pr_auc, pr_threshold = plot_pr_chart(
    y_true=y_true,
    y_pred_proba=y_proba,
    save_path=EXPERIMENT_DIR / PR_CHART_NAME,
    figsize=(8, 6),
)

In [None]:
plot_error_and_anomalies(
    y_true=y_true,
    y_pred=y_proba,
    threshold=optimal_threshold,
    save_path=EXPERIMENT_DIR / ERROR_CHART_NAME,
)

In [None]:
plot_temporal_autoencoder_reconstruction(
    model_,
    test_dataloader,
    save_path=EXPERIMENT_DIR / PREDICTIONS_NAME,
    time_dim_index=0 if MODEL_NAME == 'tae' else 1,
    indices=(
        [0]  # first video frame
        + [anomalies[0].middle()]
        + [anomalies[1].middle()]
        + [anomalies[2].middle()]
        + [anomalies[-2].middle()]
        + [anomalies[-1].middle()]
    ),
    show_heatmap=True,
    show_metrics=True,
)

In [None]:
plot_learning_curves(
    EXPERIMENT_DIR / METRICS_CSV_NAME,
    save_path=EXPERIMENT_DIR / LEARNING_CURVES_PDF_NAME,
    metrics={
        #'mse': 'Mean Squared Error',
        #'fro': 'Frobenius Norm',
        'mae': 'MAE',
    },
    figsize=(16, 4.5),
)

In [None]:
architecture_visualization = torchview.draw_graph(
    model_,
    input_size=INPUT_SAMPLE.shape,
    graph_dir='TB',
    depth=3,
    roll=True,
    expand_nested=True,
    graph_name='Temporal Autoencoder',
    save_graph=False,
    filename=ARCHITECTURE_VISUALIZATION_NAME,
    directory=str(EXPERIMENT_DIR),
)
architecture_visualization.visual_graph.render(format='svg')
architecture_visualization.visual_graph.render(format='png')

## Logging

In [None]:
# model_.to_onnx(
#     EXPERIMENT_DIR / MODEL_ONNX_NAME,
#     INPUT_SAMPLE,
#     export_params=False,
#     dynamo=False,
#     opset_version=11,
# )

In [None]:
with open(EXPERIMENT_DIR / MODEL_SUMMARY_NAME, 'w') as f:
    f.write(summarize_model([encoder, decoder]))

In [None]:
with mlflow.start_run(
    run_name=f'{RUN_NAME}', experiment_id=get_experiment_id(EXPERIMENT_NAME)
) as run:
    try:
        mlflow.set_tag('branch', get_current_branch())
        mlflow.set_tag('commit_id', get_commit_id())
        mlflow.set_tag('dataset_name', DATASET_NAME)
    except Exception as e:
        print(e)

    mlflow.log_metric('roc_auc', roc_auc)
    mlflow.log_metric('optimal_threshold', optimal_threshold)
    mlflow.log_metric('pr_auc', pr_auc)
    mlflow.log_metric('pr_threshold', pr_threshold)
    log_dict_to_mlflow(valid_metrics, 'metric')
    log_dict_to_mlflow(test_metrics, 'metric')
    log_dict_to_mlflow(get_submodule_param_count(model_), 'param')

    mlflow.log_param('dataset', DATASET)
    mlflow.log_param('annotations', ANOMALIES_FILES)
    mlflow.log_param('encoder', str(encoder))
    mlflow.log_param('decoder', str(decoder))
    mlflow.log_param('encoder_name', encoder.__class__.__name__)
    mlflow.log_param('decoder_name', decoder.__class__.__name__)
    mlflow.log_param('batch_size', BATCH_SIZE)
    mlflow.log_param('max_epochs', MAX_EPOCHS)
    mlflow.log_param('min_epochs', MIN_EPOCHS)
    mlflow.log_param('early_stopping', get_early_stopping_epoch(EXPERIMENT_DIR))
    mlflow.log_param('monitor', MONITOR)
    mlflow.log_param('patience', PATIENCE)
    mlflow.log_param('image_size', IMAGE_SIZE)
    mlflow.log_param('learning_rate', LEARNING_RATE)
    mlflow.log_param('loss_function', LOSS_FUNCTION)
    mlflow.log_param('seed', SEED)
    mlflow.log_param('sequence_length', SEQUENCE_LENGTH)
    mlflow.log_param('time_step', TIME_STEP)
    mlflow.log_param('train_set_ratio', TRAIN_SET_RATIO)
    mlflow.log_param('train_sequences', len(train_dataset))
    mlflow.log_param('driver', DRIVER)
    mlflow.log_param('best_metric', best_metric)
    mlflow.log_param('use_mask', USE_MASK)
    mlflow.log_param('source_type', SOURCE_TYPE)
    mlflow.log_param('latent_dim', LATENT_DIM)
    mlflow.log_param('training_time_minutes', training_time_minutes)
    mlflow.log_param('model_name', MODEL_NAME)
    mlflow.log_param('channels', CHANNELS)

    if DATASET == 'mrl':
        mlflow.log_param('test_session', TEST_SESSION)

    elif DATASET == 'dmd':
        mlflow.log_param('train_sessions', TRAIN_SESSIONS)
        mlflow.log_param('test_sessions', TEST_SESSIONS)
        mlflow.log_param('train_datasets', [str(dataset) for dataset in TRAIN_DATASETS])
        mlflow.log_param('test_datasets', [str(dataset) for dataset in TEST_DATASETS])

    # CSV metrics, learning curves, predictions, notebook
    mlflow.log_artifact(str(EXPERIMENT_DIR / METRICS_CSV_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(
        str(EXPERIMENT_DIR / LEARNING_CURVES_PDF_NAME), MLFLOW_ARTIFACT_DIR
    )
    mlflow.log_artifact(str(EXPERIMENT_DIR / PREDICTIONS_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(str(EXPERIMENT_DIR / ROC_CHART_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(str(EXPERIMENT_DIR / MODEL_SUMMARY_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(str(EXPERIMENT_DIR / ERROR_CHART_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(str(EXPERIMENT_DIR / PR_CHART_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(
        str(EXPERIMENT_DIR / PREDICTIONS_JSON_NAME), MLFLOW_ARTIFACT_DIR
    )
    mlflow.log_artifact(str(Path().cwd() / NOTEBOOK_NAME), MLFLOW_ARTIFACT_DIR)

    # Network visualization with `torchview`
    mlflow.log_artifact(
        str(EXPERIMENT_DIR / (ARCHITECTURE_VISUALIZATION_NAME + '.svg')),
        MLFLOW_ARTIFACT_DIR,
    )
    mlflow.log_artifact(
        str(EXPERIMENT_DIR / (ARCHITECTURE_VISUALIZATION_NAME + '.png')),
        MLFLOW_ARTIFACT_DIR,
    )

    # Models are versioned by default
    mlflow.pytorch.log_model(
        pytorch_model=model_,
        artifact_path='model',
        registered_model_name=f'pytorch-{DRIVER}',
        signature=infer_signature(
            INPUT_SAMPLE.numpy(), INPUT_SAMPLE.numpy(), dict(training=False)
        ),
    )