In [1]:
from collections import OrderedDict, defaultdict
from pprint import pprint

import numpy as np
from mmengine import load, dump

In [2]:
# 由GEN-VLKT提供
from hico_text_label import hico_hoi_text_label, hico_obj_text_label  # no-inter 排在中间

## instance

In [3]:
ins_labels = [
    'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 
    'traffic light', 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 
    'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 
    'backpack', 'umbrella', '', '', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 
    'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 
    'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 
    'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 
    'dining table', '', '', 'toilet', '', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', '', 'book', 
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
ann_ins_ids = (  # COCO 格式
    1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 
    23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 
    48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 
    73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90)
print(f'len_ins: {len(ins_labels)}, len_ins_valid: {len(ann_ins_ids)}')

len_ins: 90, len_ins_valid: 80


In [4]:
ins_info = {
    ann_id: {
        'name': ins_labels[ann_id-1],
        'text': text,

        'valid_id': valid_id,
        'ann_id'  : ann_id,
    }
    for ann_id, (valid_id, text) in zip(ann_ins_ids, hico_obj_text_label)
}
ins_info[1]

{'name': 'person', 'text': 'a photo of a person', 'valid_id': 0, 'ann_id': 1}

## relation

In [5]:
rel_labels = [
    'adjust', 'assemble', 'block', 'blow', 'board', 'break', 'brush_with', 'buy', 'carry', 
    'catch', 'chase', 'check', 'clean', 'control', 'cook', 'cut', 'cut_with', 'direct', 
    'drag', 'dribble', 'drink_with', 'drive', 'dry', 'eat', 'eat_at', 'exit', 'feed', 
    'fill', 'flip', 'flush', 'fly', 'greet', 'grind', 'groom', 'herd', 'hit', 'hold', 
    'hop_on', 'hose', 'hug', 'hunt', 'inspect', 'install', 'jump', 'kick', 'kiss', 'lasso', 
    'launch', 'lick', 'lie_on', 'lift', 'light', 'load', 'lose', 'make', 'milk', 'move', 

    'no_interaction',  # 需要在valid_id中将其挪到最后

    'open', 'operate', 'pack', 'paint', 'park', 'pay', 'peel', 'pet', 
    'pick', 'pick_up', 'point', 'pour', 'pull', 'push', 'race', 'read', 'release', 'repair', 
    'ride', 'row', 'run', 'sail', 'scratch', 'serve', 'set', 'shear', 'sign', 'sip', 
    'sit_at', 'sit_on', 'slide', 'smell', 'spin', 'squeeze', 'stab', 'stand_on', 
    'stand_under', 'stick', 'stir', 'stop_at', 'straddle', 'swing', 'tag', 'talk_on', 
    'teach', 'text_on', 'throw', 'tie', 'toast', 'train', 'turn', 'type_on', 'walk', 
    'wash', 'watch', 'wave', 'wear', 'wield', 'zip'
]
print(len(rel_labels))

117


In [6]:
rel_info = {
    valid_id+1: {
        'name': rel,
        'valid_id': valid_id,
        'ann_id'  : valid_id+1,
    }
    for valid_id, rel in enumerate(rel_labels)
}
rel_info[58]

{'name': 'no_interaction', 'valid_id': 57, 'ann_id': 58}

把no-interaction移到最后，仅改变valid_id

原本的排序非常影响loss计算和后处理

In [7]:
bg_rel_valid_id = 57  # no-interaction

# 处理 rel_info
rel_valid_id_map = [len(rel_info)] * len(rel_info)
for info in rel_info.values():
    new_valid_id = valid_id = info['valid_id']  # bg之前的不用管
    if valid_id > bg_rel_valid_id:
        info['valid_id'] = new_valid_id = valid_id - 1  # bg之后的往前移一位
    elif valid_id == bg_rel_valid_id:
        info['valid_id'] = new_valid_id = len(rel_info) - 1  # bg放到最后
    rel_valid_id_map[new_valid_id] = valid_id

# 处理 bg_rel_info
new_bg_rel_valid_id = len(rel_info) - 1

In [62]:
np.save('valid_id_new2old_rel.npy', rel_valid_id_map)

In [None]:
rel_categories = [dict(id=int(id), name=info['name']) for id, info in rel_info.items()]
rel_categories.sort(key=lambda cat: rel_info[cat['id']]['valid_id'])
len(rel_categories)

117

In [None]:
dump(rel_categories, 'rel_categories.json')

## hoi

把no-interaction移到最后，并按[obj_name, rel_name]的字典序重新排列，仅改变valid_id

In [13]:
raw_keys = list(hico_hoi_text_label.keys())
new_keys = sorted(
    raw_keys, 
    key=lambda x: [
        x[0] == bg_rel_valid_id,  # 把no-interaction移到最后
        ins_labels[ann_ins_ids[x[1]]-1],  # 按ins_name的字典序排列
        rel_labels[x[0]]  # 按rel_name的字典序排列
    ])
hoi_valid_id_map = [raw_keys.index(key) for key in new_keys]
new_hico_hoi_text_label = OrderedDict()
for key in new_keys:
    raw_valid_rel, raw_valid_obj = key
    new_ann_rel = raw_valid_rel + 1
    new_ann_ins = ann_ins_ids[raw_valid_obj]
    new_key = (new_ann_rel, new_ann_ins)
    new_hico_hoi_text_label[new_key] = hico_hoi_text_label[key]

len(new_hico_hoi_text_label)

600

In [61]:
np.save('valid_id_new2old_hoi.npy', hoi_valid_id_map)

In [16]:
hoi_info = {}
for valid_hoi, ((ann_rel, ann_obj), text) in enumerate(new_hico_hoi_text_label.items()):
    ann_hoi = hoi_valid_id_map[valid_hoi] + 1
    valid_rel = rel_info[ann_rel]['valid_id']
    valid_obj = ins_info[ann_obj]['valid_id']

    hoi_info[ann_hoi] = {
        'name': f'{rel_info[ann_rel]["name"]}-{ins_info[ann_obj]["name"]}',
        'text': text,

        'valid_id'      : valid_hoi,
        'ann_id'        : ann_hoi,

        'ann_hoi_pair'  : (ann_rel, ann_obj),
        'valid_hoi_pair': (valid_rel, valid_obj),

        'valid_rel_id': valid_rel,
        'ann_rel_id'  : ann_rel,

        'valid_obj_id': valid_obj,
        'ann_obj_id'  : ann_obj,
    }
hoi_info[1]

{'name': 'board-airplane',
 'text': 'a photo of a person boarding an airplane',
 'valid_id': 0,
 'ann_id': 1,
 'ann_hoi_pair': (5, 5),
 'valid_hoi_pair': (4, 4),
 'valid_rel_id': 4,
 'ann_rel_id': 5,
 'valid_obj_id': 4,
 'ann_obj_id': 5}

In [17]:
new_hoi_info = OrderedDict()
for key in sorted(hoi_info.keys()):
    new_hoi_info[key] = hoi_info[key]
hoi_info = new_hoi_info

## 部分汇总

In [18]:
label_info = {
    'bg_rel' : {'ann_id': 58, 'valid_id': new_bg_rel_valid_id},
    'ins_info': ins_info,
    'rel_info': rel_info,
    'hoi_info': hoi_info,
}

## hoi到rel、obj的映射

- 用于GEN-VLKT等直接用hoi标签做linear的模型

In [19]:
hoi_info = label_info['hoi_info']
valid_id_hoi2rel, valid_id_hoi2obj = [None]*len(hoi_info), [None]*len(hoi_info)
for info in hoi_info.values():
    valid_id_hoi2rel[info['valid_id']] = info['valid_rel_id']
    valid_id_hoi2obj[info['valid_id']] = info['valid_obj_id']
valid_id_hoi2rel = np.array(valid_id_hoi2rel, dtype=int)
valid_id_hoi2obj = np.array(valid_id_hoi2obj, dtype=int)

In [41]:
np.save('valid_id_hoi2rel.npy', valid_id_hoi2rel)
np.save('valid_id_hoi2obj.npy', valid_id_hoi2obj)

## 组合标签和组成标签的对应

In [23]:
def collect_utils(hoi_info, rel_info, ins_info):
    """只看ann_id"""
    pair2hoi = {}
    hois_per_obj = {i: [] for i in sorted(ins_info.keys())}
    hois_per_rel = {i: [] for i in sorted(rel_info.keys())}
    objs_per_rel = {i: [] for i in sorted(rel_info.keys())}
    rels_per_obj = {i: [] for i in sorted(ins_info.keys())}

    for key in sorted(list(hoi_info.keys())):
        info = hoi_info[key]
        pair2hoi[str(info['ann_hoi_pair'])] = key

        ann_obj_id = info['ann_obj_id']
        ann_rel_id = info['ann_rel_id']
        hois_per_obj[ann_obj_id].append(key)
        hois_per_rel[ann_rel_id].append(key)
        objs_per_rel[ann_rel_id].append(ann_obj_id)
        rels_per_obj[ann_obj_id].append(ann_rel_id)

    def sort_value(dict_: dict):
        for key, value in dict_.items():
            dict_[key] = sorted(value)

    sort_value(hois_per_obj)
    sort_value(hois_per_rel)
    sort_value(objs_per_rel)
    sort_value(rels_per_obj)

    return pair2hoi, hois_per_obj, hois_per_rel, objs_per_rel, rels_per_obj

In [24]:
pair2hoi, hois_per_obj, hois_per_rel, objs_per_rel, rels_per_obj = \
    collect_utils(hoi_info, rel_info, ins_info)
label_info.update({
    'pair2hoi': pair2hoi,
    'hois_per_obj': hois_per_obj,
    'hois_per_rel': hois_per_rel,
    'objs_per_rel': objs_per_rel,
    'rels_per_obj': rels_per_obj,
})

In [56]:
dump(label_info, 'label_info.json')

In [25]:
# 只看valid_id
valid_id_triplet2label = {
    (0, info['valid_rel_id'], info['valid_obj_id']) : info['valid_id']
    for info in label_info['hoi_info'].values()
}
valid_id_triplets = sorted(
    list(valid_id_triplet2label.keys()), 
    key=lambda x: valid_id_triplet2label[x])

In [46]:
dump(valid_id_triplet2label, 'valid_id_triplet2label.pkl')  # 0-base
np.save('valid_id_triplets.npy', valid_id_triplets)  # 0-base  triplet_parse

## 用于排除错误组合的label_matrix

In [4]:
# label_info = load('../../configs/label_info.json')
label_info = load('label_info.json')
hoi_info = label_info['hoi_info']
ins_info = label_info['ins_info']
rel_info = label_info['rel_info']

In [5]:
correct_sg_mat = np.zeros((len(ins_info), len(rel_info)), dtype=bool)
for info in hoi_info.values():
    correct_sg_mat[info['valid_obj_id'], info['valid_rel_id']] = True

In [6]:
np.save('correct_sg_mat.npy', correct_sg_mat)

In [51]:
correct_sg_mat = np.zeros((len(ins_info), len(rel_info), len(ins_info)), dtype=bool)
for info in hoi_info.values():
    correct_sg_mat[0, info['valid_rel_id'], info['valid_obj_id']] = True

In [52]:
np.save('correct_sg_mat_triplet.npy', correct_sg_mat)

## plugins

### zero-shot

In [42]:
label_info = load('../../configs/label_info.json')
# label_info = load('label_info.json')
hoi_info = label_info['hoi_info']
ins_info = label_info['ins_info']
rel_info = label_info['rel_info']

In [37]:
hico_unseen_index = {  # ann id
    "rare_first": [  # 120
        510, 280, 281, 403, 505, 287, 500, 499, 290, 486, 304, 312, 326,
        440, 352, 359,  67, 428, 380, 419,  71, 417, 390,  91, 396,  77,
        398,  85, 136, 263, 402, 593, 561, 587, 549, 594, 527, 182, 258,
        540, 536, 261, 597, 346, 190, 206, 207, 430, 180, 351, 406, 523,
        450, 262, 256, 547, 548,  45,  23, 335, 600, 240, 316, 318, 230,
        159, 196, 239, 365, 223, 282, 150, 400,  84, 128, 255, 399, 404,
        556, 553, 521, 532, 441, 437, 483, 275,   9, 189, 217, 598,  78,
        408, 557, 470, 475, 108, 391, 411,  28, 382, 464, 100, 185, 101,
        293, 518,  81, 334,  63, 355, 105,  56,  51, 199, 169, 392, 193,
        596, 137, 582],
    "non_rare_first": [  # 120
         39,  42,  21,  19, 246,  12,  20, 155, 460,  43, 156, 140,  61,
        462, 578, 154, 583,  90, 142, 577,  76, 213, 473,  62, 458, 147,
        209,  95, 472, 132, 249, 545, 516, 567, 371, 482, 227, 251, 471,
        324, 170, 481, 480, 231, 386,  74, 160, 191, 378, 177, 250, 372,
        285,  49, 584,  54, 163, 141, 186, 107, 295,  57, 321, 153, 375,
        339,  30, 595, 347, 457, 590,  46,  24,  68, 479, 224, 494, 229,
        241, 216,  92, 116, 338, 560,   8, 219, 519, 298, 192, 267, 305,
          7, 573, 530, 313,  10, 309, 418, 198, 194, 164, 456,  26,  55,
        576, 447, 388, 484, 535, 341, 509, 111, 330, 247, 174, 507, 384,
         94, 517,  65],
    "unseen_object": [  # 100
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124,
        125, 126, 127, 128, 129, 225, 226, 227, 228, 229, 230, 231, 232,
        291, 292, 293, 294, 295, 314, 315, 316, 317, 318, 319, 320, 321,
        322, 323, 324, 325, 337, 338, 339, 340, 341, 342, 419, 420, 421,
        422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434,
        454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466,
        467, 468, 469, 470, 471, 472, 473, 474, 534, 535, 536, 537, 538,
        559, 560, 561, 562, 596, 597, 598, 599, 600],
    "unseen_verb": [  # 84
          5,   7,  13,  16,  19,  26,  35,  39,  41,  50,  59,  61,  69,
         70,  73,  74,  78,  83,  97,  98, 105, 114, 117, 119, 123, 130,
        140, 148, 151, 154, 166, 167, 173, 176, 177, 182, 191, 203, 211,
        213, 220, 228, 229, 234, 236, 244, 299, 314, 316, 321, 327, 337,
        343, 346, 355, 373, 402, 405, 410, 432, 437, 460, 467, 471, 473,
        480, 482, 489, 492, 495, 499, 505, 520, 524, 536, 537, 542, 545,
        563, 566, 570, 573, 592, 596],
}


In [40]:
pprint(
    sorted(list({hoi_info[str(i)]['ann_obj_id'] for i in 
                 hico_unseen_index['unseen_object']})), 
    compact=True)

[18, 22, 24, 34, 35, 39, 41, 43, 59, 82, 85, 90]


In [41]:
pprint(
    sorted(list({hoi_info[str(i)]['ann_rel_id'] for i in 
                 hico_unseen_index['unseen_verb']})), 
    compact=True)

[3, 20, 27, 35, 39, 42, 43, 57, 63, 77, 80, 81, 85, 92, 93, 98, 100, 101, 107,
 115]


#### correct mat

In [44]:
correct_sg_mat = np.load('../../configs/correct_sg_mat.npy')

In [45]:
rare_first_mat = np.copy(correct_sg_mat)
for ann_hoi in hico_unseen_index['rare_first']:
    info = hoi_info.get(ann_hoi, hoi_info[str(ann_hoi)])
    rare_first_mat[info['valid_obj_id'], info['valid_rel_id']] = False
(correct_sg_mat != rare_first_mat).sum()

120

In [53]:
np.save('correct_sg_mat_rare_first.npy', rare_first_mat)

In [47]:
non_rare_first_mat = np.copy(correct_sg_mat)
for ann_hoi in hico_unseen_index['non_rare_first']:
    info = hoi_info.get(ann_hoi, hoi_info[str(ann_hoi)])
    non_rare_first_mat[info['valid_obj_id'], info['valid_rel_id']] = False
(correct_sg_mat != non_rare_first_mat).sum()

120

In [54]:
np.save('correct_sg_mat_non_rare_first.npy', non_rare_first_mat)

In [49]:
unseen_object_mat = np.copy(correct_sg_mat)
for ann_hoi in hico_unseen_index['unseen_object']:
    info = hoi_info.get(ann_hoi, hoi_info[str(ann_hoi)])
    unseen_object_mat[info['valid_obj_id'], info['valid_rel_id']] = False
(correct_sg_mat != unseen_object_mat).sum()

100

In [55]:
np.save('correct_sg_mat_unseen_object.npy', unseen_object_mat)

In [51]:
unseen_verb_mat = np.copy(correct_sg_mat)
for ann_hoi in hico_unseen_index['unseen_verb']:
    info = hoi_info.get(ann_hoi, hoi_info[str(ann_hoi)])
    unseen_verb_mat[info['valid_obj_id'], info['valid_rel_id']] = False
(correct_sg_mat != unseen_verb_mat).sum()

84

In [56]:
np.save('correct_sg_mat_unseen_verb.npy', unseen_verb_mat)

### pasta

In [None]:
pastas_train = load('../../annotations/pastas_train.json')

In [None]:
pastas_train['annotations'][0]

{'id': 1, 'category_id': '6-0', 'rel_ann_id': 1, 'image_id': 1}

In [None]:
part_labels = [
    'right_foot', 'right_leg', 'left_leg', 'left_foot', 
    'hip', 'head', 
    'right_hand', 'right_arm', 'left_arm', 'left_hand'
]
len(part_labels)

10

In [None]:
part_info = {
    valid_id+1: {
        'name': rel,
        'valid_id': valid_id,
        'ann_id'  : valid_id, 
    }
    for valid_id, rel in enumerate(part_labels)
}
part_info[5]

{'name': 'hip', 'valid_id': 4, 'ann_id': 4}

In [None]:
state_info = defaultdict(list)
with open('../../configs/Part_State_76.txt') as f:
    for line in f.readlines():
        part, state = line.strip().split(': ')
        state_info[part].append(state)
state_info = dict(state_info)
pprint(state_info, compact=True)

{'arm': ['shoulder carry', 'be close to', 'hug', 'swing', 'no_interaction'],
 'foot': ['stand on', 'tread on', 'walk with', 'walk to', 'run with', 'run to',
          'dribble', 'kick', 'jump down', 'jump with', 'walk away',
          'no_interaction'],
 'hand': ['hold', 'carry', 'reach for', 'touch', 'put on', 'twist', 'wear',
          'throw', 'throw out', 'write on', 'point with', 'point to',
          'use something to point to', 'press', 'squeeze', 'scratch', 'pinch',
          'gesture to', 'push', 'pull', 'pull with something', 'wash',
          'wash with something', 'hold in both hands', 'lift',
          'raise(over head)', 'feed', 'cut with something',
          'catch with something', 'pour into', 'no_interaction'],
 'head': ['eat', 'inspect', 'talk with something', 'talk to', 'be close with',
          'kiss', 'put something over head', 'lick', 'blow', 'drink with',
          'smell', 'wear', 'no_interaction'],
 'hip': ['sit on', 'sit in', 'sit beside', 'be close with', '

In [None]:
pasta_categories = []
for i, part in enumerate(part_labels):
    part_split = part.split('_')
    part = ' '.join(part_split)
    for j, state in enumerate(state_info[part_split[-1]]):
        pasta_categories.append(dict(
            id=f'{i}-{j}',
            part_name=part,
            state_name=state,
            part_id=i,
            state_id=j,
            name=f"{part}-{state}"
        ))
len(pasta_categories)

134

In [None]:
pasta_categories[:2]

[{'id': '0-0',
  'part_name': 'right foot',
  'state_name': 'stand on',
  'part_id': 0,
  'state_id': 0,
  'name': 'right foot-stand on'},
 {'id': '0-1',
  'part_name': 'right foot',
  'state_name': 'tread on',
  'part_id': 0,
  'state_id': 1,
  'name': 'right foot-tread on'}]

In [None]:
dump(pasta_categories, 'pasta_categories.json')

In [None]:
pasta_info = {item['id']: item for item in pasta_categories}

In [None]:
pasta_info['0-0']

{'id': '0-0',
 'part_name': 'right foot',
 'state_name': 'stand on',
 'part_id': 0,
 'state_id': 0,
 'name': 'right foot-stand on'}

In [None]:
# label_info = load('../../configs/label_info.json')
label_info = load('label_info.json')
label_info.update({
    'part_info': part_info,
    'state_info': state_info,
    'pasta_info': pasta_info
})
dump(label_info, 'label_info.json')