In [1]:
import json
import os
import re
import fnmatch
from pathlib import Path
import numpy as np
import cv2
import copy
import random
import xml.etree.ElementTree as ET
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm
from pycocotools.coco import COCO

In [2]:
cls_ind_in_coco_set1 = [1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72]
cls_ind_in_coco_set2 = [85, 43, 78, 87, 47, 11, 35, 53, 52, 46, 38, 51, 41, 58, 32, 82, 24, 37, 73, 22]
cls_ind_in_coco_set3 = [60, 65, 13, 48, 79, 77, 61, 54, 76, 34, 50, 74, 25, 86, 15, 31, 80, 14, 84, 28]
cls_ind_in_coco_set4 = [8, 10, 23, 27, 33, 36, 39, 40, 42, 49, 55, 56, 57, 59, 70, 75, 81, 88, 89, 90]

In [3]:
cwd = os.getcwd()
coco_json_path = cwd + '/data/coco/annotations/instances_val2014.json'
with open(coco_json_path, 'r') as f:
    data = json.load(f)

im_summary = {}
for i, im_d in tqdm(enumerate(data['images'])):
    im_summary[str(im_d['id'])] = {}
    im_summary[str(im_d['id'])]['im_dict'] = im_d
    im_summary[str(im_d['id'])]['annotations'] = []
    im_summary[str(im_d['id'])]['categories'] = []
for j, a_d in enumerate(data['annotations']):
    if a_d['iscrowd'] == 0:  # only keep non-crowd annotations
        im_id = a_d['image_id']
        cat_id = a_d['category_id']
        im_summary[str(im_id)]['annotations'].append(a_d)
        im_summary[str(im_id)]['categories'].append(cat_id)
for _key in im_summary.keys():
    im_summary[_key]['categories'] = list(set(im_summary[_key]['categories']))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




## Pure

### set 1

In [4]:
novel_cls_inds = cls_ind_in_coco_set1
output_set = 'set1'

new_categories = []
for cat in data['categories']:
    if cat['id'] in novel_cls_inds:
        new_categories.append(cat)
MIN_SIZE = 64
cat_and_their_im_id = []
for i, c_d in enumerate(new_categories):
    cat_dict = {}
    cat_id = c_d['id']
    cat_name = c_d['name']
    cat_dict['cat_id'] = cat_id
    cat_dict['cat_name'] = cat_name
    cat_dict['im_ids'] = []
    for _key in im_summary.keys():  # for every im
        im_dict = im_summary[_key]
        if cat_id in im_dict['categories']:  # if this im has this class
            valid_an = []
            for an in im_dict['annotations']:  # for annos of this im
                if an['category_id'] == cat_id:
                    if an['bbox'][2] < MIN_SIZE or an['bbox'][3] < MIN_SIZE:
                        continue
                    else:
                        valid_an.append(an)
            if len(valid_an) > 0:  # only keep those id have more than one valid box of this class
                cat_dict['im_ids'].append(_key)           
    cat_and_their_im_id.append(cat_dict)
cat_and_their_im_id.sort(key=lambda s: len(s['im_ids']))

In [14]:
cwd = os.getcwd()
dump_dir = cwd + '/data/supports/'
if not os.path.exists(dump_dir):
    os.makedirs(dump_dir)
N_SHOT = 30
# MIN_SIZE = 64
MIN_SIZE = 128
RATIO = 2.

for cat_dict in tqdm(cat_and_their_im_id):
    new_dict = {}
    cat_id = cat_dict['cat_id']
    cat_name = cat_dict['cat_name']
    ids = cat_dict['im_ids']
    random.seed(0)
    random.shuffle(ids)
    shot_cnt = 0
    for _id in ids:
        valid_anns = []
        for an in im_summary[_id]['annotations']:
            if an['category_id'] == cat_id:
                box = an['bbox']
                if box[2] < MIN_SIZE or box[3] < MIN_SIZE:
                    continue
                if box[2]/box[3] > RATIO or box[3]/box[2] > RATIO:
                    continue
                valid_anns.append(an)
        if len(valid_anns) != 0:
            im_dict = im_summary[_id]['im_dict']
            im_path = cwd + '/data/coco/images/val2014/' + im_dict['file_name']
            im = cv2.imread(im_path)
            
            random.seed(0)
            chosen_ann = random.sample(valid_anns, k=1)[0]
            box = [int(i) for i in chosen_ann['bbox']]
            im_cropped = im[box[1]:box[1]+box[3]+1, box[0]:box[0]+box[2]+1, :]
            
            file_name = str(cat_id) + '_' + _id + '.jpg'
            output_dir = os.path.join(dump_dir, output_set, cat_name)
            output_path = os.path.join(output_dir, file_name)
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            cv2.imwrite(output_path, im_cropped)
            shot_cnt += 1
            if shot_cnt == N_SHOT:
                break

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




### set 2

In [8]:
novel_cls_inds = cls_ind_in_coco_set2
output_set = 'set2'

new_categories = []
for cat in data['categories']:
    if cat['id'] in novel_cls_inds:
        new_categories.append(cat)
MIN_SIZE = 64
cat_and_their_im_id = []
for i, c_d in enumerate(new_categories):
    cat_dict = {}
    cat_id = c_d['id']
    cat_name = c_d['name']
    cat_dict['cat_id'] = cat_id
    cat_dict['cat_name'] = cat_name
    cat_dict['im_ids'] = []
    for _key in im_summary.keys():  # for every im
        im_dict = im_summary[_key]
        if cat_id in im_dict['categories']:  # if this im has this class
            valid_an = []
            for an in im_dict['annotations']:  # for annos of this im
                if an['category_id'] == cat_id:
                    if an['bbox'][2] < MIN_SIZE or an['bbox'][3] < MIN_SIZE:
                        continue
                    else:
                        valid_an.append(an)
            if len(valid_an) > 0:  # only keep those id have more than one valid box of this class
                cat_dict['im_ids'].append(_key)           
    cat_and_their_im_id.append(cat_dict)
cat_and_their_im_id.sort(key=lambda s: len(s['im_ids']))

In [10]:
cwd = os.getcwd()
dump_dir = cwd + '/data/supports/'
if not os.path.exists(dump_dir):
    os.makedirs(dump_dir)
N_SHOT = 30
# MIN_SIZE = 64
MIN_SIZE = 128
RATIO = 2.

for cat_dict in tqdm(cat_and_their_im_id):
    new_dict = {}
    cat_id = cat_dict['cat_id']
    cat_name = cat_dict['cat_name']
    ids = cat_dict['im_ids']
    random.seed(0)
    random.shuffle(ids)
    shot_cnt = 0
    for _id in ids:
        valid_anns = []
        for an in im_summary[_id]['annotations']:
            if an['category_id'] == cat_id:
                box = an['bbox']
                if box[2] < MIN_SIZE or box[3] < MIN_SIZE:
                    continue
                if box[2]/box[3] > RATIO or box[3]/box[2] > RATIO:
                    continue
                valid_anns.append(an)
        if len(valid_anns) != 0:
            im_dict = im_summary[_id]['im_dict']
            im_path = cwd + '/data/coco/images/val2014/' + im_dict['file_name']
            im = cv2.imread(im_path)
            
            random.seed(0)
            chosen_ann = random.sample(valid_anns, k=1)[0]
            box = [int(i) for i in chosen_ann['bbox']]
            im_cropped = im[box[1]:box[1]+box[3]+1, box[0]:box[0]+box[2]+1, :]
            
            file_name = str(cat_id) + '_' + _id + '.jpg'
            output_dir = os.path.join(dump_dir, output_set, cat_name)
            output_path = os.path.join(output_dir, file_name)
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            cv2.imwrite(output_path, im_cropped)
            shot_cnt += 1
            if shot_cnt == N_SHOT:
                break


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




### set 3

In [11]:
novel_cls_inds = cls_ind_in_coco_set3
output_set = 'set3'

new_categories = []
for cat in data['categories']:
    if cat['id'] in novel_cls_inds:
        new_categories.append(cat)
MIN_SIZE = 64
cat_and_their_im_id = []
for i, c_d in enumerate(new_categories):
    cat_dict = {}
    cat_id = c_d['id']
    cat_name = c_d['name']
    cat_dict['cat_id'] = cat_id
    cat_dict['cat_name'] = cat_name
    cat_dict['im_ids'] = []
    for _key in im_summary.keys():  # for every im
        im_dict = im_summary[_key]
        if cat_id in im_dict['categories']:  # if this im has this class
            valid_an = []
            for an in im_dict['annotations']:  # for annos of this im
                if an['category_id'] == cat_id:
                    if an['bbox'][2] < MIN_SIZE or an['bbox'][3] < MIN_SIZE:
                        continue
                    else:
                        valid_an.append(an)
            if len(valid_an) > 0:  # only keep those id have more than one valid box of this class
                cat_dict['im_ids'].append(_key)           
    cat_and_their_im_id.append(cat_dict)
cat_and_their_im_id.sort(key=lambda s: len(s['im_ids']))

In [12]:
cwd = os.getcwd()
dump_dir = cwd + '/data/supports/'
if not os.path.exists(dump_dir):
    os.makedirs(dump_dir)
N_SHOT = 30
# MIN_SIZE = 64
MIN_SIZE = 128
RATIO = 2.

for cat_dict in tqdm(cat_and_their_im_id):
    new_dict = {}
    cat_id = cat_dict['cat_id']
    cat_name = cat_dict['cat_name']
    ids = cat_dict['im_ids']
    random.seed(0)
    random.shuffle(ids)
    shot_cnt = 0
    for _id in ids:
        valid_anns = []
        for an in im_summary[_id]['annotations']:
            if an['category_id'] == cat_id:
                box = an['bbox']
                if box[2] < MIN_SIZE or box[3] < MIN_SIZE:
                    continue
                if box[2]/box[3] > RATIO or box[3]/box[2] > RATIO:
                    continue
                valid_anns.append(an)
        if len(valid_anns) != 0:
            im_dict = im_summary[_id]['im_dict']
            im_path = cwd + '/data/coco/images/val2014/' + im_dict['file_name']
            im = cv2.imread(im_path)
            
            random.seed(0)
            chosen_ann = random.sample(valid_anns, k=1)[0]
            box = [int(i) for i in chosen_ann['bbox']]
            im_cropped = im[box[1]:box[1]+box[3]+1, box[0]:box[0]+box[2]+1, :]
            
            file_name = str(cat_id) + '_' + _id + '.jpg'
            output_dir = os.path.join(dump_dir, output_set, cat_name)
            output_path = os.path.join(output_dir, file_name)
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            cv2.imwrite(output_path, im_cropped)
            shot_cnt += 1
            if shot_cnt == N_SHOT:
                break


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




### set 4

In [15]:
novel_cls_inds = cls_ind_in_coco_set4
output_set = 'set4'

new_categories = []
for cat in data['categories']:
    if cat['id'] in novel_cls_inds:
        new_categories.append(cat)
MIN_SIZE = 64
cat_and_their_im_id = []
for i, c_d in enumerate(new_categories):
    cat_dict = {}
    cat_id = c_d['id']
    cat_name = c_d['name']
    cat_dict['cat_id'] = cat_id
    cat_dict['cat_name'] = cat_name
    cat_dict['im_ids'] = []
    for _key in im_summary.keys():  # for every im
        im_dict = im_summary[_key]
        if cat_id in im_dict['categories']:  # if this im has this class
            valid_an = []
            for an in im_dict['annotations']:  # for annos of this im
                if an['category_id'] == cat_id:
                    if an['bbox'][2] < MIN_SIZE or an['bbox'][3] < MIN_SIZE:
                        continue
                    else:
                        valid_an.append(an)
            if len(valid_an) > 0:  # only keep those id have more than one valid box of this class
                cat_dict['im_ids'].append(_key)           
    cat_and_their_im_id.append(cat_dict)
cat_and_their_im_id.sort(key=lambda s: len(s['im_ids']))

In [16]:
cwd = os.getcwd()
dump_dir = cwd + '/data/supports/'
if not os.path.exists(dump_dir):
    os.makedirs(dump_dir)
N_SHOT = 30
# MIN_SIZE = 64
MIN_SIZE = 128
RATIO = 2.

for cat_dict in tqdm(cat_and_their_im_id):
    new_dict = {}
    cat_id = cat_dict['cat_id']
    cat_name = cat_dict['cat_name']
    ids = cat_dict['im_ids']
    random.seed(0)
    random.shuffle(ids)
    shot_cnt = 0
    for _id in ids:
        valid_anns = []
        for an in im_summary[_id]['annotations']:
            if an['category_id'] == cat_id:
                box = an['bbox']
                if box[2] < MIN_SIZE or box[3] < MIN_SIZE:
                    continue
                if box[2]/box[3] > RATIO or box[3]/box[2] > RATIO:
                    continue
                valid_anns.append(an)
        if len(valid_anns) != 0:
            im_dict = im_summary[_id]['im_dict']
            im_path = cwd + '/data/coco/images/val2014/' + im_dict['file_name']
            im = cv2.imread(im_path)
            
            random.seed(0)
            chosen_ann = random.sample(valid_anns, k=1)[0]
            box = [int(i) for i in chosen_ann['bbox']]
            im_cropped = im[box[1]:box[1]+box[3]+1, box[0]:box[0]+box[2]+1, :]
            
            file_name = str(cat_id) + '_' + _id + '.jpg'
            output_dir = os.path.join(dump_dir, output_set, cat_name)
            output_path = os.path.join(output_dir, file_name)
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            cv2.imwrite(output_path, im_cropped)
            shot_cnt += 1
            if shot_cnt == N_SHOT:
                break


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




## not pure

### set 1

In [5]:
novel_cls_inds = cls_ind_in_coco_set1
output_set = 'set1_random'

new_categories = []
for cat in data['categories']:
    if cat['id'] in novel_cls_inds:
        new_categories.append(cat)
MIN_SIZE = 64
cat_and_their_im_id = []
for i, c_d in enumerate(new_categories):
    cat_dict = {}
    cat_id = c_d['id']
    cat_name = c_d['name']
    cat_dict['cat_id'] = cat_id
    cat_dict['cat_name'] = cat_name
    cat_dict['im_ids'] = []
    for _key in im_summary.keys():  # for every im
        im_dict = im_summary[_key]
        if cat_id in im_dict['categories']:  # if this im has this class
            valid_an = []
            for an in im_dict['annotations']:  # for annos of this im
                if an['category_id'] == cat_id:
                    if an['bbox'][2] < MIN_SIZE or an['bbox'][3] < MIN_SIZE:
                        continue
                    else:
                        valid_an.append(an)
            if len(valid_an) > 0:  # only keep those id have more than one valid box of this class
                cat_dict['im_ids'].append(_key)           
    cat_and_their_im_id.append(cat_dict)
cat_and_their_im_id.sort(key=lambda s: len(s['im_ids']))

In [6]:
cwd = os.getcwd()
dump_dir = cwd + '/data/supports/'
N_SHOT = 15
MIN_SIZE = 64
RATIO = 2.

for cat_dict in tqdm(cat_and_their_im_id):
    new_dict = {}
    cat_id = cat_dict['cat_id']
    cat_name = cat_dict['cat_name']
    ids = cat_dict['im_ids']
    random.seed(0)
    random.shuffle(ids)
    shot_cnt = 0
    for _id in ids:
        valid_anns = []
        for an in im_summary[_id]['annotations']:
            if an['category_id'] == cat_id:
                box = an['bbox']
                if box[2] < MIN_SIZE or box[3] < MIN_SIZE:
                    continue
                if box[2]/box[3] > RATIO or box[3]/box[2] > RATIO:
                    continue
                valid_anns.append(an)
        if len(valid_anns) != 0:
            im_dict = im_summary[_id]['im_dict']
            im_path = cwd + '/data/coco/images/val2014/' + im_dict['file_name']
            im = cv2.imread(im_path)
            
            random.seed(0)
            chosen_ann = random.sample(valid_anns, k=1)[0]
            box = [int(i) for i in chosen_ann['bbox']]
            im_cropped = im[box[1]:box[1]+box[3]+1, box[0]:box[0]+box[2]+1, :]
            
            file_name = str(cat_id) + '_' + _id + '.jpg'
            output_dir = os.path.join(dump_dir, output_set, cat_name)
            output_path = os.path.join(output_dir, file_name)
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            cv2.imwrite(output_path, im_cropped)
            shot_cnt += 1
            if shot_cnt == N_SHOT:
                break

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




In [7]:
novel_cls_inds = cls_ind_in_coco_set2
output_set = 'set2_random'

new_categories = []
for cat in data['categories']:
    if cat['id'] in novel_cls_inds:
        new_categories.append(cat)
MIN_SIZE = 64
cat_and_their_im_id = []
for i, c_d in enumerate(new_categories):
    cat_dict = {}
    cat_id = c_d['id']
    cat_name = c_d['name']
    cat_dict['cat_id'] = cat_id
    cat_dict['cat_name'] = cat_name
    cat_dict['im_ids'] = []
    for _key in im_summary.keys():  # for every im
        im_dict = im_summary[_key]
        if cat_id in im_dict['categories']:  # if this im has this class
            valid_an = []
            for an in im_dict['annotations']:  # for annos of this im
                if an['category_id'] == cat_id:
                    if an['bbox'][2] < MIN_SIZE or an['bbox'][3] < MIN_SIZE:
                        continue
                    else:
                        valid_an.append(an)
            if len(valid_an) > 0:  # only keep those id have more than one valid box of this class
                cat_dict['im_ids'].append(_key)           
    cat_and_their_im_id.append(cat_dict)
cat_and_their_im_id.sort(key=lambda s: len(s['im_ids']))

In [8]:
cwd = os.getcwd()
dump_dir = cwd + '/data/supports/'
N_SHOT = 15
MIN_SIZE = 64
RATIO = 2.

for cat_dict in tqdm(cat_and_their_im_id):
    new_dict = {}
    cat_id = cat_dict['cat_id']
    cat_name = cat_dict['cat_name']
    ids = cat_dict['im_ids']
    random.seed(0)
    random.shuffle(ids)
    shot_cnt = 0
    for _id in ids:
        valid_anns = []
        for an in im_summary[_id]['annotations']:
            if an['category_id'] == cat_id:
                box = an['bbox']
                if box[2] < MIN_SIZE or box[3] < MIN_SIZE:
                    continue
                if box[2]/box[3] > RATIO or box[3]/box[2] > RATIO:
                    continue
                valid_anns.append(an)
        if len(valid_anns) != 0:
            im_dict = im_summary[_id]['im_dict']
            im_path = cwd + '/data/coco/images/val2014/' + im_dict['file_name']
            im = cv2.imread(im_path)
            
            random.seed(0)
            chosen_ann = random.sample(valid_anns, k=1)[0]
            box = [int(i) for i in chosen_ann['bbox']]
            im_cropped = im[box[1]:box[1]+box[3]+1, box[0]:box[0]+box[2]+1, :]
            
            file_name = str(cat_id) + '_' + _id + '.jpg'
            output_dir = os.path.join(dump_dir, output_set, cat_name)
            output_path = os.path.join(output_dir, file_name)
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            cv2.imwrite(output_path, im_cropped)
            shot_cnt += 1
            if shot_cnt == N_SHOT:
                break

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




In [9]:
novel_cls_inds = cls_ind_in_coco_set3
output_set = 'set3_random'

new_categories = []
for cat in data['categories']:
    if cat['id'] in novel_cls_inds:
        new_categories.append(cat)
MIN_SIZE = 64
cat_and_their_im_id = []
for i, c_d in enumerate(new_categories):
    cat_dict = {}
    cat_id = c_d['id']
    cat_name = c_d['name']
    cat_dict['cat_id'] = cat_id
    cat_dict['cat_name'] = cat_name
    cat_dict['im_ids'] = []
    for _key in im_summary.keys():  # for every im
        im_dict = im_summary[_key]
        if cat_id in im_dict['categories']:  # if this im has this class
            valid_an = []
            for an in im_dict['annotations']:  # for annos of this im
                if an['category_id'] == cat_id:
                    if an['bbox'][2] < MIN_SIZE or an['bbox'][3] < MIN_SIZE:
                        continue
                    else:
                        valid_an.append(an)
            if len(valid_an) > 0:  # only keep those id have more than one valid box of this class
                cat_dict['im_ids'].append(_key)           
    cat_and_their_im_id.append(cat_dict)
cat_and_their_im_id.sort(key=lambda s: len(s['im_ids']))

In [10]:
cwd = os.getcwd()
dump_dir = cwd + '/data/supports/'
N_SHOT = 15
MIN_SIZE = 64
RATIO = 2.

for cat_dict in tqdm(cat_and_their_im_id):
    new_dict = {}
    cat_id = cat_dict['cat_id']
    cat_name = cat_dict['cat_name']
    ids = cat_dict['im_ids']
    random.seed(0)
    random.shuffle(ids)
    shot_cnt = 0
    for _id in ids:
        valid_anns = []
        for an in im_summary[_id]['annotations']:
            if an['category_id'] == cat_id:
                box = an['bbox']
                if box[2] < MIN_SIZE or box[3] < MIN_SIZE:
                    continue
                if box[2]/box[3] > RATIO or box[3]/box[2] > RATIO:
                    continue
                valid_anns.append(an)
        if len(valid_anns) != 0:
            im_dict = im_summary[_id]['im_dict']
            im_path = cwd + '/data/coco/images/val2014/' + im_dict['file_name']
            im = cv2.imread(im_path)
            
            random.seed(0)
            chosen_ann = random.sample(valid_anns, k=1)[0]
            box = [int(i) for i in chosen_ann['bbox']]
            im_cropped = im[box[1]:box[1]+box[3]+1, box[0]:box[0]+box[2]+1, :]
            
            file_name = str(cat_id) + '_' + _id + '.jpg'
            output_dir = os.path.join(dump_dir, output_set, cat_name)
            output_path = os.path.join(output_dir, file_name)
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            cv2.imwrite(output_path, im_cropped)
            shot_cnt += 1
            if shot_cnt == N_SHOT:
                break

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




In [11]:
novel_cls_inds = cls_ind_in_coco_set4
output_set = 'set4_random'

new_categories = []
for cat in data['categories']:
    if cat['id'] in novel_cls_inds:
        new_categories.append(cat)
MIN_SIZE = 64
cat_and_their_im_id = []
for i, c_d in enumerate(new_categories):
    cat_dict = {}
    cat_id = c_d['id']
    cat_name = c_d['name']
    cat_dict['cat_id'] = cat_id
    cat_dict['cat_name'] = cat_name
    cat_dict['im_ids'] = []
    for _key in im_summary.keys():  # for every im
        im_dict = im_summary[_key]
        if cat_id in im_dict['categories']:  # if this im has this class
            valid_an = []
            for an in im_dict['annotations']:  # for annos of this im
                if an['category_id'] == cat_id:
                    if an['bbox'][2] < MIN_SIZE or an['bbox'][3] < MIN_SIZE:
                        continue
                    else:
                        valid_an.append(an)
            if len(valid_an) > 0:  # only keep those id have more than one valid box of this class
                cat_dict['im_ids'].append(_key)           
    cat_and_their_im_id.append(cat_dict)
cat_and_their_im_id.sort(key=lambda s: len(s['im_ids']))

In [12]:
cwd = os.getcwd()
dump_dir = cwd + '/data/supports/'
N_SHOT = 15
MIN_SIZE = 64
RATIO = 2.

for cat_dict in tqdm(cat_and_their_im_id):
    new_dict = {}
    cat_id = cat_dict['cat_id']
    cat_name = cat_dict['cat_name']
    ids = cat_dict['im_ids']
    random.seed(0)
    random.shuffle(ids)
    shot_cnt = 0
    for _id in ids:
        valid_anns = []
        for an in im_summary[_id]['annotations']:
            if an['category_id'] == cat_id:
                box = an['bbox']
                if box[2] < MIN_SIZE or box[3] < MIN_SIZE:
                    continue
                if box[2]/box[3] > RATIO or box[3]/box[2] > RATIO:
                    continue
                valid_anns.append(an)
        if len(valid_anns) != 0:
            im_dict = im_summary[_id]['im_dict']
            im_path = cwd + '/data/coco/images/val2014/' + im_dict['file_name']
            im = cv2.imread(im_path)
            
            random.seed(0)
            chosen_ann = random.sample(valid_anns, k=1)[0]
            box = [int(i) for i in chosen_ann['bbox']]
            im_cropped = im[box[1]:box[1]+box[3]+1, box[0]:box[0]+box[2]+1, :]
            
            file_name = str(cat_id) + '_' + _id + '.jpg'
            output_dir = os.path.join(dump_dir, output_set, cat_name)
            output_path = os.path.join(output_dir, file_name)
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            cv2.imwrite(output_path, im_cropped)
            shot_cnt += 1
            if shot_cnt == N_SHOT:
                break

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))


