# Demo training code

This is a simplified version of the training code to be used as a demo. The more complete version, with more options, is in the `train.py` cli script. Checkpoints and tensorboard summaries are saved in `notebooks/runs/`.

## Arguments

In [None]:
# Path to COCO formatted object detection dataset
data_path = '../data/all_but_ws_and_fb_fixed/'  

# Ignorable arguments
epochs = 35
save_every_num_epochs = None  # Optional
evaluate_every_num_epochs = 2
lr = 0.01
momentum = 0.9
weight_decay = 1e-4
lr_steps = [10, 11]
lr_gamma = 0.1
batch_size = 3
workers = 8
run_name = None  # Optional, str used to name Tensorboard summaries
num_draw_predictions = 5
draw_threshold = 0.5

## Code

In [None]:
# Notebooks are stored in 'notebooks/' which breaks my imports
import sys
sys.path.insert(0, '..')

import os
import datetime
import time
import shutil

import torch
import torch.utils.data
from models import detection
from torch.utils.tensorboard import SummaryWriter

from coco_utils import get_coco  # get_coco_kp

from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
from engine import train_one_epoch, evaluate

import utils

In [None]:
# Create summary writer for Tensorboard
if run_name:
    log_dir_path = f"runs/{run_name}" if run_name else None
    if os.path.isdir(log_dir_path):
        delete = input(f"Summary folder '{log_dir_path}' already exists. Overwrite it [yes, y / no, n]?")
        if delete in ('yes', 'y'):
            shutil.rmtree(log_dir_path)
        else:
            print(f"Chose another run name or delete the folder then!")
            exit()
else:
    log_dir_path = None
writer = SummaryWriter(log_dir=log_dir_path)

# Create datasets
dataset, num_classes, label_names = get_coco(data_path, image_set='train')
print(f"Categorizing into {num_classes} classes")
dataset_test, _, _ = get_coco(data_path, image_set='val')

# Create samplers
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
group_ids = create_aspect_ratio_groups(dataset)
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, batch_size)

# Create dataloaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_sampler=train_batch_sampler, num_workers=workers,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1,
    sampler=test_sampler, num_workers=workers,
    collate_fn=utils.collate_fn)

# Create model
device = torch.device('cuda' if torch.has_cuda else 'cpu')
model = detection.fasterrcnn_resnet50_fpn(num_classes=num_classes, pretrained=False)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params, lr=lr, momentum=momentum, weight_decay=weight_decay)

lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=lr_steps, gamma=lr_gamma
)

# Train
print("Start training")
start_time = time.time()
for epoch in range(epochs):
    start_epoch = time.time()
    train_one_epoch(
        model, optimizer, data_loader, device, epoch, 20, writer, label_names
    )
    print(f"Epoch time {time.time() - start_epoch}")
    writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step=epoch)
    lr_scheduler.step()

    if save_every_num_epochs and epoch % save_every_num_epochs == 0:
        utils.save_on_master({
            'model': model_without_ddp.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'label_names': label_names},
            os.path.join(writer.log_dir, 'model_{}.pth'.format(epoch))
        )

    if epoch % evaluate_every_num_epochs == 0:
        evaluate(
            model, data_loader_test, epoch, writer, draw_threshold,
            label_names, num_draw_predictions, device=device
        )

# Save final checkpoint after training is done
utils.save_on_master({
    'model': model_without_ddp.state_dict(),
    'optimizer': optimizer.state_dict(),
    'lr_scheduler': lr_scheduler.state_dict(),
    'label_names': label_names},
    os.path.join(writer.log_dir, 'model_finished.pth')
)

writer.close()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))