In [1]:
from collections import defaultdict
from pprint import pprint
from pycocotools.coco import COCO
from mmengine import load, dump

In [2]:
coco_ins_train = COCO('../../annotations/instances_train.json')
coco_rel_train = COCO('../../annotations/relations_train.json')
coco_ins_test = COCO('../../annotations/instances_test.json')
coco_rel_test = COCO('../../annotations/relations_test.json')

loading annotations into memory...
Done (t=0.62s)
creating index...
index created!
loading annotations into memory...
Done (t=0.28s)
creating index...
index created!
loading annotations into memory...
Done (t=0.24s)
creating index...
index created!
loading annotations into memory...
Done (t=0.05s)
creating index...
index created!


In [23]:
label_info = load('label_info.json')
pair2hoi = label_info['pair2hoi']
hoi_info = label_info['hoi_info']
ins_info = label_info['ins_info']
rel_info = label_info['rel_info']

## collect counts

In [24]:
rare_threshold = 10
num_hoi, num_rel, num_ins = len(hoi_info), len(rel_info), len(ins_info)
def collect_counts(coco_ins: COCO, coco_rel: COCO, subset: str):
    counts_hoi = defaultdict(lambda: 0)
    counts_ins = defaultdict(lambda: 0)
    counts_rel = defaultdict(lambda: 0)

    for ann in coco_ins.anns.values():
        counts_ins[ann['category_id']] += 1

    for ann in coco_rel.anns.values():
        rel = ann['category_id']
        obj = coco_ins.anns[ann['object_id']]['category_id']
        hoi = pair2hoi[str((rel, obj))]

        counts_rel[rel] += 1
        counts_hoi[hoi] += 1

    exist_hoi, exist_rel, exist_ins = len(counts_hoi), len(counts_rel), len(counts_ins)
    rare_hoi = sum([count < rare_threshold for count in counts_hoi.values()])
    print(f'--- {subset} ---')
    print(f"exist_hoi: {exist_hoi:<3d}, no_exist_hoi: {num_hoi - exist_hoi}")
    print(f"exist_rel: {exist_rel:<3d}, no_exist_rel: {num_rel - exist_rel}")
    print(f"exist_ins: {exist_ins:<3d}, no_exist_ins: {num_ins - exist_ins}")
    print(f"rare_hoi : {rare_hoi + num_hoi - exist_hoi:<3d}, non-rare_hoi : {exist_hoi - rare_hoi}")
    print(f'----{"-"*len(subset)}----\n')
    return dict(counts_hoi), dict(counts_rel), dict(counts_ins)

In [25]:
counts_hoi_train, counts_rel_train, counts_ins_train = \
    collect_counts(coco_ins_train, coco_rel_train, 'train')
counts_hoi_test, counts_rel_test, counts_ins_test = \
    collect_counts(coco_ins_test, coco_rel_test, 'test')

--- train ---
exist_hoi: 600, no_exist_hoi: 0
exist_rel: 117, no_exist_rel: 0
exist_ins: 80 , no_exist_ins: 0
rare_hoi : 138, non-rare_hoi : 462
-------------

--- test ---
exist_hoi: 600, no_exist_hoi: 0
exist_rel: 117, no_exist_rel: 0
exist_ins: 80 , no_exist_ins: 0
rare_hoi : 198, non-rare_hoi : 402
------------



## update label_info

In [8]:
def add_count(category_info: dict, counts_train: dict, counts_test: dict):
    for key, value in category_info.items():
        value['counts_train'] = counts_train[int(key)]
        value['counts_test'] = counts_test[int(key)]

In [27]:
add_count(hoi_info, counts_hoi_train, counts_hoi_test)
add_count(ins_info, counts_ins_train, counts_ins_test)
add_count(rel_info, counts_rel_train, counts_rel_test)

In [28]:
dump(label_info, 'label_info.json')