# Prepare zero-shot split 
Based on the paper: Bansal, Ankan, et al. "Zero-shot object detection." Proceedings of the European Conference on Computer Vision (ECCV). 2018.

In [1]:
import json

In [None]:
import numpy as np

In [3]:
with open('../datasets/coco/annotations/instances_train2017.json', 'r') as fin:
    coco_train_anno_all = json.load(fin)

In [None]:
with open('../datasets/coco/annotations/instances_train2017.json', 'r') as fin:
    coco_train_anno_seen = json.load(fin)

In [None]:
with open('../datasets/coco/annotations/instances_train2017.json', 'r') as fin:
    coco_train_anno_unseen = json.load(fin)

In [None]:
with open('../datasets/coco/annotations/instances_val2017.json', 'r') as fin:
    coco_val_anno_all = json.load(fin)

In [None]:
with open('../datasets/coco/annotations/instances_val2017.json', 'r') as fin:
    coco_val_anno_seen = json.load(fin)

In [None]:
with open('../datasets/coco/annotations/instances_val2017.json', 'r') as fin:
    coco_val_anno_unseen = json.load(fin)

In [None]:
with open('../datasets/coco/zero-shot/mscoco_seen_classes.json', 'r') as fin:
    labels_seen = json.load(fin)

In [None]:
with open('../datasets/coco/zero-shot/mscoco_unseen_classes.json', 'r') as fin:
    labels_unseen = json.load(fin)

In [None]:
len(labels_seen), len(labels_unseen)

In [22]:
labels_all = [item['name'] for item in coco_val_anno_all['categories']]

In [23]:
set(labels_seen) - set(labels_all)

{'background'}

In [24]:
set(labels_unseen) - set(labels_all)

set()

In [None]:
class_id_to_split = {}
class_name_to_split = {}
for item in coco_val_anno_all['categories']:
    if item['name'] in labels_seen:
        class_id_to_split[item['id']] = 'seen'
        class_name_to_split[item['name']] = 'seen'
    elif item['name'] in labels_unseen:
        class_id_to_split[item['id']] = 'unseen'
        class_name_to_split[item['name']] = 'unseen'


In [None]:
class_name_to_emb = {}
with open('../datasets/coco/zero-shot/glove.6B.300d.txt', 'r') as fin:
    for row in fin:
        row_tk = row.split()
        if row_tk[0] in class_name_to_split:
            class_name_to_emb[row_tk[0]] = [float(num) for num in row_tk[1:]]


In [12]:
len(class_name_to_emb), len(class_name_to_split)

(65, 65)

In [None]:
def filter_annotation(anno_dict, split_name_list):
    filtered_categories = []
    for item in anno_dict['categories']:
        if class_id_to_split.get(item['id']) in split_name_list:
            item['embedding'] = class_name_to_emb[item['name']]
            item['split'] = class_id_to_split.get(item['id'])
            filtered_categories.append(item)
    anno_dict['categories'] = filtered_categories
    
    filtered_images = []
    filtered_annotations = []
    useful_image_ids = set()
    for item in anno_dict['annotations']:
        if class_id_to_split.get(item['category_id']) in split_name_list:
            filtered_annotations.append(item)
            useful_image_ids.add(item['image_id'])
    for item in anno_dict['images']:
        if item['id'] in useful_image_ids:
            filtered_images.append(item)
    anno_dict['annotations'] = filtered_annotations
    anno_dict['images'] = filtered_images    

In [15]:
filter_annotation(coco_train_anno_seen, ['seen'])

In [16]:
filter_annotation(coco_train_anno_unseen, ['unseen'])

In [17]:
filter_annotation(coco_train_anno_all, ['seen', 'unseen'])

In [18]:
filter_annotation(coco_val_anno_seen, ['seen'])

In [19]:
filter_annotation(coco_val_anno_unseen, ['unseen'])

In [20]:
filter_annotation(coco_val_anno_all, ['seen', 'unseen'])

In [21]:
len(coco_val_anno_seen['categories']), len(coco_val_anno_unseen['categories']), len(coco_val_anno_all['categories'])

(48, 17, 65)

In [22]:
with open('../datasets/coco/zero-shot/instances_train2017_seen.json', 'w') as fout:
    json.dump(coco_train_anno_seen, fout)

In [23]:
with open('../datasets/coco/zero-shot/instances_train2017_unseen.json', 'w') as fout:
    json.dump(coco_train_anno_unseen, fout)

In [24]:
with open('../datasets/coco/zero-shot/instances_train2017_all.json', 'w') as fout:
    json.dump(coco_train_anno_all, fout)

In [25]:
with open('../datasets/coco/zero-shot/instances_val2017_seen.json', 'w') as fout:
    json.dump(coco_val_anno_seen, fout)

In [26]:
with open('../datasets/coco/zero-shot/instances_val2017_unseen.json', 'w') as fout:
    json.dump(coco_val_anno_unseen, fout)

In [27]:
with open('../datasets/coco/zero-shot/instances_val2017_all.json', 'w') as fout:
    json.dump(coco_val_anno_all, fout)