In [16]:
import numpy as np
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
import torch
import torch.nn as nn
import torch.nn.functional as F

from matplotlib.lines import Line2D
from matplotlib.patches import Patch

from src.anchor.anchor_generator import (gen_base_anchors, get_anchors, 
                              grid_anchors, meshgrid)
from src.anchor.assigner import assign_wrt_overlaps, bbox_overlaps
from src.anchor.loss import binary_cross_entropy, smooth_l1_loss
from src.anchor.prediction import predict_anchors
from src.anchor.transforms import bbox2delta, delta2bbox
from src.anchor.visualize import (draw_anchor_gt_overlaps, draw_anchor_samples_on_image, 
                       draw_base_anchor_on_grid, draw_pos_assigned_bboxes)
from src.datasets.loader.build_loader import build_dataloader
from src.models.builder import build_backbone, build_neck, build_head
from mmcv.runner import obj_from_dict
from mmcv.utils.config import Config
from src.core import multi_apply, weighted_smoothl1, weighted_sigmoid_focal_loss
from src.core.anchor import anchor_target_single, images_to_levels, unmap, anchor_inside_flags, expand_binary_labels

from src.core.bbox import assign_and_sample, build_assigner, PseudoSampler, bbox2delta

%store -r anchor_target_variable

In [7]:
anchor_target_variable

([tensor([[ -19.,   -7.,   26.,   14.],
          [ -25.,  -10.,   32.,   17.],
          [ -32.,  -14.,   39.,   21.],
          ...,
          [1163.,  342., 1524., 1065.],
          [1116.,  248., 1571., 1159.],
          [1057.,  129., 1630., 1278.]])],
 [tensor([1, 1, 1,  ..., 1, 1, 1], device='cuda:0', dtype=torch.uint8)],
 [tensor([[ 748.0838,  304.4447,  980.2133,  747.1882],
          [ 707.9063,   46.1551, 1026.3268,  670.4366],
          [ 982.3378,  359.9517, 1055.0696,  458.0521],
          [1012.2678,  381.8004, 1073.9814,  452.5743]])],
 [None],
 [None],
 [{'ori_shape': (360, 640, 3),
   'img_shape': (750, 1333, 3),
   'pad_shape': (768, 1344, 3),
   'scale_factor': 2.0828125,
   'flip': False}])

In [22]:
(anchor_list, valid_flag_list, gt_bboxes, gt_bboxes_ignore_list, gt_labels_list, img_metas, train_cfg) = anchor_target_variable

flat_anchors = anchor_list[0]
valid_flags = valid_flag_list[0]
gt_bboxes = gt_bboxes[0]
gt_bboxes_ignore = gt_bboxes_ignore_list[0]
gt_labels = gt_labels_list[0]
img_meta = img_metas[0]
cfg = train_cfg
label_channels=1

inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
                                   img_meta['img_shape'][:2],
                                   cfg.allowed_border)
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]


bbox_assigner = build_assigner(cfg.assigner)
assign_result = bbox_assigner.assign(anchors, gt_bboxes,
                                     gt_bboxes_ignore, gt_labels)
bbox_sampler = PseudoSampler()
sampling_result = bbox_sampler.sample(assign_result, anchors,
                                      gt_bboxes)

num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
labels = anchors.new_zeros(num_valid_anchors, dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)

pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
    pos_bbox_targets = bbox2delta(sampling_result.pos_bboxes,
                                  sampling_result.pos_gt_bboxes)
    bbox_targets[pos_inds, :] = pos_bbox_targets
    bbox_weights[pos_inds, :] = 1.0
    if gt_labels is None:
        labels[pos_inds] = 1
    else:
        labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
    if cfg.pos_weight <= 0:
        label_weights[pos_inds] = 1.0
    else:
        label_weights[pos_inds] = cfg.pos_weight
if len(neg_inds) > 0:
    label_weights[neg_inds] = 1.0

# map up to original set of anchors

num_total_anchors = flat_anchors.size(0)
labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags)
if label_channels > 1:
    labels, label_weights = expand_binary_labels(
        labels, label_weights, label_channels)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)

print(labels, label_weights, bbox_targets, bbox_weights, pos_inds,
        neg_inds)

tensor([0, 0, 0,  ..., 0, 0, 0]) tensor([1., 1., 1.,  ..., 1., 1., 1.]) tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        ...,
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]]) tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        ...,
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]]) tensor([ 76775,  78278,  78281,  78287,  78290,  78296,  78299,  79790,  79793,
         79799,  79802,  79808,  79811,  81302,  81311,  81320, 163115, 163866,
        163867, 163868, 163870, 163871, 163876, 163877, 163879, 163880, 163888,
        164614, 164615, 164622, 164623, 164624, 164625, 164626, 164627, 164628,
        164631, 164632, 164633, 164634, 164635, 164636, 164637, 164640, 164641,
        164643, 164644, 165371, 165378, 165379, 165380, 165382, 165383, 165384,
        165387, 165388, 165389, 165390, 165391, 165392, 165393, 165396, 165397,
        165399, 165400, 1661