# Aleket Faster R-CNN training notebook

In [None]:
%pip install pillow
%pip install numpy<2.0
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
%pip install matplotlib
%pip install gdown
%pip install tqdm

from IPython.display import clear_output
clear_output(wait=False)

print("ALL DEPENDENCIES INSTALLED")


In [None]:
# IMPORTS

# Standard Library
import os

# Third-Party Libraries
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from IPython.display import clear_output

# Torch
import torch

# Utils
from aleket_dataset import AleketDataset, download_dataset, split_dataset
from utils import get_model, load_checkpoint
from training_and_evaluation import train
from dataset_statisics import visualize_samples, count_analyze
from run_params import RunParams, parse_params
from predictor import Predictor
from metrics import Evaluator, VALIDATION_METRICS

In [None]:
# Helper functions
def augment_example(ds):
    examples = visualize_samples(ds, image_ids_to_visualize=list(range(4)))
    fig=plt.figure(figsize=(40, 10))
    columns = 4
    rows = 1
    for i in range(1, columns*rows +1):
        fig.add_subplot(rows, columns, i)
        plt.imshow(examples[i-1])
    plt.show()

def draw_heat_map(name: str, values: np.ndarray, ax: Axes, x_ticks: np.ndarray, y_ticks: np.ndarray):

    masked_results = np.ma.masked_where(values == -1, values)
    ax.imshow(masked_results, cmap='viridis', vmin=0, interpolation='nearest')

    X = len(x_ticks)
    Y = len(y_ticks)

    ax.set_title(name)
    ax.set_xlabel('Score Threshold')
    ax.set_ylabel('NMS Threshold')
    ax.set_xticks(np.arange(X))
    ax.set_yticks(np.arange(Y))
    ax.set_xticklabels(x_ticks)
    ax.set_yticklabels(y_ticks)
    max_val = values.max()
    
    for i in range(Y):
        for j in range(X):
            value = values[i, j]
            color = 'black' if value > max_val/2 else 'white'
            text = ax.text(j, i, f'{value:.3f}', ha="center", va="center", color=color)

In [None]:
# Device Selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Random Seed for Dataset split
SEED = 1
np_generator = np.random.default_rng(SEED)

# Dataset split
DATASET_FRACTION = 1
VALIDATION_FRACTION = 0.2
dataset = AleketDataset(download_dataset("dataset_patched", ""))
full_dataset = AleketDataset(download_dataset("dataset_full_images", ""))
train_set, val_set = split_dataset(dataset, DATASET_FRACTION, VALIDATION_FRACTION, np_generator)

# Model
model = get_model(device, trainable_backbone_layers=3)

print(f"Using model: {model._get_name()}")

In [None]:
RUN_NAME = "ac_run6_v2_tb=5"

TRAIN_COMPLETE = False
params = RunParams(
    run_name=RUN_NAME,
    batch_size=4,
    dataloader_workers=4, 
    total_epochs=150,
    augmentation={  
        "horizontal_flip": {
            "p": 0.5
        },
        "vertical_flip": {
            "p": 0.5
        },
        "scale_jitter": {
            "target_size": (1024, 1024),
            "scale_range": (0.5, 1.3)
        },
    },
    optimizer={
        "lr": 0.005,
        "weight_decay": 0.00009
    },
    lr_scheduler={
        "factor": 0.1,
        "patience": 15,
        "min_lr": 0.0001
    },
    validation_set=val_set,
    train_set=train_set
    )


print(f"Train parameters for '{RUN_NAME}'")

In [None]:
parsed_params = parse_params(params, model, dataset)
train_dataloader = parsed_params["train_loader"]
val_dataloader = parsed_params["val_loader"]
augmentation = parsed_params["augmentation"]

count_analyze(full_dataset, save_folder="full_dataset_statistics")
count_analyze(dataset, save_folder="patched_dataset_statistics")
count_analyze(dataset, indices=train_dataloader.dataset.indices, save_folder="patched_train_dataset_statistics")
count_analyze(dataset, indices=val_dataloader.dataset.indices, save_folder="patched_val_dataset_statistics")

dataset.augmentation = augmentation
augment_example(train_dataloader.dataset)

In [None]:
train(model, dataset, params, device, checkpoints=True)
clear_output(wait=False)
print("TRAIN COMPLETE")

In [None]:

params = RunParams()
params.load(os.path.join('results', RUN_NAME, "params.json")) # override parameters
train(model, dataset, params, device, checkpoints=True, resume=True, verbose=True)
clear_output(wait=False)
print("TRAIN COMPLETE")

In [None]:
RUN_NAME_TO_TEST = RUN_NAME


run_dir = os.path.join("results", RUN_NAME_TO_TEST)
params_path = os.path.join(run_dir, "params.json")
checkpoint_path = os.path.join(run_dir,"checkpoints", "best.pth")

model = get_model(device)
params = RunParams()
params.load(params_path)

val_indices = full_dataset.to_indices(params.validation_set.keys())
model = load_checkpoint(model, checkpoint_path)[0]

In [None]:
predictor = Predictor(
    model,
    device,
    detections_per_patch=150,
    detections_per_image=300,
    images_per_batch=2,
    image_size_factor=1,
    patches_per_batch=8,
)

nms_thrs = np.round(np.flip(np.arange(0.2, 0.5 + 1e-3, 0.1)),2)
score_thrs = np.round(np.arange(0.1, 0.8 + 1e-4, 0.05),2)

np.savetxt(os.path.join(run_dir,"nms_thrs.csv"), nms_thrs, delimiter=',', fmt='%.2f')
np.savetxt(os.path.join(run_dir,"score_thrs.csv"), score_thrs, delimiter=',', fmt='%.2f')

N = len(nms_thrs)
S = len(score_thrs)

eval = Evaluator(full_dataset, val_indices,)

results_ap  = np.full((N,S), -1.0)
results_aad = np.full((N,S), -1.0)
results_acd = np.full((N,S), -1.0)

for i, n in enumerate(nms_thrs):
    for j, s in tqdm(enumerate(score_thrs), total=S):
        try:
            stats = predictor.eval_dataset(full_dataset, val_indices, n, s, eval)
            results_ap[i, j] = stats[VALIDATION_METRICS[0]]
            results_acd[i, j] = stats[VALIDATION_METRICS[-2]]
            results_aad[i, j] = stats[VALIDATION_METRICS[-1]]
        except Exception as e:
            break

In [None]:
np.savetxt(os.path.join(run_dir,"ap_analysis.csv"), results_ap, delimiter=',', fmt='%.4f')
np.savetxt(os.path.join(run_dir,"aad_analysis.csv"), results_aad, delimiter=',', fmt='%.4f')
np.savetxt(os.path.join(run_dir,"acd_analysis.csv"), results_acd, delimiter=',', fmt='%.4f')

fig, axes = plt.subplots(3, 1, figsize=(60, 10))

draw_heat_map("AP",  results_ap, axes[0], score_thrs, nms_thrs)
draw_heat_map("AAD", results_aad, axes[1],score_thrs, nms_thrs)
draw_heat_map("ACD", results_acd, axes[2],score_thrs, nms_thrs)

plt.tight_layout()
plt.show()

In [None]:
predictor = Predictor(
    model,
    device,
    detections_per_patch=150,
    detections_per_image=300,
    image_size_factor=1,
    images_per_batch=1,
    patches_per_batch=4,
)

image_list = [os.path.join(f"{full_dataset.img_dir}",f"{name}.jpeg") for name in params.validation_set.keys()]

infer_dir = os.path.join(run_dir, "infer")
os.makedirs(infer_dir, exist_ok=True)

predictor.infer(
    images=image_list,
    output_dir=infer_dir,
    nms_thresh=0.2,
    score_thresh=0.8,
    num_of_annotated_images_to_save=10,
    save_bboxes=False,    
)
