<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Data" data-toc-modified-id="Data-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Data</a></span></li><li><span><a href="#Arch" data-toc-modified-id="Arch-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Arch</a></span></li><li><span><a href="#Loss" data-toc-modified-id="Loss-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Loss</a></span><ul class="toc-item"><li><span><a href="#JH-code" data-toc-modified-id="JH-code-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>JH code</a></span></li></ul></li><li><span><a href="#Train" data-toc-modified-id="Train-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Train</a></span></li><li><span><a href="#Stepping-Through-a-Batch" data-toc-modified-id="Stepping-Through-a-Batch-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Stepping Through a Batch</a></span></li></ul></div>

**~ My Code ~**

# Data

In [None]:
### Imports & Paths ###
from fastai.vision.all import *
import pandas as pd


### Params ###
im_sz   = 224
bs      = 64
val_pct = .2
sub_pct = 1
path = untar_data(URLs.PASCAL_2007)
annos_path = path/'train.json'
ims_path = path/'train'


### Items ###
fns, annos = get_annotations(annos_path)
fn2anno = {f:a for f,a in zip(fns,annos)}
def get_im(f):   return ims_path/f
def get_bbox(f): return fn2anno[f][0]
def get_lbl(f):  return fn2anno[f][1]


### DataLoaders ###
itfms = Resize(im_sz, method='squish')
btfms = setup_aug_tfms([Rotate(), Brightness(), Contrast(), Flip(),
                       Normalize.from_stats(*imagenet_stats)])
db = DataBlock(
    blocks=[ImageBlock, BBoxBlock, BBoxLblBlock(add_na=False)],
    get_x=get_im, get_y=[get_bbox, get_lbl], n_inp=1,
    splitter=RandomSplitter(val_pct),
    item_tfms=itfms, batch_tfms=btfms)
subset = L(fns).shuffle()[0:int(len(fns)*sub_pct)]
dls = db.dataloaders(subset, bs=bs)
dls.v = dls.vocab
dls.ncls = len(dls.vocab)

In [None]:
### Inspection ###
print("Vocab:", dls.v)
print("Size of train data:",len(dls.train.items))
print("Size of valid data:",len(dls.valid.items))
for i,t in enumerate(dls.one_batch()):
    print(f"batch[{i}]:",'\t',t.dtype,'\t',t.shape)

Vocab: (#20) ['aeroplane','bicycle','bird','boat','bottle','bus','car','cat','chair','cow'...]
Size of train data: 2001
Size of valid data: 500
batch[0]: 	 torch.float32 	 torch.Size([64, 3, 224, 224])
batch[1]: 	 torch.float32 	 torch.Size([64, 13, 4])
batch[2]: 	 torch.int64 	 torch.Size([64, 13])


Interpretation of tensor shapes:
- torch.Size([128, 3, 224, 224]): bs, channels (rgb), im_sz, im_sz
- torch.Size([128, 20, 4]): bs, max objs for a single im in batch, bb coords
- torch.Size([128, 20]): bs, max objs for a single im in batch

# Arch

In [None]:
### Anchors ###
def create_anchors(subdivs, zooms, ratios, device='cuda'):
    # create list of permutations per default anchor box
    perms = [(z*r1,z*r2) for z in zooms for (r1,r2) in ratios]
    k = len(perms)
    offsets = [1/(sd*2) for sd in subdivs]
    xs = np.concatenate([np.tile(  np.linspace(o,1-o,sd),sd) for o,sd in zip(offsets,subdivs)])
    ys = np.concatenate([np.repeat(np.linspace(o,1-o,sd),sd) for o,sd in zip(offsets,subdivs)])
    ctrs = np.repeat(np.stack([xs,ys], axis=1), k, axis=0)
    hws = np.concatenate([np.array([[o/sd,p/sd] for i in range(sd*sd) for o,p in perms]) for sd in subdivs])
    box_sizes = tensor(np.concatenate([np.array([1/sd for i in range(sd*sd) for o,p in perms])
                                      for sd in subdivs]), requires_grad=False).unsqueeze(1)
    anchors = tensor(np.concatenate([ctrs, hws], axis=1), requires_grad=False).float()
    return anchors.to(device), box_sizes.to(device)

def create_anchor_boxes(ctr, hw):
    return 2*torch.cat((tensor(ctr-hw/2), tensor(ctr+hw/2)), axis=1)-1

device = 'cpu'
subdivs = [4, 2, 1]
zooms   = [0.75, 1.0, 1.3]
ratios  = [(1.,1.), (1.,.5), (.5,1)]
k = len(zooms) * len(ratios)
anchors, box_size = create_anchors(subdivs, zooms, ratios, device)
anchor_boxes = create_anchor_boxes(anchors[:,:2], anchors[:,2:])

In [None]:
anchors.shape, anchors[100]

(torch.Size([189, 4]), tensor([0.8750, 0.6250, 0.1875, 0.0938]))

In [None]:
### Architecture ###
def flatten_conv(x,k):
    bs,nf,gx,gy = x.size()
    return x.permute(0,2,3,1).contiguous().view(bs,-1,nf//k)

class StdConv(Module):
    """Wraps together the standard conv2d→ batchnorm→ dropout."""
    def __init__(self, nin, nout, stride=2, drop=0.1):
        super().__init__()
        self.conv = nn.Conv2d(nin, nout, 3, stride=stride, padding=1)
        self.bn = nn.BatchNorm2d(nout)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        return self.drop(self.bn(F.relu(self.conv(x))))

class OutConv(Module):
    """Outputs two sets of acts: one for bbs, one for lbls."""
    def __init__(self, ncls, k, nin, bias):
        super().__init__()
        self.k = k
        self.bb_acts  = nn.Conv2d(nin,        4*k, 3, padding=1) # bbs
        self.lbl_acts = nn.Conv2d(nin, (ncls+1)*k, 3, padding=1) # lbls
        self.lbl_acts.bias.data.zero_().add_(bias)
    
    def forward(self, x):
        return [flatten_conv(self.bb_acts(x),  self.k), # bbs,lbls
                flatten_conv(self.lbl_acts(x), self.k)]         

class SSDHead(Module):
    """Wraps StdConv and OutConv into a head module.
       Defaults to resnet34 backbone."""
    def __init__(self, ncls, k, bias, drop, body='resnet34'):
        super().__init__()
        test(body, ['resnet34','resnet50'], operator.in_)
        self.body  = body
        self.drop  = nn.Dropout(drop)
        self.re_sz = StdConv(2048, 512, stride=1)
        self.conv0 = StdConv( 512, 256, drop=drop)
        self.conv1 = StdConv( 256, 256, drop=drop)
        self.conv2 = StdConv( 256, 256, drop=drop)
        self.out0  = OutConv(ncls, k, 256, bias)
        self.out1  = OutConv(ncls, k, 256, bias)
        self.out2  = OutConv(ncls, k, 256, bias)
        
    def forward(self, x):
        if self.body == 'resnet34': x = F.relu(x)
        x = self.drop(x)
        if self.body == 'resnet50': x = self.re_sz(x)
        x = self.conv0(x)
        bb0,lbl0 = self.out0(x)
        x = self.conv1(x)
        bb1,lbl1 = self.out1(x)
        x = self.conv2(x)
        bb2,lbl2 = self.out2(x)
        return [torch.cat([ bb0, bb1, bb2], dim=1),
                torch.cat([lbl0,lbl1,lbl2], dim=1)]

class CustMod(Module):
    """A module made from a pretrained body and an untrained head."""
    def __init__(self, body, head):
        self.body, self.head = body, head
        
    def forward(self, x):
        return self.head(self.body(x))

# Loss

In [None]:
### FocalLoss ###
def one_hot_embedding(lbls, ncls, device='cuda'):
    return torch.eye(ncls)[lbls.data].to(device)

class BCELoss(nn.Module):
    def __init__(self, ncls, device='cuda'):
        super().__init__()
        self.ncls = ncls
        self.device = device
        
    def forward(self, acts, targs):
        t = one_hot_embedding(targs, self.ncls+1, self.device)
        t = tensor(t[:,:-1].contiguous())
        a = acts[:,:-1]
        w = self.get_weight(a,t).detach()
        return F.binary_cross_entropy_with_logits(a,t,w,reduction='sum')/self.ncls
    
    def get_weight(self,a,t): return None 
    
class FocalLoss(BCELoss):
    def get_weight(self, a, t):
        alpha, gamma = 0.25, 2.0 # vals from paper
        p = a.sigmoid()
        pt = p*t + (1-p)*(1-t)
        w = alpha*t + (1-alpha)*(1-t)
        return w * (1-pt).pow(gamma)

In [None]:
### IoU ###
def intersxn(b1,b2):
    x1 = torch.max((b1)[:,None,0], (b2)[None,:,0])
    y1 = torch.max((b1)[:,None,1], (b2)[None,:,1])
    x2 = torch.min((b1)[:,None,2], (b2)[None,:,2])
    y2 = torch.min((b1)[:,None,3], (b2)[None,:,3])
    return torch.clamp((x2-x1), min=0) * torch.clamp((y2-y1), min=0)

def area(b):
    return (b[:,2]-b[:,0]) * (b[:,3]-b[:,1])

def get_iou(b1, b2):
    inter = intersxn(b1,b2)
    union = area(b1).unsqueeze(1) + area(b2).unsqueeze(0) - inter
    return inter / union

In [None]:
### ssd_loss ###
def remove_padding(bb, lbl):
    bb = bb.view(-1,4)
    padding = (bb[:,2]-bb[:,0])==0
    return bb[~padding],lbl[~padding]

def get_pred_bbs(act_bb, ancs, anc_sz, device):
    # scale acts between -1 and 1 w/ tanh
    acts = torch.tanh(act_bb)
    # move ctrs by up to box_size/2
    ctrs = ancs.to(device)[:,:2] + (acts.to(device)[:,:2]/2 * anc_sz.to(device))
    # adjust hw up to 1/2
    hws  = ancs.to(device)[:,2:] * (acts.to(device)[:,2:]/2+1)        
    return create_anchor_boxes(ctrs, hws)

def map_to_gt(ious):
    max_iou_per_bb, anc_idxs = ious.max(1)
    max_iou_per_anc, bb_idxs = ious.max(0)
    max_iou_per_anc[anc_idxs] = 1.99
    for i,iou in enumerate(anc_idxs): bb_idxs[iou] = i
    return max_iou_per_anc, bb_idxs

def ssd_item_loss(act_bbs, act_lbls, bbs, lbls, device='cuda'):
    """SSD item loss takes single items from a minibatch, creates hundreds of preds, maps gt
       to the preds, prunes the preds, then calcs & returns the bb and lbl loss for that item."""
# prep
    bbs,lbls = remove_padding(bbs,lbls)                      # remove gt padding inserted during training
    pred_bbs = get_pred_bbs(act_bbs,anchors,box_size,device) # make 196 pred bbs from acts and ancs
    # map gt to preds
    iou_gt_grid = get_iou(bbs.data, anchor_boxes.data)       # get iou(gt_bbs,anc_bbs); used to map gt → ancs
    iou_gt_preds, mapped_gt_idx = map_to_gt(iou_gt_grid)     # assign each pred an index of a gt object
    mapped_bbs  = bbs[mapped_gt_idx]                         # project gt bbs into pred space
    mapped_lbls = lbls[mapped_gt_idx]                        # project gt lbls into pred space
    # remove low-iou bb preds & set mapped lbl to bg
    high_iou = iou_gt_preds > 0.4                            # only include bb preds that overlap w/a gt obj and
    incl = torch.nonzero(high_iou)[:,0]                      #  are not predicting background
    mapped_lbls[~high_iou] = dls.ncls                        # assign gt class of bg to preds w/ low max gt iou
    # compute loss
    bb_res  = F.l1_loss(pred_bbs[incl], mapped_bbs[incl])
    lbl_res = loss_f(act_lbls, mapped_lbls)
    return bb_res, lbl_res

def ssd_loss(acts, bbs, lbls, device='cuda', print_it=True):
    bb_sum, lbl_sum = 0., 0.
    for o in zip(*acts, bbs, lbls):
        bb_loss, lbl_loss = ssd_item_loss(*o, device)
        bb_sum  += bb_loss
        lbl_sum += lbl_loss
    if print_it: print(f"bb:{bb_sum:.02f} | lbl: {lbl_sum:.02f}")
    return bb_sum + lbl_sum

## JH code

In [None]:
### Anchors ###
anc_grids = [4,2,1]
# anc_grids = [4]
anc_zooms = [0.75, 1., 1.3]
# anc_zooms = [1.]
anc_ratios = [(1.,1.), (1.,0.5), (0.5,1.)]
# anc_ratios = [(1.,1.)]
anchor_scales = [(anz*i,anz*j) for anz in anc_zooms for (i,j) in anc_ratios]
k = len(anchor_scales)
anc_offsets = [1/(o*2) for o in anc_grids]
anc_x = np.concatenate([np.tile(np.linspace(ao, 1-ao, ag), ag)
                        for ao,ag in zip(anc_offsets,anc_grids)])
anc_y = np.concatenate([np.repeat(np.linspace(ao, 1-ao, ag), ag)
                        for ao,ag in zip(anc_offsets,anc_grids)])
anc_ctrs = np.repeat(np.stack([anc_x,anc_y], axis=1), k, axis=0)
anc_sizes  =   np.concatenate([np.array([[o/ag,p/ag] for i in range(ag*ag) for o,p in anchor_scales])
               for ag in anc_grids])
grid_sizes = tensor(np.concatenate([np.array([ 1/ag       for i in range(ag*ag) for o,p in anchor_scales])
               for ag in anc_grids]), requires_grad=False).unsqueeze(1)
anchors = tensor(np.concatenate([anc_ctrs, anc_sizes], axis=1), requires_grad=False).float()

def hw2corners(ctr, hw): return torch.cat([ctr-hw/2,ctr+hw/2], dim=1)
anchor_cnr = hw2corners(anchors[:,:2], anchors[:,2:])

### Architecture ###
class StdConv(nn.Module):
    def __init__(self, n_in,n_out,stride=2,dp = 0.1):
        super().__init__()
        self.conv = nn.Conv2d(n_in,n_out,3,stride=stride,padding=1)
        self.bn = nn.BatchNorm2d(n_out)
        self.dropout = nn.Dropout(dp)
        
    def forward(self,x):
        return self.dropout(self.bn(F.relu(self.conv(x))))
    
def flatten_conv(x,k):
    bs,nf,gx,gy = x.size()
    x = x.permute(0,2,3,1).contiguous()
    return x.view(bs,-1,nf//k)

class OutConv(nn.Module):
    def __init__(self, k, n_in, bias):
        super().__init__()
        self.k = k
        self.bbs  = nn.Conv2d(n_in,            4*k, 3, padding=1)
        self.lbls = nn.Conv2d(n_in, (dls.ncls+1)*k, 3, padding=1)
        self.lbls.bias.data.zero_().add_(bias)
        
    def forward(self,x):
        return [flatten_conv(self.bbs(x),  self.k),
                flatten_conv(self.lbls(x), self.k)]
    
drop=0.4
class SSD_MultiHead(nn.Module):
    def __init__(self, k, bias):
        super().__init__()
        self.drop = nn.Dropout(drop)
        self.sconv1 = StdConv(512,256, dp=drop)
        self.sconv2 = StdConv(256,256, dp=drop)
        self.sconv3 = StdConv(256,256, dp=drop)
        self.out1 = OutConv(k, 256, bias)
        self.out2 = OutConv(k, 256, bias)
        self.out3 = OutConv(k, 256, bias)

    def forward(self, x):
        x = self.drop(F.relu(x))
        x = self.sconv1(x)
        bbs1,lbls1 = self.out1(x)
        x = self.sconv2(x)
        bbs2,lbls2 = self.out2(x)
        x = self.sconv3(x)
        bbs3,lbls3 = self.out3(x)
        return [torch.cat([ bbs1, bbs2, bbs3], dim=1),
                torch.cat([lbls1,lbls2,lbls3], dim=1)]
    
class CustMod(Module):
    """A module made from a pretrained body and an untrained head."""
    def __init__(self, body, head):
        self.body, self.head = body, head
        
    def forward(self, x):
        return self.head(self.body(x))

### FocalLoss ###
def one_hot_embedding(labels, num_classes):
    return torch.eye(num_classes)[labels.data].to('cuda')

class BCE_Loss(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes
    
    def forward(self, preds, targets):
        t = one_hot_embedding(targets, self.num_classes+1)
        t = tensor(t[:,:-1].contiguous())
        x = preds[:,:-1]
        w = self.get_weight(x,t).detach()
        return F.binary_cross_entropy_with_logits(x, t, w, reduction='sum') / self.num_classes
    
    def get_weight(self,x,t):
        return None

class FocalLoss(BCE_Loss):
    def get_weight(self,x,t):
        alpha,gamma = 0.25,2.
        p = x.sigmoid()
        pt = p*t + (1-p)*(1-t)
        w = alpha*t + (1-alpha)*(1-t)
        return w * (1-pt).pow(gamma)

loss_f = FocalLoss(dls.ncls)

### IoU ###
def intersection(box_a,box_b):
    min_xy = torch.max(box_a[:,None,:2],box_b[None,:,:2])
    max_xy = torch.min(box_a[:,None,2:],box_b[None,:,2:])
    inter = torch.clamp(max_xy-min_xy,min=0)
    return inter[:,:,0] * inter[:,:,1]

def get_size(box):
    return (box[:,2]-box[:,0]) * (box[:,3] - box[:,1])

def jaccard(box_a,box_b):
    inter = intersection(box_a,box_b)
    union = get_size(box_a).unsqueeze(1) + get_size(box_b).unsqueeze(0) - inter
    return inter/union

### ssd_loss ###
def get_y(bbox,clas):
    bbox = bbox.view(-1,4)/size
    bb_keep = ((bbox[:,2] - bbox[:,0])>0.).nonzero()[:,0]
    return bbox[bb_keep], clas[bb_keep]
    
def actn_to_bb(actn, anchors):
    actn_bbs = torch.tanh(actn)
    actn_ctrs = (actn_bbs.cuda()[:,:2] * grid_sizes.cuda()/2) + anchors.cuda()[:,:2]
    actn_hw = (1 + actn_bbs.cuda()[:,2:]/2) * anchors.cuda()[:,2:]
    return hw2corners(actn_ctrs,actn_hw)

def map_to_ground_truth(overlaps, print_it=False):
    prior_overlap, prior_idx = overlaps.max(1)
    if print_it: print(prior_overlap)
    gt_overlap, gt_idx = overlaps.max(0)
    gt_overlap[prior_idx] = 1.99
    for i,o in enumerate(prior_idx): gt_idx[o] = i
    return gt_overlap,gt_idx

def ssd_1_loss(b_bb, b_c, bbox, clas, print_it=False, use_ab=True):
    bbox,clas = get_y(bbox,clas)
    a_ic = actn_to_bb(b_bb, anchors)
    overlaps = jaccard(bbox.data, (anchor_cnr.cuda() if use_ab else a_ic).data)
    gt_overlap,gt_idx = map_to_ground_truth(overlaps)
    gt_clas = clas[gt_idx]
    pos = gt_overlap > 0.4
    pos_idx = torch.nonzero(pos)[:,0]
    gt_clas[~pos] = dls.ncls
    gt_bbox = bbox[gt_idx]
    loc_loss = ((a_ic[pos_idx] - gt_bbox[pos_idx]).abs()).mean()
    clas_loss  = loss_f(b_c, gt_clas)
    return loc_loss, clas_loss

def ssd_loss(pred, targ_bb, targ_lbl, print_it=True):
    bb_sum,lbl_sum = 0.,0.
    for pred_bb,pred_lbl,bbox,clas in zip(*pred,targ_bb,targ_lbl):
        bb_loss,lbl_loss = ssd_1_loss(pred_bb,pred_lbl,bbox,clas,print_it)
        bb_sum += bb_loss
        lbl_sum += lbl_loss
    if print_it: print(f"bb:{bb_sum:.02f} | lbl: {lbl_sum:.02f}")
    return bb_sum + lbl_sum

In [None]:
### Batch n Acts ###
# Get batch, acts
head_reg4_JH = SSD_MultiHead(k, -4.)
mod_JH = CustMod(create_body(resnet34, pretrained=True), head_reg4_JH)
mod_JH.cpu().eval()

batch_JH = next(iter(dls.cpu().valid))
acts_JH = mod(batch_JH[0])
batch_bbs_JH, batch_lbls_JH = batch_JH[1], batch_JH[2]

# Get acts and targs for a single im
#bidx = ... already set
act_bbs_JH,act_lbls_JH,bbs_JH,lbls_JH = list(zip(*acts_JH,batch_bbs_JH,batch_lbls_JH))[bidx]

###### Not Used

In [None]:
### get_preds ###
def nms(boxes, scores, iou_thresh, top_k=100):
    keep = scores.new(scores.size(0)).zero_().long()
    if boxes.numel() == 0: return keep
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    area = torch.mul(x2 - x1, y2 - y1)
    v, idx = scores.sort(0)  # sort asc
    idx = idx[-top_k:]       # indices of k largest vals
    xx1 = boxes.new()
    yy1 = boxes.new()
    xx2 = boxes.new()
    yy2 = boxes.new()
    w = boxes.new()
    h = boxes.new()

    count = 0
    while idx.numel() > 0:
        i = idx[-1]  # index of current largest val
        keep[count] = i
        count += 1
        if idx.size(0) == 1: break
        idx = idx[:-1]  # remove kept element from view
        # load bboxes of next highest vals
        torch.index_select(x1, 0, idx, out=xx1)
        torch.index_select(y1, 0, idx, out=yy1)
        torch.index_select(x2, 0, idx, out=xx2)
        torch.index_select(y2, 0, idx, out=yy2)
        # store element-wise max with next highest score
        xx1 = torch.clamp(xx1, min=x1[i])
        yy1 = torch.clamp(yy1, min=y1[i])
        xx2 = torch.clamp(xx2, max=x2[i])
        yy2 = torch.clamp(yy2, max=y2[i])
        w.resize_as_(xx2)
        h.resize_as_(yy2)
        w = xx2 - xx1
        h = yy2 - yy1
        # check sizes of xx1 and xx2.. after each iteration
        w = torch.clamp(w, min=0.0)
        h = torch.clamp(h, min=0.0)
        inter = w*h
        # IoU = i / (area(a) + area(b) - i)
        rem_areas = torch.index_select(area, 0, idx)  # load remaining areas)
        union = (rem_areas - inter) + area[i]
        IoU = inter/union  # store result in iou
        # keep only elements with an IoU <= iou_thresh
        idx = idx[IoU.le(iou_thresh)]
    return keep, count

def acts_to_preds(abb, albl, ancs, anc_sz, iou_thresh, conf_thresh, device):
    """Turn model acts into preds: abbs use get_pred_bbs, and albls use sigmoid().max().
       Used in ResultShower and mAP (and could possibly be used in loss fxn)."""
     # convert acts to preds
    pbb = get_pred_bbs(abb, ancs, anc_sz, device)
    conf, plbl = albl.sigmoid().max(1)
    # filter out preds w/ nms
    nms_idxs, nms_n = nms(pbb.data, conf, iou_thresh)
    nms_idxs = nms_idxs[:nms_n]
    pbb  = pbb[nms_idxs]
    plbl = plbl[nms_idxs]
    conf = conf[nms_idxs]
    # filter out bg and low-conf preds
    is_not_bg = (plbl!=0)
    is_confident = conf > conf_thresh
    mask = is_not_bg & is_confident
    return pbb[mask], plbl[mask], conf[mask]

def get_batch_preds(abb, albl, ancs, anc_sz, iou_thresh=.5, conf_thresh=.25, device='cpu'):
    """Loop through a batch and of activations and turn them into predictions."""
    ancs.to(device); anc_sz.to(device)
    pbbs, plbls, confs = [], [], []
    for abb, albl in zip(abb, albl):
        pbb, plbl, conf = acts_to_preds(abb,albl,ancs,anc_sz,iou_thresh,conf_thresh,device)
        pbbs  += [pbb]
        plbls += [plbl]
        confs += [conf]
    return pbbs, plbls, confs

In [None]:
### Metric ###
def _format_inps(acts, batch, anchors, box_size, iou_thresh, conf_thresh, device):
    """Format acts and targs for AP score calc. Input expects learner.acts & learner.batch,
       output format: (im_idx, pred_bbs, pred_cls, cls_conf) and (im_idx, bbs, cls).
       Ex: (46.0, tensor([0.1, 0.2, 0.9, 0.9]), tensor(3), tensor(0.78))"""
    preds = get_batch_preds(*acts, anchors, box_size, iou_thresh, conf_thresh, device)
    p_idxs = torch.cat([torch.tensor([i]*len(o)) for i,o in enumerate(preds[0])]).numpy().tolist()
    gt_idxs = torch.cat([torch.tensor([i]*len(o)) for i,o in enumerate(batch[1])]).numpy().tolist()
    batch_preds = list(zip(p_idxs, *[torch.cat(o) for o in preds]))
    batch_gts = list(zip(gt_idxs, *[o.flatten(end_dim=1) for o in batch[1:]]))
    return batch_preds, batch_gts

def _flatten_list(l, ret_L=False):
    """Flatten a list-of-lists; lists can be python `list`s or a fastai `L`s."""
    def _recur(l,res):
        for o in l:
            if   isinstance(o,list): _recur(o,res)
            elif isinstance(o,L)   : _recur(o,res)
            else: res.append(o)
        return res
    res = _recur(l, [])
    return res if not ret_L else L(res)

def _get_tp_bbs(preds_tp):
    """Output list of tp bbs per im in batch. Not used to calculate mAP; only
       used to grab true positive bb preds for visualizing in ResultShower."""
    # Each row in preds_tps is a formatted pred (see format_ap_inputs) and a list of
    # 1s and 0s (signifying tps and fps) for a cls. No preds for a cls → empty lists.
    batch_idxs, pred_bbs, tpfps = [], [], []
    for preds,tp in preds_tp:
        batch_idxs.append([o[0] for o in preds])
        pred_bbs.append([o[1] for o in preds])
        tpfps.append(tp)
    flat_idxs  = _flatten_list(batch_idxs)
    flat_bbs   = _flatten_list(pred_bbs)
    flat_tpfps = torch.cat(tpfps)

    scored_preds = list(zip(flat_idxs, flat_bbs, flat_tpfps))
    true_bbs = [(int(o[0]), o[1]) for o in scored_preds if o[2]==True]

    true_preds = [torch.zeros(4).view(1,4) for i in range(0,bs)]
    for i,bb in true_bbs:
        if true_preds[i].sum()==0: true_preds[i] = bb.view(1,4)
        else: true_preds[i] = torch.cat([true_preds[int(i)], bb.view(1,4)], dim=0)
    return true_preds

def ap_per_cls(acts, batch, n_cls, ancs, anc_sz, iou_thresh, conf_thresh, device='cuda'):
    """Calculate AP score per class. Returns avg AP over all classes."""
    batch_preds, batch_gts = _format_inps(acts,batch,ancs,anc_sz,iou_thresh,conf_thresh,device)
    # avg_precs holds ap per cls; other accumulators only used for vizing results
    avg_precs, preds_tp, n_objs_accum = [],[],[] 
    for c in range(1,n_cls): # start at 1 to ignore gt
        # store preds and gts for current cls
        preds = [b for b in batch_preds if b[2]==c]
        gts   = [b for b in batch_gts   if b[2]==c]
                
        # sort preds by conf desc
        preds.sort(key=lambda x: x[3], reverse=True)
        
        # make dict of im_idx:zeros(n_objs)
        n_objs = Counter([gt[0] for gt in gts])
        for k,v in n_objs.items(): n_objs[k] = torch.zeros(v)
        
        # init tp: a bool tensor for each im s.t. 1s indicate a pred is a tp
        tp = torch.zeros((len(preds))).bool()
        total_gt_objs = len(gts)
        
        for pred_idx, pred in enumerate(preds):
            gt_objs = [o for o in gts if o[0] == pred[0]]
            n_gt_objs = len(gt_objs)
            max_iou = 0
            
            for idx, gt in enumerate(gt_objs):
                iou = get_iou(pred[1].view(1,4), gt[1].view(1,4))
                if iou > max_iou: max_iou, idx_of_max = iou, idx
                    
            # update idx of gt_obj to indicate it's been used
            if max_iou > iou_thresh:
                if n_objs[pred[0]][idx_of_max]==0:
                    tp[pred_idx] = 1
                    n_objs[pred[0]][idx_of_max] = 1
                    
        # store tp_bbs: use tp ask mask on preds and take 1th item (the bb) from each
        preds_tp.append(([preds,tp]))
        n_objs_accum.append(n_objs)
        
        # calc avg_prec and store
        tps = torch.cumsum(tp, dim=0)               # 1. tp csum: [0,1,1,0,0] → [0,1,2,2,2]
        fps = torch.cumsum(~tp, dim=0)              # (basically same steps for fp)
        prec = torch.div(tps, (tps + fps + 1e-6))   # 2. divide each tps item by n_preds
        prec = torch.cat((torch.tensor([1]), prec)) # 3. slap on a 1 at the beginning
        rec = tps / (total_gt_objs + 1e-6)
        rec = torch.cat((torch.tensor([0]), rec))
        avg_prec = torch.trapz(prec, rec) # calc AP w/ trap rule
        avg_precs.append(avg_prec)        # store AP of this cls in accum
    return avg_precs, _get_tp_bbs(preds_tp), n_objs_accum

def get_ap_scores(dls, model, ancs, anc_sz, iou, conf, device='cpu'):
    for o in [dls, model, anchors, box_size]: o.to(device)
    model.eval()
    n_cls = len(dls.vocab)

    res=[]
    for b in dls.valid:
        scores,_,_ = ap_per_cls(model(b[0]),b,n_cls,ancs,anc_sz,iou,conf,device)
        res.append(scores)
    ap_scores = torch.stack([tensor(o) for o in res]).sum(axis=0)/len(res)
    ap_scores = ap_scores.numpy().tolist()
    return sum(ap_scores)/len(ap_scores), ap_scores

class MeanAveragePrecision(Metric):
    def __init__(self, func, n_cls):
        self.func,self.n_cls = func,n_cls
    def reset(self):
        self.res = []
    def accumulate(self, learn):
        is_last_epoch = learn.epoch==learn.n_epoch-1
        if is_last_epoch:
            cls_aps,_,_ = self.func(learn.pred,(*learn.xb,*learn.yb),self.n_cls,anchors,box_size)
            self.res.append(cls_aps)
    @property
    def value(self):
        if self.res==[]:
            return 0
        else:
            ap_scores = torch.stack([tensor(o) for o in self.res]).sum(axis=0)/len(self.res)
            return sum(ap_scores)/len(ap_scores)
    @property
    def name(self):
        return "mAP"

In [None]:
### Viz Results ###
def show_bb(im, bb=None, lbl=[''], title=None, color='white',
            ctx=None, sz=im_sz, figsize=5):
    # process empties and nones
    if bb.shape[-1]==0 or bb==None: bb  = tensor([[0.,0,0,0]])
    if lbl==['']:                   lbl = ['']*bb.shape[0]
        
    # process tensors to take advantage of fastai show methods
    bbox = TensorBBox((bb+1)*sz//2)
    labeledbbox = LabeledBBox(bbox,lbl)
    
    if ctx:     show_image(im, figsize=[figsize,figsize], title=title, ctx=ctx)
    else: ctx = show_image(im, figsize=[figsize,figsize], title=title)
    
    labeledbbox.show(ctx=ctx)       # first, draw white lbl bbs...
    bbox.show(ctx=ctx, color=color) # ... then overlay color bbs.
    return ctx

def get_im_tpfpfns(res):
    """QnD function to output tpfpfns for each im in a batch given a ResultShower."""
    # TPs are found by counting 
    tps = []
    for bbs in res.tp_bbs:
        if bbs.sum()==0: tps.append(0)
        else: tps.append(bbs.shape[0])
    # FPs are found by subtracing tps from preds_per_im
    preds_per_im = [o.shape[0] for o in res.preds[0]]
    fps = [pred-tp for pred,tp in zip(preds_per_im, tps)]
    # FNs (relies on res.nobj which is a hack)
    # res.nobj is a dict where each key is a cls and the values are some combo of
    #  im_idx and a bool list st. the len of the list is the number of gt_objs for
    #  that class in that im, and the truth value of the bool represents whether
    #  the obj is a TP or FN.
    im_dict = defaultdict(lambda: [])
    for cls in res.nobj:
        for im_idx,tp_tensor in cls.items():
            im_dict[im_idx] += ~tp_tensor.bool()
    idx_fns = [(k,sum(v).item()) for k,v in im_dict.items()]
    idx_fns.sort()
    fns = [o[1] for o in idx_fns]
    return tps, fps, fns

class ResultShower():
    def __init__(self, dls, lrn, ancs, anc_sz, iou, conf):
        # store init's args
        self.dls    = dls
        self.mod    = lrn.model.eval().cpu()
        self.ancs   = ancs.cpu()
        self.anc_sz = anc_sz.cpu()
        self.iou    = iou
        self.conf   = conf
        # compute attrs
        self.batch    = next(iter(self.dls.cpu().valid))
        self.acts     = [a.data for a in self.mod(self.batch[0])]
        self.preds    = get_batch_preds(*self.acts,self.ancs,self.anc_sz,self.iou,self.conf,'cpu')
        self.dec_ims  = self.dls.decode(self.batch)[0]
        self.bs       = self.dls.bs
        self.voc      = self.dls.vocab
        self.im_sz    = self.batch[0].shape[-1]
        self.last_res = 0
        self.fig_sz   = [8,8]
        # compute metrics
        aps,tp_bbs,nobj = ap_per_cls(self.acts,self.batch,len(self.voc),
                                     self.ancs,self.anc_sz,self.iou,self.conf,'cpu')
        self.ap_scores  = [o.item() for o in aps]
        self.tp_bbs     = tp_bbs
        self.nobj       = nobj
        self.im_tpfpfns = get_im_tpfpfns(self)
        # clean up
        self.dls.cuda(); self.mod.cuda()
        
    def __call__(self, *args, **kwargs):
        return self.show_next(*args, **kwargs)
    
    def __getitem__(self, i):
        # get everything to draw
        ims             = self.dec_ims
        _,bbs,lbls      = self.batch
        pbbs,plbls,conf = self.preds
        tbbs            = self.tp_bbs
        tps,fps,fns     = self.im_tpfpfns
        # titles
        t_gt = f"Idx {i} (nobj={(lbls[i] > 0).sum()})"
        t_p  = f"TPs:{tps[i]} | FPs:{fps[i]} | FNs:{fns[i]}"
        # two ctx: gts and all preds. lime bbs drawn over TP preds.
        ctx = get_grid(2, figsize=self.fig_sz)
        show_bb(ims[i],bbs[i], self.voc[lbls[i]], t_gt,'white', ctx[0], self.im_sz)
        show_bb(ims[i],pbbs[i],self.voc[plbls[i]],t_p, 'magenta',ctx[1],self.im_sz)
        show_bb(ims[i],tbbs[i],color='lime',ctx=ctx[1])
                 
    def show_next(self, n=1):
        for i in range(n): self[(i + self.last_res)%self.bs]
        self.last_res += n

In [None]:
# lil' helprs
def pad_strs(strs):
    nchars = [len(s) for s in strs]
    maxn = max(nchars)
    nspc = [maxn-n for n in nchars]
    return [s+' '*n for s,n in zip(strs,nspc)]
def print_lists(*lists):
    for l in [*lists]: test_eq(len(lists[0]), len(l))        # test lens eq
    pstrs = [pad_strs([str(o) for o in l]) for l in lists]   # get list of padding strs
    for zpstr in list(zip(*pstrs)): print(' | '.join(zpstr)) # print rows joined with ' | '
def batch_info(l):
    """Print idx, type, shape for items in l (a list of tensors)"""
    idxs = list(range(len(l)))
    shapes = apply(lambda t:t.shape, l)
    types = apply(type, l)
    print_lists(idxs, shapes, types)

# Train

In [None]:
### Init Anchors ###
subdivs = [4, 2, 1]
zooms   = [0.75, 1.0, 1.3]
ratios  = [(1.,1.), (1.,.5), (.5,1)]
k = len(zooms) * len(ratios)
anchors, box_size = create_anchors(subdivs, zooms, ratios)
anchor_boxes = create_anchor_boxes(anchors[:,:2], anchors[:,2:])

### Label Loss, Metric ###
loss_f = FocalLoss(dls.ncls, 'cuda')
# met = MeanAveragePrecision(ap_per_cls, n_cls)

In [None]:
def get_fresh_mod(ncls=dls.ncls, k=k, bias=-4., drop=.4, body='resnet34', device='cuda'):
    test(body, ['resnet34','resnet50'], operator.in_)
    arch = resnet34 if body=='resnet34' else resnet50
    return CustMod(create_body(arch, pretrained=True), SSDHead(ncls, k, bias, drop, body))
mod = get_fresh_mod()

In [None]:
learner = Learner(dls, mod, loss_func=ssd_loss).to_fp16()
learner.freeze()
# lr_min, lr_steep = learner.lr_find()

In [None]:
# lr = (lr_min+lr_steep)/2; print("lr:",round(lr,4))

lr: 0.0022


In [None]:
learner.fit_one_cycle(5, lr=2e-3)

epoch,train_loss,valid_loss,time
0,42.545944,83.527405,00:09
1,36.724854,36.02985,00:09
2,31.668453,22.878479,00:09
3,28.094711,22.700785,00:09
4,25.52347,20.782619,00:09


bb:16.82 | lbl: 27.95
bb:16.85 | lbl: 25.12
bb:17.32 | lbl: 32.56
bb:17.04 | lbl: 27.40
bb:17.56 | lbl: 28.57
bb:17.68 | lbl: 27.87
bb:16.84 | lbl: 25.60
bb:16.79 | lbl: 27.23
bb:17.87 | lbl: 27.12
bb:17.68 | lbl: 25.02
bb:17.72 | lbl: 26.32
bb:17.65 | lbl: 27.91
bb:17.85 | lbl: 27.26
bb:17.64 | lbl: 24.85
bb:16.31 | lbl: 25.95
bb:17.14 | lbl: 28.21
bb:16.83 | lbl: 26.55
bb:15.66 | lbl: 26.40
bb:16.88 | lbl: 28.57
bb:16.10 | lbl: 25.90
bb:15.88 | lbl: 27.10
bb:16.49 | lbl: 27.01
bb:16.42 | lbl: 27.35
bb:16.89 | lbl: 23.59
bb:15.63 | lbl: 22.65
bb:15.49 | lbl: 26.69
bb:15.75 | lbl: 25.40
bb:14.70 | lbl: 23.87
bb:15.03 | lbl: 20.64
bb:14.56 | lbl: 21.53
bb:14.46 | lbl: 28.38
bb:21.36 | lbl: 71.12
bb:19.58 | lbl: 69.91
bb:20.18 | lbl: 49.46
bb:20.70 | lbl: 56.05
bb:20.19 | lbl: 61.92
bb:22.01 | lbl: 62.38
bb:20.82 | lbl: 74.60
bb:16.91 | lbl: 59.73
bb:13.50 | lbl: 26.17
bb:14.23 | lbl: 23.29
bb:14.24 | lbl: 22.13
bb:13.75 | lbl: 22.82
bb:13.11 | lbl: 24.65
bb:13.30 | lbl: 22.30
bb:12.78 |

In [None]:
learner.save('s1')

Path('models/s1.pth')

In [None]:
learner.export('models/20201214_pascal2007_rory.pkl')

In [None]:
# learner.export('models/20201210_pascal2007.pkl')      # trained for 10 epochs
# learner.export('models/20201210_pascal2007_bad.pkl')  # trained for 5 epochs
# learner.export('models/20201211_pascal2007_focalfix.pkl') # corrected weight in focal loss

# Stepping Through a Batch

###### batch loss

In [None]:
# Setup
# device='cpu'

# subdivs = [4, 2, 1]
# zooms   = [0.75, 1.0, 1.3]
# ratios  = [(1.,1.), (1.,.5), (.5,1)]
# k = len(zooms) * len(ratios)
# anchors, box_size = create_anchors(subdivs, zooms, ratios, device)
# anchor_boxes = create_anchor_boxes(anchors[:,:2], anchors[:,2:])
# loss_f = FocalLoss(dls.ncls, device)

In [None]:
# Get batch, acts
mod = get_fresh_mod(device=device)
mod.eval()

batch = next(iter(dls.cpu().valid))
acts = mod(batch[0])
batch_bbs, batch_lbls = batch[1], batch[2]

In [None]:
# Get acts and targs for a single im
bidx = 0
act_bbs,act_lbls,bbs,lbls = list(zip(*acts,batch_bbs,batch_lbls))[bidx]

(tensor(0.1392, grad_fn=<L1LossBackward>),
 tensor(1.1336, grad_fn=<DivBackward0>))

###### item loss

In [None]:
bbs,lbls = remove_padding(bbs,lbls)                      # remove gt padding inserted during training
bbs,lbls

In [None]:
pred_bbs = get_pred_bbs(act_bbs,anchors,box_size,device) # make 196 pred bbs from acts and ancs

In [None]:
# map gt to preds
iou_gt_grid = get_iou(bbs.data, anchor_boxes.data)       # get iou(gt_bbs,anc_bbs); used to map gt → ancs

In [None]:
iou_gt_preds, mapped_gt_idx = map_to_gt(iou_gt_grid)     # assign each pred an index of a gt object

In [None]:
mapped_bbs  = bbs[mapped_gt_idx]                         # project gt bbs into pred space

In [None]:
mapped_lbls = lbls[mapped_gt_idx]                        # project gt lbls into pred space

In [None]:
# remove low-iou bb preds & set mapped lbl to bg
high_iou = iou_gt_preds > 0.4                            # only include bb preds that overlap w/a gt obj and

In [None]:
incl = torch.nonzero(high_iou)[:,0]                      #  are not predicting background

In [None]:
mapped_lbls[~high_iou] = dls.ncls                        # assign gt class of bg to preds w/ low max gt iou

In [None]:
# compute loss
bb_res  = F.l1_loss(pred_bbs[incl], mapped_bbs[incl])

In [None]:
lbl_res = loss_f(act_lbls, mapped_lbls)