Coded by Lujia Zhong @lujiazho<br>
Reference: https://github.com/facebookresearch/detr

In [1]:
import time
import torch
from torch import nn
from torchvision.ops.boxes import box_area
from scipy.optimize import linear_sum_assignment

################################################################################################
#                                       Transformer Part                                       #
################################################################################################

def padding_mask(seq_q_shape, seq_k):
    B, len_q = seq_q_shape
    B, len_k = seq_k.shape
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # 0 is padding
    return pad_attn_mask.expand(B, len_q, len_k)

class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super(MultiHeadAttention, self).__init__()
        
        self.d_key = config.d_key
        self.d_value = config.d_value
        self.n_heads = config.n_heads
        
        self.W_Q = nn.Linear(config.d_model, self.d_key * self.n_heads)
        self.W_K = nn.Linear(config.d_model, self.d_key * self.n_heads)
        self.W_V = nn.Linear(config.d_model, self.d_value * self.n_heads)
        
        self.linear = nn.Linear(self.n_heads * self.d_value, config.d_model)
        self.layer_norm = nn.LayerNorm(config.d_model)
        
        self.attn_dropout = nn.Dropout(config.dropout_rate)
        self.proj_dropout = nn.Dropout(config.dropout_rate)
        
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, Q, K, V, attn_mask, posemb, queries):
        residual, batch_size = Q, Q.shape[0]
        
        if posemb is not None and queries is not None: # MA
            Q = Q + queries
            K = K + posemb
        elif posemb is not None: # encoder MSA, Q == K
            Q = K = (Q + posemb)
        else: # decoder MSA, Q == K
            Q = K = (Q + queries)
        
        query_layer = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_key).transpose(1,2)
        key_layer = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_key).transpose(1,2)
        value_layer = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_value).transpose(1,2)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / (self.d_key**0.5)
        if attn_mask is not None:
            # expand in heads' dimension
            attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)
            attention_scores.masked_fill_(attn_mask, -1e9) # masked_fill_: 1 masked, 0 unmasked
        
        attention_probs = self.softmax(attention_scores)
        attention_probs = self.attn_dropout(attention_probs)
        
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.transpose(1, 2)
        context_layer = context_layer.contiguous().view(batch_size, -1, self.n_heads*self.d_value)
        
        attention_output = self.linear(context_layer)
        attention_output = self.proj_dropout(attention_output)
        
        return self.layer_norm(attention_output + residual), attention_probs

class MLP(nn.Module):
    def __init__(self, config):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(config.d_model, config.d_mlp)
        self.fc2 = nn.Linear(config.d_mlp, config.d_model)
        self.layer_norm = nn.LayerNorm(config.d_model)
        self.dropout = nn.Dropout(config.dropout_rate)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, inputs):
        residual = inputs
        
        output = nn.ReLU()(self.fc1(inputs))
        output = self.dropout(output)
        
        output = self.fc2(output)
        output = self.dropout(output)
        
        return self.layer_norm(output + residual)

class EncoderLayer(nn.Module):
    def __init__(self, config):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(config)
        self.ffn = MLP(config)

    def forward(self, enc_inputs, enc_self_attn_mask, posemb):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, 
                                               enc_self_attn_mask, posemb, None)
        
        enc_outputs = self.ffn(enc_outputs)
        
        return enc_outputs, attn

class DecoderLayer(nn.Module):
    def __init__(self, config):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention(config)
        self.dec_enc_attn = MultiHeadAttention(config)
        self.ffn = MLP(config)

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_pad_mask, dec_enc_attn_pad_mask, posemb, queries):
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, 
                                                        dec_self_attn_pad_mask, None, queries)
        
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, 
                                                      dec_enc_attn_pad_mask, posemb, queries)
        
        dec_outputs = self.ffn(dec_outputs)
        
        return dec_outputs, dec_self_attn, dec_enc_attn

class Encoder(nn.Module):
    def __init__(self, config):
        super(Encoder, self).__init__()
        
        self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.n_layers)])

    def forward(self, enc_inputs, mask, posemb):
        enc_outputs = enc_inputs

        enc_self_attn_mask = padding_mask(mask.shape, mask)

        enc_self_attns = []
        for layer in self.layers:
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask, posemb)
            enc_self_attns.append(enc_self_attn)
        
        return enc_outputs, enc_self_attns

class Decoder(nn.Module):
    def __init__(self, config):
        super(Decoder, self).__init__()
        
        self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.n_layers)])

    def forward(self, dec_inputs, enc_inputs, enc_outputs, mask, posemb, queries):
        dec_outputs = dec_inputs

        dec_enc_attn_pad_mask = padding_mask(dec_inputs.shape[:-1], mask)

        intermediate = []
        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, None, 
                                                             dec_enc_attn_pad_mask, posemb, queries)
            intermediate.append(dec_outputs)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        
        return torch.stack(intermediate), dec_self_attns, dec_enc_attns

class Transformer(nn.Module):
    def __init__(self, config):
        super(Transformer, self).__init__()
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)
    
    def forward(self, enc_inputs, dec_inputs, mask, posemb, queries):
        enc_outputs, enc_self_attns = self.encoder(enc_inputs, mask, posemb)
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs, 
                                                                  mask, posemb, queries)
        
        return dec_outputs, enc_self_attns, dec_self_attns, dec_enc_attns


################################################################################################
#                                         Resnet Part                                          #
################################################################################################

class BottleNeck(nn.Module):
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * 4),
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != out_channels * 4:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * 4, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * 4)
            )

    def forward(self, x):
        return nn.ReLU()(self.residual_function(x) + self.shortcut(x))

class ResNet(nn.Module):

    def __init__(self, config):
        super().__init__()
        
        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, stride=2, kernel_size=7, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.conv2_x = self.conv_layer(config.block, 64, config.num_block[0], ini_stride=1)
        self.conv3_x = self.conv_layer(config.block, 128, config.num_block[1], ini_stride=2)
        self.conv4_x = self.conv_layer(config.block, 256, config.num_block[2], ini_stride=2)
        self.conv5_x = self.conv_layer(config.block, 512, config.num_block[3], ini_stride=2)

    def conv_layer(self, block, out_channels, num_blocks, ini_stride):
        layers = [block(self.in_channels, out_channels, ini_stride)]
        self.in_channels = out_channels * block.expansion
        
        for _ in range(num_blocks-1):
            layers.append(block(self.in_channels, out_channels, 1))
            
        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv1(x)
        output = self.max_pool(output)
        
        output = self.conv2_x(output)
        output = self.conv3_x(output)
        output = self.conv4_x(output)
        output = self.conv5_x(output)

        return output


################################################################################################
#                                          LOSS Part                                           #
################################################################################################

class Criterion(nn.Module):
    def __init__(self, num_classes, eos_coef, losses, 
                 cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
        super().__init__()
        
        # HungarianMatcher loss
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
        
        # losses names
        self.losses = losses
        self.num_classes = num_classes  # 91
        
        # 0-90: classes, 91: no object; lots of bboxs would be no object, so here reduce its loss weight
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = eos_coef
        self.register_buffer('empty_weight', empty_weight)

    # all label predicts need loss including no object
    def loss_labels(self, outputs: dict, targets: list, indices: list, num_boxes):
        assert 'pred_logits' in outputs
        
        src_logits = outputs['pred_logits']
        src_logits = src_logits.view(-1, src_logits.shape[-1])
        # torch.Size([3*100, 92])

        # get real classes in order w.r.t indices
        label = torch.cat([t["labels"][i] for t, (_, i) in zip(targets, indices)])
        
        # get global idx within batch w.r.t indices
        global_idx = torch.cat([src+i*100 for i, (src, _) in enumerate(indices)])
        target_classes = torch.full(src_logits.shape[:1], self.num_classes)
        target_classes[global_idx] = label

        loss_ce = nn.functional.cross_entropy(src_logits, target_classes, self.empty_weight)

        return {'loss_ce': loss_ce}

    # only boxes that have object need loss
    def loss_boxes(self, outputs, targets, indices, num_boxes):
        assert 'pred_boxes' in outputs
        
        # pick out these with object
        global_idx = torch.cat([src+i*100 for i, (src, _) in enumerate(indices)])
        src_boxes = outputs['pred_boxes']
        src_boxes = src_boxes.view(-1, src_boxes.shape[-1])[global_idx]
        
        # get ground truth bbox w.r.t indices
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction='none')
        losses = {'loss_bbox': loss_bbox.sum() / num_boxes}

        loss_giou = 1 - self.giou(self.box_cxcywh_to_xyxy(src_boxes), self.box_cxcywh_to_xyxy(target_boxes))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        
        return losses

    def get_loss(self, loss, outputs, targets, indices, num_boxes):
        loss_map = {'labels': self.loss_labels, 'boxes': self.loss_boxes}
        assert loss in loss_map, f'The {loss} loss not founded.'
        
        return loss_map[loss](outputs, targets, indices, num_boxes)

    def forward(self, outputs, targets):

        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.hungarian_matcher(outputs_without_aux, targets)

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets)

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
        
        assert 'aux_outputs' in outputs, "Auxiliary outputs not founded."
        
        for i, aux_outputs in enumerate(outputs['aux_outputs']):
            indices = self.hungarian_matcher(aux_outputs, targets)
            for loss in self.losses:
                aux_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
                aux_dict = {k + f'_{i}': v for k, v in aux_dict.items()}
                losses.update(aux_dict)
        
        return losses
    
    @torch.no_grad()
    def hungarian_matcher(self, outputs, targets):
        bs, num_queries = outputs["pred_logits"].shape[:2]
        assert bs == len(targets), f"Batch {bs} and targets {len(targets)} number not matched."

        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

        # Also concat the target labels and boxes
        tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_bbox = torch.cat([v["boxes"] for v in targets])

        # classification cost: 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be omitted.
        cost_class = -out_prob[:, tgt_ids]

        # L1 cost between boxes
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

        # here should be 1 - giou, but constant doesn't affect matching results, it can be omitted
        cost_giou = -self.matrix_giou(self.box_cxcywh_to_xyxy(out_bbox), self.box_cxcywh_to_xyxy(tgt_bbox))

        # Final cost matrix
        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
        C = C.view(bs, num_queries, -1)
        
        sizes = [len(v["boxes"]) for v in targets]
        # hungarian match
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
        # numpy array 2 tensor
        indices = [(torch.as_tensor(i, dtype=torch.int64), 
                    torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
        # [(torch.Size([21]), torch.Size([21])), (torch.Size([75]), torch.Size([75])), ...]
        return indices
    
    def box_cxcywh_to_xyxy(self, x):
        x_c, y_c, w, h = x.unbind(-1)
        b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
             (x_c + 0.5 * w), (y_c + 0.5 * h)]
        return torch.stack(b, dim=-1)

    def matrix_iou(self, boxes1, boxes2):
        area1 = box_area(boxes1)
        area2 = box_area(boxes2)

        lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
        rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

        wh = (rb - lt).clamp(min=0)  # [N,M,2]
        inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

        union = area1[:, None] + area2 - inter

        iou = inter / union
        return iou, union

    def matrix_giou(self, boxes1, boxes2):
        assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
        assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
        iou, union = self.matrix_iou(boxes1, boxes2)

        lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
        rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

        wh = (rb - lt).clamp(min=0)  # [N,M,2]
        area = wh[:, :, 0] * wh[:, :, 1]

        return iou - (area - union) / area
    
    def iou(self, boxes1, boxes2):
        area1 = box_area(boxes1)
        area2 = box_area(boxes2)

        lt = torch.max(boxes1[:, :2], boxes2[:, :2])
        rb = torch.min(boxes1[:, 2:], boxes2[:, 2:])

        wh = (rb - lt).clamp(min=0)
        inter = wh[:, 0] * wh[:, 1]

        union = area1 + area2 - inter

        iou = inter / union
        return iou, union
    
    def giou(self, boxes1, boxes2):
        assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
        assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
        iou, union = self.iou(boxes1, boxes2)

        lt = torch.min(boxes1[:, :2], boxes2[:, :2])
        rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])

        wh = (rb - lt).clamp(min=0)
        area = wh[:, 0] * wh[:, 1]

        return iou - (area - union) / area

    
################################################################################################
#                                          DETR Part                                           #
################################################################################################

class DETR(nn.Module):
    def __init__(self, config):
        super().__init__()
        # We take only convolutional layers from ResNet-50 model
        self.backbone = ResNet(config.backboneConfig)
        self.conv = nn.Conv2d(2048, config.hidden_dim, 1)
        self.transformer = Transformer(config.transformerConfig)
        
        self.queries = nn.Parameter(torch.rand(100, config.hidden_dim))
        
        self.linear_class = nn.Linear(config.hidden_dim, config.num_classes + 1)
        # 3 layers MLP
        self.linear_bbox = nn.Sequential(*[nn.Linear(n, k) for n, k in zip([config.hidden_dim]*3, 
                                                                         [config.hidden_dim]*2+[4])])

        # positional encoding
        self.row_embed = nn.Parameter(torch.rand(50, config.hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, config.hidden_dim // 2))

    def forward(self, inputs: list):
        assert inputs[0].shape[0] == 3, "Not supported Channel"
        max_h, max_w = self.find_max_hw(inputs)
        inputs, mask = self.preprocess(inputs, max_h, max_w)
        
        B = inputs.shape[0]
        
        x = self.backbone(inputs)
        h = self.conv(x)
        
        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(0).repeat(B, 1, 1)
        # downsample mask
        mask = torch.nn.functional.interpolate(mask[None], size=x.shape[-2:])[0].flatten(-2)

        h, _, _, _ = self.transformer(h.flatten(2).permute(0, 2, 1),
                                      self.queries.unsqueeze(0).repeat(B, 1, 1), mask,
                                      pos, self.queries.unsqueeze(0).repeat(B, 1, 1))

        logits, bboxes = self.linear_class(h), self.linear_bbox(h).sigmoid()

        outputs = {'pred_logits': logits[-1], 'pred_boxes': bboxes[-1]}
        outputs['aux_outputs'] = [{'pred_logits': a, 'pred_boxes': b}
                                  for a, b in zip(logits[:-1], bboxes[:-1])]
        return outputs
    
    def preprocess(self, inputs, max_h, max_w):
        B = len(inputs)
        c, dtype = inputs[0].shape[0], inputs[0].dtype

        tensor = torch.zeros((B, c, max_h, max_w), dtype=dtype)
        mask = torch.zeros((B, max_h, max_w))
        for img, pad_img, m in zip(inputs, tensor, mask):
            pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = 1

        return tensor, mask
    
    def find_max_hw(self, data):
        max_h, max_w = 0, 0
        for img in data:
            if img.shape[1] > max_h:
                max_h = img.shape[1]
            if img.shape[2] > max_w:
                max_w = img.shape[2]
        return max_h, max_w


class Resnet50Config:
    block = BottleNeck
    num_block = [3, 4, 6, 3]

class TransformerConfig:
    d_model = 256           # embedding Size
    d_mlp = 4*d_model       # MLP hidden dimension
    d_key = d_value = 32    # dimension of K == Q, V could be different in dot_product_attention
    n_layers = 6            # number of Encoder & Decoder Layer
    n_heads = 8             # number of heads in Multi-Head Attention
    
    dropout_rate = 0.1

class detrConfig:
    num_classes = 91        # classes number
    hidden_dim = 256        # embedding dimension
    
    backboneConfig = Resnet50Config()
    transformerConfig = TransformerConfig()

model = DETR(detrConfig())

In [2]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
weight_dict = {
    'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2, 'loss_ce_0': 1, 'loss_bbox_0': 5, 'loss_giou_0': 2, 
    'loss_ce_1': 1, 'loss_bbox_1': 5, 'loss_giou_1': 2, 'loss_ce_2': 1, 'loss_bbox_2': 5, 'loss_giou_2': 2,
    'loss_ce_3': 1, 'loss_bbox_3': 5, 'loss_giou_3': 2, 'loss_ce_4': 1, 'loss_bbox_4': 5, 'loss_giou_4': 2
}
criterion = Criterion(num_classes=91, eos_coef=0.1, losses=['labels', 'boxes'])

batch = 3
iterarions = 2
begin = time.time()

def genData():
    return [torch.rand(3, 
                       torch.randint(350, 500, (1,)).item(),
                       torch.randint(200, 500, (1,)).item()) for _ in range(batch-1)] + [torch.rand(3,384,512)]

# Training
for iterarion in range(iterarions):
    optimizer.zero_grad()
    targets = [{'boxes':torch.rand(n.item(), 4), 
                'labels':torch.randint(0, 92, (n.item(),))} for n in torch.randint(1, 100, (batch,))]
    
    outputs = model(genData())
    
    loss_dict = criterion(outputs, targets)
    losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
    
    if iterarion % 1 == 0:
        print('Iterarion:', '%2d,' % (iterarion + 1), 'loss =', '{:.4f}'.format(losses))

    losses.backward()
    optimizer.step()
print(f"{(time.time() - begin)/iterarions:.4f}s / iterarion")

Iterarion:  1, loss = 66.0089
Iterarion:  2, loss = 66.7258
5.6567s / iterarion
