# Implementing and training a Detection Transformer (DETR)

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import ops
import numpy as np


# Import the custom COCO dataset loader
from dataloaders.coco_od_pytorch import TorchCOCOLoader, collate_fn
from models.detr import DETR
from models.losses.sample_detr_loss import compute_sample_loss

## Create a PyTorch Dataloader

In [None]:
BATCH_SIZE = 4
IMAGE_SIZE = 480
EMPTY_CLASS_ID = 0 # ID of the dataset classes to treat as "empty" class

# Class labels for "vehicles dataset"
CLASSES = ["N/A", "vehicle"]

    

# Load the tiny COCO dataset
coco_ds_train = TorchCOCOLoader(
    '../data/vehicles_dataset/train',
    '../data/vehicles_dataset/train/_vehicle_annotations.json'
)

coco_ds_val = TorchCOCOLoader(
    '../data/vehicles_dataset/valid',
    '../data/vehicles_dataset/valid/_vehicle_annotations.json'
)

train_loader = DataLoader(
    coco_ds_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

val_loader = DataLoader(
    coco_ds_val, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn
)

print(f"Training dataset size: {len(coco_ds_train)}")
print(f"Validation dataset size: {len(coco_ds_val)}")

## Plot some samples

In [None]:
import matplotlib.pyplot as plt
from utils.visualizer import DETRBoxVisualizer

# Create a visualizer
visualizer = DETRBoxVisualizer(class_labels= CLASSES,
                               empty_class_id=0)

# Visualize batches
dataloader_iter = iter(train_loader)
for i in range(1):
    input_, (classes, boxes) = next(dataloader_iter)
    fig = plt.figure(figsize=(10, 10), constrained_layout=True)

    for ix in range(4):
        t_cl = classes[ix]
        t_bbox = boxes[ix]

        t_bbox = ops.box_convert(
            t_bbox * IMAGE_SIZE, in_fmt='cxcywh', out_fmt='xyxy')
        
        im = input_[ix]

        ax = fig.add_subplot(2, 2, ix+1)
        visualizer._visualize_image(im, t_bbox, t_cl, ax=ax)

## Building the DETR model

In [None]:
# We do instantiate the model with the COCO dataset parameters in order to load pre-trained weights
# to fine-tune on a new dataset with...
detr_model = DETR(
    d_model=256, n_classes=92, n_tokens=225, 
    n_layers=6, n_heads=8, n_queries=100
)

# Inspect shapes of outputs from the last layer
x = torch.randn((1, 3, 480, 480))
outs = detr_model(x)
pred_cl, pred_boxes = outs['layer_5'].values()

print()
print("*****************************************")
print(f"Predicted Classes shape: {pred_cl.shape}")
print(f"Predicted Boxes shape: {pred_boxes.shape}")

## Load pre-trained weights as a starting point 

In [None]:
CHECKPOINT_PATH = "../weights/model_ep150.pt" # COCO weights

# Load the checkpoint
checkpoint = torch.load(CHECKPOINT_PATH, map_location=torch.device("cpu"))

# Load the weights into the model
print(detr_model.load_state_dict(checkpoint['state']))

# Adapt the class prediction head to our new dataset
detr_model.linear_class = nn.Linear(detr_model.linear_class.in_features, len(CLASSES))

# Verify output shapes with the new configurations
x = torch.randn((1, 3, 480, 480))
outs = detr_model(x)
pred_cl, pred_boxes = outs['layer_5'].values()

print()
print("*****************************************")
print(f"Predicted Classes shape: {pred_cl.shape}")
print(f"Predicted Boxes shape: {pred_boxes.shape}")

## Start the training

In [None]:
from torch.optim import AdamW

CUDA_ENABLED = False
FREEZE_BACKBONE = True

# Enable GPU if available
if torch.cuda.is_available():
    print(f"Model moved to GPU for training...")
    detr_model.cuda()
    CUDA_ENABLED = True
else:
    print("No GPU found, training will start on CPU...")

# DETR in the offical paper is trained with different learning rates for backbone and Transformer/prediction heads.
# In this case we will freeze the backbone and train only the Transformer/prediction heads as it's already pre-trained...
backbone_params = [
    p for n, p in detr_model.named_parameters() if 'backbone.' in n]

# Freeze backbone
if FREEZE_BACKBONE:
    for p in detr_model.backbone.parameters():
        p.requires_grad = False
print(f"CNN backbone is frozen for training: {FREEZE_BACKBONE}")

transformer_params = [
    p for n, p in detr_model.named_parameters() if 'backbone.' not in n]

optimizer = AdamW([
    {'params': transformer_params, 'lr': 1e-5},
], weight_decay=1e-4)


# Log the number of parameters
nparams = sum([p.nelement() for p in detr_model.parameters()]) / 1e6
print(f'DETR trainable parameters: {nparams:.1f}M')

In [None]:
# Ensure saving directory exists for the model checkpoints...
![ ! -d ckpts ] && mkdir ckpts

In [None]:
# Set the model to training mode
torch.set_grad_enabled(True)
detr_model.train()

EPOCHS = 30
LOG_FREQUENCY = 10
SAVE_FREQUENCY = 20
NUM_BATCHES = len(train_loader)

losses = []
class_losses = []
box_losses = []
giou_losses = []

hist = []

# More detailed logs for all losses(tuple: (class, bbox, giou))
hist_detail = []


for epoch in range(EPOCHS):
    for batch_idx, (input_, (tgt_cl, tgt_bbox)) in enumerate(train_loader):
        input_ = input_.cuda()
        tgt_cl = tgt_cl.cuda()
        tgt_bbox = tgt_bbox.cuda()
        
        outs = detr_model(input_)
        
        loss = torch.Tensor([0]).cuda()

        # For detailed logging of all the losses
        loss_class_batch = 0.0
        loss_bbox_batch = 0.0
        loss_giou_batch = 0.0

        for name, out in outs.items(): 
            out['bbox'] = out['bbox'].sigmoid()
            
            for o_bbox, t_bbox, o_cl, t_cl in zip(
                out['bbox'], tgt_bbox, out['cl'], tgt_cl):
        
                loss_class, loss_bbox, loss_giou = compute_sample_loss(
                    o_bbox, t_bbox, o_cl, t_cl, empty_class_id=EMPTY_CLASS_ID)
                
                sample_loss = 1 * loss_class + 5 * loss_bbox + 2 * loss_giou
                
                loss += sample_loss / BATCH_SIZE / len(outs)

                # Track individual losses per batch
                loss_class_batch += loss_class.item() / BATCH_SIZE / len(outs)
                loss_bbox_batch += loss_bbox.item() / BATCH_SIZE / len(outs)
                loss_giou_batch += loss_giou.item() / BATCH_SIZE / len(outs)
            
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        
        # Clip gradient norms
        nn.utils.clip_grad_norm_(detr_model.parameters(), .1)
        optimizer.step()
        
        losses.append(loss.item())
        class_losses.append(loss_class_batch)
        box_losses.append(loss_bbox_batch)
        giou_losses.append(loss_giou_batch)

    # Logging every 10 epochs
    if (epoch + 1) % LOG_FREQUENCY == 0:
        # Get the epoch losses (averaged from batch losses)
        loss_avg = np.mean(losses[-NUM_BATCHES:])
        epoch_loss_class = np.mean(class_losses[-NUM_BATCHES:])
        epoch_loss_bbox = np.mean(box_losses[-NUM_BATCHES:])
        epoch_loss_giou = np.mean(giou_losses[-NUM_BATCHES:])

        # Log the losses
        print(f'Epoch: {epoch+1}/{EPOCHS}, DETR Loss: {loss_avg:.4f}')
        print(f'→ Class Loss: {epoch_loss_class:.4f}, BBox Loss: {epoch_loss_bbox:.4f}, GIoU Loss: {epoch_loss_giou:.4f}')
        
        # Save history per epoch...
        hist.append(loss_avg)
        hist_detail.append((epoch_loss_class, epoch_loss_bbox, epoch_loss_giou))

    # Save every 20 epochs
    if (epoch + 1) % SAVE_FREQUENCY == 0:
        torch.save(detr_model.state_dict(), f'ckpts/model_epoch{epoch+1}.pt')
        np.save(f'ckpts/hist_epoch{epoch+1}.npy', hist)

## Plot the training metrics

In [None]:
import os 

# Ensure the directory exists
os.makedirs("cpkts", exist_ok=True)

# Once training is done plot the loss curve...
plt.plot(np.log(hist))
plt.xlabel('Epochs')
plt.ylabel('DETR Loss (log scale)')
plt.savefig("cpkts/loss_curve.png")
plt.show()

# Plot the detailed losses as well...
class_losses_plot, bbox_losses_plot, giou_losses_plot = zip(*hist_detail)

# Create subplots
fig, axes = plt.subplots(3, 1, figsize=(8, 12))

# Plot Class Loss
axes[0].plot(class_losses_plot, label="Class Loss", color='b')
axes[0].set_title("Class Loss Over Epochs")
axes[0].set_xlabel("Epochs")
axes[0].set_ylabel("Loss")
axes[0].legend()
axes[0].grid()

# Plot BBox Loss
axes[1].plot(bbox_losses_plot, label="BBox Loss", color='g')
axes[1].set_title("BBox Loss Over Epochs")
axes[1].set_xlabel("Epochs")
axes[1].set_ylabel("Loss")
axes[1].legend()
axes[1].grid()

# Plot GIoU Loss
axes[2].plot(giou_losses_plot, label="GIoU Loss", color='r')
axes[2].set_title("GIoU Loss Over Epochs")
axes[2].set_xlabel("Epochs")
axes[2].set_ylabel("Loss")
axes[2].legend()
axes[2].grid()

# Adjust layout and show plot
plt.tight_layout()
plt.savefig("cpkts/detr_losses.png")
plt.show()

## Load fine-tuned model and test inference

In [None]:
from utils.visualizer import DETRBoxVisualizer

# Trained on your custom dataset
detr_model = DETR(
    d_model=256, n_classes=len(CLASSES), n_tokens=225, 
    n_layers=6, n_heads=8, n_queries=100
)

# Load the checkpoint
print(detr_model.load_state_dict(torch.load('../weights/model_epoch20.pt', map_location=torch.device('cpu'))))

# Run inference and check results
visualizer = DETRBoxVisualizer(class_labels= CLASSES,
                               empty_class_id=0)

visualizer.visualize_validation_inference(detr_model, coco_ds_val, collate_fn=collate_fn)