In [None]:
%load_ext autoreload
%autoreload 2

In [66]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from model.common import Anomalies
from model.plot import plot_roc_charts


In [67]:
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'font.size': 24})

In [68]:
DRIVER_MAP = {
    'geordi': '2021_08_31_geordi_enyaq',
    'poli': '2021_09_06_poli_enyaq',
    'michal': '2021_11_05_michal_enyaq',
    'dans': '2021_11_18_dans_enyaq',
    'jakub': '2021_11_18_jakubh_enyaq',
}
DRIVER_ID_MAP = {
    'geordi': 'A',
    'poli': 'B',
    'michal': 'C',
    'dans': 'D',
    'jakub': 'E',
}
DRIVER_PREDS_MAP = {
    'geordi': 'logs/2024-12-24-114153-CLIP-geordi/version_0/preds.json',
    'poli': 'logs/2024-12-24-114711-CLIP-poli/version_0/preds.json',
    'michal': 'logs/2024-12-24-114918-CLIP-michal/version_0/preds.json',
    'dans': 'logs/2024-12-24-115042-CLIP-dans/version_0/preds.json',
    'jakub': 'logs/2024-12-24-115234-CLIP-jakub/version_0/preds.json',
}
DATASET_DIR = Path().home() / 'source/driver-dataset/2024-10-28-driver-all-frames'

In [69]:
DRIVER_PREDS_MAP = {
    'geordi': 'logs/2024-12-24-114153-CLIP-geordi/version_0/preds.json',
    'poli': 'logs/2024-12-24-114711-CLIP-poli/version_0/preds.json',
    'michal': 'logs/2024-12-24-114918-CLIP-michal/version_0/preds.json',
    'dans': 'logs/2024-12-24-115042-CLIP-dans/version_0/preds.json',
    'jakub': 'logs/2024-12-24-115234-CLIP-jakub/version_0/preds.json',
}

In [70]:
def get_gt(driver: str, length: int) -> list[int]:
    anomalies_file = DATASET_DIR / DRIVER_MAP[driver] / 'anomal' / 'labels.txt'
    assert anomalies_file.exists(), f'Anomalies file does not exist: {anomalies_file}'
    anomalies = Anomalies.from_file(anomalies_file)
    y_true = anomalies.to_ground_truth(length)
    return y_true

In [71]:
def get_pred(driver: str) -> np.ndarray:
    with open(DRIVER_PREDS_MAP[driver]) as f:
        preds = json.load(f)
    return np.array(preds)[:, 1]  # only get the anomaly score

In [72]:
y_preds = {driver: get_pred(driver) for driver in DRIVER_MAP}
y_trues = {driver: get_gt(driver, len(y_preds[driver])) for driver in DRIVER_MAP}

In [None]:
y_pred_list = []
y_true_list = []
titles = []
for driver in DRIVER_MAP:
    y_pred_list.append(y_preds[driver])
    y_true_list.append(y_trues[driver])
    titles.append(f'Driver {DRIVER_ID_MAP[driver]}')

metrics = plot_roc_charts(
    y_true_list,
    y_pred_list,
    titles=titles,
    cmap='rainbow',
    cbar_text='Thresholds',
    save_path='logs/clip-roc-charts.pdf',
)