# Aleket Faster R-CNN training notebook

In [None]:
# IMPORTS

# Standard Library
import os

# Third-Party Libraries
import numpy as np
from IPython.display import clear_output

# Torch
import torch

# Utils
from finetuning.aleket_dataset import AleketDataset, download_dataset
from finetuning.checkpoints import get_default_model, RunParams
from finetuning.training_and_evaluation import train

In [None]:
# Device Selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Random Seed for Dataset split
SEED = 1
np_generator = np.random.default_rng(SEED)

# Dataset split
DATASET_FRACTION = 1
VALIDATION_FRACTION = 0.2
patched_dataset = AleketDataset(download_dataset("dataset_patched", ""))
full_dataset = AleketDataset(download_dataset("dataset_full_images", ""))
train_set, val_set = patched_dataset.split_dataset(
    DATASET_FRACTION, VALIDATION_FRACTION, np_generator
)

# Model
model = get_default_model(device, trainable_backbone_layers=5)

print(f"Using model: {model._get_name()}")

In [None]:
RUN_NAME = "final"

params = RunParams(
    run_name=RUN_NAME,
    batch_size=8,
    dataloader_workers=8, 
    total_epochs=150,
    augmentation={
        "horizontal_flip": {
            "p": 0.5
        },
        "vertical_flip": {
            "p": 0.5
        },
        "scale_jitter": {
            "target_size": (1024, 1024),
            "scale_range": (0.8, 1.2)
        },
        "perspective": {
            "distortion_scale": 0.2,
            "p": 0.5
        },
        "rotation": {
            "degrees": 30,
            "expand": True
        },
        "color_jitter": {
            "brightness": 0.1,
            "contrast": 0.1,
        }
    },
    optimizer={
        "lr": 0.01,
        "weight_decay": 0.00009
    },
    lr_scheduler={
        "factor": 0.1,
        "patience": 10,
        "min_lr": 0.0001
    },
    validation_set=val_set,
    train_set=train_set
    )


print(f"Train parameters for '{RUN_NAME}'")

In [None]:
#START TRAINING
try:
    train(model, patched_dataset, params, device, checkpoints=True)
finally:
    pass

In [None]:
# CONTINUE TRAINING FROM CHECKPOINT
params = RunParams()
params.load(os.path.join("results", RUN_NAME, "params.json"))  # override parameters
FINISHED = False
while not FINISHED:  # might accure some unexcpected errors with bboxes in pytorch code
    try:
        train(
            model,
            patched_dataset,
            params,
            device,
            checkpoints=True,
            resume=True,
            verbose=True,
        )
        FINISHED = True
    except Exception as e:
        print(e)
        pass


In [None]:
torch.save(model.state_dict(), "model.pth")