In [1]:
from collections import defaultdict
from pprint import pprint
from pycocotools.coco import COCO
from mmengine import load, dump

In [12]:
coco_ins_train = COCO('../../annotations/instances_train.json')
coco_ins_test  = COCO('../../annotations/instances_test.json')
coco_rel_train = COCO('relations_train.json')
coco_rel_test  = COCO('relations_test.json')
# coco_rel_train = COCO('../../annotations/relations_train.json')
# coco_rel_test  = COCO('../../annotations/relations_test.json')

loading annotations into memory...
Done (t=0.77s)
creating index...
index created!
loading annotations into memory...
Done (t=0.04s)
creating index...
index created!
loading annotations into memory...
Done (t=0.65s)
creating index...
index created!
loading annotations into memory...
Done (t=0.03s)
creating index...
index created!


In [13]:
train_pairs = set()
for ann in coco_rel_train.anns.values():
    rel = ann['category_id']
    if ann['object_id'] == -1:
        obj = -1
    else:
        obj = coco_ins_train.anns[ann['object_id']]['category_id']
    train_pairs.add((rel, obj))
len(train_pairs)

242

In [14]:
test_pairs = set()
for ann in coco_rel_test.anns.values():
    rel = ann['category_id']
    if ann['object_id'] == -1:
        obj = -1
    else:
        obj = coco_ins_test.anns[ann['object_id']]['category_id']
    test_pairs.add((rel, obj))
len(test_pairs)

238

In [15]:
train_spatial_pairs = train_pairs - test_pairs
test_spatial_pairs = test_pairs - train_pairs
print(len(train_pairs), len(train_spatial_pairs))
print(len(test_pairs), len(test_spatial_pairs))

242 22
238 18


In [16]:
label_info = load('label_info.json')
pair2hoi = label_info['pair2hoi']
hoi_info = label_info['hoi_info']
ins_info = label_info['ins_info']
rel_info = label_info['rel_info']

In [17]:
for (rel, ins) in test_spatial_pairs:
    print((rel-1, ins_info[str(ins)]['valid_id']), end=' ')
    print(rel_info[str(rel)]['name'] + ' ' + ins_info[str(ins)]['name'])

(5, 58) look obj potted plant
(0, 19) hold obj cow
(5, 8) look obj boat
(17, 49) cut obj orange
(17, 73) cut obj book
(13, 18) carry obj sheep
(17, 52) cut obj hot dog
(0, 59) hold obj bed
(13, 2) carry obj car
(17, 63) cut obj laptop
(11, 61) lay instr toilet
(17, 36) cut obj skateboard
(13, 63) carry obj laptop
(5, 9) look obj traffic light
(13, 16) carry obj dog
(13, 1) carry obj bicycle
(13, 54) carry obj donut
(0, 60) hold obj dining table


## collect counts

In [34]:
rare_threshold = 10
num_hoi, num_rel, num_ins = len(hoi_info), len(rel_info), len(ins_info)
def collect_counts(coco_ins: COCO, coco_rel: COCO, subset: str):
    counts_hoi = defaultdict(lambda: 0)
    counts_ins = defaultdict(lambda: 0)
    counts_rel = defaultdict(lambda: 0)

    num_no_obj = 0

    for ann in coco_ins.anns.values():
        counts_ins[ann['category_id']] += 1

    for ann in coco_rel.anns.values():
        rel = ann['category_id']
        if ann['object_id'] == -1:
            num_no_obj += 1
            obj = 0
        else:
            obj = coco_ins.anns[ann['object_id']]['category_id']
        hoi = pair2hoi[str((rel, obj))]

        counts_rel[rel] += 1
        counts_hoi[hoi] += 1

    counts_ins[0] = num_no_obj

    exist_hoi, exist_rel, exist_ins = len(counts_hoi), len(counts_rel), len(counts_ins)
    rare_hoi = sum([count < rare_threshold for count in counts_hoi.values()])
    print(f'--- {subset} ---')
    print(f"exist_hoi: {exist_hoi:<3d}, no_exist_hoi: {num_hoi - exist_hoi}")
    print(f"exist_rel: {exist_rel:<3d}, no_exist_rel: {num_rel - exist_rel}")
    print(f"exist_ins: {exist_ins:<3d}, no_exist_ins: {num_ins - exist_ins}")
    print(f"rare_hoi : {rare_hoi + num_hoi - exist_hoi:<3d}, non-rare_hoi : {exist_hoi - rare_hoi}")
    print(f"num no-obj_hoi: {num_no_obj}, num valid_hoi: {len(coco_rel.anns) - num_no_obj}")
    print(f'----{"-"*len(subset)}----\n')
    return dict(counts_hoi), dict(counts_rel), dict(counts_ins)

In [35]:
counts_hoi_train, counts_rel_train, counts_ins_train = \
    collect_counts(coco_ins_train, coco_rel_train, 'train')
counts_hoi_test, counts_rel_test, counts_ins_test = \
    collect_counts(coco_ins_test, coco_rel_test, 'test')

--- train ---
exist_hoi: 242, no_exist_hoi: 18
exist_rel: 24 , no_exist_rel: 0
exist_ins: 81 , no_exist_ins: 0
rare_hoi : 113, non-rare_hoi : 147
num no-obj_hoi: 4534, num valid_hoi: 13817
-------------

--- test ---
exist_hoi: 238, no_exist_hoi: 22
exist_rel: 24 , no_exist_rel: 0
exist_ins: 81 , no_exist_ins: 0
rare_hoi : 117, non-rare_hoi : 143
num no-obj_hoi: 3959, num valid_hoi: 12402
------------



## update label_info

In [38]:
def add_count(category_info: dict, counts_train: dict, counts_test: dict):
    key_not_in_train, key_not_in_test = [], []
    for key, value in category_info.items():
        key = int(key)
        if key not in counts_train:
            key_not_in_train.append(key)
        if key not in counts_test:
            key_not_in_test.append(key)
        value['counts_train'] = counts_train.get(key, 0)
        value['counts_test'] = counts_test.get(key, 0)
    print(len(key_not_in_train), len(key_not_in_test))
    return key_not_in_train, key_not_in_test

In [39]:
key_not_in_train_hoi, key_not_in_test_hoi = add_count(hoi_info, counts_hoi_train, counts_hoi_test)
key_not_in_train_ins, key_not_in_test_ins = add_count(ins_info, counts_ins_train, counts_ins_test)
key_not_in_train_rel, key_not_in_test_rel = add_count(rel_info, counts_rel_train, counts_rel_test)

18 22
0 0
0 0


In [40]:
dump(label_info, 'label_info.json')