# Training our bounding box and corner models

We first make any necessary imports:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision.ops as ops
from transformers import SamModel, SamProcessor
from PIL import Image, ImageDraw
import numpy as np
import json
import os
import timm
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import typing
from itertools import chain, groupby
from datetime import datetime

!pip install --quiet optuna
import optuna

In [None]:
from google.colab import drive
import sys

BASE_PROJECT_PATH = '/content/drive/MyDrive/Erdos Institute - Solar Panels Project'
drive.mount('/content/drive')
sys.path.append(os.path.join(BASE_PROJECT_PATH, 'modules'))
from solarutils import *

We define our five-phase multitask training routine:

In [None]:
def train_multitask_backbone(
    json_file: str, image_dir: str, device: torch.device,
    train_val_split_ratio: float = 0.8, phase0_12_split_ratio: float = 0.5,
    phase0_epochs: int = 5, phase0_lr: float = 1e-4,
    phase1_epochs: int = 5, phase1_lr: float = 1e-4, sigma: float = 0.5,
    phase2_epochs: int = 10, phase2_lr: float = 1e-5,
    phase3_epochs: int = 5, phase3_lr: float = 1e-4,
    phase4_epochs: int = 10, phase4_lr: float = 1e-5,
    loss_bbox_weight: float = 1.0, loss_corner_weight: float = 1.0,
    batch_sizes: typing.List[int] = [2,4], gamma = 1.0, criterion_corner: nn.Module = nn.MSELoss(),
    corner_strategy: str = 'basic',
    save_path_prefix: typing.Optional[str] = None, verbose: bool = False, seed: int = 42
) -> typing.Dict[str, typing.List[float]]:
    """
    Trains bbox and corner models in five phases with validation.

    Args:
        json_file: Path to the COCO-like JSON annotation file.
        image_dir: Path to the directory containing the images.
        device: The device to train the models on (e.g., 'cuda' or 'cpu').
        train_val_split_ratio: Ratio for splitting the dataset into training and validation sets.
        phase0_12_split_ratio: Ratio for splitting the training set for phases 0, 1 and 2.
        phase0_epochs: Number of epochs for phase 0.
        phase0_lr: Learning rate for phase 0.
        phase1_epochs: Number of epochs for phase 1.
        phase1_lr: Learning rate for phase 1.
        sigma: Learning rate adjustment factor for the bbox model in phase 1.
        phase2_epochs: Number of epochs for phase 2.
        phase2_lr: Learning rate for phase 2.
        phase3_epochs: Number of epochs for phase 3.
        phase3_lr: Learning rate for phase 3.
        phase4_epochs: Number of epochs for phase 4.
        phase4_lr: Learning rate for phase 4.
        loss_bbox_weight: Weight for the bounding box loss in the combined loss.
        loss_corner_weight: Weight for the corner loss in the combined loss.
        batch_sizes: A list containing batch size for bbox loader and corner loader.
        criterion_corner: The loss function for the corner model.
        corner_strategy: The strategy for cropping the images for the corner model ('basic' or 'crop').
        save_path_prefix: Prefix for saving the trained model weights and loss plot.
        verbose: If True, print detailed training progress.
        seed: Random seed for data splitting.

    Returns:
        A dictionary containing the training and validation loss history for each phase.
    """
    # --- 1. Setup Models and Share Backbone ---
    bbox_model = get_bbox_detection_model(num_classes=2).to(device)
    corner_model = CornerPredictor(device, backbone=bbox_model.backbone, strategy = corner_strategy).to(device)
    if verbose: print("ResNet50 with FPN backbone shared between Faster R-CNN and Corner Predictor.")

    # --- 2. Setup Datasets and Dataloaders ---
    with open(json_file, 'r') as f:
        all_image_ids = [img['id'] for img in json.load(f)['images']]

    train_size = int(len(all_image_ids) * train_val_split_ratio)
    train_ids, val_ids = random_split(all_image_ids, [train_size, len(all_image_ids) - train_size], generator=torch.Generator().manual_seed(seed))

    p0_size = int(len(train_ids) * phase0_12_split_ratio) #size of dataset for phase 0
    p12_size = len(train_ids) - p0_size #size of dataset for phases 1 and 2
    p0_ids_subset, p12_ids_subset = random_split(train_ids, [p0_size, p12_size], generator=torch.Generator().manual_seed(seed))

    p0_ids = [train_ids.dataset[i] for i in p0_ids_subset.indices]
    p12_ids = [train_ids.dataset[i] for i in p12_ids_subset.indices]
    val_ids = [all_image_ids[i] for i in val_ids.indices]

    if verbose: print("Validation Datasets:")
    val_bbox_dataset = BoundingboxDataset(json_file, image_dir, transform=bbox_transform, image_ids=val_ids, verbose = verbose)
    val_corner_dataset = SinglePanelDataset(json_file, image_dir, image_transform=corner_img_transform, mask_transform=corner_mask_transform, image_ids=val_ids, verbose = verbose)
    val_bbox_loader = DataLoader(val_bbox_dataset, batch_size=batch_sizes[0], shuffle=False, collate_fn=collate_fn_bbox, num_workers = 2)
    val_corner_loader = DataLoader(val_corner_dataset, batch_size=batch_sizes[1], shuffle=False, collate_fn=collate_fn_corner, num_workers = 2)


    # Initialize loss history
    losses: typing.Dict[str, typing.List[float]] = {
        'bbox_train': [], 'corner_train': [],
        'bbox_val': [], 'corner_val': [],
        'combined_train_phase2': [],
        'phase_boundaries': []
    }

    # For phases 0 and 1, we freeze the backbone
    for param in bbox_model.backbone.parameters(): param.requires_grad = False

    # --- 0. Phase 0: Train corner model head on first split ---
    print("\n" + "="*50 + f"\n--- Phase 0: Training Corner Head on {len(p0_ids)} images ---\n" + "="*50)

    if verbose: print("Training Datasets:")
    p0_corner_dataset = SinglePanelDataset(json_file, image_dir, image_transform=corner_img_transform, mask_transform=corner_mask_transform, image_ids=p0_ids, verbose = verbose)
    train_corner_loader_p0 = DataLoader(p0_corner_dataset, batch_size=batch_sizes[1], shuffle=True, collate_fn=collate_fn_corner, num_workers = 2)

    optimizer_corner_p0 = torch.optim.AdamW([p for p in corner_model.parameters() if p.requires_grad], lr=phase0_lr)

    p0_train_corner, p0_val_corner = train_corner_model(corner_model, phase0_epochs, train_corner_loader_p0, optimizer_corner_p0, criterion_corner, val_corner_loader, device, verbose = verbose)

    losses['corner_train'].extend(p0_train_corner)
    losses['corner_val'].extend(p0_val_corner)
    losses['bbox_train'].extend([np.nan] * phase0_epochs)
    losses['bbox_val'].extend([np.nan] * phase0_epochs)
    losses['phase_boundaries'].append(len(losses['corner_train']))

    # --- 4. Phase 1: Train Heads of both models on second split ---
    print("\n" + "="*50 + f"\n--- Phase 1: Training Heads on {len(p12_ids)} images ---\n" + "="*50)

    if verbose: print("Training Datasets:")
    p1_bbox_dataset = BoundingboxDataset(json_file, image_dir, transform=bbox_transform, image_ids=p12_ids, verbose = verbose)
    p1_corner_dataset = SinglePanelDataset(json_file, image_dir, image_transform=corner_img_transform, mask_transform=corner_mask_transform, image_ids=p12_ids, verbose = verbose)

    train_bbox_loader_p1 = DataLoader(p1_bbox_dataset, batch_size=batch_sizes[0], shuffle=True, collate_fn=collate_fn_bbox, num_workers = 2)
    train_corner_loader_p1 = DataLoader(p1_corner_dataset, batch_size=batch_sizes[1], shuffle=True, collate_fn=collate_fn_corner, num_workers = 2)

    optimizer_bbox_p1 = torch.optim.AdamW([p for p in bbox_model.parameters() if p.requires_grad], lr=phase1_lr*sigma) #sigma is a parameter adjusting learning rate of bbox model vs corner model
    optimizer_corner_p1 = torch.optim.AdamW([p for p in corner_model.parameters() if p.requires_grad], lr=phase1_lr)

    p1_train_bbox, p1_train_corner, p1_val_bbox, p1_val_corner = train_bbox_corner_together(
        bbox_model, corner_model, phase1_epochs, train_bbox_loader_p1, train_corner_loader_p1,
        optimizer_bbox_p1, optimizer_corner_p1, criterion_corner, val_bbox_loader, val_corner_loader, device, gamma = gamma, verbose = verbose
    )
    losses['bbox_train'].extend(p1_train_bbox)
    losses['corner_train'].extend(p1_train_corner)
    losses['bbox_val'].extend(p1_val_bbox)
    losses['corner_val'].extend(p1_val_corner)
    losses['phase_boundaries'].append(len(losses['bbox_train']))
    del p1_train_bbox, p1_train_corner, p1_val_bbox, p1_val_corner, train_bbox_loader_p1, train_corner_loader_p1, optimizer_bbox_p1, optimizer_corner_p1, p1_bbox_dataset, p1_corner_dataset

    # --- 5. Phase 2: Fine-tune Shared Backbone on second split ---
    print("\n" + "="*50 + f"\n--- Phase 2: Fine-tuning Shared Backbone on {len(p12_ids)} images ---\n" + "="*50)
    for param in bbox_model.parameters(): param.requires_grad = False
    for param in corner_model.parameters(): param.requires_grad = False
    for param in bbox_model.backbone.parameters(): param.requires_grad = True

    backbone_layers = [module for name, module in bbox_model.backbone.body.named_children() if name.startswith('layer')]+[bbox_model.backbone.fpn]
    param_groups = [{'params': layer.parameters(), 'lr': phase2_lr / (2**(len(backbone_layers)-1-i))} for i, layer in enumerate(backbone_layers)]
    optimizer_backbone_p2 = torch.optim.AdamW(param_groups)

    if verbose: print("Training Datasets:")
    p2_bbox_dataset = BoundingboxDataset(json_file, image_dir, transform=bbox_transform, image_ids=p12_ids, return_ids = True, only_pos = True, verbose = verbose)
    p2_corner_dataset = SinglePanelDataset(json_file, image_dir, image_transform=corner_img_transform, mask_transform=corner_mask_transform, image_ids=p12_ids, verbose = verbose)

    train_bbox_loader_p2 = DataLoader(p2_bbox_dataset, batch_size=batch_sizes[0], shuffle=True, collate_fn=collate_fn_bbox, num_workers = 2)

    for epoch in range(phase2_epochs):
        train_loss_bbox = 0
        train_loss_corner = 0
        train_loss_combined = 0
        bbox_model.train(); corner_model.train()
        for images, targets, image_ids in train_bbox_loader_p2:
            if not images: continue
            corner_data = collate_fn_corner([p2_corner_dataset.process_annotation(ann) for image_id in image_ids for ann in p2_corner_dataset.img2ann[image_id]])
            if corner_data[0] is None: continue #skip if collate function returns None, None, None, None
            corner_imgs, corner_bboxes, corner_masks, corner_keypoints = corner_data
            n_single_panels = len(corner_imgs)
            optimizer_backbone_p2.zero_grad(set_to_none = True)
            loss_dict_bbox = bbox_model([img.to(device) for img in images], [{k: v.to(device) for k, v in t.items()} for t in targets])
            losses_bbox = sum(loss for loss in loss_dict_bbox.values())
            train_loss_bbox += losses_bbox.item()
            outputs_corner = corner_model(corner_imgs.to(device), corner_bboxes.to(device), corner_masks.to(device))
            loss_corner = criterion_corner(outputs_corner, corner_keypoints.to(device))
            train_loss_corner += loss_corner.item()/n_single_panels
            combined_loss = (loss_bbox_weight * losses_bbox) + (loss_corner_weight/n_single_panels * loss_corner)
            train_loss_combined += combined_loss.item()
            combined_loss.backward()
            optimizer_backbone_p2.step()
            del images, targets, image_ids, corner_data, corner_imgs, corner_bboxes, corner_masks, corner_keypoints, loss_dict_bbox, losses_bbox, outputs_corner, loss_corner, combined_loss

        avg_train_loss_bbox = train_loss_bbox / len(train_bbox_loader_p2)
        avg_train_loss_corner = train_loss_corner / len(train_bbox_loader_p2)
        avg_train_loss_combined = train_loss_combined / len(train_bbox_loader_p2)
        if verbose: print(f"Backbone Finetune - Epoch {epoch+1}, BBox Train Loss: {avg_train_loss_bbox:.4f}, Corner Train Loss: {avg_train_loss_corner:.4f}, Combined Train Loss: {avg_train_loss_combined:.4f}.")
        losses['bbox_train'].append(avg_train_loss_bbox)
        losses['corner_train'].append(avg_train_loss_corner)
        losses['combined_train_phase2'].append(avg_train_loss_combined)
        # No validation in phase 2, so append NaN
        losses['bbox_val'].append(float('nan'))
        losses['corner_val'].append(float('nan'))
    losses['phase_boundaries'].append(len(losses['bbox_train']))
    del p2_bbox_dataset, p2_corner_dataset, train_bbox_loader_p2, optimizer_backbone_p2

    # --- 6. Phase 3: Re-train Heads on Full Training Data ---
    print("\n" + "="*50 + f"\n--- Phase 3: Re-training Heads on {len(train_ids)} images ---\n" + "="*50)
    for param in bbox_model.parameters(): param.requires_grad = True
    for param in corner_model.parameters(): param.requires_grad = True
    for param in bbox_model.backbone.parameters(): param.requires_grad = False

    optimizer_bbox_p3 = torch.optim.AdamW([p for name, p in bbox_model.named_parameters() if p.requires_grad], lr=phase3_lr)
    optimizer_corner_p3 = torch.optim.AdamW([p for name, p in corner_model.named_parameters() if p.requires_grad], lr=phase3_lr)

    if verbose: print("Training Datasets:")
    full_train_bbox_dataset = BoundingboxDataset(json_file, image_dir, transform=bbox_transform, image_ids=train_ids.indices, verbose = verbose)
    full_train_corner_dataset = SinglePanelDataset(json_file, image_dir, image_transform=corner_img_transform, mask_transform=corner_mask_transform, image_ids=train_ids.indices, verbose = verbose)
    full_train_bbox_loader = DataLoader(full_train_bbox_dataset, batch_size=batch_sizes[0], shuffle=True, collate_fn=collate_fn_bbox, num_workers = 2)
    full_train_corner_loader = DataLoader(full_train_corner_dataset, batch_size=batch_sizes[1], shuffle=True, collate_fn=collate_fn_corner, num_workers = 2)

    p3_train_bbox, p3_train_corner, p3_val_bbox, p3_val_corner = train_bbox_corner_together(
        bbox_model, corner_model, phase3_epochs, full_train_bbox_loader, full_train_corner_loader,
        optimizer_bbox_p3, optimizer_corner_p3, criterion_corner, val_bbox_loader, val_corner_loader, device, gamma = gamma, verbose = verbose
    )
    losses['bbox_train'].extend(p3_train_bbox)
    losses['corner_train'].extend(p3_train_corner)
    losses['bbox_val'].extend(p3_val_bbox)
    losses['corner_val'].extend(p3_val_corner)
    losses['phase_boundaries'].append(len(losses['bbox_train']))
    del p3_train_bbox, p3_train_corner, p3_val_bbox, p3_val_corner, optimizer_bbox_p3, optimizer_corner_p3

    # --- 7. Phase 4: Fine-tune Everything on Full Training Data ---
    print("\n" + "="*50 + f"\n--- Phase 4: Fine-tuning All Layers on {len(train_ids)} images ---\n" + "="*50)
    for param in bbox_model.parameters(): param.requires_grad = True
    for param in corner_model.parameters(): param.requires_grad = True

    optimizer_bbox_p4 = torch.optim.AdamW(list(bbox_model.parameters()), lr=phase4_lr)
    optimizer_corner_p4 = torch.optim.AdamW(list(corner_model.parameters()), lr=phase4_lr)

    if verbose: print("Training Datasets: Same as in Phase 3")

    p4_train_bbox, p4_train_corner, p4_val_bbox, p4_val_corner = train_bbox_corner_together(
        bbox_model, corner_model, phase4_epochs, full_train_bbox_loader, full_train_corner_loader,
        optimizer_bbox_p4, optimizer_corner_p4, criterion_corner, val_bbox_loader, val_corner_loader, device, gamma=gamma, verbose = verbose
    )
    losses['bbox_train'].extend(p4_train_bbox)
    losses['corner_train'].extend(p4_train_corner)
    losses['bbox_val'].extend(p4_val_bbox)
    losses['corner_val'].extend(p4_val_corner)
    losses['phase_boundaries'].append(len(losses['bbox_train']))
    del p4_train_bbox, p4_train_corner, p4_val_bbox, p4_val_corner, optimizer_bbox_p4, optimizer_corner_p4, full_train_bbox_dataset, full_train_corner_dataset, full_train_bbox_loader, full_train_corner_loader


    if save_path_prefix:
        torch.save(bbox_model.state_dict(), f"{save_path_prefix}_bbox.pth")
        torch.save(corner_model.state_dict(), f"{save_path_prefix}_corner.pth")
        print(f"Multi-task trained models saved with prefix: {save_path_prefix}")
    plot_losses(losses, save_path_prefix)
    del bbox_model, corner_model, val_bbox_loader, val_corner_loader
    return losses

We define the Optuna objective, for hyperparameter tuning:

In [None]:
def objective(trial: optuna.Trial, json_file: str, image_dir: str, device: torch.device, path: str) -> float:
    """
    Objective function for Optuna to minimize.

    Args:
        trial: An Optuna trial object.
        json_file: Path to the COCO-like JSON annotation file.
        image_dir: Path to the directory containing the images.
        device: The device to train the models on (e.g., 'cuda' or 'cpu').
        path: The path to save the Optuna study artifacts.

    Returns:
        float: The validation loss decrease to be maximized.
    """
    # Define hyperparameters to tune using trial suggestions
    phase0_lr = trial.suggest_float("phase0_lr", 1e-4, 1e-2, log=True)
    phase1_lr = trial.suggest_float("phase1_lr", 1e-4, 1e-2, log=True)
    phase2_lr = trial.suggest_float("phase2_lr", 1e-6, 1e-4, log=True)
    phase3_lr = trial.suggest_float("phase3_lr", 1e-5, 1e-4, log=True)
    phase4_lr = trial.suggest_float("phase4_lr", 1e-6, 1e-4, log=True)
    sigma = trial.suggest_float("sigma", 0.01, 1)
    phase0_epochs = trial.suggest_int("phase0_epochs", 3, 25)
    phase1_epochs = trial.suggest_int("phase1_epochs", 3, 25)
    phase2_epochs = trial.suggest_int("phase2_epochs", 3, 25)
    phase3_epochs = trial.suggest_int("phase3_epochs", 3, 20)
    phase4_epochs = trial.suggest_int("phase4_epochs", 3, 20)
    loss_bbox_weight = trial.suggest_float("loss_bbox_weight", 0.1, 0.9)
    loss_corner_weight = 1.0 - loss_bbox_weight # Ensure weights sum to 1
    phase0_12_split_ratio = trial.suggest_float("phase0_12_split_ratio", 0.4, 0.6)
    batch_size_bbox = trial.suggest_int("batch_size_bbox", 1, 8)
    batch_size_corner = trial.suggest_int("batch_size_corner", 1, 8)
    batch_sizes = [batch_size_bbox, batch_size_corner]

    # Fixed parameters
    CRITERION_CORNER = nn.MSELoss()

    # Paths
    TIME = datetime.now().strftime("%Y_%m_%d_%H_%M")
    MODEL_SAVE_PATH_PREFIX = os.path.join(path, f"{TIME}_model")

    # Call the training function
    losses = train_multitask_backbone(
        json_file=json_file,
        image_dir=image_dir,
        device=device,
        train_val_split_ratio=0.8,
        phase0_12_split_ratio=phase0_12_split_ratio,
        phase0_epochs=phase0_epochs, phase0_lr=phase0_lr,
        phase1_epochs=phase1_epochs, sigma = sigma,
        phase1_lr=phase1_lr,
        phase2_epochs=phase2_epochs,
        phase2_lr=phase2_lr,
        phase3_epochs=phase3_epochs,
        phase3_lr=phase3_lr,
        phase4_epochs=phase4_epochs,
        phase4_lr=phase4_lr,
        loss_bbox_weight=loss_bbox_weight,
        loss_corner_weight=loss_corner_weight,
        batch_sizes=batch_sizes,
        criterion_corner=CRITERION_CORNER,
        corner_strategy='basic',
        save_path_prefix=MODEL_SAVE_PATH_PREFIX, verbose = True
    )

    # Return the validation loss to minimize.
    # Handle potential NaNs in phase 2 validation.
    bbox_val_losses = [l for l in losses['bbox_val'] if not np.isnan(l)]
    corner_val_losses = [l for l in losses['corner_val'] if not np.isnan(l)]

    # Calculate the percentage of decrease between the first epoch validation loss and final validation loss
    bbox_val_loss_decrease = (bbox_val_losses[0] - bbox_val_losses[-1]) / bbox_val_losses[0]
    corner_val_loss_decrease = (corner_val_losses[0] - corner_val_losses[-1]) / corner_val_losses[0]

    # Return the combined decrease, which we will aim to maximize
    return bbox_val_loss_decrease + corner_val_loss_decrease

Main code block:

In [None]:
# ==========================================================================================
# Global definitions
# ==========================================================================================
# --- Configuration ---
TIME = datetime.now().strftime("%Y_%m_%d_%H_%M")
JSON_FILE = os.path.join(BASE_PROJECT_PATH, 'data', 'only_high_quality_training.json') #'annotations.json'
IMAGE_DIR = os.path.join(BASE_PROJECT_PATH, 'data', 'images')
SAM_CHECKPOINT = "Zigeng/SlimSAM-uniform-77" # big model: "facebook/sam-vit-huge", SlimSAM: "Zigeng/SlimSAM-uniform-50" "Zigeng/SlimSAM-uniform-77"

OUTPUT_JSON_PATH = os.path.join(BASE_PROJECT_PATH, 'data', f'{TIME}_inf_predictions.json')
MODEL_SAVE_PATH_PREFIX = os.path.join(BASE_PROJECT_PATH, 'models', f"{TIME}_model")
OPTUNA_PATH_PREFIX = os.path.join(BASE_PROJECT_PATH, 'models', f"hyperparameter_tuning_{TIME}")


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

bbox_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
corner_img_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
corner_mask_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])


# ==========================================================================================
# Main Execution Block
# ==========================================================================================

def main():
    pass


    # print("Starting hyperparameter tuning...")
    # os.makedirs(OPTUNA_PATH_PREFIX, exist_ok=True)
    # try:
    #     study = optuna.create_study(direction="maximize")
    #     study.optimize(lambda trial: objective(trial, JSON_FILE, IMAGE_DIR, DEVICE, OPTUNA_PATH_PREFIX), n_trials=10)
    # except RuntimeError as e:
    #     if "out of memory" in str(e):
    #         print(">>> CUDA out of memory. Triggering memory summary...")
    #         print(torch.cuda.memory_summary(device=None, abbreviated=False))
    #     raise e

    # print("\nHyperparameter tuning finished.")
    # print("Best trial:")
    # trial = study.best_trial

    # print(f"  Value: {trial.value}")
    # print("  Params: ")
    # for key, value in trial.params.items():
    #     print(f"    {key}: {value}")


    # # # --- Training ---
    # train_multitask_backbone(
    #     json_file=JSON_FILE,
    #     image_dir=IMAGE_DIR,
    #     device=DEVICE,
    #     train_val_split_ratio=0.8,
    #     phase0_12_split_ratio=0.5,
    #     phase0_epochs=10, phase0_lr=1e-2,
    #     phase1_epochs=5, sigma = 0.5,
    #     phase1_lr=1e-3,
    #     phase2_epochs=5,
    #     phase2_lr=1e-4,
    #     phase3_epochs=12,
    #     phase3_lr=1e-3,
    #     phase4_epochs=5,
    #     phase4_lr=1e-3,
    #     loss_bbox_weight=0.01,   #useful for balancing the different scales (empirically found)
    #     loss_corner_weight=0.99,
    #     batch_sizes=[4,16],
    #     gamma = 0.5,
    #     criterion_corner=nn.MSELoss(), #nn.SmoothL1Loss(beta = 100) could perform better than MSE loss, since it exaggerates outliers less which is good in the initial training steps when the params are random
    #     corner_strategy='basic',
    #     save_path_prefix=MODEL_SAVE_PATH_PREFIX, verbose = True
    # )


if __name__ == '__main__':
    main()