In [None]:
import os
from os.path import join as pjoin

import albumentations as A
import numpy as np
import torch
from accelerate import Accelerator
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

from checkpointer import CheckpointSaver, load_checkpoint
from dataset import (
    ID_LABELS_MAP,
    CustomVOCDetectionDataset,
    detection_collate_fn,
)
from loss import YOLOLoss
from metric import CustomMeanAveragePrecision
from train import train
from utils import add_bboxes_on_img, non_max_suppression, seed_everything
from yolo_custom import yolo_v8_n

In [None]:
seed_everything(42, torch_deterministic=False)

## Аугментации

In [None]:
IMAGE_SIZE = 640
MOSAIC = True

# YOLO bbox format: [x_center, y_center, width, height] (normalized)
bbox_params = A.BboxParams(
    format="yolo", min_area=16, min_visibility=0.1, label_fields=["label_ids"]
)

train_transforms = A.Compose(
    [
        # geometric transforms
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=10, p=0.5),
        # spatial transforms
        (
            A.AtLeastOneBBoxRandomCrop(height=IMAGE_SIZE, width=IMAGE_SIZE, p=1.0)
            if MOSAIC
            else A.SmallestMaxSize(max_size=IMAGE_SIZE, p=1.0)
        ),
        # color transforms
        A.OneOf(
            [
                A.RandomBrightnessContrast(
                    brightness_limit=0.2, contrast_limit=0.2, p=0.5
                ),
                A.HueSaturationValue(
                    hue_shift_limit=10,
                    sat_shift_limit=20,
                    val_shift_limit=10,
                    p=0.5,
                ),
            ],
            p=0.5,
        ),
        # conversion
        A.ToTensorV2(),
    ],
    bbox_params=bbox_params,
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
        A.ToTensorV2(),
    ],
    bbox_params=bbox_params,
)

## Dataset

Набор данных Pascal VOC. Рассмотрим его версию для задачи сегментации. 

Сайт: http://host.robots.ox.ac.uk/pascal/VOC/

Лидерборд за 2012 год: http://host.robots.ox.ac.uk:8080/leaderboard/displaylb_main.php?challengeid=11&compid=3

При тех или иных проблемах со скачиванием с сайта соревнования, скачайте и распакуйте архив в папку `data` (`data/VOCdevkit`) отсюда: https://disk.yandex.ru/d/1jS3yBBN7YdZ-w

In [None]:
train_dataset = CustomVOCDetectionDataset(
    root="data",
    year="2012",
    image_set="train",
    download=False,  # True
    transform=train_transforms,  # transform, not transforms!
    mosaic=MOSAIC,
    img_size=IMAGE_SIZE,
)

val_dataset = CustomVOCDetectionDataset(
    root="data",
    year="2012",
    image_set="val",
    download=False,
    transform=val_transforms,
    mosaic=False,
    img_size=IMAGE_SIZE,
)

In [None]:
img, target = train_dataset[0]

img = img.numpy().transpose(1, 2, 0).astype(np.uint8)

img_bboxes = img.copy()
for bbox, label_id in zip(target["bboxes"], target["label_ids"]):
    label = ID_LABELS_MAP[label_id.item()]
    img_bboxes = add_bboxes_on_img(
        img=img_bboxes,
        bbox=bbox,
        bbox_format="xcycwh",
        denormalize_bbox=True,
        label=label,
    )

fig, ax = plt.subplots(1, 2, figsize=(12, 24))
ax[0].imshow(img)
ax[1].imshow(img_bboxes)

## Обучаем модель

См. `train.py`

In [None]:
CLASSES_NUM = 20

LEARNING_RATE_SGD = 1e-2
LEARNING_RATE_ADAM = 1e-4
MIN_LEARNING_RATE = 1e-5
WEIGHT_DECAY = 1e-5
MOMENTUM_SGD = 0.93
BETAS_ADAM = (0.9, 0.999)
BATCH_SIZE = 32
NUM_WORKERS = 4
GRAD_ACCUMULATION_STEPS = 1
EPOCH_NUM = 100
SCHEDULER_PATIENCE = 5
SCHEDULER_GAMMA = 0.5
CHECKPOINTS_DIR = "checkpoints"
TENSORBOARD_DIR = "tensorboard"
RM_CHECKPOINTS_DIR = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    collate_fn=detection_collate_fn,
    pin_memory=True,
    shuffle=True,
    drop_last=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    collate_fn=detection_collate_fn,
    pin_memory=True,
    shuffle=False,
    drop_last=False,
)

model = yolo_v8_n(classes_num=20)
# model = load_checkpoint(
#     model=model,
#     load_path=pjoin(CHECKPOINTS_DIR, "model_checkpoint_best.pt"),
# )

metric_fn = CustomMeanAveragePrecision(box_format="cxcywh", iou_type="bbox")

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=LEARNING_RATE_SGD,
    weight_decay=WEIGHT_DECAY,
    momentum=MOMENTUM_SGD,
    nesterov=True,
)

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer,
    mode="max",
    factor=SCHEDULER_GAMMA,
    patience=SCHEDULER_PATIENCE,
    min_lr=MIN_LEARNING_RATE,
)

accelerator = Accelerator(
    cpu="cpu" == DEVICE,
    mixed_precision="no",
    gradient_accumulation_steps=GRAD_ACCUMULATION_STEPS,
)
model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, val_dataloader, lr_scheduler
)

loss_fn = YOLOLoss(model=model)  # after accelerate!

os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
checkpointer = CheckpointSaver(
    accelerator=accelerator,
    model=model,
    metric_name="map",
    save_dir=CHECKPOINTS_DIR,
    rm_save_dir=RM_CHECKPOINTS_DIR,
    max_history=5,
    should_minimize=False,
)

In [None]:
os.makedirs(TENSORBOARD_DIR, exist_ok=True)
tensorboard_logger = torch.utils.tensorboard.SummaryWriter(log_dir=TENSORBOARD_DIR)

# Раскомментировать в Google Colab
# %load_ext tensorboard
# %tensorboard --logdir "tensorboard"  --port 6006

In [None]:
train(
    model=model,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_fn=loss_fn,
    metric_fn=metric_fn,
    lr_scheduler=lr_scheduler,
    accelerator=accelerator,
    epoch_num=EPOCH_NUM,
    checkpointer=checkpointer,
    tb_logger=tensorboard_logger,
)

## Загрузим и протестируем обученную модель

In [None]:
CHECKPOINTS_DIR = "checkpoints"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = yolo_v8_n(classes_num=20)
model = load_checkpoint(
    model=model,
    load_path=pjoin(CHECKPOINTS_DIR, "model_checkpoint_best.pt"),
)
model = model.to(DEVICE)
model.eval()

In [None]:
sample_idx = 0
img, target = val_dataset[sample_idx]

outputs = model(img.unsqueeze(0).to(DEVICE))
outputs = non_max_suppression(preds=outputs, conf_thd=0.3, iou_thd=0.5)
outputs = outputs[0]

img = img.numpy().transpose(1, 2, 0).astype(np.uint8)

h, w = img.shape[:2]
img_bboxes_pred = img.copy()
for output in outputs:
    bbox = output[:4].tolist()
    conf = output[4].item()
    label = ID_LABELS_MAP[output[5].item()]
    img_bboxes_pred = add_bboxes_on_img(
        img=img_bboxes_pred,
        bbox=bbox,
        bbox_format="xyxy",
        denormalize_bbox=False,
        label=label,
    )

img_bboxes_target = img.copy()
for bbox, label_id in zip(target["bboxes"], target["label_ids"]):
    label = ID_LABELS_MAP[label_id.item()]
    img_bboxes_target = add_bboxes_on_img(
        img=img_bboxes_target,
        bbox=bbox,
        bbox_format="xcycwh",
        denormalize_bbox=True,
        label=label,
    )

fig, ax = plt.subplots(1, 3, figsize=(12, 24))
ax[0].imshow(img)
ax[1].imshow(img_bboxes_pred)
ax[2].imshow(img_bboxes_target)