# Driver State Anomaly Detection With STAE

[https://dagshub.com/matejfric/driver-state](https://dagshub.com/matejfric/driver-state)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy
import datetime
import json
import os
import tempfile
import time
from pathlib import Path
from typing import Literal, cast

import dagshub
import matplotlib.pyplot as plt
import mlflow
import mlflow.pytorch
import onnx
import pytorch_lightning as L
import torch
import torchview
from mlflow.models.signature import infer_signature
from pytorch_lightning.callbacks import TQDMProgressBar

# Pytorch Lightning EarlyStopping callback does not recover the best weights as in Keras!
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# https://github.com/Lightning-AI/pytorch-lightning/discussions/10399,
# https://pytorch-lightning.readthedocs.io/en/1.5.10/extensions/generated/pytorch_lightning.callbacks.ModelCheckpoint.html
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torch.utils.data import DataLoader, random_split

from model.ae import evaluate_model_parallel, summarize_model
from model.ae.temporal_3d import RegularizationType, STAEModel
from model.common import Anomalies, BatchSizeDict
from model.dataset import STAEDataset, TemporalAutoencoderDatasetDMD
from model.dmd import DRIVER_SESSION_MAPPING
from model.eval import compute_best_roc_auc
from model.fonts import set_cmu_serif_font
from model.git import get_commit_id, get_current_branch
from model.logging import (
    get_early_stopping_epoch,
    get_submodule_param_count,
    log_dict_to_mlflow,
)
from model.plot import (
    plot_error_and_anomalies,
    plot_learning_curves,
    plot_pr_chart,
    plot_roc_chart,
    plot_temporal_autoencoder_reconstruction,
)

## Configuration

In [None]:
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'font.size': 18})
set_cmu_serif_font()

# Experiment logging
REPO_NAME = 'driver-stae'
USER_NAME = 'matejfric'
dagshub.init(REPO_NAME, USER_NAME, mlflow=True)  # type: ignore

# Reproducibility
# https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
SEED = 42
L.seed_everything(SEED, workers=True)

print(
    f'torch: {torch.__version__}, cuda: {torch.cuda.is_available()}, lightning: {L.__version__}'  # type: ignore
)

In [None]:
driver = None
source_type = None
dataset = None

In [None]:
# HYPERPARAMETERS
# ----------------------------------------
MAX_EPOCHS = 100
MIN_EPOCHS = 10
MIN_DELTA = 5e-5  # Minimum change of valid loss to qualify as an improvement
MONITOR = 'valid_total_loss'
PATIENCE = 7
# Run memory map script to use different image size (`run_memory_map_conversion.py`)
IMAGE_SIZE: Literal[64, 128, 224, 256] = 64
BATCH_SIZE = 64
SEQUENCE_LENGTH = 16
TIME_STEP = 1
LEARNING_RATE = 1e-4
LAMBDA_REG = 1e-4 if not (dataset == 'dmd' and source_type == 'masks') else 1e-6
TRAIN_SET_RATIO = 0.9
USE_MASK = True
USE_2D_BOTTLENECK: list[int] | None = [64, 128]
USE_EXTRA_3DCONV = True
REGULARIZATION: RegularizationType | None = 'l2_model_weights'
USE_PREDICTION_BRANCH = True
SOURCE_TYPE = source_type or 'depth'
if SOURCE_TYPE == 'images' or SOURCE_TYPE == 'rgb':
    CHANNELS = 3
elif SOURCE_TYPE == 'rgbd' or SOURCE_TYPE == 'rgb_source_depth':
    CHANNELS = 4
elif SOURCE_TYPE == 'rgbdm':
    CHANNELS = 5
else:
    CHANNELS = 1
TIME_DIM_INDEX = 1

# LOGGING
# ----------------------------------------
DRIVER_MAP = {
    'geordi': '2021_08_31_geordi_enyaq',
    'poli': '2021_09_06_poli_enyaq',
    'michal': '2021_11_05_michal_enyaq',
    'dans': '2021_11_18_dans_enyaq',
    'jakub': '2021_11_18_jakubh_enyaq',
}
NOW = datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')
DRIVER = driver or 'geordi'
EXPERIMENT_NAME = DRIVER if isinstance(DRIVER, str) else f'driver_{DRIVER}'
LOG_DIR = Path('logs')
RUN_NAME = f'STAE-{DRIVER}-{SOURCE_TYPE}-{IMAGE_SIZE}x{IMAGE_SIZE}'
VERSION = 0
EXPERIMENT_DIR = LOG_DIR / f'{NOW}-{RUN_NAME}' / f'version_{VERSION}'

# ARTIFACTS
# ----------------------------------------
MLFLOW_ARTIFACT_DIR = 'outputs'
METRICS_CSV_NAME = 'metrics.csv'
LEARNING_CURVES_PDF_NAME = 'learning_curves.pdf'
PREDICTIONS_NAME = 'predictions.pdf'
TRAIN_TRANSFORMS_JSON_NAME = 'train_transforms.json'
NOTEBOOK_NAME = 'train.ipynb'
ARCHITECTURE_VISUALIZATION_NAME = 'architecture'
MODEL_SUMMARY_NAME = 'model_summary.txt'
ROC_CHART_NAME = 'roc_chart.pdf'
MODEL_ONNX_NAME = 'model.onnx'
ERROR_CHART_NAME = 'error_chart.pdf'
PR_CHART_NAME = 'pr_chart.pdf'
PREDICTIONS_JSON_NAME = 'predictions.json'

# DATASET
# ----------------------------------------
DATASET: Literal['mrl', 'dmd'] = dataset or 'mrl'
DATASET_NAME = '2024-10-28-driver-all-frames' if DATASET == 'mrl' else 'dmd'
DATASET_DIR = Path().home() / f'source/driver-dataset/{DATASET_NAME}'
assert DATASET_DIR.exists(), f'Dataset directory does not exist: {DATASET_DIR}'

if DATASET == 'mrl':
    driver_dir = DRIVER_MAP[DRIVER]
    root_dir = DATASET_DIR / driver_dir
    memory_map_filename = (
        f'{SOURCE_TYPE}_{IMAGE_SIZE}{"" if USE_MASK else "_no_mask"}.dat'
    )
    NORMAL_MEMORY_MAP = root_dir / 'normal' / 'memory_maps' / memory_map_filename
    ANOMAL_MEMORY_MAP = root_dir / 'anomal' / 'memory_maps' / memory_map_filename
    ANOMALIES_FILE = root_dir / 'anomal' / 'labels.txt'

    assert NORMAL_MEMORY_MAP.exists(), (
        f'Normal memory map does not exist: {NORMAL_MEMORY_MAP}'
    )
    assert ANOMAL_MEMORY_MAP.exists(), (
        f'Anomal memory map does not exist: {ANOMAL_MEMORY_MAP}'
    )
    assert ANOMALIES_FILE.exists(), f'Anomalies file does not exist: {ANOMALIES_FILE}'

elif DATASET == 'dmd':
    TRAIN_SESSIONS = sorted(DRIVER_SESSION_MAPPING[DRIVER])

    # Use session 's1' for testing.
    DMD_TEST_SESSION = 's1'
    TEST_SESSIONS = [x for x in TRAIN_SESSIONS if DMD_TEST_SESSION in x]
    TRAIN_SESSIONS = [x for x in TRAIN_SESSIONS if DMD_TEST_SESSION not in x]
    assert all(s not in TEST_SESSIONS for s in TRAIN_SESSIONS), (
        'Training session(s) cannot be in the test set!'
    )

    TRAIN_DATASETS = sorted(
        [DATASET_DIR / session / 'normal' for session in TRAIN_SESSIONS]
    )
    assert all(dataset.exists() for dataset in TRAIN_DATASETS), (
        'Training datasets do not exist!'
    )

    TEST_DATASETS = sorted(
        [DATASET_DIR / session / 'memory_maps' for session in TEST_SESSIONS]
    )
    assert all(dataset.exists() for dataset in TEST_DATASETS), (
        'Test datasets do not exist!'
    )

    ANOMALIES_FILES = [
        dataset.parent / f'{dataset.parent.name}.json'
        for dataset in TEST_DATASETS
        if dataset.is_dir()
    ]
    assert all(anom_file.exists() for anom_file in ANOMALIES_FILES), (
        f'Anomalies file does not exist: {ANOMALIES_FILES}'
    )

## Loaders

In [None]:
batch_size_dict = BatchSizeDict(
    {'train': BATCH_SIZE, 'valid': BATCH_SIZE, 'test': BATCH_SIZE}
)

if DATASET == 'dmd':
    train_val_dataset = TemporalAutoencoderDatasetDMD(
        dataset_directories=TRAIN_DATASETS,
        memory_map_image_shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS),
        window_size=SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        time_dim_index=TIME_DIM_INDEX,
        source_type=SOURCE_TYPE,
        model_type='stae',
    )
    test_dataset = TemporalAutoencoderDatasetDMD(
        dataset_directories=TEST_DATASETS,
        memory_map_image_shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS),
        window_size=SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        time_dim_index=TIME_DIM_INDEX,
        source_type=SOURCE_TYPE,
        model_type='stae',
    )
elif DATASET == 'mrl':
    train_val_dataset = STAEDataset(
        memory_map_file=NORMAL_MEMORY_MAP,
        memory_map_image_shape=(IMAGE_SIZE, IMAGE_SIZE),
        window_size=SEQUENCE_LENGTH,
        time_step=TIME_STEP,
    )
    test_dataset = STAEDataset(
        memory_map_file=ANOMAL_MEMORY_MAP,
        memory_map_image_shape=(IMAGE_SIZE, IMAGE_SIZE),
        window_size=SEQUENCE_LENGTH,
        time_step=TIME_STEP,
    )

# Train validation split
train_size = int(TRAIN_SET_RATIO * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size
train_dataset, valid_dataset = random_split(train_val_dataset, [train_size, val_size])

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size_dict['train'],
    shuffle=False,  # YOU DON'T WANT TO SHUFFLE TEMPORAL DATA!
    num_workers=int(os.cpu_count()),  # type: ignore
    drop_last=True,
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=batch_size_dict['valid'],
    shuffle=False,
    num_workers=int(os.cpu_count()),  # type: ignore
    drop_last=True,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,  # batch_size_dict['test'],
    shuffle=False,
    num_workers=0,  # int(os.cpu_count()),  # type: ignore
)

In [None]:
print(len(train_dataset), len(valid_dataset), len(test_dataset))

fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes = axes.flatten()  # type: ignore

axes[0].imshow(
    test_dataloader.dataset[0]['image'].permute(1, 2, 3, 0)[0],
    cmap='gray' if CHANNELS == 1 else None,
)
axes[0].axis('off')
axes[0].set_title('Mask')

axes[1].imshow(
    test_dataloader.dataset[0]['mask'].permute(1, 2, 3, 0)[0],
    cmap='gray' if CHANNELS == 1 else None,
)
axes[1].axis('off')
axes[1].set_title('Future Mask')

plt.tight_layout()
plt.show()

## Training

In [None]:
model = STAEModel(
    batch_size_dict=batch_size_dict,
    learning_rate=LEARNING_RATE,
    lambda_reg=LAMBDA_REG,
    use_2d_bottleneck=USE_2D_BOTTLENECK,
    regularization=REGULARIZATION,
    use_prediction_branch=USE_PREDICTION_BRANCH,
    use_extra_3dconv=USE_EXTRA_3DCONV,
)

### Check forward pass

In [None]:
# Test input tensor of size (batch_size, channels, time_steps/depth, height, width)
INPUT_SAMPLE = torch.randn(
    BATCH_SIZE, CHANNELS, SEQUENCE_LENGTH, IMAGE_SIZE, IMAGE_SIZE
)

with torch.no_grad():
    output = model(INPUT_SAMPLE)
reconstruction = output[0]
prediction = output[1] if USE_PREDICTION_BRANCH else reconstruction

assert INPUT_SAMPLE.shape == reconstruction.shape == prediction.shape, (
    f'Input and output shapes do not match! Expected: {INPUT_SAMPLE.shape}, got: {output.shape} and {prediction.shape}'
)

print(summarize_model(model))

In [None]:
csv_logger = CSVLogger(LOG_DIR, name=f'{NOW}-{RUN_NAME}', version=VERSION)
early_stopping = EarlyStopping(
    monitor=MONITOR,
    mode='min',
    patience=PATIENCE,
    min_delta=MIN_DELTA,
)
model_checkpoint = ModelCheckpoint(
    dirpath=EXPERIMENT_DIR,
    filename='{epoch}-{valid_total_loss:3f}',
    monitor=MONITOR,
    save_top_k=1,  # save only the best model
    mode='min',
)
progress_bar = TQDMProgressBar(refresh_rate=10)

In [None]:
trainer = L.Trainer(
    accelerator='gpu',
    logger=csv_logger,
    callbacks=[model_checkpoint, early_stopping, progress_bar],
    max_epochs=MAX_EPOCHS,
    min_epochs=MIN_EPOCHS,
    log_every_n_steps=1,  # log every batch
    # https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
    deterministic=True,
)
torch.use_deterministic_algorithms(
    True, warn_only=True
)  # torch 2.5 does not have a deterministic implementation of `max_pool3d_with_indices_backward_cuda`

In [None]:
# https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
torch.set_float32_matmul_precision('medium')

In [None]:
start_time = time.time()
trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=valid_dataloader,
)
elapsed_time = time.time() - start_time
print(f'Training took {elapsed_time / 60:.2f} minutes.')

## Validation

In [None]:
if False:
    # Load from MLflow
    model_name = f'pytorch-{DRIVER}'
    model_version = 7
    model_uri = f'models:/{model_name}/{model_version}'
    model_ = mlflow.pytorch.load_model(model_uri)
else:
    model_checkpoint_path = list(EXPERIMENT_DIR.glob('*.ckpt'))[0]
    print(f'Loading model from: {model_checkpoint_path}')
    model_ = STAEModel.load_from_checkpoint(model_checkpoint_path)
model_.eval()

In [None]:
trainer_ = L.Trainer(logger=False)  # no need to log anything for validation and testing

In [None]:
valid_metrics = trainer_.validate(model_, dataloaders=valid_dataloader, verbose=True)[0]

In [None]:
# Load ground truth
if DATASET == 'mrl':
    anomalies = Anomalies.from_file(ANOMALIES_FILE)
elif DATASET == 'dmd':
    video_lengths = [vid.length for vid in test_dataset.videos]  # type: ignore
    # Anomaly files are sorted so that they correspond to the video lengths.
    anomalies = Anomalies.from_json(ANOMALIES_FILES, video_lengths)

y_true = anomalies.to_ground_truth()

In [None]:
errors = evaluate_model_parallel(model_, test_dataloader, model_type='stae')
test_metrics = {
    'mse': sum(errors['mse']) / len(errors['mse']),
    'mae': sum(errors['mae']) / len(errors['mae']),
    'fro': sum(errors['fro']) / len(errors['fro']),
}
for metric_name, value in test_metrics.items():
    print(f'{metric_name}: {value:.4f}')

In [None]:
res_dict = compute_best_roc_auc(y_true, errors)
best_metric = str(res_dict['best_metric'])
y_proba: list[float] = cast(list[float], res_dict['y_proba'])
y_true = y_true[: len(y_proba)]

In [None]:
roc_auc, optimal_threshold = plot_roc_chart(
    y_true,
    y_proba,
    save_path=EXPERIMENT_DIR / ROC_CHART_NAME,
    cbar_text=f'Thresholds ({best_metric.upper()})',
)

In [None]:
pr_auc, pr_threshold = plot_pr_chart(
    y_true,
    y_proba,
    save_path=EXPERIMENT_DIR / PR_CHART_NAME,
    cbar_text=f'Thresholds ({best_metric.upper()})',
)

In [None]:
plot_error_and_anomalies(
    y_true=y_true,
    y_pred=y_proba,
    threshold=optimal_threshold,
    save_path=EXPERIMENT_DIR / ERROR_CHART_NAME,
)

In [None]:
plot_temporal_autoencoder_reconstruction(
    model_,
    test_dataloader,
    save_path=EXPERIMENT_DIR / PREDICTIONS_NAME,
    time_dim_index=1,
    indices=(
        [0]  # first video frame
        + [anomalies[0].middle()]
        + [anomalies[1].middle()]
        + [anomalies[2].middle()]
        + [anomalies[-2].middle()]
        + [anomalies[-1].middle()]
    ),
    show_heatmap=True,
    show_metrics=True,
    model_type='stae',
)

In [None]:
plot_learning_curves(
    EXPERIMENT_DIR / METRICS_CSV_NAME,
    save_path=EXPERIMENT_DIR / LEARNING_CURVES_PDF_NAME,
    metrics={
        'fro': 'Frobenius Norm',
        'mae': 'Mean Absolute Error',
        'reconstruction_loss': 'Reconstruction Loss',  # MSE
        'prediction_loss': 'Prediction Loss',
        'regularization_loss': 'Regularization Loss',
    },
    figsize=(38, 6),
    loss_name='total_loss',
)

In [None]:
# Save the predictions
predictions = {
    'y_true': y_true,
    'y_proba': y_proba,
    'errors': errors,
    'roc_auc': roc_auc,
    'best_metric': best_metric,
    'optimal_threshold': optimal_threshold,
}
with open(EXPERIMENT_DIR / PREDICTIONS_JSON_NAME, 'w') as f:
    json.dump(predictions, f)

In [None]:
architecture_visualization = torchview.draw_graph(
    model_,
    input_size=INPUT_SAMPLE.shape,
    graph_dir='TB',
    depth=3,
    roll=True,
    expand_nested=True,
    graph_name='Temporal Autoencoder',
    save_graph=False,
    filename=ARCHITECTURE_VISUALIZATION_NAME,
    directory=str(EXPERIMENT_DIR),
)
architecture_visualization.visual_graph.render(format='svg')
architecture_visualization.visual_graph.render(format='png')

## Logging

In [None]:
def remove_prediction_branch(model: STAEModel) -> STAEModel:
    """
    Remove the prediction branch from the model. For backward compatibility,
    current STAEModel model implementation has a method to remove the prediction branch.
    """
    modified_model = STAEModel(
        batch_size_dict=model.batch_size_dict,
        learning_rate=model.lr,
        lambda_reg=model.lambda_reg,
        use_2d_bottleneck=USE_2D_BOTTLENECK,
        regularization=model.regularization,  # type: ignore
        use_prediction_branch=False,
        use_extra_3dconv=USE_EXTRA_3DCONV,
    )
    modified_model.encoder = copy.deepcopy(model.encoder)
    modified_model.decoder = copy.deepcopy(model.decoder)

    modified_model.eval()
    modified_model.to(model.device)

    return modified_model


mod_model = remove_prediction_branch(model_)
mod_model.to_onnx(
    EXPERIMENT_DIR / MODEL_ONNX_NAME,
    INPUT_SAMPLE[0].unsqueeze(0),
    export_params=True,
    dynamo=True,
    input_names=['input'],
    output_names=['output'],
    opset_version=20,
    do_constant_folding=True,
    external_data=False,  # This prevents creating external .data file
)

onnx_model = onnx.load(EXPERIMENT_DIR / MODEL_ONNX_NAME)
onnx.checker.check_model(onnx_model)

In [None]:
with open(EXPERIMENT_DIR / MODEL_SUMMARY_NAME, 'w') as f:
    f.write(summarize_model(model_))
    f.write('\n')
    f.write(model_.__str__())

In [None]:
with mlflow.start_run(run_name=RUN_NAME) as run:
    try:
        mlflow.set_tag('Branch', get_current_branch())
        mlflow.set_tag('Commit ID', get_commit_id())
        mlflow.set_tag('Dataset', DATASET_NAME)
    except Exception as e:
        print(e)

    mlflow.log_metric('roc_auc', roc_auc)
    mlflow.log_metric('pr_auc', pr_auc)
    mlflow.log_metric('pr_threshold', pr_threshold)
    mlflow.log_metric('optimal_threshold', optimal_threshold)
    log_dict_to_mlflow(valid_metrics, 'metric')
    log_dict_to_mlflow(test_metrics, 'metric')
    log_dict_to_mlflow(get_submodule_param_count(model_), 'param')

    mlflow.log_param('batch_size', BATCH_SIZE)
    mlflow.log_param('max_epochs', MAX_EPOCHS)
    mlflow.log_param('min_epochs', MIN_EPOCHS)
    mlflow.log_param('min_delta', MIN_DELTA)
    mlflow.log_param('early_stopping', get_early_stopping_epoch(EXPERIMENT_DIR))
    mlflow.log_param('monitor', MONITOR)
    mlflow.log_param('patience', PATIENCE)
    mlflow.log_param('image_size', IMAGE_SIZE)
    mlflow.log_param('learning_rate', LEARNING_RATE)
    mlflow.log_param('lambda_regularization', LAMBDA_REG)
    mlflow.log_param('seed', SEED)
    mlflow.log_param('sequence_length', SEQUENCE_LENGTH)
    mlflow.log_param('time_step', TIME_STEP)
    mlflow.log_param('train_set_ratio', TRAIN_SET_RATIO)
    mlflow.log_param('train_sequences', len(train_dataset))
    mlflow.log_param('driver', DRIVER)
    mlflow.log_param('best_metric', best_metric)
    mlflow.log_param('use_mask', USE_MASK)
    mlflow.log_param('training_time_s', elapsed_time)
    mlflow.log_param('training_time_m', round(elapsed_time / 60))
    mlflow.log_param('use_2d_bottleneck', USE_2D_BOTTLENECK)
    mlflow.log_param('use_extra_3dconv', USE_EXTRA_3DCONV)
    mlflow.log_param('regularization', REGULARIZATION)
    mlflow.log_param('use_prediction_branch', USE_PREDICTION_BRANCH)
    mlflow.log_param('source_type', SOURCE_TYPE)
    if DATASET == 'dmd':
        mlflow.log_param('train_sessions', TRAIN_SESSIONS)
        mlflow.log_param('test_sessions', TEST_SESSIONS)

    # CSV metrics, learning curves, predictions, notebook
    mlflow.log_artifact(str(EXPERIMENT_DIR / METRICS_CSV_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(
        str(EXPERIMENT_DIR / LEARNING_CURVES_PDF_NAME), MLFLOW_ARTIFACT_DIR
    )
    mlflow.log_artifact(str(EXPERIMENT_DIR / PREDICTIONS_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(str(EXPERIMENT_DIR / ROC_CHART_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(str(EXPERIMENT_DIR / MODEL_SUMMARY_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(str(Path().cwd() / NOTEBOOK_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(str(EXPERIMENT_DIR / ERROR_CHART_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(str(EXPERIMENT_DIR / PR_CHART_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(
        str(EXPERIMENT_DIR / PREDICTIONS_JSON_NAME), MLFLOW_ARTIFACT_DIR
    )

    # Network visualization with `torchview`
    mlflow.log_artifact(
        str(EXPERIMENT_DIR / (ARCHITECTURE_VISUALIZATION_NAME + '.svg')),
        MLFLOW_ARTIFACT_DIR,
    )
    mlflow.log_artifact(
        str(EXPERIMENT_DIR / (ARCHITECTURE_VISUALIZATION_NAME + '.png')),
        MLFLOW_ARTIFACT_DIR,
    )

    # Log parameters for inference in the model artifact directory
    # for later use during inference.
    with tempfile.TemporaryDirectory() as temp_dir:
        file_path = Path(temp_dir) / 'inference.json'
        with open(file_path, 'w') as f:
            json.dump(
                {
                    'roc_threshold': optimal_threshold,
                    'pr_threshold': pr_threshold,
                    'max_error': max(errors[best_metric]),
                    'min_error': min(errors[best_metric]),
                    'best_metric': best_metric,
                },
                f,
            )
        mlflow.log_artifact(str(file_path), artifact_path='model')

    # Models are versioned by default
    mlflow.pytorch.log_model(
        pytorch_model=model_,
        artifact_path='model',
        registered_model_name=f'pytorch-{DRIVER}',
        signature=infer_signature(
            INPUT_SAMPLE.numpy(), INPUT_SAMPLE.numpy(), dict(training=False)
        ),
    )
    mlflow.log_artifact(str(EXPERIMENT_DIR / MODEL_ONNX_NAME), 'model')