# Implementing and training a Detection Transformer (DETR)

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import ops
import numpy as np


# Import the custom COCO dataset loader
from dataloaders.coco_od_pytorch import TorchCOCOLoader, collate_fn
from models.detr import DETR
from models.losses.detr_loss import compute_batch_loss

## Set the experiment configurations

In [None]:
# Batch size for dataloaders and image size for model/pre-processing
BATCH_SIZE = 2
IMAGE_SIZE = 480
CUDA_ENABLED = True if torch.cuda.is_available() else False
MAX_OBJECTS = 100
FREEZE_BACKBONE = True
EPOCHS = 1
LOG_FREQUENCY = 1 # Training-time losses will be logged according to this frequency
SAVE_FREQUENCY = 20 # Model weights will be saved according to this frequency
device = torch.device("cuda" if CUDA_ENABLED else "cpu") # Training device

## Create a PyTorch Dataloader

In [None]:
# Class labels for "vehicles_dataset"
CLASSES = ["N/A", "vehicle"]
EMPTY_CLASS_ID = 0 # ID of the dataset classes to treat as "empty" class

# NOTE: You can instead load the COCO CLASSES if you wish to train on the COCO dataset
#       by importing "from datasets_info.available_classes import DATASET_CLASSES". This
#       is a lookup dictionary in the format:
#       DATASET_CLASSES = {
#           "coco" : {
#               "class_names" : COCO_CLASSES,
#               "empty_class_id": 91
#           }
#           ....
#       }


# Load and COCO dataset (adjust the paths accordingly)
coco_ds_train = TorchCOCOLoader(
    '../data/vehicles_dataset/train',
    '../data/vehicles_dataset/train/_vehicle_annotations.json',
    max_boxes=MAX_OBJECTS,
    empty_class_id=EMPTY_CLASS_ID,
    image_size=IMAGE_SIZE,
)

coco_ds_val = TorchCOCOLoader(
    '../data/vehicles_dataset/valid',
    '../data/vehicles_dataset/valid/_vehicle_annotations.json',
    max_boxes=MAX_OBJECTS,
    empty_class_id=EMPTY_CLASS_ID,
    image_size=IMAGE_SIZE,
)

train_loader = DataLoader(
    coco_ds_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

val_loader = DataLoader(
    coco_ds_val, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn
)

print(f"Training dataset size: {len(coco_ds_train)}")
print(f"Validation dataset size: {len(coco_ds_val)}")

## Plot some samples

In [None]:
import matplotlib.pyplot as plt
from utils.visualizers import DETRBoxVisualizer

# Create a visualizer
visualizer = DETRBoxVisualizer(class_labels= CLASSES,
                               empty_class_id=0)

# Visualize batches
dataloader_iter = iter(train_loader)
for i in range(1):
    input_, (classes, boxes, masks) = next(dataloader_iter)
    fig = plt.figure(figsize=(10, 10), constrained_layout=True)

    for ix in range(4):
        t_cl = classes[ix]
        t_bbox = boxes[ix]
        mask = masks[ix].bool()

        # Filter padded classes/boxes using the binary mask...
        t_cl = t_cl[mask]
        t_bbox = t_bbox[mask] * IMAGE_SIZE

        # Convert to x1y1x2y2 for visualization and denormalize boxes..
        t_bbox = ops.box_convert(
            t_bbox, in_fmt='cxcywh', out_fmt='xyxy')
        
        im = input_[ix]

        ax = fig.add_subplot(2, 2, ix+1)
        visualizer._visualize_image(im, t_bbox, t_cl, ax=ax)

## Building the DETR model

In [None]:
# We do instantiate the model with the COCO dataset parameters in order to load pre-trained weights
# to fine-tune on a new dataset with...
detr_model = DETR(
    d_model=256, n_classes=92, n_tokens=225, 
    n_layers=6, n_heads=8, n_queries=MAX_OBJECTS
)

# Inspect shapes of outputs from the last layer
x = torch.randn((1, 3, 480, 480))
outs = detr_model(x)
pred_cl, pred_boxes = outs['layer_5'].values()

print()
print("*****************************************")
print(f"Predicted Classes shape: {pred_cl.shape}")
print(f"Predicted Boxes shape: {pred_boxes.shape}")

## Load pre-trained weights on COCO as a starting point 

In [None]:
CHECKPOINT_PATH = "<YOUR_COCO_WEIGHTS.pt>"

# Load the checkpoint
checkpoint = torch.load(CHECKPOINT_PATH, map_location=torch.device("cpu"))

# Load the weights into the model
print(detr_model.load_state_dict(checkpoint['state']))

# Adapt the class prediction head to our new dataset
detr_model.linear_class = nn.Linear(detr_model.linear_class.in_features, len(CLASSES))

# Verify output shapes with the new configurations
x = torch.randn((1, 3, 480, 480))
outs = detr_model(x)
pred_cl, pred_boxes = outs['layer_5'].values()

print()
print("*****************************************")
print(f"Predicted Classes shape: {pred_cl.shape}")
print(f"Predicted Boxes shape: {pred_boxes.shape}")

## Start the training

In [None]:
from torch.optim import AdamW

# Enable GPU if available
if CUDA_ENABLED:
    print(f"Model moved to GPU for training...")
    detr_model.cuda()
else:
    print("No GPU found, training will start on CPU...")

# DETR in the offical paper is trained with different learning rates for backbone and Transformer/prediction heads.
# In this case we will freeze the backbone and train only the Transformer/prediction heads as it's already pre-trained...
backbone_params = [
    p for n, p in detr_model.named_parameters() if 'backbone.' in n]

# Freeze backbone
if FREEZE_BACKBONE:
    for p in detr_model.backbone.parameters():
        p.requires_grad = False
print(f"CNN backbone is frozen for training: {FREEZE_BACKBONE}")

transformer_params = [
    p for n, p in detr_model.named_parameters() if 'backbone.' not in n]

optimizer = AdamW([
    {'params': transformer_params, 'lr': 1e-5},
], weight_decay=1e-4)


# Log the number of parameters
nparams = sum([p.nelement() for p in detr_model.parameters()]) / 1e6
print(f'DETR trainable parameters: {nparams:.1f}M')

In [None]:
# Ensure saving directory exists for the model checkpoints...
![ ! -d ckpts ] && mkdir ckpts

In [None]:
# Set the model to training mode
torch.set_grad_enabled(True)
detr_model.train()

NUM_BATCHES = len(train_loader)

losses = torch.tensor([], device=device)
class_losses = torch.tensor([], device=device)
box_losses = torch.tensor([], device=device)
giou_losses = torch.tensor([], device=device)

hist = []

# More detailed logs for all losses(tuple: (class, bbox, giou))
hist_detail = []


for epoch in range(EPOCHS):
    for batch_idx, (input_, (tgt_cl, tgt_bbox, tgt_mask)) in enumerate(train_loader):
        # Move data to device
        input_ = input_.to(device)
        tgt_cl = tgt_cl.to(device)
        tgt_bbox = tgt_bbox.to(device)
        tgt_mask = tgt_mask.bool().to(device)

        # Run inference
        outs = detr_model(input_)

        # Accumulate losses
        loss = torch.tensor(0.0, device=device)
        loss_class_batch = torch.tensor(0.0, device=device)
        loss_bbox_batch = torch.tensor(0.0, device=device)
        loss_giou_batch = torch.tensor(0.0, device=device)

        for name, out in outs.items(): 
            out['bbox'] = out['bbox'].sigmoid().to(device)
            out['cl'] = out['cl'].to(device)
            
            for o_bbox, t_bbox, o_cl, t_cl, t_mask in zip(
                out['bbox'], tgt_bbox, out['cl'], tgt_cl, tgt_mask):
        
                loss_class, loss_bbox, loss_giou = compute_batch_loss(
                    o_bbox, 
                    t_bbox, 
                    o_cl, 
                    t_cl, 
                    t_mask,
                    n_queries=MAX_OBJECTS,
                    empty_class_id=EMPTY_CLASS_ID,
                    device=device
                )
                
                sample_loss = 1 * loss_class + 5 * loss_bbox + 2 * loss_giou
                
                loss += sample_loss / BATCH_SIZE / len(outs)

                # Track individual losses per batch
                loss_class_batch += loss_class / BATCH_SIZE / len(outs)
                loss_bbox_batch += loss_bbox / BATCH_SIZE / len(outs)
                loss_giou_batch += loss_giou / BATCH_SIZE / len(outs)
            
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        
        # Clip gradient norms
        nn.utils.clip_grad_norm_(detr_model.parameters(), .1)
        optimizer.step()
        
        # Gather batch-level losses
        losses = torch.cat((losses, loss.unsqueeze(0)))
        class_losses = torch.cat((class_losses, loss_class_batch.unsqueeze(0)))
        box_losses = torch.cat((box_losses, loss_bbox_batch.unsqueeze(0)))
        giou_losses = torch.cat((giou_losses, loss_giou_batch.unsqueeze(0)))

    # Logging every 10 epochs
    if (epoch + 1) % LOG_FREQUENCY == 0:
        # Get the epoch losses (averaged from batch losses)
        loss_avg = losses[-NUM_BATCHES:].mean().item()
        epoch_loss_class = class_losses[-NUM_BATCHES:].mean().item()
        epoch_loss_bbox = box_losses[-NUM_BATCHES:].mean().item()
        epoch_loss_giou = giou_losses[-NUM_BATCHES:].mean().item()
       
        # Log the losses
        print(f'Epoch: {epoch+1}/{EPOCHS}, DETR Loss: {loss_avg:.4f}')
        print(f'→ Class Loss: {epoch_loss_class:.4f}, BBox Loss: {epoch_loss_bbox:.4f}, GIoU Loss: {epoch_loss_giou:.4f}')
        
        # Save history per epoch...
        hist.append(loss_avg)
        hist_detail.append((epoch_loss_class, epoch_loss_bbox, epoch_loss_giou))

    # Save every 20 epochs
    if (epoch + 1) % SAVE_FREQUENCY == 0:
        torch.save(detr_model.state_dict(), f'ckpts/model_epoch{epoch+1}.pt')
        np.save(f'ckpts/hist_epoch{epoch+1}.npy', hist)

## Plot the training metrics

In [None]:
from utils.visualizers import visualize_losses

# Plot and save the loss plots
visualize_losses(hist, hist_detail, save_dir='./')

## Load fine-tuned model and test inference

In [None]:
from utils.visualizers import DETRBoxVisualizer

WEIGHS_PATH = "../weights/model_epoch20.pt"
INFERENCE_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Trained on your custom dataset
detr_model = DETR(
    d_model=256, n_classes=len(CLASSES), n_tokens=225, 
    n_layers=6, n_heads=8, n_queries=MAX_OBJECTS
).to(INFERENCE_DEVICE)

# Load the checkpoint
print(detr_model.load_state_dict(torch.load(WEIGHS_PATH, map_location=torch.device(INFERENCE_DEVICE))))

# Run inference and check results
visualizer = DETRBoxVisualizer(class_labels= CLASSES,
                               empty_class_id=0)

# This will always run inference on a GPU if one is available...
visualizer.visualize_validation_inference(detr_model, coco_ds_val, collate_fn=collate_fn)