In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW, lr_scheduler
import torchvision.transforms as T
from torchvision import datasets, ops
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models.feature_extraction import create_feature_extractor
from einops import rearrange
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import linear_sum_assignment
from tqdm import tqdm

# Category ID labels from Coco dataset
CLASSES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", 
           "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", 
           "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", 
           "baseball glove", "skateboard", "surfboard","tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", 
           "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", 
           "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", 
           "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "empty"]

# Images revert normalisation before demo
revertNormalisation = T.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225])

# Apple macbook metal performance shader
device = torch.device('mps')

## Data preprocessing (Coco format)

In [None]:
def preprocessDataTargets(ann, w, h):
    """
    Preprocessing image object annotations
    
    Args:
        ann: path to json annotation file (train/validation)
        w: width of the image
        h: height of the image
    Returns:
        classes: valid category id labels for each image
        bboxes: valid bounding boxes for each image
    """
    # Extract annotations for each object in images
    ann = [o for o in ann]
    # Extract valid bounding boxes
    bboxes = [o['bbox'] for o in ann]
    bboxes = torch.as_tensor(bboxes, dtype=torch.float32).reshape(-1, 4)
    # Bboxes format convert from xywh to x1y1x2y2
    bboxes[:, 2:] += bboxes[:, :2]
    bboxes[:, 0::2].clamp_(min=0, max=w)
    bboxes[:, 1::2].clamp_(min=0, max=h)
    # Mask for retaining valid bounding boxes and category labels
    valid = (bboxes[:, 3]>bboxes[:, 1])&(bboxes[:, 2]>bboxes[:, 0])
    bboxes = bboxes[valid]
    
    # Extract valid category label ids
    classes = [o['category_id'] for o in ann]
    classes = torch.tensor(classes, dtype=torch.int64)
    classes = classes[valid]

    # Scaling bboxes within [0, 1]
    bboxes[:, 0::2]/=w
    bboxes[:, 1::2]/=h
    bboxes.clamp_(min=0, max=1)
    # Bboxes format convert from x1y1x2y2 to cxcywh
    bboxes = ops.box_convert(bboxes, in_fmt='xyxy', out_fmt='cxcywh')
    return classes, bboxes

class customCocoDetection(datasets.CocoDetection):
    """
    Prepare dataset for dataloader
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Size of images to standardise (480*480)
        self.size = 480
        # Image resize and normalisaion
        self.T = T.Compose([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 
                            T.Resize((self.size, self.size), antialias=True)])
        self.T_target = preprocessDataTargets
    def __getitem__(self, idx):
        image, target = super().__getitem__(idx)
        width, height = image.size
        _input = self.T(image)
        classes, bboxes = self.T_target(target, width, height)
        return _input, (classes, bboxes)
        
def collateFunction(inputs):
    """
    Collate function for dataloader to process samples for batches
    Args:
        inputs: samples consist of images, category label ids and bounding boxes
    Returns:
        _input: stack of images
        (classes, bboxes): tuples of category label ids and bounding boxes
    """
    _input = torch.stack([i[0] for i in inputs])
    classes = tuple([i[1][0] for i in inputs])
    bboxes = tuple([i[1][1] for i in inputs])
    return _input, (classes, bboxes)

In [None]:
trainDataset = customCocoDetection("/Users/ivanng/dataset/coco-2017/train/data", "/Users/ivanng/dataset/coco-2017/raw/instances_train2017.json")
trainDataLoader = DataLoader(trainDataset, batch_size=16, shuffle=True, collate_fn=collateFunction)
valDataset = customCocoDetection("/Users/ivanng/dataset/coco-2017/validation/data", "/Users/ivanng/dataset/coco-2017/raw/instances_val2017.json")
valDataLoader = DataLoader(valDataset, batch_size=2, shuffle=False, collate_fn=collateFunction)

## DETR model architecture

In [None]:
def getHook(outputs, name):
    def hook(self, input, output):
        outputs[name] = output
    return hook
    
class DETR(nn.Module):
    def __init__(self, dimensions=256, numClasses=81, numTokens=225, numLayers=6, numHeads=8, numQueries=100):
        super().__init__()
        # ResNet50 Backbone for image feature extraction
        self.backbone = create_feature_extractor(resnet50(weights=ResNet50_Weights.IMAGENET1K_V2), return_nodes={'layer4':'layer4'})
        # 1*1 convolutional layer
        self.conv11 = nn.Conv2d(2048, dimensions, kernel_size=1, stride=1)
        # Learnable positional encoding during training
        self.positionEncoding = nn.Parameter(torch.rand((1, numTokens, dimensions)), requires_grad=True)
        # Transformer encoder
        self.encoderLayer = nn.TransformerEncoderLayer(d_model=dimensions, nhead=numHeads, dim_feedforward=4*dimensions, batch_first=True, dropout=0.1)
        self.transformerEncoder = nn.TransformerEncoder(self.encoderLayer, num_layers=numLayers)
        # Learnable queries during training
        self.queries = nn.Parameter(torch.rand((1, numQueries, dimensions)), requires_grad=True)
        # Transformer decoder
        self.decoderLayer = nn.TransformerDecoderLayer(d_model=dimensions, nhead=numHeads, dim_feedforward=4*dimensions, batch_first=True, dropout=0.1)
        self.transformerDecoder = nn.TransformerDecoder(self.decoderLayer, num_layers=numLayers)
        # Feed-forward Network heads
        self.linearClass = nn.Linear(dimensions, numClasses)
        self.linearBbox = nn.Linear(dimensions, 4)
        # Hook to get intermediate outcome from each layer
        self.decoder_outputs = {}
        for idx, layer in enumerate(self.transformerDecoder.layers):
            name = f'layer{idx}'
            layer.register_forward_hook(getHook(self.decoder_outputs, name))
        self.projection = nn.Linear(dimensions, dimensions)
    def forward(self, x):
        # (1, 3, 480, 480) input tensor to (1, 2048, 15, 15) feature map
        embeddings = self.backbone(x)['layer4']
        # Reduce channel dimensions from 2048 to 256
        embeddings = self.conv11(embeddings)
        # Rearrange embeddings to 1D sequence, tensor dimension transform from (1 256 15 15) to (1 225 256)
        embeddings = rearrange(embeddings, 'b c h w -> b (h w) c')
        # Encoder ouput dimension (1 225 256)
        outputEncoder = self.transformerEncoder(embeddings+self.positionEncoding)
        # Queries tensor and decoder output dimension (1 100 256)
        outputDecoder = self.transformerDecoder(self.queries.repeat(len(outputEncoder), 1, 1), outputEncoder)
        # Integrates outcomes from all intermediate decoder layers
        outputs = {}
        for i, o in self.decoder_outputs.items():
            outputs[i] = {'category': self.linearClass(o), 'bbox': self.linearBbox(o)}
        encoderEmbeddings = self.projection(outputEncoder)
        decoderEmbeddings = self.projection(outputDecoder)  
        embeddingsDict = {'encoderEmbeddings': encoderEmbeddings, 'decoderEmbeddings': decoderEmbeddings}
        return outputs, embeddingsDict
        

## Loss computation for DETR

In [None]:
def lossCompute(outputBbox, gtBbox, outputClass, gtClass, numQueries=100):
    if len(gtBbox)>0:
        gtBbox = gtBbox.to(device)
        gtClass = gtClass.to(device)
        outputProbs = outputClass.softmax(dim=-1)
        # Compute costs for categories, bounding boxes and giou
        classCost = -outputProbs[..., gtClass]
        bboxCost = torch.cdist(outputBbox, gtBbox, p=1)
        giouCost = -ops.generalized_box_iou(ops.box_convert(outputBbox, in_fmt='cxcywh', out_fmt='xyxy'),
                                           ops.box_convert(gtBbox, in_fmt='cxcywh', out_fmt='xyxy'))
        totalCost = 2*classCost+5*bboxCost+2*giouCost
        totalCost = totalCost.cpu().detach().numpy()
        # Optimal pairs compute minimal cost, returns pairs of indices
        outputIndices, gtIndices = linear_sum_assignment(totalCost)
        # Indices into tensors
        outputIndices = torch.IntTensor(outputIndices)
        gtIndices = torch.IntTensor(gtIndices)
        # Sort output indices from model that align with ground truth category
        outputIndices = outputIndices[gtIndices.argsort()]
        # Compute losses for categories, bounding boxes and giou 
        numBboxes = len(gtBbox)
        bboxLoss = F.l1_loss(outputBbox[outputIndices], gtBbox, reduction='sum')/numBboxes
        giou = ops.generalized_box_iou(ops.box_convert(outputBbox[outputIndices], in_fmt='cxcywh', out_fmt='xyxy'), 
                                      ops.box_convert(gtBbox, in_fmt='cxcywh', out_fmt='xyxy'))
        # Diagonal contains Bipartite pairs, 1-giou = giouloss
        giouLoss = 1-torch.diag(giou).mean()
        # Assign no-object class since actual number predictions less than number of queries
        queriesClassesLabel = torch.full(outputProbs.shape[:1], 81).to(device)
        queriesClassesLabel[outputIndices] = gtClass
        classLoss = F.cross_entropy(outputClass, queriesClassesLabel)
    else:
        queriesClassesLabel = torch.full((numQueries,), 81).to(device)
        classLoss = F.cross_entropy(outputClass, queriesClassesLabel)
        bboxLoss = giouLoss = torch.tensor(0)
    return classLoss, bboxLoss, giouLoss

## Training DETR

In [None]:
detr = DETR(dimensions=256, numClasses=81, numTokens=225, numLayers=6, numHeads=8, numQueries=100).to(device)
backboneParameters = [p for n, p in detr.named_parameters() if 'backbone.' in n]

for p in detr.backbone.parameters():
   p.requires_grad = False
    
transformerParameters = [p for n, p in detr.named_parameters() if 'backbone.' not in n]
optimizer = AdamW([{'params': transformerParameters, 'lr': 1e-4}, {'params': backboneParameters, 'lr': 1e-5}], weight_decay=1e-4)


In [None]:
torch.set_grad_enabled(True)
detr.train()
numEpochs = 141
batchSize=16
losses = []
history = []

for i in tqdm(range(len(history)+1, numEpochs)):
    for _input, (gtClasses, gtBboxes) in trainDataLoader:
        _input = _input.to(device)
        outputs, _ = detr(_input)
        loss = torch.Tensor([0]).to(device)
        for name, output in outputs.items():
            output['bbox'] = output['bbox'].sigmoid()
            for predictBbox, gtBbox, predictClass, gtClass in zip(output['bbox'], gtBboxes, output['category'], gtClasses):
                classLoss, bboxLoss, giouLoss = lossCompute(predictBbox, gtBbox, predictClass, gtClass)
                sampleTotalLoss = 1*classLoss+5*bboxLoss+2*giouLoss
                loss += sampleTotalLoss/batchSize
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(detr.parameters(), 0.1)
        optimizer.step()
        losses.append(loss.item())
    if i%1==0:
        avgLoss = np.mean(losses)
        print(f'epoch: {i}, loss: {avgLoss:.4f}')
        print(f'classLoss: {classLoss.item():.4f}, bboxLoss: {bboxLoss.item():.4f}, giouLoss: {giouLoss.item():.4f}')
        history.append(avgLoss)
        losses = []
    if i%1==0:
        torch.save({'state': detr.state_dict(), 'optimizerState': optimizer.state_dict()}, f'/Users/ivanng/model/RT_DETR_epoch{i}.pt')
        np.save(f'/Users/ivanng/model/history_epoch{i}.npy', history)