# Обучение моделей в Pytorch

In [None]:
import os
import random
from os.path import join as pjoin
from shutil import rmtree

import albumentations
import numpy as np
import torch
from PIL import Image
from accelerate import Accelerator
from albumentations.pytorch.transforms import ToTensorV2
from matplotlib import pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets

from checkpointer import CheckpointSaver, load_checkpoint
from dataset import CustomVOCSegmentation
from deeplabv3plus.modeling import deeplabv3_resnet50
from loss import CEDiceLoss, CrossEntropyLoss, DiceLoss, FocalLoss
from metric import IoUMetric
from train import count_pytorch_model_params, train
from unet import UNet
from unet_custom import CustomUNet
from utils import seed_everything

In [None]:
seed_everything(42, torch_deterministic=False)

## Аугментации

Трансформации/аугментации для исходных изображений и масок/таргетов.

Аугментации для задач компьютерного хрения: https://albumentations.ai/

In [None]:
IMAGE_SIZE = 256
train_transforms = albumentations.Compose(
    [
        # albumentations.SmallestMaxSize(max_size=IMAGE_SIZE),
        albumentations.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE, p=1.0),
        # albumentations.CropNonEmptyMaskIfExists(height=IMAGE_SIZE, width=IMAGE_SIZE),
        # albumentations.PadIfNeeded(min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, p=1.0),
        albumentations.ChannelShuffle(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.AdvancedBlur(p=0.5),
        # albumentations.GaussNoise(p=0.1, std_range=[0.01, 0.1]),
        albumentations.CLAHE(p=0.5),
        albumentations.RandomBrightnessContrast(p=0.5),
        albumentations.RandomGamma(p=0.5),
        albumentations.ColorJitter(p=0.5),
        albumentations.Normalize(),  # !!!
        ToTensorV2(transpose_mask=True),
    ]
)

val_transforms = albumentations.Compose(
    [
        # albumentations.SmallestMaxSize(max_size=IMAGE_SIZE),
        albumentations.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE, p=1.0),
        # albumentations.CropNonEmptyMaskIfExists(height=IMAGE_SIZE, width=IMAGE_SIZE),
        albumentations.Normalize(),  # !!!
        ToTensorV2(transpose_mask=True),
    ]
)

## Dataset

Набор данных Pascal VOC. Рассмотрим его версию для задачи сегментации. 

Сайт: http://host.robots.ox.ac.uk/pascal/VOC/

Лидерборд за 2012 год: http://host.robots.ox.ac.uk:8080/leaderboard/displaylb_main.php?challengeid=11&compid=5

При тех или иных проблемах со скачиванием с сайта соревнования, скачайте и распакуйте архив в папку `data` (`data/VOCdevkit`) отсюда: https://disk.yandex.ru/d/1jS3yBBN7YdZ-w

In [None]:
train_dataset = CustomVOCSegmentation(
    root="data",
    year="2012",
    image_set="train",
    download=False,  # set to True druing first run
    transform=train_transforms,  # transform, not transforms!
)

val_dataset = CustomVOCSegmentation(
    root="data",
    year="2012",
    image_set="val",
    download=False,
    transform=val_transforms,
)

In [None]:
image, target = train_dataset[0]
Image.fromarray(image.numpy().astype(np.uint8).transpose(1, 2, 0))

In [None]:
Image.fromarray(255 * target[15, :, :].numpy().astype(np.uint8))

## UNet model

Статья: https://arxiv.org/abs/1505.04597

In [None]:
model = UNet(in_channels=3, out_channels=21)
print(model)

In [None]:
count_pytorch_model_params(model)

## Accelerator

"Accelerate — это библиотека, которая позволяет запускать один и тот же код PyTorch в любой распределенной конфигурации, добавляя всего четыре строки кода! Короче говоря, обучение и вывод в больших масштабах стали простыми, эффективными и адаптируемыми". (c)

Сайт: https://huggingface.co/docs/accelerate/index

In [None]:
accelerator = Accelerator(cpu=False, mixed_precision="fp16")

## Checkpointer

Класс для сохранения наилучших версий модели в процессе обучения.

См. класс `Checkpointer` в `train.py`

## Обучаем модель

См. `train.py`

In [None]:
CLASSES_NUM = 21

BACKBONE_NAME = "resnet50"

LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
BETAS = (0.9, 0.999)
BATCH_SIZE = 32
NUM_WORKERS = 8
EPOCH_NUM = 100
SCHEDULER_STEP_SIZE = 50
SCHEDULER_GAMMA = 0.1
CHECKPOINTS_DIR = "checkpoints"
TENSORBOARD_DIR = "tensorboard"
RM_CHECKPOINTS_DIR = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
    drop_last=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
    drop_last=True,
)

accelerator = Accelerator(cpu="cpu" == DEVICE, mixed_precision="fp16")

# model = UNet(in_channels=3, out_channels=CLASSES_NUM, bilinear=True)
model = CustomUNet(backbone_name=BACKBONE_NAME, classes_num=CLASSES_NUM)
# model = deeplabv3_resnet50()

loss_fn = CrossEntropyLoss(ignore_index=255, reduction="sum")
metric_fn = IoUMetric(classes_num=CLASSES_NUM, reduction="macro")
metric_fns = {metric_fn.__class__.__name__: metric_fn}

optimizer = torch.optim.AdamW(
    model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, betas=BETAS
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer=optimizer, step_size=SCHEDULER_STEP_SIZE, gamma=SCHEDULER_GAMMA
)

os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
checkpointer = CheckpointSaver(
    accelerator=accelerator,
    model=model,
    metric_name=metric_fn.__class__.__name__,
    save_dir=CHECKPOINTS_DIR,
    rm_save_dir=RM_CHECKPOINTS_DIR,
    max_history=5,
    should_minimize=False,
)

In [None]:
os.makedirs(TENSORBOARD_DIR, exist_ok=True)
tensorboard_logger = torch.utils.tensorboard.SummaryWriter(log_dir=TENSORBOARD_DIR)

%load_ext tensorboard
%tensorboard --logdir "tensorboard"  --port 6006

In [None]:
# акселерируем
model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, val_dataloader, lr_scheduler
)

In [None]:
train(
    model=model,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_fn=loss_fn,
    metric_fns=metric_fns,
    lr_scheduler=lr_scheduler,
    accelerator=accelerator,
    epoch_num=EPOCH_NUM,
    checkpointer=checkpointer,
    tb_logger=tensorboard_logger,
    save_on_val=True,
)

## Загрузим и протестируем обученную модель

In [None]:
# model = UNet(in_channels=3, out_channels=CLASSES_NUM)
model = CustomUNet(backbone_name=BACKBONE_NAME, classes_num=CLASSES_NUM)
# model = deeplabv3_resnet50()
model = load_checkpoint(
    model=model,
    load_path=pjoin(CHECKPOINTS_DIR, "model_checkpoint_best.pt"),
)  # "custom_unet" "deeplabv3plus"
model = model.to(DEVICE)
model.eval()

In [None]:
sample_idx = 0
image, target = val_dataset[sample_idx]
target = torch.argmax(target, axis=0)
preds = torch.argmax(
    F.softmax(model(image.unsqueeze(0).to(DEVICE)), dim=1).squeeze(0), axis=0
)

fig, ax = plt.subplots(1, 3, figsize=(9, 18))
ax[0].imshow(image.numpy().transpose(1, 2, 0).astype(np.uint8))
ax[1].imshow(target.numpy())
ax[2].imshow(preds.cpu().numpy());

## Разметка данных с помощью CVAT

Сайт: https://www.cvat.ai/

## Обзоры бекбонов

- Обзор до ~2020: https://arxiv.org/pdf/2206.08016.pdf
- Чуть поновее: https://arxiv.org/pdf/2310.19909.pdf
- Трансформеры и VLM как-нибудь потом