<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

Everything:

In [None]:
from fastai.vision.all import *

# paths
path = Path('/home/rory/data/coco2017')
train_im_dir, valid_im_dir = 'train2017', 'val2017'
train_json = 'annotations/instances_train2017.json'
valid_json = 'annotations/instances_val2017.json'

# get items & annos
def get_annos(path, anno_file, im_folder):
    xs, ys = get_annotations(path/anno_file)
    return L(xs).map(lambda x: path/im_folder/x), ys
train_paths, train_annos = get_annos(path, train_json, train_im_dir)
valid_paths, valid_annos = get_annos(path, valid_json, valid_im_dir)
paths  = train_paths + valid_paths
annos  = train_annos + valid_annos
bboxes = [a[0] for a in annos]
lbls   = [a[1] for a in annos]

# create df and pkl
df = pd.DataFrame({
    "path": paths,
    "lbl":  lbls,
    "bbox": bboxes,
    "is_valid": [0]*len(train_files) + [1]*len(valid_files),
    "n_items": L(len(l) for l in lbls)
    })
df.to_pickle(path/'20201027_coco_df.pkl')

df.head(1)

Subset:

In [None]:
from fastai.vision.all import *
    
    
# paths
path = Path('/home/rory/data/coco2017')
train_im_dir, valid_im_dir = 'train2017', 'val2017'
train_json = 'annotations/instances_train2017.json'
valid_json = 'annotations/instances_val2017.json'

# get items & annos
def get_annos(path, anno_file, im_folder):
    xs, ys = get_annotations(path/anno_file)
    return L(xs).map(lambda x: path/im_folder/x), ys
train_files, train_annos = get_annos(path, train_json, train_im_dir)
valid_files, valid_annos = get_annos(path, valid_json, valid_im_dir)
annos = train_annos + valid_annos
bboxes = L(L(b) for b,l in annos)
lbls   = L(L(l) for b,l in annos)

# get label subset
def flatten(l):
    def _recur(l,res):
        for o in l:
            if   isinstance(o,list): _recur(o,res)
            elif isinstance(o,L)   : _recur(o,res)
            else: res.append(o)
        return res
    return _recur(l, L())
lbls_flat = flatten(lbls)
def get_count(lbl): return len(lbls_flat.filter(lambda x: x==lbl))
lbls_sorted = sorted(lbls_flat.unique(), key=get_count, reverse=True)
lbl_cts = [(l, get_count(l)) for l in lbls_sorted]
lbl_ss = [l for l,c in lbl_cts if 5000<c<30000]
lbl_ss.remove('traffic light')
lbl_ss.remove('motorcycle')
lbl_ss.remove('bus')

# get subset items & annos
ss_idxs   = [[o in lbl_ss for o in l] for l in lbls]
ss_lbls   = [list(o[i]) for o,i in zip(lbls,ss_idxs)]
ss_bboxes = [list(o[i]) for o,i in zip(bboxes,ss_idxs)]
ss_obj    = [len(l) for l in ss_lbls]

# create df
df = pd.DataFrame({
    "path"  : train_files + valid_files,
    "lbl"   : ss_lbls,
    "bbox"  : ss_bboxes,
    "n_obj" : ss_obj
    })

# remove rows w/ 4+ objects per im
# reduce value counts to a max of 2*min_vc
too_many = 8
rng = list(range(too_many))
cts = list(df['n_obj'].value_counts())[:too_many]
minv = min(cts)
zipd = list(zip(*[rng,cts]))
zipd[0] = (0, minv) # limit ims w/zero objs
df = pd.concat([df[df['n_obj']==i].sample(n=min(s,minv*2)) for i,s in zipd])

# create pkl
df.to_pickle(path/'20201029_coco_ss_df.pkl')

Create dls:

In [None]:
from fastai.vision.all import *


### Params ###
im_size    = 224
batch_size = 64
path       = Path('/home/rory/data/coco2017')
valid_pct  = .10


### Items ###
df = pd.read_pickle(path/'20201029_coco_ss_df.pkl')
# get items
def get_cols(df,cols): return [df[c].to_list() for c in cols]
paths, bboxes, lbls  = get_cols(df, ['path', 'bbox', 'lbl'])
p2b = {p:b for p,b in zip(paths,bboxes)}
p2l = {p:l for p,l in zip(paths,lbls)}
def get_bbox(p): return p2b[p]
def get_lbl(p):  return p2l[p]


### DataBlock & DataLoaders ###
db = DataBlock(
    blocks=[ImageBlock, BBoxBlock, BBoxLblBlock],
    get_y=[get_bbox, get_lbl],
    splitter=RandomSplitter(valid_pct),
    item_tfms=Resize(im_size, method='squish'),
    batch_tfms=Normalize.from_stats(*imagenet_stats),
    n_inp=1)
dls = db.dataloaders(paths)

In [None]:
### Inspection (IMPORTANT) ###
print("Size of train data:",len(dls.train.items))
print("Size of valid data:",len(dls.valid.items))
for i,t in enumerate(dls.one_batch()):
    print(f"batch[{i}]:",'\t',t.dtype,'\t',t.shape)