# Driver State Analysis

[https://dagshub.com/matejfric/driver-state](https://dagshub.com/matejfric/driver-state)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import os
import re
import sys
from pathlib import Path
from pprint import pprint

repo_root = str(Path.cwd().parent.parent)
if repo_root not in sys.path:
    sys.path.append(repo_root)

import albumentations as albu
import dagshub
import matplotlib.pyplot as plt
import mlflow
import mlflow.pytorch
import numpy as np
import onnx
import pytorch_lightning as L
import torch
from mlflow.models.signature import infer_signature

# Pytorch Lightning EarlyStopping callback does not recover the best weights as in Keras!
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# https://github.com/Lightning-AI/pytorch-lightning/discussions/10399,
# https://pytorch-lightning.readthedocs.io/en/1.5.10/extensions/generated/pytorch_lightning.callbacks.ModelCheckpoint.html
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

from model import (
    BatchSizeDict,
    DatasetPathsLoader,
    DatasetSplit,
    SegmentationModel,
)
from model.augmentation import (
    compose_transforms,
    hard_transforms,
    post_transforms,
    pre_transforms,
)
from model.git import get_commit_id, get_current_branch
from model.plot import (
    plot_learning_curves,
    plot_predictions_compact,
)


## Configuration

In [None]:
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'font.size': 14})

# Experiment logging
REPO_NAME = 'driver-seg'
USER_NAME = 'matejfric'
dagshub.init(REPO_NAME, USER_NAME, mlflow=True)  # type: ignore

# Reproducibility
# https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
SEED = 42
L.seed_everything(SEED, workers=True)

print(
    f'torch: {torch.__version__}, cuda: {torch.cuda.is_available()}, lightning: {L.__version__}'  # type: ignore
)

| Model            | Channel Multiplier | Depth Multiplier | Resolution | Dropout Rate |
|-----------------|------------------|----------------|------------|--------------|
| efficientnet-b0 | 1.0              | 1.0            | 224        | 0.2          |
| efficientnet-b1 | 1.0              | 1.1            | 240        | 0.2          |
| efficientnet-b2 | 1.1              | 1.2            | 260        | 0.3          |
| efficientnet-b3 | 1.2              | 1.4            | 300        | 0.3          |
| efficientnet-b4 | 1.4              | 1.8            | 380        | 0.4          |
| efficientnet-b5 | 1.6              | 2.2            | 456        | 0.4          |

In [None]:
# HYPERPARAMETERS
# ----------------------------------------
ENCODER = 'efficientnet-b0'  # 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'efficientnet-b1', 'mit_b1', ...
DECODER = 'unet'  # 'unet', 'unetplusplus', 'deeplabv3', 'deeplabv3plus', 'fpn', ...  # 'unet', 'unetplusplus', 'deeplabv3', 'deeplabv3plus', 'fpn', ...
FREEZE_ENCODER = True
MAX_EPOCHS = 10
MONITOR = 'valid_loss'
PATIENCE = 10
IMAGE_SIZE = 224
BATCH_SIZE = 32
ENCODER_WEIGHTS = 'imagenet'
LEARNING_RATE = 1e-4
AUGMENTATION = True

# LOGGING
# ----------------------------------------
LOG_DIR = Path('logs')
EXPERIMENT_NAME = f'{datetime.datetime.now().strftime("%Y-%m-%d-%H%M%S")}-{DECODER}-{ENCODER}-test-new-dataset'
VERSION = 0
EXPERIMENT_DIR = LOG_DIR / EXPERIMENT_NAME / f'version_{VERSION}'
DATASET_NAME = (
    '2025-02-26-driver-segmentation-dataset'  # '2024-09-15-driver-segmentation-dataset'
)

MLFLOW_ARTIFACT_DIR = 'outputs'
METRICS_CSV_NAME = 'metrics.csv'
LEARNING_CURVES_PDF_NAME = 'learning_curves.pdf'
PREDICTIONS_PNG_NAME = 'predictions.png'
TRAIN_TRANSFORMS_JSON_NAME = 'train_transforms.json'
MODEL_ONNX_NAME = 'model.onnx'

# DATASET
# ----------------------------------------
DATASET_DIR = Path.home() / f'source/driver-dataset/{DATASET_NAME}'
assert DATASET_DIR.exists(), f'Dataset directory does not exist: {DATASET_DIR}'

TRAIN_SET_DIR = 'train'
VALID_SET_DIR = 'validation'
TEST_SET_DIR = 'test'

IMAGES_DIR = 'images'
MASKS_DIR = 'masks'

TRAIN_IMAGES = sorted((DATASET_DIR / TRAIN_SET_DIR / IMAGES_DIR).glob('*.jpg'))
TRAIN_MASKS = sorted((DATASET_DIR / TRAIN_SET_DIR / MASKS_DIR).glob('*.png'))

VALID_IMAGES = sorted((DATASET_DIR / VALID_SET_DIR / IMAGES_DIR).glob('*.jpg'))
VALID_MASKS = sorted((DATASET_DIR / VALID_SET_DIR / MASKS_DIR).glob('*.png'))

TEST_IMAGES = sorted((DATASET_DIR / TEST_SET_DIR / IMAGES_DIR).glob('*.jpg'))
TEST_MASKS = sorted((DATASET_DIR / TEST_SET_DIR / MASKS_DIR).glob('*.png'))

print(
    f' Train: {len(TRAIN_IMAGES)} images, {len(TRAIN_MASKS)} masks\n',
    f'Valid: {len(VALID_IMAGES)} images, {len(VALID_MASKS)} masks\n',
    f'Test: {len(TEST_IMAGES)} images, {len(TEST_MASKS)} masks',
)

In [None]:
im_names = list(np.unique(['_'.join(im.name.split('_')[:-1]) for im in TRAIN_IMAGES]))
pprint(im_names)

In [None]:
# Exclude images from the training set
# TRAIN_IMAGES = [img for img in TRAIN_IMAGES if 'stribny' not in img.stem]
# TRAIN_MASKS = [mask for mask in TRAIN_MASKS if 'stribny' not in mask.stem]

## Augmentations and Transforms

In [None]:
if AUGMENTATION:
    train_transforms = compose_transforms(
        [
            pre_transforms(image_size=IMAGE_SIZE),
            hard_transforms(),
            post_transforms(),
        ]
    )
else:
    train_transforms = compose_transforms(
        [
            pre_transforms(image_size=IMAGE_SIZE),
            post_transforms(),
        ]
    )
valid_transforms = compose_transforms(
    [
        pre_transforms(image_size=IMAGE_SIZE),
        post_transforms(),
    ]
)
test_transforms = compose_transforms(
    [
        pre_transforms(image_size=IMAGE_SIZE),
        post_transforms(),
    ]
)

## Loaders

In [None]:
dataset_loader = DatasetPathsLoader(
    train=DatasetSplit(images=TRAIN_IMAGES, masks=TRAIN_MASKS),
    valid=DatasetSplit(images=VALID_IMAGES, masks=VALID_MASKS),
    test=DatasetSplit(images=TEST_IMAGES, masks=TEST_MASKS),
)

In [None]:
BATCH_SIZE_DICT = BatchSizeDict(
    {'train': BATCH_SIZE, 'valid': BATCH_SIZE, 'test': BATCH_SIZE}
)
loaders = dataset_loader.get_loaders(
    # set to zero if RuntimeError: Trying to resize storage that is not resizable
    num_workers=int(os.cpu_count()),  # type: ignore
    batch_size_dict=BATCH_SIZE_DICT,
    train_transforms=train_transforms,
    valid_transforms=valid_transforms,
    test_transforms=test_transforms,
)

train_dataloader = loaders['train']
valid_dataloader = loaders['valid']
test_dataloader = loaders['test']

## Training

In [None]:
model = SegmentationModel(
    DECODER,
    ENCODER,
    in_channels=3,
    out_classes=1,
    batch_size_dict=BATCH_SIZE_DICT,
    freeze_encoder=FREEZE_ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    learning_rate=LEARNING_RATE,
)

In [None]:
csv_logger = CSVLogger(LOG_DIR, name=EXPERIMENT_NAME, version=VERSION)
early_stopping = EarlyStopping(
    monitor=MONITOR,
    mode='min',
    patience=PATIENCE,
)
model_checkpoint = ModelCheckpoint(
    dirpath=EXPERIMENT_DIR,
    filename='{epoch}-{valid_loss:3f}',
    monitor=MONITOR,
    save_top_k=1,  # save only the best model
    mode='min',
)

In [None]:
trainer = L.Trainer(
    logger=csv_logger,
    callbacks=[model_checkpoint, early_stopping],
    max_epochs=MAX_EPOCHS,
    log_every_n_steps=1,  # log every batch
    # https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
    deterministic=True,
)

In [None]:
# https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
torch.set_float32_matmul_precision('medium')

In [None]:
trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=valid_dataloader,
)

## Validation

In [None]:
# Load from MLflow
# model_name = 'pytorch-unet-resnet18'
# model_version = 2
# model_uri = f'models:/{model_name}/{model_version}'
# model_ = mlflow.pytorch.load_model(model_uri)

In [None]:
model_checkpoint_path = list(EXPERIMENT_DIR.glob('*.ckpt'))[0]
model_ = SegmentationModel.load_from_checkpoint(model_checkpoint_path)
trainer_ = L.Trainer(logger=False)  # no need to log anything for validation and testing

In [None]:
valid_metrics = trainer_.validate(model_, dataloaders=valid_dataloader, verbose=False)[
    0
]
pprint(valid_metrics)

In [None]:
test_metrics = trainer_.test(model_, dataloaders=test_dataloader, verbose=False)[0]
pprint(test_metrics)

In [None]:
plot_predictions_compact(
    model_,
    test_dataloader,
    save_path=EXPERIMENT_DIR / PREDICTIONS_PNG_NAME,
    n_cols=4,
    limit=12,
    seed=SEED,
    cmap='jet',
)

In [None]:
plot_learning_curves(
    EXPERIMENT_DIR / METRICS_CSV_NAME,
    save_path=EXPERIMENT_DIR / LEARNING_CURVES_PDF_NAME,
    metrics={'jaccard_index': 'Jaccard Index', 'f1_score': 'F1 Score'},
)

## Logging

In [None]:
INPUT_SAMPLE = torch.randn((1, 3, 256, 256), dtype=torch.float32)
OUTPUT_SAMPLE = torch.randn((1, 1, 256, 256), dtype=torch.float32)
model_.to_onnx(
    EXPERIMENT_DIR / MODEL_ONNX_NAME, INPUT_SAMPLE, export_params=True, dynamo=False
)

In [None]:
onnx_model = onnx.load(EXPERIMENT_DIR / MODEL_ONNX_NAME)
onnx.checker.check_model(onnx_model)

In [None]:
# Save the transforms for experiment logging
albu.save(train_transforms, EXPERIMENT_DIR / TRAIN_TRANSFORMS_JSON_NAME)

In [None]:
def get_early_stopping_epoch() -> int | None:
    checkpoint = list(EXPERIMENT_DIR.glob('*.ckpt'))[0].stem
    pattern = r'epoch=(\d+)'
    match = re.search(pattern, checkpoint)
    if match:
        return int(match.group(1))
    else:
        return None

In [None]:
def log_dict_to_mlflow(dictionary: dict[str, float]) -> None:
    for k, v in dictionary.items():
        mlflow.log_metric(k, v)

In [None]:
with mlflow.start_run(run_name=f'{EXPERIMENT_NAME}') as run:
    mlflow.set_tag('Dataset', DATASET_NAME)
    try:
        mlflow.set_tag('Branch', get_current_branch())
        mlflow.set_tag('Commit ID', get_commit_id())
    except Exception as e:
        print(e)

    log_dict_to_mlflow(dict(valid_metrics))
    log_dict_to_mlflow(dict(test_metrics))

    mlflow.log_param('encoder', ENCODER)
    mlflow.log_param('decoder', DECODER)
    mlflow.log_param('batch_size', BATCH_SIZE)
    mlflow.log_param('max_epochs', MAX_EPOCHS)
    mlflow.log_param('early_stopping', get_early_stopping_epoch())
    mlflow.log_param('monitor', MONITOR)
    mlflow.log_param('patience', PATIENCE)
    mlflow.log_param('image_size', IMAGE_SIZE)
    mlflow.log_param('frozen_encoder', FREEZE_ENCODER)
    mlflow.log_param('encoder_weights', ENCODER_WEIGHTS)
    mlflow.log_param('learning_rate', LEARNING_RATE)
    mlflow.log_param('augmentation', AUGMENTATION)

    mlflow.log_artifact(str(EXPERIMENT_DIR / METRICS_CSV_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(
        str(EXPERIMENT_DIR / LEARNING_CURVES_PDF_NAME), MLFLOW_ARTIFACT_DIR
    )
    mlflow.log_artifact(str(EXPERIMENT_DIR / PREDICTIONS_PNG_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact(
        str(EXPERIMENT_DIR / TRAIN_TRANSFORMS_JSON_NAME), MLFLOW_ARTIFACT_DIR
    )
    mlflow.log_artifact(str(EXPERIMENT_DIR / MODEL_ONNX_NAME), MLFLOW_ARTIFACT_DIR)
    mlflow.log_artifact('train_segmentation.ipynb', MLFLOW_ARTIFACT_DIR)

    # Models are versioned by default
    mlflow.pytorch.log_model(
        pytorch_model=model_,
        artifact_path='model',
        registered_model_name=f'pytorch-{EXPERIMENT_NAME}',
        signature=infer_signature(
            INPUT_SAMPLE.numpy(), OUTPUT_SAMPLE.numpy(), dict(training=False)
        ),
    )