In [1]:
!git clone https://github.com/facebookresearch/detr.git

Cloning into 'detr'...
remote: Enumerating objects: 260, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 260 (delta 0), reused 2 (delta 0), pack-reused 257[K
Receiving objects: 100% (260/260), 12.85 MiB | 5.47 MiB/s, done.
Resolving deltas: 100% (140/140), done.


## 0. Positional Encoding

Various positional encodings for the transformer.

논문에서는 sin-cos 함수를 이용한 fixed positional encoding 활용

In [2]:
import math
import torch
from torch import nn

from detr.util.misc import NestedTensor

In [3]:
class PositionEmbeddingSine(nn.Module):
    
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
          super().__init__()
          self.num_pos_feats = num_pos_feats
          self.temperature = temperature
          self.normalize = normalize
          if scale is not None and normalize is False:
              raise ValueError("normalize should be True if scale is passed")
          if scale is None:
              scale = 2 * math.pi
          self.scale = scale

    def forward(self, tensor_list: NestedTensor):  # NestedTensor : 입력 이미지 텐서, mask 텐서 결합
          x = tensor_list.tensors
          mask = tensor_list.mask
          assert mask is not None
          not_mask = ~mask
          y_embed = not_mask.cumsum(1, dtype=torch.float32)
          x_embed = not_mask.cumsum(2, dtype=torch.float32)
          if self.normalize:
              eps = 1e-6
              y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
              x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

          dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
          dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

          pos_x = x_embed[:, :, :, None] / dim_t
          pos_y = y_embed[:, :, :, None] / dim_t
          pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
          pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
          pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
          
          return pos


In [4]:
position_embedding = PositionEmbeddingSine(256 // 2, normalize=True)

## 1. Backbone

In [5]:
from collections import OrderedDict

import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List

from detr.util.misc import NestedTensor, is_main_process

In [6]:
class FrozenBatchNorm2d(torch.nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.
    Copy-paste from torchvision.misc.ops with added eps before rqsrt,
    without which any other models than torchvision.models.resnet[18,34,50,101]
    produce nans.
    """

    def __init__(self, n):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def forward(self, x):
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = 1e-5
        scale = w * (rv + eps).rsqrt()
        bias = b - rm * scale
        
        return x * scale + bias

In [7]:
class BackboneBase(nn.Module):
    def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
             # eval용이거나 / layer2, layer3, layer4가 없으면 (classifier면) 가중치 동결
             # layer2~4의 conv layer들만 학습에 사용 (FPN 형식으로 활용-Appendix 참고)
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        
        if return_interm_layers:  # conv layer만 반환
            return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:
            return_layers = {'layer4': "0"}

        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.num_channels = num_channels  # backbone 최종 출력의 채널 차원 지정 (2048)

    def forward(self, tensor_list: NestedTensor):
        xs = self.body(tensor_list.tensors)  # 입력 (이미지 텐서)
        out: Dict[str, NestedTensor] = {}  # 출력
        for name, x in xs.items():
            m = tensor_list.mask
            assert m is not None
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
            out[name] = NestedTensor(x, mask)  # 이미지 텐서에 mask 반영

        return out


# ResNet기반 Backbone 정의
class Backbone(BackboneBase):
    """ResNet backbone with frozen BatchNorm."""
    def __init__(self, name: str,
                 train_backbone: bool,  # backbone을 처음부터 학습시킬 지 여부
                 return_interm_layers: bool,  # backbone의 중간 conv layer들을 반환할 지 여부
                 dilation: bool):  # dilation conv 사용 여부

        # backbone 불러오기         
        backbone = getattr(torchvision.models, name)(
            replace_stride_with_dilation=[False, False, dilation],  # layer2, 3은 dilation 적용 x
            pretrained = is_main_process(), norm_layer=FrozenBatchNorm2d)  # default: pretrained=False
        
        num_channels = 512 if name in ('resnet18', 'resnet34') else 2048  # backbone conv layer의 최종 출력 채널 dim
        super().__init__(backbone, train_backbone, num_channels, return_interm_layers)


# Backbone output + Positonal encodings
class Joiner(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)

    def forward(self, tensor_list: NestedTensor):
        xs = self[0](tensor_list)
        out: List[NestedTensor] = []
        pos = []
        for name, x in xs.items():
            out.append(x)
            # position encoding
            pos.append(self[1](x).to(x.tensors.dtype))

        return out, pos

In [8]:
# backbone build of official code
def build_backbone(args):
    position_embedding = build_position_encoding(args)  # position embedding
    train_backbone = args.lr_backbone > 0   # True
    return_interm_layers = args.masks  # True

    backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)  # backbone

    model = Joiner(backbone, position_embedding)
    model.num_channels = backbone.num_channels

    return model

In [9]:
backbone_name = 'resnet50'
lr_backbone = 1e-5
train_backbone = lr_backbone > 0  # True 
return_interm_layers = 'store_true'
dilation = 'store_true'

backbone = Backbone(backbone_name, train_backbone, return_interm_layers, dilation)
backbone = Joiner(backbone, position_embedding)  # position embedding layer가 추가되는 것을 확인
backbone.num_channels = 2048

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [None]:
#backbone  # layer1~4반환

## 2. Transformer

* positional encodings are passed in MHattention

* extra LN at the end of encoder is removed

* decoder returns a stack of activations from all decoding layers

In [11]:
import copy
from typing import Optional, List

import torch
import torch.nn.functional as F
from torch import nn, Tensor

### 2.1 Encoder 

In [15]:
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")

In [16]:
class TransformerEncoder(nn.Module):

    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src,  # src : 입력 이미지 feature (spatial dim collapsed)
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
      
        output = src
        
        # src가 encoder layer 통과하며 갱신됨
        for layer in self.layers:
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)

        if self.norm is not None:
            output = self.norm(output)

        return output  # encdoer embedding representation

In [17]:
class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()

        # Multi-head Self attention
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)  # d_model=256

        # FFN
        self.linear1 = nn.Linear(d_model, dim_feedforward)  # dim_feedforward=2048
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before  # normalize한 후 layer에 통과 (default=True이기 때문에 forward pre사용)
    
    # img feature 입력 (src) + positional encoding
    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos
    
    def forward_pre(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        
        # normalize before
        src2 = self.norm1(src)

        # query & key = src + pos_embedding / value : normalized src (w.o pos_emb)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]

        # 여러 layer를 거치면서 src가 갱신됨 (2.3의 memory)                  
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)

        return src

    def forward(self, src,
                src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
      
        return self.forward_pre(src, src_mask, src_key_padding_mask, pos)

### 2.2 Decoder

Object query를 입력으로 받음 (small fixed # of learned positional embeddings)

  - Self attention 수행

  - Encoder output과 attention 수행 (pair-wise relationship 학습, 이미지에 대한 global text 반영)

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.return_intermediate = return_intermediate

    def forward(self, tgt, memory,  # object query embedding, encoder output
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,  # position embedding (for encoder & decoder)
                query_pos: Optional[Tensor] = None):  # position embedding (for object query)
      
        output = tgt  # object query
        intermediate = []
        
        # Decoder layer 통과 - object query와 encoder출력(memory) 간의 attention
        # 초기화된 object query가 갱신됨
        for layer in self.layers:
            output = layer(output, memory, tgt_mask=tgt_mask,
                           memory_mask=memory_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask,
                           memory_key_padding_mask=memory_key_padding_mask,
                           pos=pos, query_pos=query_pos)
            
            if self.return_intermediate:
                intermediate.append(self.norm(output))

        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)

        return output.unsqueeze(0)

In [None]:
class TransformerDecoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()

        # MHA
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)  # self-attn
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)  # cross-attn

        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_pre(self, tgt, memory,
                    tgt_mask: Optional[Tensor] = None,
                    memory_mask: Optional[Tensor] = None,
                    tgt_key_padding_mask: Optional[Tensor] = None,
                    memory_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None,
                    query_pos: Optional[Tensor] = None):
        
        # normalize object query before
        tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)  # object query + position embedding

        # self attention (object query)
        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt2 = self.norm2(tgt)

        # cross attention (object query - encoder output)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),  # object query + position embedding
                                   key=self.with_pos_embed(memory, pos),  # encoder output + position embedding
                                   value=memory, attn_mask=memory_mask,  # encoder output with mask
                                   key_padding_mask=memory_key_padding_mask)[0]
                                   
        tgt = tgt + self.dropout2(tgt2)
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)

        return tgt  # output embedding

    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
      
        return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
                                tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)

### 2.3 Transformer Network

In [None]:
class Transformer(nn.Module):

    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()
        
        # Encoder
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
        
        # Decoder
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
                                          return_intermediate=return_intermediate_dec)
        
        self._reset_parameters()

        self.d_model = d_model  # 512
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, mask, query_embed, pos_embed):
        # flatten NxCxHxW image feature(backbone output) to HWxNxC
        bs, c, h, w = src.shape
        src = src.flatten(2).permute(2, 0, 1)  # Encdoer 입력

        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)  # positional encoding
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)  # object query embedding
        mask = mask.flatten(1)  # attention mask

        tgt = torch.zeros_like(query_embed)  # Decoder 입력
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)  # Encoder 출력
        hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,  # tgt, memory : object query, encoder output간의 attn
                          pos=pos_embed, query_pos=query_embed)  # hs : output embedding of object queries
        
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

In [None]:
# Transformer Parameters
d_model = hidden_dim = 256  # 2048 ch -> 256 ch (backbone 출력에 1x1 conv를 통해 차원 축소)
dropout = 0.1
nhead = 8
dim_feedforward = 2048  # Transformer의 FFN layers에서의 intermediate size
num_encoder_layers = enc_layers = 6
num_decoder_layers = dec_layers = 6
normalize_before = 'store_true'
return_intermediate_dec = True 

In [None]:
Transformer = Transformer()

## 3. Hungarian Matcher

Modules to compute the matching cost and solve the corresponding LSAP.

In [None]:
import torch
from scipy.optimize import linear_sum_assignment
from torch import nn

from detr.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou

In [None]:
class HungarianMatcher(nn.Module):
    """This class computes an unique assignment between the targets and the predictions of the network (1대1 매칭)
    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
        """
        Params: matching cost를 구성하는 개별 loss들에 대한 가중치
            cost_class : relative weight of the classification error 
            cost_bbox : relative weight of the L1 error of the bounding box coordinates 
            cost_giou : relative weight of the giou loss of the bounding box 
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"

    @torch.no_grad()
    def forward(self, outputs, targets):
        '''Params:
               outputs: DETR 네트워크의 출력
                   "pred_logits": classification logits - [batch_size, num_queries, num_classes]
                   "pred_boxes": predicted box coordinates - [batch_size, num_queries, 4]

               targets: GT targets (len(targets) = batch_size)
                   "labels": class labels - [num_target_boxes]
                   "boxes": target box coordinates - [num_target_boxes, 4]

            Returns:
                A list of size batch_size, containing tuples of (index_i, index_j) where:
                    - index_i is the indices of the selected predictions (in order)
                    - index_j is the indices of the corresponding selected targets (in order)
                For each batch element, it holds:
                    len(index_i) = len(index_j) = min(num_queries, num_target_boxes)'''

        bs, num_queries = outputs["pred_logits"].shape[:2]

        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

        # Also concat the target labels and boxes
        tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_bbox = torch.cat([v["boxes"] for v in targets])

        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
        # but approximate it in 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be ommitted.
        cost_class = -out_prob[:, tgt_ids]

        # Compute the L1 cost between boxes
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

        # Compute the giou cost betwen boxes
        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))

        # Final cost matrix
        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
        C = C.view(bs, num_queries, -1).cpu()

        sizes = [len(v["boxes"]) for v in targets]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]

        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

In [None]:
set_cost_class = 1
set_cost_bbox = 5
set_cost_giou = 2

In [None]:
HungarianMatcher = HungarianMatcher(cost_class=set_cost_class, cost_bbox=set_cost_bbox, cost_giou=set_cost_giou)

## 4. DETR

DETR module that performs object detection (bbox reg + class label pred)

In [None]:
import torch
import torch.nn.functional as F
from torch import nn

from detr.util import box_ops
from detr.util.misc import (NestedTensor, nested_tensor_from_tensor_list,
                            accuracy, get_world_size, interpolate,
                            is_dist_avail_and_initialized)

### 4.1 Network

In [None]:
class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)
        최종 예측은 hidden_dim=256에 대한 3개의 MLP layer 연산을 통해 출력됨"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers  # 3
        h = [hidden_dim] * (num_layers - 1)
        # input_dim -> hidden_dim -> output_dim 형식
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))  

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x

In [None]:
class DETR(nn.Module):
    def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
        """ 
        Parameters:
            backbone: 1.Backbone
            transformer: 2. Transformer
            num_classes: # of object classes
            num_queries: # of object queries, ie detection slot. (recommend=100)

            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.num_queries = num_queries
        self.transformer = transformer
        hidden_dim = transformer.d_model

        self.class_embed = nn.Linear(hidden_dim, num_classes + 1)  # +1 : no object 추가
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)  # MLP 3개를 거쳐서 예측됨

        self.query_embed = nn.Embedding(num_queries, hidden_dim)  # object query embedding
        self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)  # proj of Transformer
        self.backbone = backbone
        self.aux_loss = aux_loss  # default=False (Bipartite matching loss를 사용)

    def forward(self, samples: NestedTensor):
        """ NestedTensor :
               - samples.tensor: minibatch - [batch_size x 3 x H x W]
               - samples.mask: mask - [batch_size x H x W]

            Return :
               - "pred_logits": the classification logits (including no-object) for all object queries.
                                Shape= [batch_size x num_queries x (num_classes + 1)]
               - "pred_boxes": The normalized boxes coordinates for all object queries - (center_x, center_y, height, width)
                               Postprocessing 과정 필요
        """
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)  # img with mask

        # 1. Backbone    
        features, pos = self.backbone(samples)
        
        # 2. Transformer
        src, mask = features[-1].decompose()
        assert mask is not None
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]  # output embedding

        # 3. Prediction FFNs
        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
            
        return out

### 4.2 Loss

In [None]:
class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """
    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef  # matching loss의 hyper-parameter
        self.losses = losses
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer('empty_weight', empty_weight)

    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
        return losses

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_boxes):
        """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
        """
        pred_logits = outputs['pred_logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
        """
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

        losses = {}
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
            box_ops.box_cxcywh_to_xyxy(src_boxes),
            box_ops.box_cxcywh_to_xyxy(target_boxes)))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        return losses

    def loss_masks(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the masks: the focal loss and the dice loss.
           targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
        """
        assert "pred_masks" in outputs

        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)
        src_masks = outputs["pred_masks"]
        src_masks = src_masks[src_idx]
        masks = [t["masks"] for t in targets]
        # TODO use valid to mask invalid areas due to padding in loss
        target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
        target_masks = target_masks.to(src_masks)
        target_masks = target_masks[tgt_idx]

        # upsample predictions to the target size
        src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
                                mode="bilinear", align_corners=False)
        src_masks = src_masks[:, 0].flatten(1)

        target_masks = target_masks.flatten(1)
        target_masks = target_masks.view(src_masks.shape)
        losses = {
            "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
            "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
        }
        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'boxes': self.loss_boxes,
            'masks': self.loss_masks
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))

        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    if loss == 'masks':
                        # Intermediate masks losses are too costly to compute, we ignore them.
                        continue
                    kwargs = {}
                    if loss == 'labels':
                        # Logging is enabled only for the last layer
                        kwargs = {'log': False}
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                    losses.update(l_dict)

        return losses

### 4.3 PostProcess

In [None]:
class PostProcess(nn.Module):
    """ This module converts the model's output into the format expected by the coco api"""
    @torch.no_grad()
    def forward(self, outputs, target_sizes):
        """ Perform the computation
        Parameters:
            outputs: raw outputs of the model
            target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
                          For evaluation, this must be the original image size (before any data augmentation)
                          For visualization, this should be the image size after data augment, but before padding
        """
        out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']

        assert len(out_logits) == len(target_sizes)
        assert target_sizes.shape[1] == 2

        prob = F.softmax(out_logits, -1)
        scores, labels = prob[..., :-1].max(-1)

        # convert to [x0, y0, x1, y1] format
        boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
        # and from relative [0, 1] to absolute [0, height] coordinates
        img_h, img_w = target_sizes.unbind(1)
        scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        boxes = boxes * scale_fct[:, None, :]

        results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]

        return results

## 5. Build DETR

### 5.1 Setting

In [None]:
num_classes = 90 + 1  # coco
device = 'cuda'

backbone = backbone  # backbone
transformer = Transformer  # Transformer

model = DETR(backbone,
             transformer,
             num_classes=91,
             num_queries=100,  # # of object queries
             aux_loss=False)

matcher = HungarianMatcher

### 5.2 Loss

In [None]:
# 개별 loss에 대한 가중치 설정
weight_dict = {'loss_ce': 1, 'loss_bbox': 5}
weight_dict['loss_giou'] = 2

losses = ['labels', 'boxes', 'cardinality']

criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,
                          eos_coef=0.1, losses=losses)
criterion.to(device)

postprocessors = {'bbox': PostProcess()}

### 5.2 Main.py

In [None]:
import argparse
import datetime
import json
import random
import time
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler

import detr.datasets  # detr/datasets/coco.py에서 import detr.datasets.transforms as T로 수정
import detr.util.misc as utils  # datasets/transforms.py에서 util 앞에 detr. 붙이기
from detr.datasets import build_dataset, get_coco_api_from_dataset
from detr.engine import evaluate, train_one_epoch  # util/engine.py, util/eval.py, panoptic_eval.py에서 turil, datasets앞에 detr. 붙이기
#from detr.models import build_model

In [None]:
model.to(device)

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

In [None]:
param_dicts = [
    {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
    {
     "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
     "lr": 1e-5,
    },
]
optimizer = torch.optim.AdamW(param_dicts, lr=1e-4,
                              weight_decay=1e-4)
    
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 200)

## Appendix. DETR Segmentation

## A.1 Segmentation

convolutional heads used to predict masks, as well as the losses

In [None]:
import io
from collections import defaultdict
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from PIL import Image

import detr.util.box_ops as box_ops
from detr.util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list

try:
    from panopticapi.utils import id2rgb, rgb2id
except ImportError:
    pass

### A.1.1 Multi Head Attention Map

2D attention module, which only returns the attention softmax (no multiplication by value)

In [None]:
class MHAttentionMap(nn.Module):
    def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        self.num_heads = num_heads  # 8
        self.hidden_dim = hidden_dim  # 256
        self.dropout = nn.Dropout(dropout)
        
        # query, key embedding
        self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
        self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)

        nn.init.zeros_(self.k_linear.bias)
        nn.init.zeros_(self.q_linear.bias)
        nn.init.xavier_uniform_(self.k_linear.weight)
        nn.init.xavier_uniform_(self.q_linear.weight)
        self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5

    def forward(self, q, k, mask: Optional[Tensor] = None):
        q = self.q_linear(q)
        k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
        
        # scaled dot
        qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
        kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
        weights = torch.einsum("bqnc, bnchw -> bqnhw", qh * self.normalize_fact, kh)
        
        # mask
        if mask is not None:
            weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))

        weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size())
        weights = self.dropout(weights)

        return weights

### A.2 Multi Head Small Conv

Segmentation mask 예측

Simple convolutional head, using group norm. Upsampling is done using a FPN approach

In [None]:
def _expand(tensor, length: int):
    return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)


class MaskHeadSmallConv(nn.Module):
    def __init__(self, dim, fpn_dims, context_dim):  # 256+8, [1024, 512, 256], 256
        super().__init__()

        inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]  # [264, 128, 64, 32, 16, 4]  

        self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)  # 264 -> 264
        self.gn1 = torch.nn.GroupNorm(8, dim)

        self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)  # 264 -> 128
        self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])

        self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)  # 128 -> 64
        self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])

        self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)  # 64 -> 32
        self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])

        self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)  # 32 -> 16
        self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])

        self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1)  # 16 -> 4 (bbox)

        self.dim = dim
        
        # Feature Pyramid
        self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)  # 1024 -> 128
        self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)  # 512 -> 64
        self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)  # 256 -> 32

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]):
        '''x : backbone의 image feature 출력 (high level & low resolution)
           bbox_mask : Trnasformer의 output embedding, encoder embedding간의 attn
           fpns : backbone에서 사용되는 conv layer들(layer2~4)의 입력 차원'''

        x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)  

        x = self.lay1(x)
        x = self.gn1(x)
        x = F.relu(x)

        x = self.lay2(x)  # 128 ch
        x = self.gn2(x)
        x = F.relu(x)
        
        # 128 -> 64
        cur_fpn = self.adapter1(fpns[0])
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")  # up-sampling
        x = self.lay3(x)
        x = self.gn3(x)
        x = F.relu(x)
        
        # 64 -> 32
        cur_fpn = self.adapter2(fpns[1])
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay4(x)
        x = self.gn4(x)
        x = F.relu(x)
        
        # 32 -> 16
        cur_fpn = self.adapter3(fpns[2])
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay5(x)
        x = self.gn5(x)
        x = F.relu(x)
        
        # 16 -> 4
        x = self.out_lay(x)

        return x

### A.3 DETR Segmentation

In [None]:
class DETRsegm(nn.Module):
    def __init__(self, detr, freeze_detr=False):
        super().__init__()
        self.detr = detr

        if freeze_detr:
            for p in self.parameters():
                p.requires_grad_(False)

        hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead  # 256, 8
        self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0)
        self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim)  # dim, FPN_dims, context_dim
    
    def forward(self, samples: NestedTensor):
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)

        # 1. Backbone    
        features, pos = self.detr.backbone(samples)

        bs = features[-1].tensors.shape[0]
        src, mask = features[-1].decompose()  # img_feature, mask
        assert mask is not None

        # 2. Trnasformer
        src_proj = self.detr.input_proj(src)
        hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1])  # output_embedding, encoder_output

        # 3. Prediction with FFNs
        outputs_class = self.detr.class_embed(hs)  # class pred
        outputs_coord = self.detr.bbox_embed(hs).sigmoid()  # bbox pred
        out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
        if self.detr.aux_loss:
            out['aux_outputs'] = self.detr._set_aux_loss(outputs_class, outputs_coord)
        
        # 4. Segmentation
        # FIXME h_boxes takes the last one computed, keep this in mind
        bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)  # MHA, (query, key, mask)
        
        # layer4, layer3, layer2의 feature map 입력을 FPN으로 이용
        seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors])  
        outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])

        out["pred_masks"] = outputs_seg_masks

        return out

In [None]:
def dice_loss(inputs, targets, num_boxes):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    """
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1)
    numerator = 2 * (inputs * targets).sum(1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss.sum() / num_boxes


def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
    Returns:
        Loss tensor
    """
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss.mean(1).sum() / num_boxes


class PostProcessSegm(nn.Module):
    def __init__(self, threshold=0.5):
        super().__init__()
        self.threshold = threshold

    @torch.no_grad()
    def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
        assert len(orig_target_sizes) == len(max_target_sizes)
        max_h, max_w = max_target_sizes.max(0)[0].tolist()
        outputs_masks = outputs["pred_masks"].squeeze(2)
        outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
        outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()

        for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
            img_h, img_w = t[0], t[1]
            results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
            results[i]["masks"] = F.interpolate(
                results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
            ).byte()

        return results


class PostProcessPanoptic(nn.Module):
    """This class converts the output of the model to the final panoptic result, in the format expected by the
    coco panoptic API """

    def __init__(self, is_thing_map, threshold=0.85):
        """
        Parameters:
           is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether
                          the class is  a thing (True) or a stuff (False) class
           threshold: confidence threshold: segments with confidence lower than this will be deleted
        """
        super().__init__()
        self.threshold = threshold
        self.is_thing_map = is_thing_map

    def forward(self, outputs, processed_sizes, target_sizes=None):
        """ This function computes the panoptic prediction from the model's predictions.
        Parameters:
            outputs: This is a dict coming directly from the model. See the model doc for the content.
            processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the
                             model, ie the size after data augmentation but before batching.
            target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size
                          of each prediction. If left to None, it will default to the processed_sizes
            """
        if target_sizes is None:
            target_sizes = processed_sizes
        assert len(processed_sizes) == len(target_sizes)
        out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"]
        assert len(out_logits) == len(raw_masks) == len(target_sizes)
        preds = []

        def to_tuple(tup):
            if isinstance(tup, tuple):
                return tup
            return tuple(tup.cpu().tolist())

        for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
            out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
        ):
            # we filter empty queries and detection below threshold
            scores, labels = cur_logits.softmax(-1).max(-1)
            keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold)
            cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
            cur_scores = cur_scores[keep]
            cur_classes = cur_classes[keep]
            cur_masks = cur_masks[keep]
            cur_masks = interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
            cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep])

            h, w = cur_masks.shape[-2:]
            assert len(cur_boxes) == len(cur_classes)

            # It may be that we have several predicted masks for the same stuff class.
            # In the following, we track the list of masks ids for each stuff class (they are merged later on)
            cur_masks = cur_masks.flatten(1)
            stuff_equiv_classes = defaultdict(lambda: [])
            for k, label in enumerate(cur_classes):
                if not self.is_thing_map[label.item()]:
                    stuff_equiv_classes[label.item()].append(k)

            def get_ids_area(masks, scores, dedup=False):
                # This helper function creates the final panoptic segmentation image
                # It also returns the area of the masks that appears on the image

                m_id = masks.transpose(0, 1).softmax(-1)

                if m_id.shape[-1] == 0:
                    # We didn't detect any mask :(
                    m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
                else:
                    m_id = m_id.argmax(-1).view(h, w)

                if dedup:
                    # Merge the masks corresponding to the same stuff class
                    for equiv in stuff_equiv_classes.values():
                        if len(equiv) > 1:
                            for eq_id in equiv:
                                m_id.masked_fill_(m_id.eq(eq_id), equiv[0])

                final_h, final_w = to_tuple(target_size)

                seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy()))
                seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST)

                np_seg_img = (
                    torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy()
                )
                m_id = torch.from_numpy(rgb2id(np_seg_img))

                area = []
                for i in range(len(scores)):
                    area.append(m_id.eq(i).sum().item())
                return area, seg_img

            area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
            if cur_classes.numel() > 0:
                # We know filter empty masks as long as we find some
                while True:
                    filtered_small = torch.as_tensor(
                        [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device
                    )
                    if filtered_small.any().item():
                        cur_scores = cur_scores[~filtered_small]
                        cur_classes = cur_classes[~filtered_small]
                        cur_masks = cur_masks[~filtered_small]
                        area, seg_img = get_ids_area(cur_masks, cur_scores)
                    else:
                        break

            else:
                cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)

            segments_info = []
            for i, a in enumerate(area):
                cat = cur_classes[i].item()
                segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a})
            del cur_classes

            with io.BytesIO() as out:
                seg_img.save(out, format="PNG")
                predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
            preds.append(predictions)
        return preds