# Aleket Faster R-CNN training notebook

In [1]:
# IMPORTS

# Standard Library
import os

# Third-Party Libraries
import numpy as np
from IPython.display import clear_output

# Torch
import torch

# Utils
from data.aleket_dataset import AleketDataset, download_dataset
from data.checkpoints import get_default_model, RunParams
from train.training_and_evaluation import train

In [2]:
# Device Selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Random Seed for Dataset split
SEED = 1

# Dataset split
DATASET_FRACTION = 1
VALIDATION_FRACTION = 0.2

full_dataset = AleketDataset(download_dataset("../datasets/orobanche_cummana/images", ""))
patched_dataset = AleketDataset(download_dataset("../datasets/orobanche_cummana/patched/images", ""))

patched_train_indices = patched_dataset.load_split_indices("../datasets/orobanche_cummana/patched/autosplit_train.txt")
patched_val_indices = patched_dataset.load_split_indices("../datasets/orobanche_cummana/patched/autosplit_val.txt")

# Model
model = get_default_model(device, trainable_backbone_layers=5)

print(f"Using model: {model._get_name()}")

Using device: cuda
Using model: FasterRCNN


In [3]:
RUN_NAME = "final"

params = RunParams(
    run_name=RUN_NAME,
    batch_size=8,
    dataloader_workers=8, 
    total_epochs=150,
    augmentation={
        "horizontal_flip": {
            "p": 0.5
        },
        "vertical_flip": {
            "p": 0.5
        },
        "scale_jitter": {
            "target_size": (1024, 1024),
            "scale_range": (0.7, 1.2)
        },
        "perspective": {
            "distortion_scale": 0.25,
            "p": 0.5
        },
        "rotation": {
            "degrees": 50,
            "expand": True
        },
        "color_jitter": {
            "brightness": 0.1,
            "contrast": 0.1,
        }
    },
    optimizer={
        "lr": 0.015,
        "weight_decay": 0.00009
    },
    lr_scheduler={
        "factor": 0.1,
        "patience": 10,
        "min_lr": 0.0001
    },
    val_indices=patched_val_indices,
    train_indices=patched_train_indices,
    )


print(f"Train parameters for '{RUN_NAME}'")

Train parameters for 'final'


In [6]:
#START TRAINING
try:
    params.save("train_params.json")
    train(model, patched_dataset, params, device, checkpoints=True)
finally:
    pass


Epoch 1/150; Learning rate: 0.015


Training batches:   0%|          | 0/4512 [00:08<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# CONTINUE TRAINING FROM CHECKPOINT
params = RunParams()
params.load(os.path.join("results", RUN_NAME, "params.json"))  # override parameters
FINISHED = False
while not FINISHED:  # might accure some unexcpected errors with bboxes in pytorch code
    try:
        train(
            model,
            patched_dataset,
            params,
            device,
            checkpoints=True,
            resume=True,
            verbose=True,
        )
        FINISHED = True
    except Exception as e:
        print(e)
        pass


In [None]:
from data.checkpoints import load_checkpoint


model = load_checkpoint(model, os.path.join("results", RUN_NAME, "checkpoints", "best.pth"))[0]
torch.save(model.state_dict(), "best_model.pth")
model = load_checkpoint(model, os.path.join("results", RUN_NAME, "checkpoints", "last.pth"))[0]
torch.save(model.state_dict(), "last_model.pth")