In [1]:
import json
import os
import re
import fnmatch
from pathlib import Path
import numpy as np
import cv2
import copy
import random
import xml.etree.ElementTree as ET
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm
from pycocotools.coco import COCO

In [2]:
cls_ind_in_coco_set1 = [1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72]
cls_ind_in_coco_set2 = [85, 43, 78, 87, 47, 11, 35, 53, 52, 46, 38, 51, 41, 58, 32, 82, 24, 37, 73, 22]
cls_ind_in_coco_set3 = [60, 65, 13, 48, 79, 77, 61, 54, 76, 34, 50, 74, 25, 86, 15, 31, 80, 14, 84, 28]
cls_ind_in_coco_set4 = [8, 10, 23, 27, 33, 36, 39, 40, 42, 49, 55, 56, 57, 59, 70, 75, 81, 88, 89, 90]

In [4]:
cwd = os.getcwd()
coco_json_path = cwd + '/data/coco/annotations/instances_val2014.json'
with open(coco_json_path, 'r') as f:
    data = json.load(f)

im_summary = {}
for i, im_d in tqdm(enumerate(data['images'])):
    im_summary[str(im_d['id'])] = {}
    im_summary[str(im_d['id'])]['im_dict'] = im_d
    im_summary[str(im_d['id'])]['annotations'] = []
    im_summary[str(im_d['id'])]['categories'] = []
for j, a_d in enumerate(data['annotations']):
    if a_d['iscrowd'] == 0:  # only keep non-crowd annotations
        im_id = a_d['image_id']
        cat_id = a_d['category_id']
        im_summary[str(im_id)]['annotations'].append(a_d)
        im_summary[str(im_id)]['categories'].append(cat_id)
for _key in im_summary.keys():
    im_summary[_key]['categories'] = list(set(im_summary[_key]['categories']))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




### set1

In [5]:
novel_cls_inds = cls_ind_in_coco_set1
new_categories = []
for cat in data['categories']:
    if cat['id'] in novel_cls_inds:
        new_categories.append(cat)

MIN_SIZE = 64
cat_and_their_im_id = []
for i, c_d in enumerate(new_categories):
    cat_dict = {}
    cat_id = c_d['id']
    cat_dict['cat_id'] = cat_id
    cat_dict['im_ids'] = []
    for _key in im_summary.keys():  # for every im
        im_dict = im_summary[_key]
        if cat_id in im_dict['categories']:  # if this im has this class
            valid_an = []
            for an in im_dict['annotations']:  # for annos of this im
                if an['category_id'] == cat_id:
                    if an['bbox'][2] < MIN_SIZE or an['bbox'][3] < MIN_SIZE:
                        continue
                    else:
                        valid_an.append(an)
            if len(valid_an) > 0:  # only keep those id have more than one valid box of this class
                cat_dict['im_ids'].append(_key)           
    cat_and_their_im_id.append(cat_dict)
cat_and_their_im_id.sort(key=lambda s: len(s['im_ids']))

In [6]:
cwd = os.getcwd()
dump_dir = cwd + '/data/coco/annotations/coco20_test/set1'
N_EPISODE = 10
IM_PER_CLASS = 10

for epi in tqdm(range(N_EPISODE)):
    ### generate query json
    id_have_selected = []
    episode = []
    for cat_dict in cat_and_their_im_id:
        new_dict = {}
        cat_id = cat_dict['cat_id']
        ids = cat_dict['im_ids']
        for _id in id_have_selected:  # remove im that has been picked
            if _id in ids:
                ids.remove(_id)
        random.seed(epi)
        selected_ids = random.sample(ids, k=IM_PER_CLASS)
        id_have_selected.extend(selected_ids)
        new_dict['cat_id'] = cat_id
        new_dict['im_ids'] = selected_ids
        episode.append(new_dict)
        
    new_data = {}
    new_data['info'] = data['info']
    new_data['images'] = []
    new_data['licenses'] = data['licenses']
    new_data['annotations'] = []
    new_data['categories'] = new_categories
    for cat_dict in episode:
        cat_id = cat_dict['cat_id']
        id_list = cat_dict['im_ids']
        for _id in id_list:
            new_data['images'].append(im_summary[_id]['im_dict'])
            for an in im_summary[_id]['annotations']:
                if an['category_id'] == cat_id:
                    new_data['annotations'].append(an)
    dump_path = os.path.join(dump_dir, 'ep' + str(epi) + '.json')
    with open(dump_path, 'w') as f:
        json.dump(new_data, f)
        

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




In [11]:
_COCO = COCO(dump_path)
print(len(_COCO.imgs))
print(len(_COCO.anns))
print(len(_COCO.cats))

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
200
475
20


### part2

In [5]:
cwd = os.getcwd()
novel_cls_inds = cls_ind_in_coco_set2
dump_dir = cwd + '/data/coco/annotations/coco20_test/set2'
N_EPISODE = 10
IM_PER_CLASS = 10

new_categories = []
for cat in data['categories']:
    if cat['id'] in novel_cls_inds:
        new_categories.append(cat)

MIN_SIZE = 64
cat_and_their_im_id = []
for i, c_d in enumerate(new_categories):
    cat_dict = {}
    cat_id = c_d['id']
    cat_dict['cat_id'] = cat_id
    cat_dict['im_ids'] = []
    for _key in im_summary.keys():  # for every im
        im_dict = im_summary[_key]
        if cat_id in im_dict['categories']:  # if this im has this class
            valid_an = []
            for an in im_dict['annotations']:  # for annos of this im
                if an['category_id'] == cat_id:
                    if an['bbox'][2] < MIN_SIZE or an['bbox'][3] < MIN_SIZE:
                        continue
                    else:
                        valid_an.append(an)
            if len(valid_an) > 0:  # only keep those id have more than one valid box of this class
                cat_dict['im_ids'].append(_key)           
    cat_and_their_im_id.append(cat_dict)
cat_and_their_im_id.sort(key=lambda s: len(s['im_ids']))

for epi in tqdm(range(N_EPISODE)):
    ### generate query json
    id_have_selected = []
    episode = []
    for cat_dict in cat_and_their_im_id:
        new_dict = {}
        cat_id = cat_dict['cat_id']
        ids = cat_dict['im_ids']
        for _id in id_have_selected:  # remove im that has been picked
            if _id in ids:
                ids.remove(_id)
        random.seed(epi)
        selected_ids = random.sample(ids, k=IM_PER_CLASS)
        id_have_selected.extend(selected_ids)
        new_dict['cat_id'] = cat_id
        new_dict['im_ids'] = selected_ids
        episode.append(new_dict)
        
    new_data = {}
    new_data['info'] = data['info']
    new_data['images'] = []
    new_data['licenses'] = data['licenses']
    new_data['annotations'] = []
    new_data['categories'] = new_categories
    for cat_dict in episode:
        cat_id = cat_dict['cat_id']
        id_list = cat_dict['im_ids']
        for _id in id_list:
            new_data['images'].append(im_summary[_id]['im_dict'])
            for an in im_summary[_id]['annotations']:
                if an['category_id'] == cat_id:
                    new_data['annotations'].append(an)
    dump_path = os.path.join(dump_dir, 'ep' + str(epi) + '.json')
    with open(dump_path, 'w') as f:
        json.dump(new_data, f)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




### part3

In [6]:
cwd = os.getcwd()
novel_cls_inds = cls_ind_in_coco_set3
dump_dir = cwd + '/data/coco/annotations/coco20_test/set3'
N_EPISODE = 10
IM_PER_CLASS = 10

new_categories = []
for cat in data['categories']:
    if cat['id'] in novel_cls_inds:
        new_categories.append(cat)

MIN_SIZE = 64
cat_and_their_im_id = []
for i, c_d in enumerate(new_categories):
    cat_dict = {}
    cat_id = c_d['id']
    cat_dict['cat_id'] = cat_id
    cat_dict['im_ids'] = []
    for _key in im_summary.keys():  # for every im
        im_dict = im_summary[_key]
        if cat_id in im_dict['categories']:  # if this im has this class
            valid_an = []
            for an in im_dict['annotations']:  # for annos of this im
                if an['category_id'] == cat_id:
                    if an['bbox'][2] < MIN_SIZE or an['bbox'][3] < MIN_SIZE:
                        continue
                    else:
                        valid_an.append(an)
            if len(valid_an) > 0:  # only keep those id have more than one valid box of this class
                cat_dict['im_ids'].append(_key)           
    cat_and_their_im_id.append(cat_dict)
cat_and_their_im_id.sort(key=lambda s: len(s['im_ids']))

for epi in tqdm(range(N_EPISODE)):
    ### generate query json
    id_have_selected = []
    episode = []
    for cat_dict in cat_and_their_im_id:
        new_dict = {}
        cat_id = cat_dict['cat_id']
        ids = cat_dict['im_ids']
        for _id in id_have_selected:  # remove im that has been picked
            if _id in ids:
                ids.remove(_id)
        random.seed(epi)
        selected_ids = random.sample(ids, k=IM_PER_CLASS)
        id_have_selected.extend(selected_ids)
        new_dict['cat_id'] = cat_id
        new_dict['im_ids'] = selected_ids
        episode.append(new_dict)
        
    new_data = {}
    new_data['info'] = data['info']
    new_data['images'] = []
    new_data['licenses'] = data['licenses']
    new_data['annotations'] = []
    new_data['categories'] = new_categories
    for cat_dict in episode:
        cat_id = cat_dict['cat_id']
        id_list = cat_dict['im_ids']
        for _id in id_list:
            new_data['images'].append(im_summary[_id]['im_dict'])
            for an in im_summary[_id]['annotations']:
                if an['category_id'] == cat_id:
                    new_data['annotations'].append(an)
    dump_path = os.path.join(dump_dir, 'ep' + str(epi) + '.json')
    with open(dump_path, 'w') as f:
        json.dump(new_data, f)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




### part4

In [7]:
cwd = os.getcwd()
novel_cls_inds = cls_ind_in_coco_set4
dump_dir = cwd + '/data/coco/annotations/coco20_test/set4'
N_EPISODE = 10
IM_PER_CLASS = 10

new_categories = []
for cat in data['categories']:
    if cat['id'] in novel_cls_inds:
        new_categories.append(cat)

MIN_SIZE = 64
cat_and_their_im_id = []
for i, c_d in enumerate(new_categories):
    cat_dict = {}
    cat_id = c_d['id']
    cat_dict['cat_id'] = cat_id
    cat_dict['im_ids'] = []
    for _key in im_summary.keys():  # for every im
        im_dict = im_summary[_key]
        if cat_id in im_dict['categories']:  # if this im has this class
            valid_an = []
            for an in im_dict['annotations']:  # for annos of this im
                if an['category_id'] == cat_id:
                    if an['bbox'][2] < MIN_SIZE or an['bbox'][3] < MIN_SIZE:
                        continue
                    else:
                        valid_an.append(an)
            if len(valid_an) > 0:  # only keep those id have more than one valid box of this class
                cat_dict['im_ids'].append(_key)           
    cat_and_their_im_id.append(cat_dict)
cat_and_their_im_id.sort(key=lambda s: len(s['im_ids']))

for epi in tqdm(range(N_EPISODE)):
    ### generate query json
    id_have_selected = []
    episode = []
    for cat_dict in cat_and_their_im_id:
        new_dict = {}
        cat_id = cat_dict['cat_id']
        ids = cat_dict['im_ids']
        for _id in id_have_selected:  # remove im that has been picked
            if _id in ids:
                ids.remove(_id)
        random.seed(epi)
        selected_ids = random.sample(ids, k=IM_PER_CLASS)
        id_have_selected.extend(selected_ids)
        new_dict['cat_id'] = cat_id
        new_dict['im_ids'] = selected_ids
        episode.append(new_dict)
        
    new_data = {}
    new_data['info'] = data['info']
    new_data['images'] = []
    new_data['licenses'] = data['licenses']
    new_data['annotations'] = []
    new_data['categories'] = new_categories
    for cat_dict in episode:
        cat_id = cat_dict['cat_id']
        id_list = cat_dict['im_ids']
        for _id in id_list:
            new_data['images'].append(im_summary[_id]['im_dict'])
            for an in im_summary[_id]['annotations']:
                if an['category_id'] == cat_id:
                    new_data['annotations'].append(an)
    dump_path = os.path.join(dump_dir, 'ep' + str(epi) + '.json')
    with open(dump_path, 'w') as f:
        json.dump(new_data, f)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




In [14]:
_COCO.getCatIds()

[1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72]

In [None]:
# print(coco_cat_id_to_class_ind)
# {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 9: 8, 16: 9, 17: 10, 18: 11, 19: 12, 20: 13, 21: 14, 44: 15, 62: 16, 63: 17, 64: 18, 67: 19, 72: 20}

In [17]:
cats = _COCO.loadCats(_COCO.getCatIds())
cats

[{'supercategory': 'person', 'id': 1, 'name': 'person'},
 {'supercategory': 'vehicle', 'id': 2, 'name': 'bicycle'},
 {'supercategory': 'vehicle', 'id': 3, 'name': 'car'},
 {'supercategory': 'vehicle', 'id': 4, 'name': 'motorcycle'},
 {'supercategory': 'vehicle', 'id': 5, 'name': 'airplane'},
 {'supercategory': 'vehicle', 'id': 6, 'name': 'bus'},
 {'supercategory': 'vehicle', 'id': 7, 'name': 'train'},
 {'supercategory': 'vehicle', 'id': 9, 'name': 'boat'},
 {'supercategory': 'animal', 'id': 16, 'name': 'bird'},
 {'supercategory': 'animal', 'id': 17, 'name': 'cat'},
 {'supercategory': 'animal', 'id': 18, 'name': 'dog'},
 {'supercategory': 'animal', 'id': 19, 'name': 'horse'},
 {'supercategory': 'animal', 'id': 20, 'name': 'sheep'},
 {'supercategory': 'animal', 'id': 21, 'name': 'cow'},
 {'supercategory': 'kitchen', 'id': 44, 'name': 'bottle'},
 {'supercategory': 'furniture', 'id': 62, 'name': 'chair'},
 {'supercategory': 'furniture', 'id': 63, 'name': 'couch'},
 {'supercategory': 'furni

In [47]:
cwd = os.getcwd()
dump_dir = cwd + '/data/coco/annotations/supports/'
if not os.path.exists(dump_dir):
    os.makedirs(dump_dir)
N_SHOT = 30
MIN_SIZE = 64
RATIO = 2.

for cat_dict in tqdm(cat_and_their_im_id):
    new_dict = {}
    cat_id = cat_dict['cat_id']
    ids = cat_dict['im_ids']
    random.seed(0)
    random.shuffle(ids)
    shot_cnt = 0
    for _id in ids:
        valid_anns = []
        for an in im_summary[_id]['annotations']:
            if an['category_id'] == cat_id:
                box = an['bbox']
                if box[2] < MIN_SIZE or box[3] < MIN_SIZE:
                    continue
                if box[2]/box[3] > RATIO or box[3]/box[2] > RATIO:
                    continue
                valid_anns.append(an)
        if len(valid_anns) != 0:
            im_dict = im_summary[_id]['im_dict']
            im_path = cwd + '/data/coco/images/val2014/' + im_dict['file_name']
            im = cv2.imread(im_path)
            
            random.seed(0)
            chosen_ann = random.sample(valid_anns, k=1)[0]
            box = [int(i) for i in chosen_ann['bbox']]
            im_cropped = im[box[1]:box[1]+box[3]+1, box[0]:box[0]+box[2]+1, :]
            file_name = dump_dir + str(cat_id) + '_' + _id + '.jpg'
            cv2.imwrite(file_name, im_cropped)
            shot_cnt += 1
            if shot_cnt == N_SHOT:
                break


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


