<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Data" data-toc-modified-id="Data-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Data</a></span></li><li><span><a href="#Code" data-toc-modified-id="Code-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Code</a></span><ul class="toc-item"><li><span><a href="#RM" data-toc-modified-id="RM-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>RM</a></span></li></ul></li><li><span><a href="#Train" data-toc-modified-id="Train-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Train</a></span><ul class="toc-item"><li><span><a href="#Train" data-toc-modified-id="Train-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Train</a></span></li><li><span><a href="#View-Results" data-toc-modified-id="View-Results-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>View Results</a></span></li></ul></li><li><span><a href="#Stepping-Through-a-Batch" data-toc-modified-id="Stepping-Through-a-Batch-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Stepping Through a Batch</a></span><ul class="toc-item"><li><span><a href="#RM" data-toc-modified-id="RM-4.1"><span class="toc-item-num">4.1&nbsp;&nbsp;</span>RM</a></span><ul class="toc-item"><li><span><a href="#Init" data-toc-modified-id="Init-4.1.1"><span class="toc-item-num">4.1.1&nbsp;&nbsp;</span>Init</a></span></li><li><span><a href="#Get-batch,-acts,-item" data-toc-modified-id="Get-batch,-acts,-item-4.1.2"><span class="toc-item-num">4.1.2&nbsp;&nbsp;</span>Get batch, acts, item</a></span></li><li><span><a href="#mAP" data-toc-modified-id="mAP-4.1.3"><span class="toc-item-num">4.1.3&nbsp;&nbsp;</span>mAP</a></span></li><li><span><a href="#ssd-item-loss" data-toc-modified-id="ssd-item-loss-4.1.4"><span class="toc-item-num">4.1.4&nbsp;&nbsp;</span>ssd item loss</a></span></li><li><span><a href="#lbl-loss" data-toc-modified-id="lbl-loss-4.1.5"><span class="toc-item-num">4.1.5&nbsp;&nbsp;</span>lbl loss</a></span></li></ul></li></ul></li></ul></div>

**~ My Code ~**

# Data

In [None]:
### Imports & Paths ###
from fastai.vision.all import *
import pandas as pd


def random_seed(s, use_cuda):
    #Also, remember to use num_workers=0 when creating the DataBunch
    np.random.seed(s)
    torch.manual_seed(s)
    random.seed(s)
    if use_cuda:
        torch.cuda.manual_seed(s)
        torch.cuda.manual_seed_all(s)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False      
random_seed(42,True)


### Params ###
im_sz   = 224
bs      = 64
val_pct = .2
sub_pct = 1
path = untar_data(URLs.PASCAL_2007)
annos_path = path/'train.json'
ims_path = path/'train'


### Items ###
fns, annos = get_annotations(annos_path)
fn2anno = {f:a for f,a in zip(fns,annos)}
def get_im(f):   return ims_path/f
def get_bbox(f): return fn2anno[f][0]
def get_lbl(f):  return fn2anno[f][1]


### DataLoaders ###
itfms = Resize(im_sz, method='squish')
btfms = setup_aug_tfms([Rotate(), Brightness(), Contrast(), Flip(),
                       Normalize.from_stats(*imagenet_stats)])
db = DataBlock(
    blocks=[ImageBlock, BBoxBlock, BBoxLblBlock(add_na=False)],
    get_x=get_im, get_y=[get_bbox, get_lbl], n_inp=1,
    splitter=RandomSplitter(val_pct),
    item_tfms=itfms, batch_tfms=btfms)
# subset = L(fns).shuffle()[0:int(len(fns)*sub_pct)]
# dls = db.dataloaders(subset, bs=bs)
dls = db.dataloaders(fns, bs=bs)
dls.v = dls.vocab
dls.ncls = len(dls.vocab)

In [None]:
### Inspection ###
print("Vocab:", dls.v)
print("Size of train data:",len(dls.train.items))
print("Size of valid data:",len(dls.valid.items))
for i,t in enumerate(dls.one_batch()):
    print(f"batch[{i}]:",'\t',t.dtype,'\t',t.shape)

Vocab: (#20) ['aeroplane','bicycle','bird','boat','bottle','bus','car','cat','chair','cow'...]
Size of train data: 2001
Size of valid data: 500
batch[0]: 	 torch.float32 	 torch.Size([64, 3, 224, 224])
batch[1]: 	 torch.float32 	 torch.Size([64, 24, 4])
batch[2]: 	 torch.int64 	 torch.Size([64, 24])


Interpretation of tensor shapes:
- torch.Size([128, 3, 224, 224]): bs, channels (rgb), im_sz, im_sz
- torch.Size([128, 20, 4]): bs, max objs for a single im in batch, bb coords
- torch.Size([128, 20]): bs, max objs for a single im in batch

# Code

## RM

In [None]:
### Anchors ###
def create_anchors(subdivs, zooms, ratios, device='cuda'):
    # create list of permutations per default anchor box
    perms = [(z*r1,z*r2) for z in zooms for (r1,r2) in ratios]
    k = len(perms)
    offsets = [1/(sd*2) for sd in subdivs]
    xs = np.concatenate([np.tile(  np.linspace(o,1-o,sd),sd) for o,sd in zip(offsets,subdivs)])
    ys = np.concatenate([np.repeat(np.linspace(o,1-o,sd),sd) for o,sd in zip(offsets,subdivs)])
    ctrs = np.repeat(np.stack([xs,ys], axis=1), k, axis=0)
    hws = np.concatenate([np.array([[o/sd,p/sd] for i in range(sd*sd) for o,p in perms]) for sd in subdivs])
    box_sizes = tensor(np.concatenate([np.array([1/sd for i in range(sd*sd) for o,p in perms])
                                      for sd in subdivs]), requires_grad=False).unsqueeze(1)
    anchors = tensor(np.concatenate([ctrs, hws], axis=1), requires_grad=False).float()
    return anchors.to(device), box_sizes.to(device)

def hw2pp(ctr, hw):
    return torch.cat([ctr-hw/2, ctr+hw/2], dim=1)

In [None]:
### Architecture ###
def flatten_conv(x,k):
    bs,nf,gx,gy = x.size()
    return x.permute(0,2,3,1).contiguous().view(bs,-1,nf//k)

class StdConv(Module):
    """Wraps together the standard conv2d→ batchnorm→ dropout."""
    def __init__(self, nin, nout, stride=2, drop=0.1):
        super().__init__()
        self.conv = nn.Conv2d(nin, nout, 3, stride=stride, padding=1)
        self.bn = nn.BatchNorm2d(nout)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        return self.drop(self.bn(F.relu(self.conv(x))))

class OutConv(Module):
    """Outputs two sets of acts: one for bbs, one for lbls."""
    def __init__(self, k, nin, bias):
        super().__init__()
        self.k = k
        self.bbs  = nn.Conv2d(nin,            4*k, 3, padding=1) # bbs
        self.lbls = nn.Conv2d(nin, (dls.ncls+1)*k, 3, padding=1) # lbls
        self.lbls.bias.data.zero_().add_(bias)
    
    def forward(self, x):
        return [flatten_conv(self.bbs(x),  self.k), # bbs,lbls
                flatten_conv(self.lbls(x), self.k)]         

class SSDHead(Module):
    """Wraps StdConv and OutConv into a head module.
       Defaults to resnet34 backbone."""
    def __init__(self, k, bias, drop=0.4, body='resnet34'):
        super().__init__()
        test(body, ['resnet34','resnet50'], operator.in_)
        self.body  = body
        self.drop  = nn.Dropout(drop)
        self.conv0 = StdConv( 512, 256, drop=drop)
        self.conv1 = StdConv( 256, 256, drop=drop)
        self.conv2 = StdConv( 256, 256, drop=drop)
        self.out0  = OutConv(k, 256, bias)
        self.out1  = OutConv(k, 256, bias)
        self.out2  = OutConv(k, 256, bias)
        self.re_sz = StdConv(2048, 512, stride=1)
        
    def forward(self, x):
        if self.body == 'resnet34': x = F.relu(x)
        x = self.drop(x)
        if self.body == 'resnet50': x = self.re_sz(x)
        x = self.conv0(x)
        bb0,lbl0 = self.out0(x)
        x = self.conv1(x)
        bb1,lbl1 = self.out1(x)
        x = self.conv2(x)
        bb2,lbl2 = self.out2(x)
        return [torch.cat([ bb0, bb1, bb2], dim=1),
                torch.cat([lbl0,lbl1,lbl2], dim=1)]

class CustMod(Module):
    """A module made from a pretrained body and an untrained head."""
    def __init__(self, body, head):
        self.body, self.head = body, head
        
    def forward(self, x):
        return self.head(self.body(x))

In [None]:
### FocalLoss ###
def one_hot_embedding(lbls, ncls, device='cuda'):
    return torch.eye(ncls)[lbls.data].to(device)

class BCELoss(nn.Module):
    def __init__(self, ncls, device='cuda'):
        super().__init__()
        self.ncls = ncls
        self.device = device

    def forward(self, acts, targs):
        t = one_hot_embedding(targs, self.ncls+1, self.device)
        t = tensor(t[:,:-1].contiguous())
        a = acts[:,:-1]
        w = self.get_weight(a,t).detach()
        return F.binary_cross_entropy_with_logits(a,t,w,reduction='sum')/self.ncls
    
    def get_weight(self,a,t): return None 
    
class FocalLoss(BCELoss):
    def get_weight(self, a, t):
        alpha, gamma = 0.25, 2.0 # vals from paper
        p = a.sigmoid()
        pt = p*t + (1-p)*(1-t)
        w = alpha*t + (1-alpha)*(1-t)
        return w * (1-pt).pow(gamma)

In [None]:
### IoU ###
def intersxn(b1,b2):
    x1 = torch.max((b1)[:,None,0], (b2)[None,:,0])
    y1 = torch.max((b1)[:,None,1], (b2)[None,:,1])
    x2 = torch.min((b1)[:,None,2], (b2)[None,:,2])
    y2 = torch.min((b1)[:,None,3], (b2)[None,:,3])
    return torch.clamp((x2-x1), min=0) * torch.clamp((y2-y1), min=0)

def area(b):
    return (b[:,2]-b[:,0]) * (b[:,3]-b[:,1])

def get_iou(b1, b2):
    inter = intersxn(b1,b2)
    union = area(b1).unsqueeze(1) + area(b2).unsqueeze(0) - inter
    return inter / union

In [None]:
### ssd_loss ###
def remove_padding(bb, lbl, rescale=True):
    z = (bb[:,2]-bb[:,0])==0
    return ((1+bb[~z])/2,lbl[~z]) if rescale else (bb[~z],lbl[~z])

def get_pred_bbs(abbs, ancs, anc_sz, device):
    act_bbs = torch.tanh(abbs)
    ctrs = ancs.to(device)[:,:2] + (act_bbs.to(device)[:,:2]/2 * anc_sz.to(device))
    hws  = ancs.to(device)[:,2:] * (act_bbs.to(device)[:,2:]/2+1)        
    return hw2pp(ctrs, hws)

def map_to_gt(ious):
    max_iou_per_bb, anc_idxs = ious.max(1)
    max_iou_per_anc, bb_idxs = ious.max(0)
    max_iou_per_anc[anc_idxs] = 1.99
    for i,iou in enumerate(anc_idxs): bb_idxs[iou] = i
    return max_iou_per_anc, bb_idxs

def ssd_item_loss(act_bbs, act_lbls, bbs, lbls, device='cuda'):
    """SSD item loss takes single items from a minibatch, creates hundreds of preds, maps gt
       to the preds, prunes the preds, then calcs & returns the bb and lbl loss for that item."""
   # prep
    bbs,lbls = remove_padding(bbs,lbls)                      # remove gt padding inserted during training
    pred_bbs = get_pred_bbs(act_bbs,anchors,box_size,device) # make 196 pred bbs from acts and ancs
    # map gt to preds
    iou_gt_grid = get_iou(bbs.data, anchor_boxes.data)       # get iou(gt_bbs,anc_bbs); used to map gt → ancs
    iou_gt_preds, mapped_gt_idx = map_to_gt(iou_gt_grid)     # assign each pred an index of a gt object
    mapped_bbs  = bbs[mapped_gt_idx]                         # project gt bbs into pred space
    mapped_lbls = lbls[mapped_gt_idx]                        # project gt lbls into pred space
    # remove low-iou bb preds & set mapped lbl to bg
    high_iou = iou_gt_preds > 0.4                            # only include bb preds that overlap w/a gt obj and
    incl = torch.nonzero(high_iou)[:,0]                      #  are not predicting background
    mapped_lbls[~high_iou] = dls.ncls                        # assign gt class of bg to preds w/ low max gt iou
    # compute loss
    bb_res  = F.l1_loss(pred_bbs[incl], mapped_bbs[incl])
    lbl_res = loss_f(act_lbls, mapped_lbls)
    return bb_res, lbl_res

def ssd_loss(acts, bbs, lbls, device='cuda', print_it=False):
    bb_sum, lbl_sum = 0., 0.
    for o in zip(*acts, bbs, lbls):
        bb_loss, lbl_loss = ssd_item_loss(*o, device)
        bb_sum  += bb_loss
        lbl_sum += lbl_loss
    if print_it: print(f"bb:{bb_sum:.02f} | lbl: {lbl_sum:.02f}")
    return bb_sum + lbl_sum

###### Metric & Results

In [None]:
### nms ###
def nms(boxes, scores, iou_thresh, top_k=100):
    keep = scores.new(scores.size(0)).zero_().long()
    if boxes.numel() == 0: return keep
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    area = torch.mul(x2 - x1, y2 - y1)
    v, idx = scores.sort(0)  # sort asc
    idx = idx[-top_k:]       # indices of k largest vals
    xx1 = boxes.new()
    yy1 = boxes.new()
    xx2 = boxes.new()
    yy2 = boxes.new()
    w = boxes.new()
    h = boxes.new()

    count = 0
    while idx.numel() > 0:
        i = idx[-1]  # index of current largest val
        keep[count] = i
        count += 1
        if idx.size(0) == 1: break
        idx = idx[:-1]  # remove kept element from view
        # load bboxes of next highest vals
        torch.index_select(x1, 0, idx, out=xx1)
        torch.index_select(y1, 0, idx, out=yy1)
        torch.index_select(x2, 0, idx, out=xx2)
        torch.index_select(y2, 0, idx, out=yy2)
        # store element-wise max with next highest score
        xx1 = torch.clamp(xx1, min=x1[i])
        yy1 = torch.clamp(yy1, min=y1[i])
        xx2 = torch.clamp(xx2, max=x2[i])
        yy2 = torch.clamp(yy2, max=y2[i])
        w.resize_as_(xx2)
        h.resize_as_(yy2)
        w = xx2 - xx1
        h = yy2 - yy1
        # check sizes of xx1 and xx2.. after each iteration
        w = torch.clamp(w, min=0.0)
        h = torch.clamp(h, min=0.0)
        inter = w*h
        # IoU = i / (area(a) + area(b) - i)
        rem_areas = torch.index_select(area, 0, idx)  # load remaining areas)
        union = (rem_areas - inter) + area[i]
        IoU = inter/union  # store result in iou
        # keep only elements with an IoU <= iou_thresh
        idx = idx[IoU.le(iou_thresh)]
    return keep, count

In [None]:
### get_batch_preds ###
def acts_to_preds(abb, albl, ancs, anc_sz, iou_thresh, conf_thresh, device):
    """Turn model acts into preds: abbs use get_pred_bbs, and albls use sigmoid().max().
       Used in ResultShower and mAP (and could possibly be used in loss fxn)."""
     # convert acts to preds
    pbb = get_pred_bbs(abb, ancs, anc_sz, device)
    conf, plbl = albl.sigmoid().max(1)
    # filter out preds w/ nms
    nms_idxs, nms_n = nms(pbb.data, conf, iou_thresh)
    nms_idxs = nms_idxs[:nms_n]
    pbb  = pbb[nms_idxs]
    plbl = plbl[nms_idxs]
    conf = conf[nms_idxs]
    # filter out bg and low-conf preds
    is_not_bg = (plbl!=20)
    is_confident = conf > conf_thresh
    mask = is_not_bg & is_confident
    return 2*pbb[mask]-1, plbl[mask], conf[mask]

def get_batch_preds(abb, albl, ancs, anc_sz, iou_thresh=.5, conf_thresh=.3, device='cpu'):
    """Loop through a batch and of activations and turn them into predictions."""
    ancs.to(device); anc_sz.to(device)
    pbbs, plbls, confs = [], [], []
    for abb, albl in zip(abb, albl):
        pbb, plbl, conf = acts_to_preds(abb,albl,ancs,anc_sz,iou_thresh,conf_thresh,device)
        pbbs  += [pbb]
        plbls += [plbl]
        confs += [conf]
    return pbbs, plbls, confs

In [None]:
### Metric ###
def format_inps(acts, batch, anchors, box_size, iou=.5, conf=.3, device='cuda'):
    """Format acts and targs for AP score calc. Input expects learner.acts & learner.batch,
       output format: (im_idx, pred_bbs, pred_cls, cls_conf) and (im_idx, bbs, cls).
       Ex: (46.0, tensor([0.1, 0.2, 0.9, 0.9]), tensor(3), tensor(0.78))"""
    
    preds = get_batch_preds(acts[0].data, acts[1].data, anchors, box_size, iou, conf, device)
    p_idxs  = torch.cat([torch.tensor([i]*len(o)) for i,o in enumerate(preds[0])]).numpy().tolist()
    batch_preds = list(zip(p_idxs, *[torch.cat(o) for o in preds]))
    
    unp = [remove_padding(b,l,False) for b,l in zip(batch[1],batch[2])]
    unp = [o[0] for o in unp], [o[1] for o in unp]
    unp_flat = torch.cat(unp[0]), torch.cat(unp[1])
    gt_idxs = torch.cat([torch.tensor([i]*len(o)) for i,o in enumerate(unp[0])]).numpy().tolist()
    batch_gts = list(zip(gt_idxs, *unp_flat))
    return batch_preds, batch_gts

def flatten_list(l, ret_L=False):
    """Flatten a list-of-lists; lists can be python `list`s or a fastai `L`s."""
    def _recur(l,res):
        for o in l:
            if   isinstance(o,list): _recur(o,res)
            elif isinstance(o,L)   : _recur(o,res)
            else: res.append(o)
        return res
    res = _recur(l, [])
    return res if not ret_L else L(res)

def _get_tp_bbs(preds_tp):
    """Output list of tp bbs per im in batch. Not used to calculate mAP; only
       used to grab true positive bb preds for visualizing in ResultShower."""
    # Each row in preds_tps is a formatted pred (see format_ap_inputs) and a list of
    # 1s and 0s (signifying tps and fps) for a cls. No preds for a cls → empty lists.
    batch_idxs, pred_bbs, tpfps = [], [], []
    for preds,tp in preds_tp:
        batch_idxs.append([o[0] for o in preds])
        pred_bbs.append([o[1] for o in preds])
        tpfps.append(tp)
    flat_idxs  = flatten_list(batch_idxs)
    flat_bbs   = flatten_list(pred_bbs)
    flat_tpfps = torch.cat(tpfps)

    scored_preds = list(zip(flat_idxs, flat_bbs, flat_tpfps))
    true_bbs = [(int(o[0]), o[1]) for o in scored_preds if o[2]==True]

    true_preds = [torch.zeros(4).view(1,4) for i in range(0,bs)]
    for i,bb in true_bbs:
        if true_preds[i].sum()==0: true_preds[i] = bb.view(1,4)
        else: true_preds[i] = torch.cat([true_preds[int(i)], bb.view(1,4)], dim=0)
    return true_preds

def agg_ten(ten,agg,ifempty=0):
    return ifempty if ten.shape[0]==0 else agg(ten).item()

def ap_per_cls(batch_preds, batch_gts):
    """Calculate AP score per class. Returns AP for each class."""
    avg_precs,preds_out,tally,ntps,nfps = [],[],[],[],[]
    for c in range(dls.ncls): # start at 1 to ignore gt
        # store preds and gts for current cls
        preds = [b for b in batch_preds if b[2]==c]
        gts   = [b for b in batch_gts   if b[2]==c]
                
        # sort preds by conf desc
        preds.sort(key=lambda x: x[3], reverse=True)
        
        # make dict of im_idx:zeros(n_objs)
        n_objs = Counter([gt[0] for gt in gts])
        for k,v in n_objs.items(): n_objs[k] = torch.zeros(v)
        
        # init tp: a bool tensor for each im s.t. 1s indicate a pred is a tp
        tp = torch.zeros((len(preds))).bool()
        total_gt_objs = len(gts)
        
        for pred_idx, pred in enumerate(preds):
            gt_objs = [o for o in gts if o[0] == pred[0]]
            n_gt_objs = len(gt_objs)
            max_iou = 0
            
            for idx, gt in enumerate(gt_objs):
                iou = get_iou(pred[1].view(1,4), gt[1].view(1,4))
                if iou > max_iou: max_iou, idx_of_max = iou, idx
                    
            # update idx of gt_obj to indicate it's been used
            if max_iou > iou_thresh:
                if n_objs[pred[0]][idx_of_max]==0:
                    tp[pred_idx] = 1
                    n_objs[pred[0]][idx_of_max] = 1
        
        # calc avg_prec and store
        tps = torch.cumsum(tp, dim=0)               # 1. tp csum: [0,1,1,0,0] → [0,1,2,2,2]
        fps = torch.cumsum(~tp, dim=0)              # (basically same steps for fp)
        prec = torch.div(tps, (tps + fps + 1e-6))   # 2. divide each tps item by n_preds
        prec = torch.cat((torch.tensor([1]), prec)) # 3. slap on a 1 at the beginning
        rec = tps / (total_gt_objs + 1e-6)
        rec = torch.cat((torch.tensor([0]), rec))
        avg_prec = torch.trapz(prec, rec) # calc AP w/ trap rule
        avg_precs.append(avg_prec)        # store AP of this cls in accum
        
        # store data for reporting
        preds_out.append(([preds,tp]))
        tally.append(n_objs) #needs processing
        ntps.append(agg_ten(tps,max))
        nfps.append(agg_ten(fps,max))
        
    fn_bools = [[~dct[k].bool() for k in dct] for dct in tally]
    fns = [sum([agg_ten(t,sum) for t in tl]) for tl in fn_bools]
    return avg_precs, preds_out, (ntps,nfps,fns)

def get_ap_scores(dls, model, ancs, anc_sz, iou, conf, device='cpu'):
    for o in [dls, model, anchors, box_size]: o.to(device)
    model.eval()

    aps,tps,tallies = [],[],[]
    for b in dls.valid:
        ap,tp,tally = ap_per_cls(model(b[0]),b,dls.ncls,ancs,anc_sz,iou,conf,device)
        aps.append(ap)
        tps.append(tp)
        tallies.append(tally)
    ap_scores = torch.stack([tensor(o) for o in res]).sum(axis=0)/len(res)
    ap_scores = ap_scores.numpy().tolist()
    return sum(ap_scores)/len(ap_scores), ap_scores

class MeanAveragePrecision(Metric):
    def __init__(self, func):
        self.func = func
    def reset(self):
        self.res = []
    def accumulate(self, learn):
        is_last_epoch = learn.epoch==learn.n_epoch-1
        if is_last_epoch:
            cls_aps,_,_ = self.func(learn.pred,(*learn.xb,*learn.yb),dls.ncls,anchors,box_size)
            self.res.append(cls_aps)
    @property
    def value(self):
        if self.res==[]:
            return 0
        else:
            ap_scores = torch.stack([tensor(o) for o in self.res]).sum(axis=0)/len(self.res)
            return sum(ap_scores)/len(ap_scores)
    @property
    def name(self):
        return "mAP"

In [None]:
# # get_fresh_mod
# def get_fresh_mod(ncls=dls.ncls, k=k, bias=-4., drop=.4, body='resnet34', device='cuda'):
#     test(body, ['resnet34','resnet50'], operator.in_)
#     arch = resnet34 if body=='resnet34' else resnet50
#     return CustMod(create_body(arch, pretrained=True), SSDHead(ncls, k, bias, drop, body))

In [None]:
# lil' helprs
def pad_strs(strs):
    nchars = [len(s) for s in strs]
    maxn = max(nchars)
    nspc = [maxn-n for n in nchars]
    return [s+' '*n for s,n in zip(strs,nspc)]
def print_lists(*lists):
    for l in [*lists]: test_eq(len(lists[0]), len(l))        # test lens eq
    pstrs = [pad_strs([str(o) for o in l]) for l in lists]   # get list of padding strs
    for zpstr in list(zip(*pstrs)): print(' | '.join(zpstr)) # print rows joined with ' | '
def batch_info(l):
    """Print idx, type, shape for items in l (a list of tensors)"""
    idxs = list(range(len(l)))
    shapes = apply(lambda t:t.shape, l)
    types = apply(type, l)
    print_lists(idxs, shapes, types)

# Train

## Train

In [None]:
# device='cuda'
# subdivs = [4, 2, 1]
# zooms   = [0.75, 1.0, 1.3]
# ratios  = [(1.,1.), (1.,.5), (.5,1)]
# k = len(zooms) * len(ratios)

# anchors, box_size = create_anchors(subdivs, zooms, ratios, device)
# anchor_boxes = hw2pp(anchors[:,:2], anchors[:,2:])
# loss_f = FocalLoss(dls.ncls, device)

# body = create_body(resnet34, pretrained=True)
# head = SSDHead(k, -4., 0.4, 'resnet34')
# mod = CustMod(body, head)

In [None]:
# learner = Learner(dls, mod, loss_func=ssd_loss).to_fp16()
# learner.freeze()
# lr_min, lr_steep = learner.lr_find()

In [None]:
# lr = (lr_min+lr_steep)/2; print("lr:",round(lr,4))

In [None]:
# learner.fit_one_cycle(7, lr=1e-3)

In [None]:
# learner.save('s1')
# learner.export('models/20201215_pascal2007_rory.pkl')

In [None]:
# learner.export('models/20201210_pascal2007.pkl')          # trained for 10 epochs
# learner.export('models/20201210_pascal2007_bad.pkl')      # trained for 5 epochs
# learner.export('models/20201211_pascal2007_focalfix.pkl') # corrected weight in focal loss
# learner.export('models/20201215_pascal2007_rory.pkl')     # corrected bugs in anchor code

## View Results

In [None]:
### Viz Results ###
def show_bb(im, bb=None, lbl=[''], title=None, color='white',
            ctx=None, sz=im_sz, figsize=5):
    # process empties and nones
    if bb.shape[-1]==0 or bb==None: bb  = tensor([[0.,0,0,0]])
    if lbl==['']:                   lbl = ['']*bb.shape[0]
        
    # process tensors to take advantage of fastai show methods
    bbox = TensorBBox((bb+1)*sz//2)
    labeledbbox = LabeledBBox(bbox,lbl)
    
    if ctx:     show_image(im, figsize=[figsize,figsize], title=title, ctx=ctx)
    else: ctx = show_image(im, figsize=[figsize,figsize], title=title)
    
    labeledbbox.show(ctx=ctx)       # first, draw white lbl bbs...
    bbox.show(ctx=ctx, color=color) # ... then overlay color bbs.
    return ctx

def get_im_tpfpfns(res):
    """QnD function to output tpfpfns for each im in a batch given a ResultShower."""
    # TPs are found by counting 
    tps = []
    for bbs in res.tp_bbs:
        if bbs.sum()==0: tps.append(0)
        else: tps.append(bbs.shape[0])
    # FPs are found by subtracing tps from preds_per_im
    preds_per_im = [o.shape[0] for o in res.preds[0]]
    fps = [pred-tp for pred,tp in zip(preds_per_im, tps)]
    # FNs (relies on res.nobj which is a hack)
    # res.nobj is a dict where each key is a cls and the values are some combo of
    #  im_idx and a bool list st. the len of the list is the number of gt_objs for
    #  that class in that im, and the truth value of the bool represents whether
    #  the obj is a TP or FN.
    im_dict = defaultdict(lambda: [])
    for cls in res.nobj:
        for im_idx,tp_tensor in cls.items():
            im_dict[im_idx] += ~tp_tensor.bool()
    idx_fns = [(k,sum(v).item()) for k,v in im_dict.items()]
    idx_fns.sort()
    fns = [o[1] for o in idx_fns]
    return tps, fps, fns

class ResultShower():
    def __init__(self, dls, lrn, ancs, anc_sz, iou, conf):
        # store init's args
        self.dls    = dls
        self.mod    = lrn.model.eval().cpu()
        self.ancs   = ancs.cpu()
        self.anc_sz = anc_sz.cpu()
        self.iou    = iou
        self.conf   = conf
        # compute attrs
        self.batch    = next(iter(self.dls.cpu().valid))
        self.acts     = [a.data for a in self.mod(self.batch[0])]
        self.preds    = get_batch_preds(*self.acts,self.ancs,self.anc_sz,self.iou,self.conf,'cpu')
        self.dec_ims  = self.dls.decode(self.batch)[0]
        self.bs       = self.dls.bs
        self.voc      = self.dls.vocab
        self.im_sz    = self.batch[0].shape[-1]
        self.last_res = 0
        self.fig_sz   = [8,8]
        # compute metrics
        aps,tp_bbs,nobj = ap_per_cls(self.acts,self.batch,len(self.voc),
                                     self.ancs,self.anc_sz,self.iou,self.conf,'cpu')
        self.ap_scores  = [o.item() for o in aps]
        self.tp_bbs     = tp_bbs
        self.nobj       = nobj
        self.im_tpfpfns = get_im_tpfpfns(self)
        # clean up
        self.dls.cuda(); self.mod.cuda()
        
    def __call__(self, *args, **kwargs):
        return self.show_next(*args, **kwargs)
    
    def __getitem__(self, i):
        # get everything to draw
        ims             = self.dec_ims
        _,bbs,lbls      = self.batch
        pbbs,plbls,conf = self.preds
        tbbs            = self.tp_bbs
        tps,fps,fns     = self.im_tpfpfns
        # titles
        t_gt = f"Idx {i} (nobj={(lbls[i] > 0).sum()})"
        t_p  = f"TPs:{tps[i]} | FPs:{fps[i]} | FNs:{fns[i]}"
        # two ctx: gts and all preds. lime bbs drawn over TP preds.
        ctx = get_grid(2, figsize=self.fig_sz)
        show_bb(ims[i],bbs[i], self.voc[lbls[i]], t_gt,'white', ctx[0], self.im_sz)
        show_bb(ims[i],pbbs[i],self.voc[plbls[i]],t_p, 'magenta',ctx[1],self.im_sz)
        show_bb(ims[i],tbbs[i],color='lime',ctx=ctx[1])
                 
    def show_next(self, n=1):
        for i in range(n): self[(i + self.last_res)%self.bs]
        self.last_res += n

In [None]:
# res = ResultShower(dls, learner, anchors, box_size, .5, .30)
# res(8)

In [None]:
# def test_confs(confs):
#     for conf in confs:
#         final,_ = get_ap_scores(dls, learner.model, anchors.cpu(), box_size.cpu(), .5, conf, 'cpu')
#         print(f"mAP for conf {conf:.2f}: {round(final,3)}")
# test_confs([round(.05*x,2) for x in range(1,11)])

In [None]:
# # show ap per cls, desc
# sum_scores, scores = get_ap_scores(dls,learner.model,anchors.cpu(),box_size.cpu(),.5,.3,'cpu')
# print("mAP:",round(sum_scores,3))
# pd.DataFrame({'Class':dls.v, 'AP':[round(o,3) for o in scores]}).sort_values('AP',ascending=False)

# Stepping Through a Batch

## RM

### Init

In [None]:
# Setup
device='cpu'

subdivs = [4, 2, 1]
zooms   = [0.75, 1.0, 1.3]
ratios  = [(1.,1.), (1.,.5), (.5,1)]
k = len(zooms) * len(ratios)
anchors, box_size = create_anchors(subdivs, zooms, ratios, device)
anchor_boxes = hw2pp(anchors[:,:2], anchors[:,2:])
loss_f = FocalLoss(dls.ncls, device)

### Get batch, acts, item

In [None]:
# # Get batch, acts
# body = create_body(resnet34, pretrained=True)
# head = SSDHead(k, -4., 0.4, 'resnet34')
# mod = CustMod(body, head)
# mod.eval()

In [None]:
learner = load_learner('models/20201215_pascal2007_rory.pkl')

In [None]:
mod = learner.model.cpu().eval()

In [None]:
batch = next(iter(dls.cpu().valid))
acts = mod(batch[0])
bbs, lbls = batch[1], batch[2]


# Get acts and targs for a single im
b_idx = 0
abb,albl,bb,lbl = list(zip(*acts,bbs,lbls))[b_idx]

In [None]:
abb.shape,albl.shape

(torch.Size([189, 4]), torch.Size([189, 21]))

In [None]:
bb,lbl

(tensor([[-0.9177, -0.7167,  0.5527,  0.9417],
         [-0.5219, -0.7125,  0.2134,  0.1583],
         [ 0.0437,  0.1958,  0.5321,  0.7083],
         [-0.0026,  0.1167,  0.3728,  0.4792],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]),
 tensor([12, 14, 12, 14,  0,  0,  0,  0,  0,  0,  0,  0]))

### mAP

In [None]:
iou_thresh=.5
conf_thresh=.3

###### format inputs

In [None]:
# get batch of preds (pred_bbs, pred_lbls, pred_confs)
preds = get_batch_preds(acts[0].data, acts[1].data, anchors, box_size, iou_thresh, conf_thresh, device)

In [None]:
# out of interest: how many preds?
n_preds = torch.cat(preds[0]).shape[0]
n_preds

135

In [None]:
unp = [remove_padding(b,l,False) for b,l in zip(batch[1],batch[2])]
unp = [o[0] for o in unp], [o[1] for o in unp]
unp_flat = torch.cat(unp[0]), torch.cat(unp[1])

In [None]:
# get list of im idxs containing each pred (used in mAP algorithm)
p_idxs  = torch.cat([torch.tensor([i]*len(o)) for i,o in enumerate(preds[0])]).numpy().tolist()
gt_idxs = torch.cat([torch.tensor([i]*len(o)) for i,o in enumerate(unp[0])]).numpy().tolist()

In [None]:
batch_preds = list(zip(p_idxs, *[torch.cat(o) for o in preds]))
batch_preds

[(0.0,
  tensor([-0.3964, -0.6585,  0.4472,  0.7216]),
  tensor(14),
  tensor(0.4489)),
 (0.0,
  tensor([-0.2058, -0.7428,  0.3694,  0.2123]),
  tensor(14),
  tensor(0.4188)),
 (1.0,
  tensor([-0.8118,  0.0033, -0.5204,  0.3782]),
  tensor(14),
  tensor(0.3651)),
 (2.0,
  tensor([-0.9212, -1.0978,  0.9875,  0.9936]),
  tensor(7),
  tensor(0.6503)),
 (2.0,
  tensor([-0.8433, -0.6127,  1.1810,  0.3928]),
  tensor(7),
  tensor(0.3050)),
 (3.0,
  tensor([-0.1082,  0.2558,  0.5228,  0.9645]),
  tensor(8),
  tensor(0.4751)),
 (3.0,
  tensor([-0.4371,  0.2209,  0.1259,  1.0206]),
  tensor(8),
  tensor(0.4260)),
 (3.0,
  tensor([-0.0981, -0.0264,  0.4014,  0.5821]),
  tensor(8),
  tensor(0.3560)),
 (3.0,
  tensor([-0.3673,  0.0883, -0.0183,  0.6392]),
  tensor(8),
  tensor(0.3467)),
 (3.0,
  tensor([-0.6955,  0.4008, -0.2635,  0.9906]),
  tensor(8),
  tensor(0.3260)),
 (3.0,
  tensor([-0.9313,  0.1074, -0.3092,  0.9473]),
  tensor(8),
  tensor(0.3188)),
 (3.0, tensor([0.1617, 0.1668, 0.8498, 0

In [None]:
batch_gts = list(zip(gt_idxs, *unp_flat))
batch_gts

[(0, tensor([-0.9177, -0.7167,  0.5527,  0.9417]), tensor(12)),
 (0, tensor([-0.5219, -0.7125,  0.2134,  0.1583]), tensor(14)),
 (0, tensor([0.0437, 0.1958, 0.5321, 0.7083]), tensor(12)),
 (0, tensor([-0.0026,  0.1167,  0.3728,  0.4792]), tensor(14)),
 (1, tensor([-0.9000, -0.0190, -0.5960,  0.3688]), tensor(11)),
 (1, tensor([-0.5640,  0.0570, -0.3640,  0.3840]), tensor(11)),
 (1, tensor([ 0.0040, -0.0342,  0.2800,  0.3004]), tensor(11)),
 (1, tensor([-1.0000, -0.3840, -0.8680,  0.2928]), tensor(14)),
 (2, tensor([-0.4640, -0.9893,  0.9880,  1.0000]), tensor(7)),
 (3, tensor([0.4720, 0.0301, 0.7440, 0.6024]), tensor(8)),
 (3, tensor([-0.2480,  0.2831,  0.1080,  1.0000]), tensor(8)),
 (3, tensor([-0.5960,  0.1446, -0.2760,  0.9337]), tensor(8)),
 (3, tensor([-0.4960,  0.1988,  0.4480,  0.9277]), tensor(10)),
 (4, tensor([-0.8040, -0.8187,  1.0000,  0.9147]), tensor(6)),
 (5, tensor([ 0.0259, -0.4200,  0.7752,  0.7400]), tensor(1)),
 (5, tensor([-0.7406, -0.3200,  0.0432,  0.6320]), ten

###### get AP per cls (one batch)

In [None]:
batch_preds, batch_gts = format_inps(acts, batch, anchors, box_size, iou=.5, conf=.3, device='cpu')
aps, pred_bbs, tpfpfns = ap_per_cls(batch_preds, batch_gts)
sum(aps) / len(aps)

tensor(0.1312)

In [None]:
#tpfpfns contains three lists: tps, fps, and fns per class
tpfpfns

([2, 2, 1, 0, 0, 1, 6, 4, 3, 0, 1, 4, 0, 1, 13, 1, 0, 0, 0, 0],
 [0, 2, 1, 1, 1, 0, 16, 3, 14, 0, 0, 8, 1, 1, 41, 3, 0, 0, 4, 0],
 [4, 3, 4, 3, 3, 3, 10, 3, 17, 0, 4, 17, 6, 6, 41, 1, 0, 5, 4, 3])

In [None]:
test_eq(n_preds, sum(tpfpfns[0])+sum(tpfpfns[1])) # equals the number of preds i found before

In [None]:
####################
### stopped here ### :3
####################

### ssd item loss

In [None]:
# Item Loss
#assignments
act_bbs, act_lbls, bbs, lbls = abb,albl,bb,lbl

In [None]:
bbs,lbls = remove_padding(bbs,lbls)                      # remove gt padding inserted during training
bbs,lbls

(tensor([[0.0411, 0.1417, 0.7763, 0.9708],
         [0.2391, 0.1438, 0.6067, 0.5792],
         [0.5219, 0.5979, 0.7661, 0.8542],
         [0.4987, 0.5583, 0.6864, 0.7396]]),
 tensor([12, 14, 12, 14]))

In [None]:
pred_bbs = get_pred_bbs(act_bbs, anchors, box_size, device) # make 196 pred bbs from acts and ancs
pred_bbs.shape, pred_bbs

(torch.Size([189, 4]),
 tensor([[ 0.0388,  0.0548,  0.1889,  0.1897],
         [ 0.0350,  0.0752,  0.1641,  0.1768],
         [ 0.0933,  0.0817,  0.1777,  0.2160],
         [ 0.0510,  0.0223,  0.2184,  0.1888],
         [ 0.0154,  0.0164,  0.1592,  0.1482],
         [ 0.0404,  0.0571,  0.1333,  0.2073],
         [ 0.0240,  0.0167,  0.2563,  0.2961],
         [ 0.0206,  0.0685,  0.2763,  0.1975],
         [ 0.0332,  0.0305,  0.1772,  0.2297],
         [ 0.2547,  0.0517,  0.4217,  0.2185],
         [ 0.2628,  0.0641,  0.3889,  0.1662],
         [ 0.3562,  0.0807,  0.4213,  0.2188],
         [ 0.3027,  0.0767,  0.4978,  0.3198],
         [ 0.2346,  0.0788,  0.3736,  0.1637],
         [ 0.2768,  0.0581,  0.4048,  0.2178],
         [ 0.3098,  0.0982,  0.5987,  0.3073],
         [ 0.3044,  0.0633,  0.4876,  0.1915],
         [ 0.3120,  0.0664,  0.4563,  0.2837],
         [ 0.5207,  0.0499,  0.6875,  0.2086],
         [ 0.5446,  0.0614,  0.6481,  0.1699],
         [ 0.5673,  0.0866,  0.6437, 

In [None]:
# map gt to preds
iou_gt_grid = get_iou(bbs.data, anchor_boxes.data)       # get iou(gt_bbs,anc_bbs); used to map gt → ancs
iou_gt_grid.shape, iou_gt_grid

(torch.Size([4, 189]),
 tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 1.6586e-04, 0.0000e+00, 0.0000e+00,
          4.6010e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.000

In [None]:
iou_gt_preds, mapped_gt_idx = map_to_gt(iou_gt_grid)     # assign each pred an index of a gt object
iou_gt_preds, mapped_gt_idx

(tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9900e+00, 0.0000e+00, 0.0000e+00,
         1.9900e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e

In [None]:
mapped_lbls = lbls[mapped_gt_idx]                        # project gt lbls into pred space
mapped_lbls.shape, mapped_lbls

(torch.Size([189]),
 tensor([12, 12, 12, 14, 12, 12, 14, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12]))

In [None]:
# remove low-iou bb preds & set mapped lbl to bg
high_iou = iou_gt_preds > 0.4                            # only include bb preds that overlap w/a gt obj and
high_iou

tensor([False, False, False,  True, False, False,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [None]:
incl = torch.nonzero(high_iou)[:,0]                      #  are not predicting background
incl

tensor([3, 6])

In [None]:
mapped_lbls[~high_iou] = dls.ncls                        # assign gt class of bg to preds w/ low max gt iou
mapped_lbls

tensor([20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 14, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 14, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 20, 20, 20])

In [None]:
mapped_bbs  = bbs[mapped_gt_idx]                         # project gt bbs into pred space
mapped_bbs.shape, mapped_bbs

(torch.Size([189, 4]),
 tensor([[-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-1.1476e-05,  5.2083e-04,  1.6641e-03,  2.1391e-03],
         [-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-2.3297e-03, -3.1808e-03,  9.5253e-04,  7.0685e-04],
         [-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-4.0970e-03, -3.1994e-03,  2.4674e-03,  4.2039e-03],
         [-4.0970e-03, -3.1994e-

In [None]:
# compute loss
bb_res  = F.l1_loss(pred_bbs[incl], mapped_bbs[incl])
bb_res

tensor(0.1364, grad_fn=<L1LossBackward>)

In [None]:
lbl_res = loss_f(act_lbls.cpu(), mapped_lbls.cpu())
lbl_res

tensor(9.1986, grad_fn=<DivBackward0>)

### lbl loss

In [None]:
# lbl_loss
#assignments
acts, targs = act_lbls.cpu(), mapped_lbls.cpu()

In [None]:
t = one_hot_embedding(targs, dls.ncls+1, device)
t.shape, t

(torch.Size([189, 21]),
 tensor([[0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.]]))

In [None]:
t = tensor(t[:,:-1].contiguous())
t.shape, t

(torch.Size([189, 20]),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]))

In [None]:
a = acts[:,:-1]
a.shape, a

(torch.Size([189, 20]),
 tensor([[-3.9891, -4.1691, -4.0815,  ..., -4.1617, -3.9448, -4.2141],
         [-3.9118, -3.8588, -4.1509,  ..., -3.9471, -4.1068, -4.4997],
         [-3.9061, -4.0771, -4.0870,  ..., -4.1831, -3.9755, -3.9879],
         ...,
         [-3.9870, -4.0060, -3.9948,  ..., -3.9976, -4.0073, -3.9924],
         [-3.9988, -3.9908, -4.0037,  ..., -4.0180, -4.0328, -3.9996],
         [-4.0175, -4.0025, -4.0115,  ..., -3.9883, -4.0004, -3.9869]],
        grad_fn=<SliceBackward>))

In [None]:
def get_weight(a, t):
    alpha, gamma = 0.25, 2.0 # vals from paper
    p = a.sigmoid()
    pt = p*t + (1-p)*(1-t)
    w = alpha*t + (1-alpha)*(1-t)
    return w * (1-pt).pow(gamma)

In [None]:
w = get_weight(a,t).detach()
w.shape, w

(torch.Size([189, 20]),
 tensor([[2.4787e-04, 1.7398e-04, 2.0673e-04,  ..., 1.7655e-04, 2.7037e-04,
          1.5923e-04],
         [2.8849e-04, 3.2002e-04, 1.8033e-04,  ..., 2.6919e-04, 1.9670e-04,
          9.0596e-05],
         [2.9171e-04, 2.0851e-04, 2.0448e-04,  ..., 1.6924e-04, 2.5457e-04,
          2.4844e-04],
         ...,
         [2.4890e-04, 2.3979e-04, 2.4511e-04,  ..., 2.4376e-04, 2.3919e-04,
          2.4626e-04],
         [2.4322e-04, 2.4704e-04, 2.4087e-04,  ..., 2.3420e-04, 2.2750e-04,
          2.4284e-04],
         [2.3443e-04, 2.4144e-04, 2.3722e-04,  ..., 2.4827e-04, 2.4243e-04,
          2.4896e-04]]))