In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import copy
import torch
import h5py
import numpy as np
import seaborn as sns
from tqdm import tqdm
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
from maskrcnn_benchmark.data.datasets.visual_genome import load_graphs, load_info
from maskrcnn_benchmark.modeling.roi_heads.relation_head.utils_motifs import obj_edge_vectors
from maskrcnn_benchmark.data.datasets.visual_genome import VGDataset
%matplotlib inline

In [2]:
def convert_obj_class(obj_classes, rel):
    for index, (i_gt_class, i_relationships) in enumerate(zip(obj_classes, rel)):
        for index_rel in range(len(i_relationships)):
            i_relationships[index_rel, 0] = i_gt_class[i_relationships[index_rel, 0]]
            i_relationships[index_rel, 1] = i_gt_class[i_relationships[index_rel, 1]]
        rel[index] = i_relationships
    return rel

# data path
path = '/irip/lijiankai_2020/dataset/vg'
roidb_file = os.path.join(path, "VG-SGG-with-attri.h5")
dict_file = os.path.join(path, "VG-SGG-dicts-with-attri.json")
ind_to_classes, ind_to_predicates, ind_to_attributes = load_info(dict_file)
rel_categories = {index: pd for index, pd in enumerate(ind_to_predicates)}
obj_categories = {index: pd for index, pd in enumerate(ind_to_classes)}

_, train_gt_boxes, train_gt_classes, _, train_relationships = load_graphs(
    roidb_file, split="train", num_im=-1, num_val_im=5000,
    filter_empty_rels=True, filter_non_overlap=False
)

_, _, test_gt_classes, _, test_relations = load_graphs(
    roidb_file, split="test", num_im=-1, num_val_im=5000,
    filter_empty_rels=True, filter_non_overlap=False
)

# mapper entity labels
test_triplets = convert_obj_class(test_gt_classes, test_relations)
train_triplets = convert_obj_class(train_gt_classes, train_relationships)


train_triplets = np.concatenate(np.array(train_triplets), axis=0)
test_triplets = np.concatenate(np.array(test_triplets), axis=0)
train_triplet_set = set()
test_triplet_set = set()


for i in range(len(train_triplets)):
    item = train_triplets[i]
    train_triplet_set.add((item[0], item[1], item[2]))

for i in range(len(test_triplets)):
    item = test_triplets[i]
    test_triplet_set.add((item[0], item[1], item[2]))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 57723/57723 [00:01<00:00, 29683.24it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26446/26446 [00:00<00:00, 29970.33it/s]


In [3]:
zs_set = test_triplet_set.difference(train_triplet_set)
print(len(train_triplet_set), len(test_triplet_set), len(zs_set))

29283 17659 5426


In [4]:
# zeroshot file in https://github.com/KaihuaTang/Scene-Graph-Benchmark.pytorch
zs_prior_path = "/irip/lijiankai_2020/databackup114/Ready/maskrcnn_benchmark/data/datasets/evaluation/vg/zeroshot_triplet.pytorch"
zs_prior_data =  torch.load(zs_prior_path).long().numpy()

zs_prior_set = set()
for i in range(len(zs_prior_data)):
    item = zs_prior_data[i]
    zs_prior_set.add((item[0], item[1], item[2]))

inter_set = zs_prior_set.intersection(zs_set)
union_set = zs_prior_set.union(zs_set)
train_inter_set = zs_prior_set.intersection(train_triplet_set)
print(len(zs_prior_set), len(inter_set), len(union_set), len(train_inter_set))

5971 5426 5971 545


In [5]:
# zeroshot file in tcar test
zs_tcar_path = "/irip/lijiankai_2020/tcar/Scene-Graph-Benchmark.pytorch/zeroshot_triplet_new.pytorch"
zs_tcar_data =  torch.load(zs_tcar_path).long().numpy()
zs_tcar_set = set()
for i in range(len(zs_tcar_data)):
    item = zs_tcar_data[i]
    zs_tcar_set.add((item[0], item[1], item[2]))

inter_set = zs_tcar_set.intersection(zs_set)
union_set = zs_tcar_set.union(zs_set)
train_inter_set = zs_tcar_set.intersection(train_triplet_set)
print(len(zs_prior_set), len(inter_set), len(union_set), len(train_inter_set))

5971 5426 5426 0
