# Implementing and training a Detection Transformer (DETR)

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import ops


# Import the custom COCO dataset loader
from dataloaders.coco_od_pytorch import TorchCOCOLoader, collate_fn
from models.detr import DETR

## Set the experiment configurations

In [None]:
# Batch size for dataloaders and image size for model/pre-processing
BATCH_SIZE = 4
IMAGE_SIZE = 480
MAX_OBJECTS = 100
FREEZE_BACKBONE = True
EPOCHS = 150
LOG_FREQUENCY = 5 # Training-time losses will be logged according to this frequency
SAVE_FREQUENCY = 20 # Model weights will be saved according to this frequency
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Training device

## Create a PyTorch Dataloader

In [None]:
# NOTE: You can load the COCO dataset infromation or any other from the available datasets 
#       if it's in the DATASET_CLASSES class map. This map is a lookup dictionary with the
#       dataset name as key where each instance has the following attributes:
#           - "class_names" : The list of class names
#           - "empty_class_id": The ID of the class to be treated as the "empty" class for boxes
#           - "links": Contains some sort of link to download the dataset
# NOTE: All the available datasets are listed in the project README file.
from datasets.info import DATASET_CLASSES
CLASSES = DATASET_CLASSES["people_hq"]["class_names"]
EMPTY_CLASS_ID = DATASET_CLASSES["people_hq"]["empty_class_id"]


# Or explicitly set the  class labels/empty class ID for your custom dataset if its not added to the DATASET_CLASSES map...
# CLASSES = ["N/A", "something"]
# EMPTY_CLASS_ID = 0 # ID of the dataset classes to treat as "empty" class


# Load and COCO dataset (adjust the paths accordingly)
coco_ds_train = TorchCOCOLoader(
    '../../DL_Datasets/data/people_hq/train',
    '../../DL_Datasets/data/people_hq/train/_annotations.coco.json',
    max_boxes=MAX_OBJECTS,
    empty_class_id=EMPTY_CLASS_ID,
    image_size=IMAGE_SIZE,
    augment=True
)

coco_ds_val = TorchCOCOLoader(
    '../../DL_Datasets/data/people_hq/valid',
    '../../DL_Datasets/data/people_hq/valid/_annotations.coco.json',
    max_boxes=MAX_OBJECTS,
    empty_class_id=EMPTY_CLASS_ID,
    image_size=IMAGE_SIZE,
)

train_loader = DataLoader(
    coco_ds_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

val_loader = DataLoader(
    coco_ds_val, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn
)

print(f"Training dataset size: {len(coco_ds_train)}")
print(f"Validation dataset size: {len(coco_ds_val)}")

## Plot some samples

In [None]:
import matplotlib.pyplot as plt
from utils.visualizers import DETRBoxVisualizer

# Create a visualizer
visualizer = DETRBoxVisualizer(class_labels= CLASSES,
                               empty_class_id=0)

# Visualize batches
dataloader_iter = iter(train_loader)
for i in range(1):
    input_, (classes, boxes, masks, _) = next(dataloader_iter)
    fig = plt.figure(figsize=(10, 10), constrained_layout=True)

    for ix in range(4):
        t_cl = classes[ix]
        t_bbox = boxes[ix]
        mask = masks[ix].bool()

        # Filter padded classes/boxes using the binary mask...
        t_cl = t_cl[mask]
        t_bbox = t_bbox[mask] * IMAGE_SIZE

        # Convert to x1y1x2y2 for visualization and denormalize boxes..
        t_bbox = ops.box_convert(
            t_bbox, in_fmt='cxcywh', out_fmt='xyxy')
        
        im = input_[ix]

        ax = fig.add_subplot(2, 2, ix+1)
        visualizer._visualize_image(im, t_bbox, t_cl, ax=ax)

## Building the DETR model

**Note**: If you decide to start from pre-trained weights the `n_classes` argument should
match the number of classes the pre-trained model was trained on (e.g. if trained on "COCO" then it should be 92 classes).

In [None]:
# We do instantiate the model with the COCO dataset parameters in order to load pre-trained weights
# to fine-tune on a new dataset with...
detr_model = DETR(
    d_model=256, n_classes=92, n_tokens=225, 
    n_layers=6, n_heads=8, n_queries=MAX_OBJECTS, use_frozen_bn=True
)

## Load pre-trained weights as a starting point to explore Transfer Learning (optional)

Using pre-trained weights can significantly speed up training as the model doesn't start from 0.

Training DETR from scratch might take significant time even with enough GPU horsepower, while with
fine-tuning you can get somewhat decent results with 100-150 epochs depending on your dataset size.

In [None]:
CHECKPOINT_PATH = "<YOUR_DETR_WEIGHTS.pt>"

# Load the checkpoint
checkpoint = torch.load(CHECKPOINT_PATH, map_location=torch.device("cpu"))

# Load the weights into the model
# We don't use strict matching as you might want to use FrozenBatchNorm2D...
# Some pre-trained weights come from trainings using BatchNorm2D
print(detr_model.load_state_dict(checkpoint['state'], strict=False))

# Adapt the class prediction head to our new dataset
detr_model.linear_class = nn.Linear(detr_model.linear_class.in_features, len(CLASSES))

## Start the training

In [None]:
from models.trainer import DETRTrainer

# Create a trainer for DETR
trainer = DETRTrainer(model = detr_model,
                      train_loader= train_loader,
                      val_loader=val_loader,
                      device=device,
                      epochs=EPOCHS,
                      batch_size=BATCH_SIZE,
                      log_freq=5,
                      save_freq=SAVE_FREQUENCY,
                      freeze_backbone= FREEZE_BACKBONE,
                      num_queries=MAX_OBJECTS,
                      empty_class_id=EMPTY_CLASS_ID)

# Start the training
trainer.train()

## Plot the training metrics and save the plots

In [None]:
trainer.visualize_losses(save_dir = "./")

## Load fine-tuned model and test inference

In [None]:
from utils.visualizers import DETRBoxVisualizer

WEIGHS_PATH = "<YOUR_DETR_WEIGHTS>"
INFERENCE_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Trained on your custom dataset
detr_model = DETR(
    d_model=256, n_classes=len(CLASSES), n_tokens=225, 
    n_layers=6, n_heads=8, n_queries=MAX_OBJECTS, use_frozen_bn=True
).to(INFERENCE_DEVICE)

# Load the checkpoint
print(detr_model.load_state_dict(torch.load(WEIGHS_PATH, map_location=torch.device(INFERENCE_DEVICE)), strict=False))

# Run inference and check results
visualizer = DETRBoxVisualizer(class_labels= CLASSES,
                               empty_class_id=EMPTY_CLASS_ID)

# This will always run inference on a GPU if one is available...
for i in range(2):
    visualizer.visualize_validation_inference(detr_model, coco_ds_val, collate_fn=collate_fn, batch_size=1)

## Or run inference on a video

In [None]:
visualizer.visualize_video_inference(
    model=detr_model,
    video_path="<YOUR_VIDEO_PATH>",
    save_dir= "./",
    image_size=480,
    batch_size=4,
    nms_threshold=0.5
)

## Evaluate the trained model using the COCO API

In [None]:
from models.evaluator import DETREvaluator

evaluator = DETREvaluator(detr_model, coco_ds_val, device, EMPTY_CLASS_ID, collate_fn, batch_size=4)
stats = evaluator.evaluate()