In [None]:
%load_ext autoreload
%autoreload 2

In [73]:
from collections import Counter, defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from model.common import Anomalies

In [74]:
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'font.size': 16})

In [75]:
DRIVER_MAP = {
    'geordi': '2021_08_31_geordi_enyaq',
    'poli': '2021_09_06_poli_enyaq',
    'michal': '2021_11_05_michal_enyaq',
    'dans': '2021_11_18_dans_enyaq',
    'jakub': '2021_11_18_jakubh_enyaq',
}
DRIVER_ID_MAP = {
    'geordi': 'A',
    'poli': 'B',
    'michal': 'C',
    'dans': 'D',
    'jakub': 'E',
}

# Path().home() / 'source/driver-dataset/2024-10-28-driver-all-frames'
DATASET_DIR = Path('annotations')

In [76]:
def get_gt(driver: str) -> list[int]:
    anomalies_file = (
        DATASET_DIR / f'{DRIVER_MAP[driver]}.txt'
    )  # / 'anomal' / 'labels.txt'
    assert anomalies_file.exists(), f'Anomalies file does not exist: {anomalies_file}'
    anomalies = Anomalies.from_file(anomalies_file)
    y_true = anomalies.to_ground_truth()
    return y_true


def get_anomalies(driver: str) -> Anomalies:
    anomalies_file = (
        DATASET_DIR / f'{DRIVER_MAP[driver]}.txt'
    )  # / 'anomal' / 'labels.txt'
    assert anomalies_file.exists(), f'Anomalies file does not exist: {anomalies_file}'
    anomalies = Anomalies.from_file(anomalies_file)
    return anomalies

In [77]:
# Extract data

all_anomalies = {driver: get_anomalies(driver) for driver in DRIVER_MAP}

label_counts = Counter()
durations = defaultdict(list)

for anomalies in all_anomalies.values():
    for anomaly in anomalies:
        duration = anomaly.end - anomaly.start
        for label in anomaly.labels:
            label_counts[label] += 1
            durations[label].append(duration)

# Calculate mean and standard deviation for durations
stats = {
    label: (np.mean(durations[label]), np.std(durations[label])) for label in durations
}

In [78]:
custom_order = [
    'only right',
    'only left',
    'hands off',
    'cough',
    'yawn',
    'sneezing',
    'scratch',
    'phone',
    'radio',
    'eyes closed',
    'not looking road',
    'safety belts',
]
assert set(label_counts.keys()) == set(custom_order)

In [79]:
label_renames = {
    'only left': 'only left hand',
    'only right': 'only right hand',
}

In [80]:
counts = [label_counts.get(label, 0) for label in custom_order]
means = [stats[label][0] if label in stats else 0 for label in custom_order]
stds = [stats[label][1] if label in stats else 0 for label in custom_order]
labels = [label_renames.get(label, label) for label in custom_order]

In [None]:
fig, axes = plt.subplots(
    1, 2, figsize=(12, 8), sharey=True, gridspec_kw={'wspace': 0.1}
)

# Left plot: Bar plot
axes[0].barh(labels, counts)
axes[0].set_xlabel('Number of Clips')
# axes[0].set_title('Number of Anomaly Sequences per Class')
for i, count in enumerate(counts):
    axes[0].text(count + 0.5, i, f'{count}', va='center')

# Right plot: Error bar plot
axes[1].errorbar(
    means,
    labels,
    xerr=stds,
    fmt='o',
    ecolor='black',
    capsize=5,
    capthick=2,
)
axes[1].set_xlabel('Clip Duration (frames)')
# axes[1].set_title('Duration Statistics per Class')

# Remove y-axis labels from the right plot
axes[1].tick_params(labelleft=False)

# plt.tight_layout()
plt.savefig('logs/anomaly_stats.pdf', bbox_inches='tight')
plt.show()