<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [None]:
from fastai.vision.all import *

from pathlib import Path
import json
from PIL import ImageDraw, ImageFont
from matplotlib import patches, patheffects
from ipywidgets import FloatProgress
from IPython.display import display
import time

In [None]:
def random_seed(s, use_cuda):
    #Also, remember to use num_workers=0 when creating the DataBunch
    np.random.seed(s)
    torch.manual_seed(s)
    random.seed(s)
    if use_cuda:
        torch.cuda.manual_seed(s)
        torch.cuda.manual_seed_all(s)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False      
random_seed(42,True)

In [None]:
def dict2list(d):
    return [(k,v) for k,v in list(d.items())]

In [None]:
path = untar_data(URLs.PASCAL_2007)
annos_path = path/'train.json'
ims_path = path/'train'

In [None]:
trn_json = json.load(annos_path.open())                        # {'images':[img data], ...
cats = {o['id']:o['name'] for o in trn_json['categories']}     # {1: 'aeroplane', ...
imgs_fn = {o['id']:o['file_name'] for o in trn_json['images']} # {12:'000012.jpg', 17: ...
imgs_id = [o['id'] for o in trn_json['images']]                # [12, 17, 23, ...

In [None]:
trn_anno = collections.defaultdict(lambda:[])                  # {12: [(bb,clsid)], ...
for annot in trn_json['annotations']:
    if annot['ignore'] == 0:
        bb = annot['bbox']
        #Transforms bb which is left,top, width, height into top, left, bottom, right.
        bb = [bb[1],bb[0],bb[1] + bb[3], bb[0]+bb[2]]
        trn_anno[annot[('image_id')]].append((bb, annot['category_id']))

In [None]:
f_model       = resnet34
size=im_sz    = 224
batch_size=bs = 32

In [None]:
annot_cats = [[cats[int(ann[1])] for ann in trn_anno[i]] for i in imgs_id]
id2cats = list(cats.values())
cats2id = {c:i for i,c in enumerate(id2cats)}
model_cats = np.array([np.array([cats2id[c] for c in ac]) for ac in annot_cats],
                      dtype=object)
model_cats[0:3] # obj classes per im; encoded from [0,19]

array([array([6]), array([14, 12]), array([ 1,  1, 14, 14, 14])],
      dtype=object)

In [None]:
model_bbs  = [np.concatenate([ann[0] for ann in trn_anno[i]]) for i in imgs_id]
model_bbsc = [' '.join([str(p) for p in o]) for o in model_bbs]
model_bbsc[0:3] # bb coords per im

['96 155 270 351',
 '61 184 199 279 77 89 336 403',
 '229 8 500 245 219 229 500 334 0 1 369 117 1 2 462 243 0 224 486 334']

**dls**

In [None]:
fnames = [imgs_fn[k] for k in imgs_fn]
labels = list(zip(model_bbs, model_cats))
labels = [(list(bb),list(cls)) for bb,cls in labels]

fn2label = {k:v for k,v in zip(fnames,labels)}
fn2label[fnames[0]]

([96, 155, 270, 351], [6])

In [None]:
def get_path(f): return ims_path/f
def get_bb(f):   return fn2label[f][0]
def get_cls(f):  return fn2label[f][1]

In [None]:
fnidx=0
get_path(fnames[fnidx]),  get_bb(fnames[fnidx]),  get_cls(fnames[fnidx])

(Path('/home/rory/.fastai/data/pascal_2007/train/000012.jpg'),
 [96, 155, 270, 351],
 [6])

In [None]:
dss = Datasets(fnames, [[get_path, PILImage.create, ToTensor(), Resize(im_sz, method='squish')],
                        [get_bb, TensorBBox.create],
                        [get_cls, MultiCategorize(add_na=False)]],
               n_inp=1)

In [None]:
[[r.shape for r in s] for s in dss[0:5]]

[[torch.Size([3, 224, 224]), torch.Size([1, 4]), torch.Size([1])],
 [torch.Size([3, 224, 224]), torch.Size([2, 4]), torch.Size([2])],
 [torch.Size([3, 224, 224]), torch.Size([5, 4]), torch.Size([5])],
 [torch.Size([3, 224, 224]), torch.Size([1, 4]), torch.Size([1])],
 [torch.Size([3, 224, 224]), torch.Size([4, 4]), torch.Size([4])]]

In [None]:
aug_tfms = setup_aug_tfms([IntToFloatTensor(),Rotate(),Brightness(),Contrast(),Flip()])

tfmdl = TfmdDL(dss, bs=bs, after_item=[BBoxLabeler(), PointScaler()],
               before_batch=bb_pad, after_batch=aug_tfms)

b=next(iter(tfmdl))

In [None]:
b[0].shape

torch.Size([32, 3, 224, 224])

In [None]:
b[1].shape

torch.Size([32, 6, 4])

In [None]:
b[2].shape

torch.Size([32, 6])