In [162]:
import torch
import torch.nn.functional as F
from torch import nn

import torchvision
from torchvision.ops import boxes as box_ops

In [97]:
    preds = [
        torch.FloatTensor(10, 20, 4).random_(0, 1),
        torch.LongTensor(10, 20).random_(0, 21),
        torch.FloatTensor(20, 4).random_(0, 201) / 200
    ]

    targets = {
       'boxes': torch.FloatTensor(10, 12, 4).random_(0, 200) / 200,
       'labels':  torch.LongTensor(10, 12).random_(1, 21),
    }

    preds_loc_delta, preds_conf, anchors = preds
    gt_boxes = targets['boxes']
    labels = targets['labels']


# self.box_similarity(gt_boxes, anchors)
- Iou_mat by gt_box vs anchors 

In [98]:
def box_similarity(truths, anchors):
    IoUs_mat = box_ops.box_iou(truths, anchors)
    
    return IoUs_mat

In [104]:
l = labels[0][:]
b = gt_boxes[0][:, :]
b[:, 2:] += b[:, :2]
anchors[:, 2:] += anchors[:, :2]
print(l.shape, b.shape, anchors.shape)

mqm = box_similarity(b, anchors)

# (obj_N, anchors_N)
# obj_N.max <- mqm.max(dim=1)
# anchor_N.max <- mqm.max(dim=0)

torch.Size([12]) torch.Size([12, 4]) torch.Size([20, 4])


# self.proposal_matcher

In [101]:
class Matcher(object):
    BELOW_LOW_THRESHOLD = -1
    BETWEEN_THRESHOLDS = -2

    def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
        assert low_threshold <= high_threshold
        self.high_threshold = high_threshold
        self.low_threshold = low_threshold
        self.allow_low_quality_matches = allow_low_quality_matches

    def __call__(self, match_quality_matrix):
        matched_vals, matches = match_quality_matrix.max(dim=0)

        if self.allow_low_quality_matches:
            all_matches = matches.clone()
            
        below_low_threshold = matched_vals < self.low_threshold
        between_thresholds = (matched_vals >= self.low_threshold) & ( matched_vals < self.high_threshold )

        matches[below_low_threshold] = Matcher.BELOW_LOW_THRESHOLD
        matches[between_thresholds] = Matcher.BETWEEN_THRESHOLDS
        
        if self.allow_low_quality_matches:
            self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)

        return matches

    def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
        highest_quality_foreach_gt, hoge = match_quality_matrix.max(dim=1)
        gt_pred_pairs_of_highest_quality = torch.nonzero( match_quality_matrix == highest_quality_foreach_gt[:, None] )

        pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1]
        matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
        

In [114]:
max_tresh = .7
min_tresh = .3
match = Matcher(max_tresh, min_tresh, allow_low_quality_matches=True)

mqm_ids = match(mqm)
matched_gt_boxes = anchors[mqm_ids.clamp(min=0)]

print(mqm_ids)
print(l)
bg_ind = mqm_ids < 0
matched_label = l[mqm_ids]
matched_label[bg_ind] = 0
print(matched_label)

tensor([ 8, -1, -2, -2, -2, -2,  6, 11, -2,  4,  1, -2, -2, -2, -2, -2,  2, -2,
         9, -2])
tensor([ 8, 18,  2, 15,  3,  2, 12, 12, 13,  2,  7, 17])
tensor([13,  0,  0,  0,  0,  0, 12, 17,  0,  3, 18,  0,  0,  0,  0,  0,  2,  0,
         2,  0])


In [55]:
hoge = torch.FloatTensor([[0.8508, 0.7609, 0.8174, 0.6254, 0.7831, 0.7430, 0.7154, 0.3022, 0.4388,
        0.6559, 0.6186, 0.5114], [0.8508, 0.7609, 0.8174, 0.6254, 0.7831, 0.7430, 0.7154, 0.3022, 0.4388,
        0.6559, 0.6186, 0.5114]])
hoge[:, ] # 転置処理

tensor([[0.8174, 0.6254, 0.7831, 0.7430, 0.7154, 0.3022, 0.4388, 0.6559, 0.6186,
         0.5114],
        [0.8174, 0.6254, 0.7831, 0.7430, 0.7154, 0.3022, 0.4388, 0.6559, 0.6186,
         0.5114]])

# compute loss

In [148]:
def log_sum_exp(x):
    """Utility function for computing log_sum_exp while determining
    This will be used to determine unaveraged confidence loss across
    all examples in a batch.
    Args:
        x (Variable(tensor)): conf_preds from conf layers
    """
    x_max = x.data.max()
    return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max

In [275]:
    def compute_loss(preds_loc, preds_conf, match_gt_box, match_gt_label, negapos_ratio=3, device='cpu'):
        pos_anchors = match_gt_label > 0 # [B, A]
        num_batch = preds_loc.shape[0] # B
        num_anchor = preds_loc.shape[1]

        # loss loc
        loc_pos_idx = pos_anchors.unsqueeze(pos_anchors.dim()).expand_as(preds_loc) # [B, A, 4]
        loss_loc = F.smooth_l1_loss(preds_loc[loc_pos_idx].view(-1, 4), match_gt_box[loc_pos_idx].view(-1, 4)) # [BxA, 4], [BxA, 4]

        # loss conf
        num_pos = pos_anchors.long().sum(dim=1, keepdim=True)
        num_hard_nega = negapos_ratio * num_pos
        
        cross_entropy = nn.CrossEntropyLoss(reduce=False)
        loss_conf_all = cross_entropy(preds_conf.view(-1, 21), match_gt_label.view(-1))
        loss_conf_all = loss_conf_all.view(num_batch, -1)
        
        #Hard Negative Mining
        loss_conf_pos = loss_conf_all[pos_anchors].sum()        
        loss_conf_neg = loss_conf_all.clone()
        loss_conf_neg[pos_anchors] = 0
        loss_conf_neg, _ = loss_conf_neg.sort(1, descending=True)
        
        hardness_rank = torch.LongTensor(range(num_anchor)).unsqueeze(0).expand_as(loss_conf_neg).to(device)
        hard_negas = hardness_rank < num_hard_nega.unsqueeze(1)
        print(hard_negas.shape, loss_conf_neg.shape)
        
        loss_conf_hard_neg = loss_conf_neg[hard_negas].sum()
        
        # total loss
        N = num_pos.sum()
        loss_conf = (loss_conf_neg + loss_conf_pos) / N
        loss_loc /= N
        
        return loss_loc, loss_conf


In [276]:
# test_load
conf_p = torch.FloatTensor(5, 15, 21).random_(0, 100) / 100
loc_delta = torch.FloatTensor(5, 15, 4).random_(0, 200) / 200
loc_delta[:, :, 2:] += loc_delta[:, :, :2]

conf_t = torch.LongTensor(5, 15).random_(0, 21)
loc_t = torch.FloatTensor(5, 15, 4).random_(0, 200) / 200
loc_t[:, :, 2:] += loc_t[:, :, :2]

print(loc_t.shape, loc_delta.shape)
print(conf_t.shape, conf_p.shape)
loc, conf = compute_loss(loc_delta, conf_p, loc_t, conf_t)

torch.Size([5, 15, 4]) torch.Size([5, 15, 4])
torch.Size([5, 15]) torch.Size([5, 15, 21])
torch.Size([5, 5, 15]) torch.Size([5, 15])




IndexError: too many indices for tensor of dimension 2

In [266]:
pos_anchors = torch.Tensor(10, 4).random_(0, 2) # [B, A]
num_anchor = pos_anchors.shape[0]
hoge = range(num_anchor)
print(hoge)

range(0, 10)
