In [1]:
from mmdet.models import ResNet, FPN
import torch
import numpy as np
import random


def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
setup_seed(0)

# out_indices=(0, 1, 2, 3)
backbone = ResNet( depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
#         norm_cfg = dict(type='BN', requires_grad=True),
#         dcn=dict(type='DCN', deform_groups=1, fallback_on_stride=False),
#         stage_with_dcn=(False, False, True, True)
        )
backbone.eval()
inputs = torch.rand(1, 3, 128, 128)
# net.to("cuda")
# inputs = inputs.to("cuda")
backbone_outputs = backbone(inputs)
print("resnet level_outputs:")
for backbone_output in backbone_outputs:
    print(tuple(backbone_output.shape))
neck = FPN(in_channels=[256, 512, 1024, 2048],
             out_channels=128,
             start_level=1,
             end_level=-1,
             num_outs=5,
             init_cfg=dict(type='Xavier', layer='Conv2d', distribution='uniform'))

neck.eval()
neck_outputs = neck(backbone_outputs)
print("resnet level_outputs:")
for neck_output in neck_outputs:
    print(tuple(neck_output.shape))

resnet level_outputs:
(1, 256, 32, 32)
(1, 512, 16, 16)
(1, 1024, 8, 8)
(1, 2048, 4, 4)
resnet level_outputs:
(1, 128, 16, 16)
(1, 128, 8, 8)
(1, 128, 4, 4)
(1, 128, 2, 2)
(1, 128, 1, 1)


In [2]:
from mmdet.models import Res2Net, BiFPN
import torch
import numpy as np
import random


def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
setup_seed(0)

# out_indices=(0, 1, 2, 3)
backbone = Res2Net( depth=101,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
#         norm_cfg = dict(type='BN', requires_grad=True),
#         dcn=dict(type='DCN', deform_groups=1, fallback_on_stride=False),
        stage_with_dcn=(False, False, True, True))
backbone.eval()
inputs = torch.rand(1, 3, 800, 1152)
# net.to("cuda")
# inputs = inputs.to("cuda")
backbone_outputs = backbone(inputs)
print("resnet level_outputs:")
for backbone_output in backbone_outputs:
    print(tuple(backbone_output.shape))
neck = BiFPN(in_channels=[256, 512, 1024, 2048],
             out_channels=128,
             start_level=1,
             end_level=3,
             num_outs=5,
             num_bifpn=1,
             init_cfg=dict(type='Xavier', layer='Conv2d', distribution='uniform'))
neck.init_weights()
neck.eval()
neck_outputs = neck(backbone_outputs)
print("neck level_outputs:")
for neck_output in neck_outputs:
    print(tuple(neck_output.shape))

resnet level_outputs:
(1, 256, 200, 288)
(1, 512, 100, 144)
(1, 1024, 50, 72)
(1, 2048, 25, 36)
neck level_outputs:
(1, 128, 100, 144)
(1, 128, 50, 72)
(1, 128, 25, 36)
(1, 128, 13, 18)
(1, 128, 7, 9)


In [3]:
from mmcv.cnn import ConvModule, Scale
from mmdet.core import multi_apply
import torch.nn as nn
from mmdet.models.builder import HEADS, build_loss
from mmdet.models.dense_heads.base_dense_head import BaseDenseHead
from mmdet.models.dense_heads.dense_test_mixins import BBoxTestMixin

INF = 1e8

class CenterNetV2Head(BaseDenseHead, BBoxTestMixin):
    """Objects as Points Head. CenterV2Head use center_point to indicate object's
    position. Paper link <https://arxiv.org/abs/2103.07461>

    Args:
        in_channels (int): Number of channel in the input feature map.
        feat_channels (int): Number of channel in the intermediate feature map.
        stacked_convs (int):
        strides (list or tuple[int])
        regress_ranges (tuple[tuple[int, int]]):
        dcn_on_last_conv (bool):
        loss_heatmap:
        loss_bbox:
        conv_cfg (dict):
        norm_cfg (dict):
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None
    """

    def __init__(self,
                 in_channels,
                 feat_channels,
                 stacked_convs=4,
                 strides=(8, 16, 32, 64, 128),
                 regress_ranges=((-1, 64), (64, 128), (128, 256),
                                 (256, 512), (512, INF),),
                 dcn_on_last_conv=True,
                 loss_heatmap=dict(
                     type='FocalLoss',
                     use_sigmoid=True,
                     gamma=2.0,
                     alpha=0.25,
                     loss_weight=1.0),
                 loss_bbox=dict(type='IoULoss', loss_weight=1.0),
                 conv_cfg=None,
                 norm_cfg=None,
                 init_cfg=None):
        super(CenterNetV2Head, self).__init__()
        self.regress_ranges = regress_ranges
        self.in_channels = in_channels
        self.feat_channels = feat_channels
        self.stacked_convs = stacked_convs
        self.strides = strides
        self.regress_ranges = regress_ranges
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.dcn_on_last_conv = dcn_on_last_conv
        self._init_layers()
        self.loss_heatmap = build_loss(loss_heatmap)
        self.loss_bbox = build_loss(loss_bbox)

    def _init_layers(self):
        """Initialize layers of the head."""
        self._init_convs()
        self._init_predictor()
        self.scales = nn.ModuleList([Scale(1.0) for _ in self.regress_ranges])

    def _init_convs(self):
        """Build head for each branch."""
        
        for name in ['cls', 'reg']:
            tower = nn.ModuleList()
            for i in range(self.stacked_convs):
                in_channel = self.in_channels if i == 0 else self.feat_channels
                if self.dcn_on_last_conv and i == self.stacked_convs - 1:
                    conv_cfg = dict(type='DCNv2')
                else:
                    conv_cfg = self.conv_cfg
                tower.append(
                    ConvModule(
                        in_channel,
                        self.feat_channels,
                        3,
                        stride=1,
                        padding=1,
                        conv_cfg=conv_cfg,
                        norm_cfg=self.norm_cfg))
            self.add_module(f'{name}_convs', tower)

    def _init_predictor(self):
        # background or foreground, class agnostic heatmap
        self.heatmap_head = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
        self.bbox_head = nn.Conv2d(self.feat_channels, 4, 3, padding=1)

    def forward(self, feats):
        """Forward features.
        Args:
            feats (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.
        Returns:
            center_heatmap_preds (List[Tensor]): center predict heatmaps for
                all levels, class agnostic, which means out channel is 1.
            bbox_preds (List[Tensor]): Box energies / deltas for all scale
            levels, each is a 4D-tensor, the channel number is num_points * 4.
        """
        return multi_apply(self.forward_single, feats)

    def forward_single(self, feat):
        """Forward a single level feature.
        Args:
            feat (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.
        Returns:
            center_heatmap_preds (List[Tensor]): center predict heatmaps for
                one level, class agnostic, which means out channel is 1.
            bbox_preds (List[Tensor]): Box energies / deltas for one scale
            level, each is a 4D-tensor, the channel number is num_points * 4.
        """
        center_heatmap_preds = self.heatmap_head(feat).sigmoid()
        bbox_preds = self.bbox_head(feat)
        return center_heatmap_preds, bbox_preds
    
    def get_bboxes():
        NotImplementedError
        
    def get_targets():
        NotImplementedError    
    
    def loss():
        NotImplementedError

In [4]:
rpn = CenterNetV2Head(128,128)
center_heatmap_preds, bbox_preds = rpn(neck_outputs)
for center_heatmap_pred, bbox_pred in zip(center_heatmap_preds, bbox_preds):
    print(center_heatmap_pred.shape, bbox_pred.shape)

torch.Size([1, 1, 100, 144]) torch.Size([1, 4, 100, 144])
torch.Size([1, 1, 50, 72]) torch.Size([1, 4, 50, 72])
torch.Size([1, 1, 25, 36]) torch.Size([1, 4, 25, 36])
torch.Size([1, 1, 13, 18]) torch.Size([1, 4, 13, 18])
torch.Size([1, 1, 7, 9]) torch.Size([1, 4, 7, 9])


In [5]:
img_metas = [{'filename': '/amiintellect/turing/young/competition/mmdetection/works/data/oil/images/qx-307.png', 'ori_filename': 'qx-307.png', 'ori_shape': (495, 701, 3), 'img_shape': (800, 1133, 3), 'pad_shape': (800, 1152, 3), 'scale_factor': np.array([1.6162624, 1.6161616, 1.6162624, 1.6161616], dtype=np.float32), 'flip': False, 'flip_direction': None, 'img_norm_cfg': {'mean': np.array([102.9801, 115.9465, 122.7717], dtype=np.float32), 'std': np.array([1., 1., 1.], dtype=np.float32), 'to_rgb': False}}]
neck_outputs = neck_outputs
featmap_sizes = []
device = 'cpu'
gt_bboxes = [torch.Tensor([[ 408.0027,  315.4001,  611.0715,  372.3802],
        [ 645.2617,  332.2336,  850.4027,  387.9202],
        [ 885.6290,  345.4442, 1093.8781,  402.4242],
        [ 346.2424,  328.8053,  366.4457,  422.1526],
        [ 398.5853,  388.6484,  601.4492,  445.4065],
        [ 637.1537,  403.9441,  840.6491,  461.7406],
        [ 876.7070,  416.5464, 1083.5500,  477.1525],
        [ 390.0223,  460.6703,  598.7449,  520.6074],
        [ 627.0999,  475.6545,  832.0861,  535.1974],
        [ 871.8967,  489.6585, 1077.7776,  547.3785],
        [ 863.1270,  564.4899, 1066.9784,  629.7786],
        [ 209.1294,  639.4113,  306.5333,  675.3246],
        [ 833.9768,  591.5006,  854.5640,  608.3724],
        [ 797.9386,  602.4206,  813.9556,  626.0806]])]

agn_hms = center_heatmap_preds
bbox_preds = bbox_preds
strides = (8, 16, 32, 64, 128) 
regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512), (512, INF),)
center_sample_radius = 1.5
norm_on_bbox = True

In [6]:
 def get_points(featmap_sizes, dtype, device, flatten=False):
        """Get points according to feature map sizes.

        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            dtype (torch.dtype): Type of points.
            device (torch.device): Device of points.

        Returns:
            tuple: points of each image.
        """
        mlvl_points = []
        for i in range(len(featmap_sizes)):
            mlvl_points.append(
                _get_points_single(featmap_sizes[i], strides[i], dtype, device, flatten))
        return mlvl_points


def _get_points_single(featmap_size, stride, dtype, device, flatten=False):
    """Get points of a single scale level."""
    h, w = featmap_size
    # First create Range with the default dtype, than convert to
    # target `dtype` for onnx exporting.
    x_range = torch.arange(w, device=device).to(dtype)
    y_range = torch.arange(h, device=device).to(dtype)
    y, x = torch.meshgrid(y_range, x_range)
    if flatten:
        y = y.flatten()
        x = x.flatten()
    points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride), dim=-1) + stride // 2
    return points

In [7]:
from mmdet.core import distance2bbox, multi_apply, multiclass_nms, reduce_mean

def _get_label_indices(gt_bboxes_list, feature_map_sizes):
    
    self_strides = (8, 16, 32, 64, 128) 
    
    pos_indices = []
    L = len(self_strides)   ############ self
    B = len(gt_bboxes_list)
    shapes_per_level = torch.tensor(feature_map_sizes).long()
    loc_per_level = (shapes_per_level[:, 0] * shapes_per_level[:, 1]).long() # L
    level_bases = []
    s = 0
    for l in range(L):
        level_bases.append(s)
        s = s + B * loc_per_level[l]
    level_bases = shapes_per_level.new_tensor(level_bases) # L    
    strides_default = level_bases.new_tensor(self_strides).float() # L ############ self
    for im_i in range(B):
        bboxes = gt_bboxes_list[im_i]
        n = bboxes.shape[0]
        centers = ((bboxes[:, [0, 1]] + bboxes[:, [2, 3]]) / 2)
        centers = centers.view(n, 1, 2).expand(n, L, 2)
        strides = strides_default.view(1, L, 1).expand(n, L, 2)
        centers_inds = (centers / strides).long()
        Ws = shapes_per_level[:, 1].view(1, L).expand(n, L)
        pos_ind = level_bases.view(1, L).expand(n, L) + \
                   im_i * loc_per_level.view(1, L).expand(n, L) + \
                   centers_inds[:, :, 1] * Ws + \
                   centers_inds[:, :, 0] # n x L
        is_cared_in_the_level = _assign_fpn_level(bboxes)  # ################# self
        pos_ind = pos_ind[is_cared_in_the_level].view(-1)
        pos_indices.append(pos_ind)
    pos_indices = torch.cat(pos_indices, dim=0).long()
    return pos_indices
    
def _assign_fpn_level(boxes):
    
    self_regress_ranges = ((-1, 64), (64, 128), (128, 256), (256, 512), (512, INF),)
    
    '''
    Inputs:
        boxes: n x 4
        size_ranges: L x 2
    Return:
        is_cared_in_the_level: n x L
    '''
    size_ranges = boxes.new_tensor(
        self_regress_ranges).view(len(self_regress_ranges), 2) 
    crit = ((boxes[:, 2:] - boxes[:, :2]) **2).sum(dim=1) ** 0.5 / 2 # n
    n, L = crit.shape[0], size_ranges.shape[0]
    crit = crit.view(n, 1).expand(n, L)
    size_ranges_expand = size_ranges.view(1, L, 2).expand(n, L, 2)
    is_cared_in_the_level = (crit >= size_ranges_expand[:, :, 0]) & \
        (crit <= size_ranges_expand[:, :, 1])
    return is_cared_in_the_level

def get_center3x3(locations, centers, strides):
        '''
        Inputs:
            locations: M x 2
            centers: N x 2
            strides: M
        '''
        M, N = locations.shape[0], centers.shape[0]
        locations_expanded = locations.view(M, 1, 2).expand(M, N, 2) # M x N x 2
        centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2
        strides_expanded = strides.view(M, 1, 1).expand(M, N, 2) # M x N
        centers_discret = ((centers_expanded / strides_expanded).int() * \
            strides_expanded).float() + strides_expanded / 2 # M x N x 2
        dist_x = (locations_expanded[:, :, 0] - centers_discret[:, :, 0]).abs()
        dist_y = (locations_expanded[:, :, 1] - centers_discret[:, :, 1]).abs()
        return (dist_x <= strides_expanded[:, :, 0]) & \
            (dist_y <= strides_expanded[:, :, 0])

def assign_reg_fpn(reg_targets_per_im, size_ranges):
        '''
        TODO (Xingyi): merge it with assign_fpn_level
        Inputs:
            reg_targets_per_im: M x N x 4
            size_ranges: M x 2
        '''
        crit = ((reg_targets_per_im[:, :, :2] + \
            reg_targets_per_im[:, :, 2:])**2).sum(dim=2) ** 0.5 / 2 # M x N
        is_cared_in_the_level = (crit >= size_ranges[:, [0]]) & \
            (crit <= size_ranges[:, [1]])
        return is_cared_in_the_level

def _get_reg_targets(reg_targets, dist, mask, area):
        '''
          reg_targets (M x N x 4): long tensor
          dist (M x N)
          is_*: M x N
        '''
        dist[mask == 0] = INF * 1.0
        min_dist, min_inds = dist.min(dim=1) # M
        reg_targets_per_im = reg_targets[
            range(len(reg_targets)), min_inds] # M x N x 4 --> M x 4
        reg_targets_per_im[min_dist == INF] = - INF
        return reg_targets_per_im
    
def _create_agn_heatmaps_from_dist(dist):
        '''
        TODO (Xingyi): merge it with _create_heatmaps_from_dist
        dist: M x N
        return:
          heatmaps: M x 1
        '''
        heatmaps = dist.new_zeros((dist.shape[0], 1))
        heatmaps[:, 0] = torch.exp(-dist.min(dim=1)[0])
        zeros = heatmaps < 1e-4
        heatmaps[zeros] = 0
        return heatmaps 
    
def _transpose(training_targets, num_loc_list):
    '''
    This function is used to transpose image first training targets to 
        level first ones
    :return: level first training targets
    '''
    for im_i in range(len(training_targets)):
        training_targets[im_i] = torch.split(
            training_targets[im_i], num_loc_list, dim=0)

    targets_level_first = []
    for targets_per_level in zip(*training_targets):
        targets_level_first.append(
            torch.cat(targets_per_level, dim=0))
    return targets_level_first    
    
def get_targets(points, feature_map_sizes, gt_bboxes_list):
    
        self_strides = (8, 16, 32, 64, 128) 
        self_regress_ranges = ((-1, 64), (64, 128), (128, 256), (256, 512), (512, INF),)
        self_get_center3x3 = get_center3x3
        self_assign_reg_fpn = assign_reg_fpn
        self_hm_min_overlap = 0.8
        self_min_radius = 4
        self__get_reg_targets = _get_reg_targets
        self__create_agn_heatmaps_from_dist = _create_agn_heatmaps_from_dist
    
        """Compute regression and classification for points in multiple images.

        Args:
            points (list[Tensor]): Points of each fpn level, each has shape
                (num_points, 2).
            feature_map_sizes (list[torch.Size[2]]):
            gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
                each has shape (num_gt, 4).

        Returns:
            tuple:
                
        """
        
        # Get positive pixel index
        pos_indices = _get_label_indices(gt_bboxes_list, feature_map_sizes)
        L = len(points)
        num_loc_list = [len(loc) for loc in points]
        strides = torch.cat([
            points[0].new_ones(num_loc_list[l]) * self_strides[l] for l in range(L)]).float() # M
        reg_size_ranges = torch.cat([points[0].new_tensor(self_regress_ranges[l]).float().view(
            1, 2).expand(num_loc_list[l], 2) for l in range(L)])
        points = torch.cat(points, dim=0) # M x 2
        M =  points.shape[0]
        reg_targets = []
        flattened_hms = []
        for i in range(len(gt_bboxes_list)):
            boxes = gt_bboxes_list[i]
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            N = boxes.shape[0]
            if N == 0:
                reg_targets.append(points.new_zeros((M, 4)) - INF)
                flattened_hms.append(points.new_zeros((M, 1)))
                continue
            l = points[:, 0].view(M, 1) - boxes[:, 0].view(1, N) # M x N
            t = points[:, 1].view(M, 1) - boxes[:, 1].view(1, N) # M x N
            r = boxes[:, 2].view(1, N) - points[:, 0].view(M, 1) # M x N
            b = boxes[:, 3].view(1, N) - points[:, 1].view(M, 1) # M x N
            reg_target = torch.stack([l, t, r, b], dim=2) # M x N x 4
            
            centers = ((boxes[:, [0, 1]] + boxes[:, [2, 3]]) / 2) # N x 2
            centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2
            strides_expanded = strides.view(M, 1, 1).expand(M, N, 2)
            centers_discret = ((centers_expanded / strides_expanded).int() * \
                strides_expanded).float() + strides_expanded / 2 # M x N x 2
            
            is_peak = (((points.view(M, 1, 2).expand(M, N, 2) - \
                centers_discret) ** 2).sum(dim=2) == 0) # M x N
            is_in_boxes = reg_target.min(dim=2)[0] > 0 # M x N
            is_center3x3 = self_get_center3x3(points, centers, strides) & is_in_boxes # M x N
            is_cared_in_the_level = self_assign_reg_fpn(reg_target, reg_size_ranges) 
            reg_mask = is_center3x3 & is_cared_in_the_level
            
            dist2 = ((points.view(M, 1, 2).expand(M, N, 2) - \
                centers_expanded) ** 2).sum(dim=2) # M x N
            dist2[is_peak] = 0
            delta = (1 - self_hm_min_overlap) / (1 + self_hm_min_overlap)
            radius2 = delta ** 2 * 2 * area # N
            radius2 = torch.clamp(radius2, min=self_min_radius ** 2)
            weighted_dist2 = dist2 / radius2.view(1, N).expand(M, N)
            
            reg_target = self__get_reg_targets(reg_target, weighted_dist2.clone(), reg_mask, area) 
            flattened_hm = self__create_agn_heatmaps_from_dist(weighted_dist2.clone())
            
            reg_targets.append(reg_target)
            flattened_hms.append(flattened_hm)
        
        # transpose im first training_targets to level first ones
        reg_targets = _transpose(reg_targets, num_loc_list)
        flattened_hms = _transpose(flattened_hms, num_loc_list)
        
        for l in range(len(reg_targets)):
            reg_targets[l] = reg_targets[l] / float(self_strides[l])
        reg_targets = torch.cat([x for x in reg_targets], dim=0) # MB x 4
        flattened_hms = torch.cat([x for x in flattened_hms], dim=0) # MB x C
        
        return pos_indices,  reg_targets, flattened_hms

In [8]:
def losses(pos_indices, reg_targets, flattened_hms, reg_pred, agn_hm_pred):
    
        self_not_norm_reg = False
        from detectron2.utils.comm import get_world_size
        from mmdet.models.builder import build_loss
        from mmdet.core import distance2bbox, multi_apply, multiclass_nms, reduce_mean
        
        loss_hm=dict(type='BinaryFocalLoss',
              alpha = 0.25,
              beta = 4,
              gamma = 2,
              pos_weight = 0.5,
              neg_weight = 0.5,
              sigmoid_clamp = 1e-4,
              ignore_high_fp = 0.85)
        loss_bbox=dict(type='GIoULoss',
                      loss_weight=1.0)
        self_loss_hm = build_loss(loss_hm)
        self_loss_bbox = build_loss(loss_bbox)
        
        def reduce_sum(tensor):
            world_size = get_world_size()
            if world_size < 2:
                return tensor
            tensor = tensor.clone()
            torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
            return tensor
        
        '''
        Inputs:
            pos_indices: N
            reg_targets: M x 4
            flattened_hms: M x C
            reg_pred: M x 4
            agn_hm_pred: M x 1 or None
            N: number of positive locations in all images
            M: number of pixels from all FPN levels
            C: number of classes
        '''
        
        agn_hm_pred = torch.cat([x.permute(0, 2, 3, 1).reshape(-1) for x in agn_hm_pred], dim=0) 
        reg_pred = torch.cat([x.permute(0, 2, 3, 1).reshape(-1, 4) for x in reg_pred], dim=0)
        
    
        num_pos_local = pos_indices.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(
            pos_indices.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)
        
        reg_inds = torch.nonzero(reg_targets.max(dim=1)[0] >= 0).squeeze(1)
        reg_pred = reg_pred[reg_inds]
        
        reg_targets_pos = reg_targets[reg_inds]
        reg_weight_map = flattened_hms.max(dim=1)[0]
        reg_weight_map = reg_weight_map[reg_inds]
        reg_weight_map = reg_weight_map * 0 + 1 \
            if self_not_norm_reg else reg_weight_map
        reg_norm = max(reduce_sum(reg_weight_map.sum()).item() / num_gpus, 1)
        reg_loss = self_loss_bbox(reg_pred, 
                                  reg_targets_pos, 
                                  reg_weight_map,
                                  avg_factor=reg_norm)

        
        cat_agn_heatmap = flattened_hms.max(dim=1)[0] # M
        loss_hm = self_loss_hm(agn_hm_pred, cat_agn_heatmap, pos_indices, avg_factor=num_pos_avg)
        
        return dict(loss_hm=loss_hm, loss_bbox=reg_loss)

In [9]:
feature_map_sizes = [featmap.size()[-2:] for featmap in agn_hms]
all_level_points = get_points(feature_map_sizes, bbox_preds[0].dtype, bbox_preds[0].device)
pos_indices,  reg_targets, flattened_hms = get_targets(all_level_points, feature_map_sizes, gt_bboxes)
print(f'{int((flattened_hms.reshape(-1)>0).int().sum())} hm samples' )
print(f'{int((reg_targets.max(-1)[0].reshape(-1)>= 0).int().sum())} bbox samples')
losses(pos_indices,  reg_targets, flattened_hms, bbox_preds, agn_hms)

1682 hm samples
120 bbox samples


{'loss_hm': tensor(190.4738, grad_fn=<DivBackward0>),
 'loss_bbox': tensor(1.9135, grad_fn=<MulBackward0>)}

In [164]:
def predict_single_level(points, heatmap, reg_pred, img_metas, cfg=None):
    
    cfg = dict(
            nms_pre=4000,
            max_per_img=2000,
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0)
    
    N, C, H, W = heatmap.shape
    image_sizes = [img_meta['img_shape'] for img_meta in img_metas]
    heatmap = heatmap.reshape(N, -1, C) # N x HW x C
    box_regression = reg_pred.view(N, 4, H, W).permute(0, 2, 3, 1) # N x H x W x 4 
    box_regression = box_regression.reshape(N, -1, 4)
    
    candidate_inds = heatmap > cfg.get('score_thr', -1)
    pre_nms_top_n = candidate_inds.view(N, -1).sum(1) # N
    pre_nms_topk = cfg.get('nms_pre', 1000)
    pre_nms_top_n = pre_nms_top_n.clamp(max=pre_nms_topk)
      
    result = []
    for i in range(N):
        per_box_cls = heatmap[i] # HW x C
        per_candidate_inds = candidate_inds[i] # n
        per_box_cls = per_box_cls[per_candidate_inds] # n

        per_candidate_nonzeros = per_candidate_inds.nonzero() # n
        per_box_loc = per_candidate_nonzeros[:, 0] # n
        per_class = per_candidate_nonzeros[:, 1] # n

        per_box_regression = box_regression[i] # HW x 4
        per_box_regression = per_box_regression[per_box_loc] # n x 4
        per_points = points[per_box_loc] # n x 2

        pre_nms_top_n = pre_nms_top_n[i] # 1

        if per_candidate_inds.sum().item() > pre_nms_top_n.item():
            per_box_cls, top_k_indices = \
                per_box_cls.topk(pre_nms_top_n, sorted=False)
            per_class = per_class[top_k_indices]
            per_box_regression = per_box_regression[top_k_indices]
            per_points = per_points[top_k_indices]

        detections = torch.stack([
            per_points[:, 0] - per_box_regression[:, 0],
            per_points[:, 1] - per_box_regression[:, 1],
            per_points[:, 0] + per_box_regression[:, 2],
            per_points[:, 1] + per_box_regression[:, 3],
        ], dim=1) # n x 4
        
        # avoid invalid boxes in RoI heads
        detections[:, 2] = torch.max(detections[:, 2], detections[:, 0] + 0.01)
        detections[:, 3] = torch.max(detections[:, 3], detections[:, 1] + 0.01)
        
        
        result.append(dict(proposals=detections,score=per_box_cls))
    return result   
        
            
reg_pred_per_level = bbox_preds
agn_hm_pred_per_level = agn_hms
points = all_level_points
sampled_boxes = []

for l in range(len(points)):
    sampled_boxes.append(predict_single_level(points[l],agn_hms[l], bbox_preds[l], img_metas))     
    break

In [153]:
from mmcv.ops import batched_nms
boxes = torch.randn(10, 4)*1000
scores = torch.randn(10).clamp(min=0.5)
idxs = torch.cat([torch.zeros(5), torch.ones(5)]).long()
nms_cfg = dict(max_num=10, iou_threshold=0.7)
batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False)

(tensor([[-1.4262e+03,  1.6850e+03,  1.7673e+02, -1.0880e+03,  1.5962e+00],
         [ 1.2151e+03, -2.9039e+03, -1.6933e+03,  2.3907e+02,  9.5319e-01],
         [ 1.2713e+02,  2.0146e+03,  2.9388e+02, -7.0277e+02,  9.2745e-01],
         [-9.1963e+02, -2.6573e+02,  1.2963e+03,  1.5221e+03,  5.0000e-01],
         [-1.6174e+02, -1.1654e+03,  4.3114e+01, -2.5149e+03,  5.0000e-01],
         [ 1.1953e+03, -1.7077e+03, -1.0109e+03,  1.3001e+03,  5.0000e-01],
         [ 2.4223e+03,  1.2864e+01,  8.8145e+02,  9.2308e+01,  5.0000e-01],
         [-1.6230e+02,  2.1452e+03, -2.8130e+02, -1.9562e+03,  5.0000e-01],
         [ 1.9333e+03, -1.0536e+03,  3.5697e+02,  5.3240e+02,  5.0000e-01],
         [-5.0509e+02,  1.5034e+02,  6.1563e+02,  7.6775e+02,  5.0000e-01]]),
 tensor([9, 5, 6, 0, 1, 2, 3, 4, 7, 8]))

In [203]:
gt_bboxes = torch.Tensor([[ 408.0027,  315.4001,  611.0715,  372.3802],
        [ 645.2617,  332.2336,  850.4027,  387.9202],
        [ 885.6290,  345.4442, 1093.8781,  402.4242],
        [ 346.2424,  328.8053,  366.4457,  422.1526],
        [ 398.5853,  388.6484,  601.4492,  445.4065],
        [ 637.1537,  403.9441,  840.6491,  461.7406],
        [ 876.7070,  416.5464, 1083.5500,  477.1525],
        [ 390.0223,  460.6703,  598.7449,  520.6074],
        [ 627.0999,  475.6545,  832.0861,  535.1974],
        [ 871.8967,  489.6585, 1077.7776,  547.3785],
        [ 863.1270,  564.4899, 1066.9784,  629.7786],
        [ 209.1294,  639.4113,  306.5333,  675.3246],
        [ 833.9768,  591.5006,  854.5640,  608.3724],
        [ 797.9386,  602.4206,  813.9556,  626.0806]])
scores = torch.randn(gt_bboxes.shape[0])
nms_cfg = dict(
                type='nms', 
                iou_threshold=0.5,
                max_num=256)
idxs = torch.cat([torch.zeros(gt_bboxes.shape[0]//2), torch.ones(gt_bboxes.shape[0]//2)]).long()
print(idxs)
_, keep = batched_nms(gt_bboxes, scores, idxs, nms_cfg, class_agnostic=False)

tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1])


In [135]:
gt_bboxes.split(7, dim=0)

(tensor([[ 408.0027,  315.4001,  611.0715,  372.3802],
         [ 645.2617,  332.2336,  850.4027,  387.9202],
         [ 885.6290,  345.4442, 1093.8781,  402.4242],
         [ 346.2424,  328.8053,  366.4457,  422.1526],
         [ 398.5853,  388.6484,  601.4492,  445.4065],
         [ 637.1537,  403.9441,  840.6491,  461.7406],
         [ 876.7070,  416.5464, 1083.5500,  477.1525]]),
 tensor([[ 390.0223,  460.6703,  598.7449,  520.6074],
         [ 627.0999,  475.6545,  832.0861,  535.1974],
         [ 871.8967,  489.6585, 1077.7776,  547.3785],
         [ 863.1270,  564.4899, 1066.9784,  629.7786],
         [ 209.1294,  639.4113,  306.5333,  675.3246],
         [ 833.9768,  591.5006,  854.5640,  608.3724],
         [ 797.9386,  602.4206,  813.9556,  626.0806]]))

In [205]:
ranked_scores, rank_inds = scores.sort(descending=True)

In [208]:
gt_bboxes[rank_inds][:3,:]

tensor([[ 876.7070,  416.5464, 1083.5500,  477.1525],
        [ 408.0027,  315.4001,  611.0715,  372.3802],
        [ 398.5853,  388.6484,  601.4492,  445.4065]])

In [14]:
import torch
from mmdet.core import reduce_mean

def get_world_size() -> int:
    if not torch.distributed.is_available():
        return 1
    if not torch.distributed.is_initialized():
        return 1
    return torch.distributed.get_world_size()

def reduce_sum(tensor):
    world_size = get_world_size()
    if world_size < 2:
        return tensor
    tensor = tensor.clone()
    torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
    return tensor

pos_inds = (torch.randn(10) > 0.5).nonzero().squeeze(-1)
num_pos_local = pos_inds.numel()
num_gpus = get_world_size()
total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
num_pos_avg = max(total_num_pos / num_gpus, 1.0)
print(num_pos_avg)

print(max(reduce_mean(num_pos_local), 1.0))

3.0
3


In [17]:
pos_inds = (torch.randn(10) > 0.5).nonzero().squeeze(-1)
print(pos_inds)
print(pos_inds.max(dim=1)[0])

tensor([1, 2, 8, 9])


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)