In [6]:
import datetime
import json
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from pprint import pprint

import clip
import dagshub
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

from model.common import Anomalies, crop_driver_image_contains
from model.git import get_commit_id, get_current_branch
from model.plot import plot_roc_chart

In [7]:
driver = 'geordi'

In [None]:
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'font.size': 14})

# Experiment logging
REPO_NAME = 'driver-state'
USER_NAME = 'matejfric'
dagshub.init(REPO_NAME, USER_NAME, mlflow=True)  # type: ignore

DRIVER_MAP = {
    'geordi': '2021_08_31_geordi_enyaq',
    'poli': '2021_09_06_poli_enyaq',
    'michal': '2021_11_05_michal_enyaq',
    'dans': '2021_11_18_dans_enyaq',
    'jakub': '2021_11_18_jakubh_enyaq',
}
DRIVER = driver
DATASET_NAME = f'2024-10-28-driver-all-frames/{DRIVER_MAP[DRIVER]}'
DATASET_DIR = Path().home() / f'source/driver-dataset/{DATASET_NAME}'

ANOMALIES_FILE = DATASET_DIR / 'anomal' / 'labels.txt'
assert ANOMALIES_FILE.exists(), f'Anomalies file does not exist: {ANOMALIES_FILE}'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'Using device: {device}')

PREFIX = 'a photo of a person inside a car'
PROMPTS = {
    'normal': f'{PREFIX} with both hands on the steering wheel',
    'anomal': f'{PREFIX} coughing, scratching, or holding a phone',
}
pprint(PROMPTS)

BATCH_SIZE = 128

ANOMAL_IMAGES_PATHS = sorted((DATASET_DIR / 'anomal/images/').glob('*.jpg'))

# LOGGING
# ----------
NOTEBOOK_NAME = 'clip.ipynb'
PREDS_JSON_NAME = 'preds.json'
PROMPTS_JSON_NAME = 'prompts.json'
MLFLOW_ARTIFACT_DIR = 'outputs'
MODEL_NAME = 'CLIP'
LOG_DIR = Path('logs')
EXPERIMENT_NAME = (
    f'{datetime.datetime.now().strftime("%Y-%m-%d-%H%M%S")}-{MODEL_NAME}-{DRIVER}'
)
VERSION = 0
EXPERIMENT_DIR = LOG_DIR / EXPERIMENT_NAME / f'version_{VERSION}'
ROC_CHART_NAME = 'roc_chart.svg'
EXPERIMENT_DIR.mkdir(parents=True, exist_ok=True)

In [9]:
model, preprocess = clip.load('ViT-B/32', device=device)
text = clip.tokenize(list(PROMPTS.values())).to(device)

with torch.no_grad():
    text_features = model.encode_text(text)

y_pred = []
y_pred_proba = []


# Function to preprocess a single image
def preprocess_image(image_path: Path) -> torch.Tensor:
    """Square crop and resize the image to 224x224."""
    image = Image.open(image_path)
    image = crop_driver_image_contains(image, image_path)
    # The `preprocess` function resizes to 224x224
    # and normalizes the image.
    return preprocess(image)

In [None]:
# Split ANOMAL_IMAGES into batches
with tqdm(total=len(ANOMAL_IMAGES_PATHS), desc='Processing') as pbar:
    for i in range(0, len(ANOMAL_IMAGES_PATHS), BATCH_SIZE):
        batch = ANOMAL_IMAGES_PATHS[i : i + BATCH_SIZE]

        # Use ThreadPoolExecutor to preprocess images in parallel
        with ThreadPoolExecutor() as executor:
            processed_images = list(executor.map(preprocess_image, batch))

        # Stack and move the preprocessed images to the device
        images = torch.stack(processed_images).to(device)

        with torch.no_grad():
            # Encode image batch
            image_features = model.encode_image(images)

            # Pass the image batch and text batch to the model
            logits_per_image, logits_per_text = model(
                images, text
            )  # Assuming text is preprocessed and batched if needed
            proba_batch = logits_per_image.softmax(dim=-1)
            cls_batch = proba_batch.argmax(dim=-1).cpu().detach().tolist()
        proba_batch = proba_batch.cpu().detach().tolist()

        y_pred_proba.extend(proba_batch)  # Append batch probabilities
        y_pred.extend(cls_batch)  # Append batch predictions
        pbar.update(len(batch))

## Evaluation

In [11]:
anomalies = Anomalies.from_file(ANOMALIES_FILE)
y_true = anomalies.to_ground_truth(len(y_pred))

In [None]:
roc_auc, optimal_threshold = plot_roc_chart(
    y_true,
    np.array(y_pred_proba)[:, 1],
    save_path=EXPERIMENT_DIR / ROC_CHART_NAME,
)

In [13]:
with open(EXPERIMENT_DIR / PREDS_JSON_NAME, 'w') as f:
    json.dump(y_pred_proba, f)

with open(EXPERIMENT_DIR / PROMPTS_JSON_NAME, 'w') as f:
    json.dump(PROMPTS, f)

In [None]:
with mlflow.start_run(run_name=f'{EXPERIMENT_NAME}') as run:
    try:
        mlflow.set_tag('Branch', get_current_branch())
        mlflow.set_tag('Commit ID', get_commit_id())
        mlflow.set_tag('Dataset', DATASET_NAME)
    except Exception as e:
        print(e)
    mlflow.log_metric('roc_auc', roc_auc)
    mlflow.log_metric('optimal_threshold', optimal_threshold)
    mlflow.log_param('driver', DRIVER)
    mlflow.log_param('model', MODEL_NAME)
    mlflow.log_param('prompts', PROMPTS)

    # For comparison with the previous experiments
    mlflow.log_param('sequence_length', 1)
    mlflow.log_param('time_step', 1)
    mlflow.log_param('image_size', 224)
    mlflow.log_param('use_mask', False)

    # Artifacts
    mlflow.log_artifact(str(EXPERIMENT_DIR / PREDS_JSON_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(str(EXPERIMENT_DIR / PROMPTS_JSON_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(NOTEBOOK_NAME, MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(str(EXPERIMENT_DIR / ROC_CHART_NAME), MLFLOW_ARTIFACT_DIR)
