# Aleket Faster R-CNN training notebook

In [1]:
# IMPORTS

# Standard Library
import os

# Third-Party Libraries
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
from IPython.display import clear_output

# Torch
import torch
from torchvision.models.detection import (
    FasterRCNN,
    fasterrcnn_resnet50_fpn,
    fasterrcnn_resnet50_fpn_v2,
    fasterrcnn_mobilenet_v3_large_fpn,
)
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import v2
from torch.utils.data import Subset

# Utils
from finetuning.aleket_dataset import AleketDataset, download_dataset
from finetuning.checkpoints import load_checkpoint, RunParams
from finetuning.metrics import Evaluator
from finetuning.training_and_evaluation import train
from utils.analyze import count_analyze
from infer import infer
from utils.visualize import visualize_bboxes, draw_heat_map
from utils.predictor import Predictor
from utils.consts import NUM_TO_CLASSES, VALIDATION_METRICS

In [2]:
def visualize_samples(dataset, images_to_visualize=4):
    """
    Visualizes samples from the dataset with bounding boxes and labels.

    Args:
        dataset (AleketDataset): The dataset to visualize samples from.
        images_to_visualize (int, optional): The number of images to visualize. Defaults to 4.
    """
    visualized_images = []

    for id, annot in enumerate(dataset.get_annots(None)):
        if len(annot["boxes"]) > 0:
            img, target = dataset[id]
            img = v2.functional.to_pil_image(img)

            bboxes = target["boxes"].cpu().tolist()
            labels = [NUM_TO_CLASSES[label.item()] for label in target["labels"]]

            img_with_boxes = visualize_bboxes(img, bboxes, labels)
            visualized_images.append(img_with_boxes)
            if len(visualized_images) == images_to_visualize:
                break

    fig = plt.figure(figsize=(60, 20))
    columns = 4
    rows = 1
    for i in range(1, columns * rows + 1):
        fig.add_subplot(rows, columns, i)
        plt.imshow(visualized_images[i - 1])
    plt.show()


def get_model(device, trainable_backbone_layers=3):
    """
    Loads a pretrained Faster R-CNN ResNet-50 FPN model and modifies the classification head
    to accommodate the specified number of classes in dataset (3 - including background).

    Args:
        device (torch.device): The device to move the model to (e.g., 'cuda' or 'cpu').
        trainable_backbone_layers (int, optional): Number of trainable backbone layers. Defaults to 3.

    Returns:
        FasterRCNN: The Faster R-CNN model with the modified classification head.
    """
    model = fasterrcnn_resnet50_fpn_v2(
        weights="DEFAULT", trainable_backbone_layers=trainable_backbone_layers
    )
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 3)

    return model.to(device)

In [3]:
# Device Selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Random Seed for Dataset split
SEED = 1
np_generator = np.random.default_rng(SEED)

# Dataset split
DATASET_FRACTION = 1
VALIDATION_FRACTION = 0.2
patched_dataset = AleketDataset(download_dataset("dataset_patched", ""))
full_dataset = AleketDataset(download_dataset("dataset_full_images", ""))
train_set, val_set = patched_dataset.split_dataset(
    DATASET_FRACTION, VALIDATION_FRACTION, np_generator
)

# Model
model = get_model(device, trainable_backbone_layers=5)

print(f"Using model: {model._get_name()}")

Using device: cuda
Dataset loaded from dataset_patched
Dataset loaded from dataset_full_images
Using model: FasterRCNN


In [5]:
RUN_NAME = "ONEFINALROUND_v2_tb=5"

TRAIN_COMPLETE = False
params = RunParams(
    run_name=RUN_NAME,
    batch_size=8,
    dataloader_workers=8, 
    total_epochs=150,
    augmentation={
        "horizontal_flip": {
            "p": 0.5
        },
        "vertical_flip": {
            "p": 0.5
        },
        "scale_jitter": {
            "target_size": (1024, 1024),
            "scale_range": (0.8, 1.2)
        },
        "perspective": {
            "distortion_scale": 0.2,
            "p": 0.5
        },
        "rotation": {
            "degrees": 30,
            "expand": True
        },
        "color_jitter": {
            "brightness": 0.1,
            "contrast": 0.1,
            "saturation": 0.05
        }
    },
    optimizer={
        "lr": 0.01,
        "weight_decay": 0.00009
    },
    lr_scheduler={
        "factor": 0.1,
        "patience": 10,
        "min_lr": 0.0001
    },
    validation_set=val_set,
    train_set=train_set
    )


print(f"Train parameters for '{RUN_NAME}'")

Train parameters for 'ONEFINALROUND_v2_tb=5'


In [None]:
parsed_params = params.parse(model, patched_dataset)
train_dataloader = parsed_params["train_loader"]
val_dataloader = parsed_params["val_loader"]
augmentation = parsed_params["augmentation"]

count_analyze(full_dataset.get_annots(), save_folder="full_dataset_statistics")
count_analyze(patched_dataset.get_annots(train_dataloader.dataset.indices), save_folder="patched_dataset__train_statistics")
count_analyze(patched_dataset.get_annots(val_dataloader.dataset.indices), save_folder="patched_dataset_val_statistics")

patched_dataset.augmentation = augmentation
visualize_samples(patched_dataset)

In [None]:
#START TRAINING
train(model, patched_dataset, params, device, checkpoints=True)

In [6]:
# CONTINUE TRAINING FROM CHECKPOINT
params = RunParams()
params.load(os.path.join("results", RUN_NAME, "params.json"))  # override parameters
FINISHED = False
while not FINISHED:  # might accure some unexcpected errors with bboxes in pytorch code
    try:
        train(
            model,
            patched_dataset,
            params,
            device,
            checkpoints=True,
            resume=True,
            verbose=True,
        )
        FINISHED = True
    except Exception as e:
        print(e)
        pass

clear_output(wait=False)
print("TRAIN COMPLETE")

TRAIN COMPLETE


In [8]:
#TESTING
RUN_NAME_TO_TEST = RUN_NAME

run_dir = os.path.join("results", RUN_NAME_TO_TEST)
params_path = os.path.join(run_dir, "params.json")
checkpoint_path = os.path.join(run_dir,"checkpoints", "best.pth")

params = RunParams()
params.load(params_path)

val_indices = full_dataset.to_indices(params.validation_set.keys())
val_subset = Subset(full_dataset, val_indices)

model = load_checkpoint(get_model(device), checkpoint_path)[0]

torch.save(model.state_dict(), "model.pth")
with open("val_image_list.txt", "w") as f:
  for name in params.validation_set.keys():
    f.write(f"{os.path.join(full_dataset.img_dir,name)}.jpeg\n") 

In [None]:
predictor = Predictor(
    model,
    device,
    detections_per_patch=150,
    detections_per_image=300,
    images_per_batch=4,
    image_size_factor=1,
    patches_per_batch=4,
)

iou_thrs = np.round(np.flip(np.arange(0.1, 0.5 + 1e-3, 0.1)), 2)
score_thrs = np.round(np.arange(0.1, 0.8 + 1e-4, 0.05), 2)

np.savetxt(os.path.join(run_dir, "nms_thrs.csv"), iou_thrs, delimiter=",", fmt="%.2f")
np.savetxt(
    os.path.join(run_dir, "score_thrs.csv"), score_thrs, delimiter=",", fmt="%.2f"
)

N = len(iou_thrs)
S = len(score_thrs)

eval = Evaluator(full_dataset.get_annots(val_indices))

results_ap = np.full((N, S), -1.0)
results_aad = np.full((N, S), -1.0)
results_acd = np.full((N, S), -1.0)


for i, n in enumerate(iou_thrs):
    for j, s in tqdm(enumerate(score_thrs), total=S):
        try:
            preds = predictor.get_predictions(val_subset, n, s)
            stats = eval.eval(preds)
            results_ap[i, j] = stats[VALIDATION_METRICS[0]]
            results_acd[i, j] = stats[VALIDATION_METRICS[-2]]
            results_aad[i, j] = stats[VALIDATION_METRICS[-1]]
        except Exception as e:
            print(e)

In [None]:
np.savetxt(
    os.path.join(run_dir, "ap_analysis.csv"), results_ap, delimiter=",", fmt="%.4f"
)
np.savetxt(
    os.path.join(run_dir, "aad_analysis.csv"), results_aad, delimiter=",", fmt="%.4f"
)
np.savetxt(
    os.path.join(run_dir, "acd_analysis.csv"), results_acd, delimiter=",", fmt="%.4f"
)

fig, axes = plt.subplots(3, 1, figsize=(60, 10))

draw_heat_map("AP", "score_thresh", "iou_thresh", results_ap, axes[0], score_thrs, iou_thrs)
draw_heat_map("AAD", "score_thresh", "iou_thresh", results_aad, axes[1], score_thrs, iou_thrs)
draw_heat_map("ACD", "score_thresh", "iou_thresh", results_acd, axes[2], score_thrs, iou_thrs)

plt.tight_layout()
plt.show()

In [11]:
predictor = Predictor(
    model,
    device,
    detections_per_patch=150,
    detections_per_image=300,
    image_size_factor=.5,
    images_per_batch=1,
    patches_per_batch=4,
)

image_list = [
    os.path.join(f"{full_dataset.img_dir}", f"{name}.jpeg")
    for name in params.validation_set.keys()
]

infer_dir = os.path.join(run_dir, "infer2")
os.makedirs(infer_dir, exist_ok=True)

infer(
    predictor,
    images=image_list,
    classes=NUM_TO_CLASSES,
    output_dir=infer_dir,
    iou_thresh=0.2,
    score_thresh=0.85,
    num_of_annotations_to_save=-1,
    save_annotated_images=True,
)

KeyboardInterrupt: 