# Обучение моделей в Pytorch

In [None]:

# !pip install accelerate

In [1]:
import os
import random
from os.path import join as pjoin
from shutil import rmtree

import numpy as np
import torch
from accelerate import Accelerator
from dataset import CustomVOCSegmentation
from matplotlib import pyplot as plt
from PIL import Image
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from train import (
    CheckpointSaver,
    IoUMetric,
    MulticlassCrossEntropyLoss,
    MulticlassDiceLoss,
    load_checkpoint,
    train,
)
from unet import UNet, count_model_params

ModuleNotFoundError: No module named 'dataset'

In [None]:
def seed_everything(seed: int = 314159, torch_deterministic: bool = False) -> None:
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.use_deterministic_algorithms(torch_deterministic)


seed_everything(42, torch_deterministic=False)

## Dataset

Набор данных Pascal VOC. Рассмотрим его версию для задачи сегментации. 

Сайт: http://host.robots.ox.ac.uk/pascal/VOC/

Лидерборд за 2012 год: http://host.robots.ox.ac.uk:8080/leaderboard/displaylb_main.php?challengeid=11&compid=5

In [None]:
train_dataset = datasets.VOCSegmentation(
    root="data",
    year="2012",
    image_set="train",
    download=True,
    transforms=transforms,
)

val_dataset = datasets.VOCSegmentation(
    root="data",
    year="2012",
    image_set="val",
    download=True,
    transforms=transforms,
)

len(train_dataset), len(val_dataset)

In [None]:
train_dataset = CustomVOCSegmentation(
    root="data",
    year="2012",
    image_set="train",
    download=False,
    transform=transforms,  # transform!
)

val_dataset = CustomVOCSegmentation(
    root="data",
    year="2012",
    image_set="val",
    download=False,
    transform=transforms,  # transform!
)

In [None]:
image, target = train_dataset[0]
Image.fromarray(image.numpy().astype(np.uint8).transpose(1, 2, 0))

In [None]:
Image.fromarray(255 * target[15, :, :].numpy().astype(np.uint8))

## UNet model

![UNet](unet.jpg)

Визуализация разных типов сверток: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

См. `unet.py`

In [None]:
model = UNet(in_channels=3, out_channels=21)
print(model)

In [None]:
count_model_params(model)

## Accelerator

"Accelerate — это библиотека, которая позволяет запускать один и тот же код PyTorch в любой распределенной конфигурации, добавляя всего четыре строки кода! Короче говоря, обучение и вывод в больших масштабах стали простыми, эффективными и адаптируемыми". (c)

Сайт: https://huggingface.co/docs/accelerate/index

In [None]:
accelerator = Accelerator(cpu=False, mixed_precision="fp16")

## Checkpointer

Класс для сохранения наилучших версий модели в процессе обучения.

См. класс `Checkpointer` в `train.py`

## Обучаем модель

См. `train.py`

In [None]:
LEARNING_RATE = 1e-4
BATCH_SIZE = 4
NUM_WORKERS = 2
EPOCH_NUM = 20
CHECKPOINTS_DIR = "checkpoints"
TENSORBOARD_DIR = "tensorboard"
RM_CHECKPOINTS_DIR = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)
val_dataloader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)

model = UNet(in_channels=3, out_channels=21)

loss_fn = MulticlassCrossEntropyLoss(ignore_index=0)  # MulticlassDiceLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer=optimizer, step_size=5, gamma=0.8
)
metric_fn = loss_fn

os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
checkpointer = CheckpointSaver(
    accelerator=accelerator,
    model=model,
    metric_name="DICE",
    save_dir=CHECKPOINTS_DIR,
    rm_save_dir=RM_CHECKPOINTS_DIR,
    max_history=5,
    should_minimize=True,
)

In [None]:
# !pip install tensorboard
# tensorboard_logger = None

os.makedirs(TENSORBOARD_DIR, exist_ok=True)
tensorboard_logger = torch.utils.tensorboard.SummaryWriter(log_dir=TENSORBOARD_DIR)

In [None]:
# акселерируем
model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, val_dataloader, lr_scheduler
)

In [None]:
train(
    model=model,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_function=loss_fn,
    metric_function=metric_fn,
    lr_scheduler=lr_scheduler,
    accelerator=accelerator,
    epoch_num=EPOCH_NUM,
    checkpointer=checkpointer,
    tb_logger=tensorboard_logger,
    save_on_val=True,
)

## Загрузим и протестируем обученную модель

In [None]:
model = UNet(in_channels=3, out_channels=21)
model = load_checkpoint(
    model=model, load_path=pjoin(CHECKPOINTS_DIR, "model_checkpoint_best.pt")
)
model = model.to(DEVICE)
model.eval()

In [None]:
sample_idx = 0
image, target = train_dataset[sample_idx]
target = torch.argmax(target, axis=0)
preds = torch.argmax(
    F.softmax(model(image.unsqueeze(0).to(DEVICE)), dim=1).squeeze(0), axis=0
)

fig, ax = plt.subplots(1, 3, figsize=(9, 18))
ax[0].imshow(image.numpy().transpose(1, 2, 0).astype(np.uint8))
ax[1].imshow(target.numpy())
ax[2].imshow(preds.cpu().numpy());