In [34]:
import os
import pickle as pkl

# data_name = 'YAGO3-10'
# data_name = 'WN18RR_new'
data_name = 'FB237_new'
data_dir = os.path.join('datasets', 'data', data_name)

train_data_dir = os.path.join(data_dir, 'train.pickle')
valid_data_dir = os.path.join(data_dir, 'valid.pickle')
test_data_dir = os.path.join(data_dir, 'test.pickle')

with open(train_data_dir, "rb") as in_file:
        train_examples = pkl.load(in_file)

with open(valid_data_dir, "rb") as in_file:
        valid_examples = pkl.load(in_file)

with open(test_data_dir, "rb") as in_file:
        test_examples = pkl.load(in_file)


print('Total test num: ', len(test_examples))

Total test num:  20466


In [31]:
import numpy as np
class graph:
    def __init__(self, data, max_hop) -> None:
        self.data = data
        self.max_hop = max_hop
        self.nodes = set()
        self.rels = set()

        # {ent1: {relA: [ent10, ent11...]}}
        self.paths = {} 

        # {ent1: [ent10, ent11, ...]}, i.e. multi-relations are not considered
        self.connects = {}
        
        self.relNum = 1000

        self.addElement()
        self.nodeNum = len(self.nodes)
        self.relNum = len(self.rels)

        self.addPaths()
        self.addConnectivity()
        self.getCompositionPattern(max_hop)
      
    
    def addElement(self):
        for tri in self.data:         
            h,r,t = tri
            if h not in self.nodes:
               self.nodes.add(h)
            if t not in self.nodes:
               self.nodes.add(t)
            if r not in self.rels:
               self.rels.add(r)

    def addPaths(self):
        for tri in self.data:
            h,r,t = tri
            if h not in self.paths:
                self.paths[h] = {r:[t]}
            else:
                if r not in self.paths[h]:
                    self.paths[h][r] = [t]
                else:
                    self.paths[h][r].append(t)
            rev_r = r + self.relNum
            if t not in self.paths:
                self.paths[t] = {rev_r:[h]}
            else:
                if rev_r not in self.paths[t]:
                    self.paths[t][rev_r] = [h]
                else:
                    self.paths[t][rev_r].append(h)

    def addConnectivity(self):
        for tri in self.data:
            h,r,t = tri
            if h not in self.connects:
                self.connects[h] = [t]
            else:
                self.connects[h].append(t)
            if t not in self.connects:
                self.connects[t] = [h]
            else:
                self.connects[t].append(h)
            



    def findConnectivity(self, head, tail, hop):
        if hop > 0:
            candidates = self.connects[head]
            if tail in candidates:
                return True
            for ent in candidates:
                if self.findConnectivity(ent, tail, hop-1):
                    return True
        return False
    
    def getCompositionPattern(self, hop):
        self.compos_pattern_full = {}
        # self.compos_pattern = {}
        # {rel1: {[rel2,rel3]: [[full path1], [full path2]..], [rel2, rel4],...]: [[full path1]] }, rel2: {}}

        for tri in self.data:
            h0,r0,t0 = tri
            if r0 not in self.compos_pattern_full:
                self.compos_pattern_full[r0] = {}
            collect_rel_paths = []
            tmp_rel_path = ''
            tmp_ent_in_path = []
            tmp_full_paths = []
            self.compositionPath(collect_rel_paths, tmp_rel_path, tmp_ent_in_path, tmp_full_paths, h0, t0, hop)

            for rel_path, full_path in zip(collect_rel_paths, tmp_full_paths):
                if rel_path not in self.compos_pattern_full[r0]:
                    self.compos_pattern_full[r0][rel_path] = [full_path]
                else:
                    self.compos_pattern_full[r0][rel_path].append(full_path)
        

    def compositionPath(self, collect_rel_paths, tmp_rel_path, tmp_ent_in_path, tmp_full_paths, head, tail, hop):
        # print(tmp_ent_in_path)
        # print(type(tmp_ent_in_path))
        if hop < 1:
            return 
        tmp_ent_in_path.append(head)
        for rel in self.paths[head]:
            for ent in self.paths[head][rel]:
                if ent == tail:   
                    # rel_path = tmp_rel_path.copy()
                    # rel_path.append(rel)
                    rel_path = tmp_rel_path +'-'+str(rel)
                    collect_rel_paths.append(rel_path)
                    rel_list = [int(i) for i in rel_path.split('-')[1:]]
                    tmp_full_path = [val for pair in zip(tmp_ent_in_path, rel_list) for val in pair]
                    tmp_full_path.append(tail)
                    tmp_full_paths.append(tmp_full_path)
                else:    
                    if ent not in tmp_ent_in_path:
                    # new_tmp_rel_path = tmp_rel_path.copy()
                    # new_tmp_rel_path.append(rel)
                        new_tmp_rel_path = tmp_rel_path +'-'+str(rel)
                        self.compositionPath(collect_rel_paths, new_tmp_rel_path, tmp_ent_in_path, tmp_full_paths, ent, tail, hop-1)
                    
            
        return
    
    def searchComposPattern(self, query, hop = 6, threshold = 5):
        h,r,t = query
        try:
            rel_paths = [pstr for pstr in self.compos_pattern_full[r].keys() if len(self.compos_pattern_full[r][pstr])>=threshold]
        except:
            print(len(self.compos_pattern_full[r]))
        for rp in rel_paths:
            rel_list = [int(i) for i in rp.split('-')[1:]]
            if len(rel_list)<=hop and self.checkComposPatthern(rel_list, h, t, 0):
                return True
        return False

            
    def checkComposPatthern(self, rel_path, head, tail, i):
        if i >= len(rel_path):
            return False
        if rel_path[i] not in self.paths[head]:
            return False
        if tail in self.paths[head][rel_path[i]]:
            return True
        for ent in self.paths[head][rel_path[i]]:
            return self.checkComposPatthern(rel_path, ent, tail, i+1)
        
    def searchConnectivity(self, query, threshold = 5):
        h,r,t = query
        return self.findConnectivity(h,t,threshold)
    
    def getComposPattern(self):
        cp = {}
        for rel in self.compos_pattern_full:
            cp[rel] = {}
            for rel_path in self.compos_pattern_full[rel]:
                cp[rel][rel_path] = len(self.compos_pattern_full[rel][rel_path])
        return cp
    
    def getComposPatternFull(self):
        return self.compos_pattern_full
        

# small_train_examples = [
# [1,1,2],
# [5,2,2],
# [5,3,3],
# [1,4,3],
# [1,3,5],
# [1,1,6],
# [5,2,6]
# ]
# g = graph(small_train_examples, max_hop=2)
g = graph(train_examples, max_hop=2)

# test = test_examples[0]
# g.searchComposPattern(test)


In [35]:
# load graph
graph_name = 'FB237_new_graph_full.pkl'
with open(os.path.join(data_dir,  graph_name), "rb") as in_file:
        g = pkl.load(in_file)

# compos_pattern_all = g.getComposPattern()
# compos_pattern_all = g.compos_pattern

In [36]:
ratio_threshold = 0.0003
pattern_freq_threshold = int(ratio_threshold * len(train_examples))
compos_pattern_all = g.getComposPatternFull()
# pattern_freq_threshold = 3
# pattern_freq_threshold = 24
print('pattern_freq_threshold: ', pattern_freq_threshold)

compos_rel_set = set()
compos_tri = {}
compos_num = 0
for r3 in compos_pattern_all:
    paths = compos_pattern_all[r3]
    for path in paths:
        if len(path.split('-'))==3:
            # compos_rel_set.add(r3)
            # r1, r2 = [int(r) for r in path.split('-')[1:3]]
            # compos_rel_set.add(r1)
            # compos_rel_set.add(r2)
            if r3 not in compos_tri:
                compos_tri[r3] = {}
            # if path not in compos_tri[r3]:
            cp_tri=set()
            for p in paths[path]:
                a, r1, b, r2, c = p
                tri1 = (a, r1, b)
                tri2 = (b, r2, c)
                tri3 = (a, r3, c)
                cp_tri.add(tri1)
                cp_tri.add(tri2)
                cp_tri.add(tri3)
            compos_tri[r3][path] = len(cp_tri)
                
            # compos_num += 1

total_cp_tri = 0
rel_tri = {}
for r3 in compos_tri:
    for path in compos_tri[r3]:
        if compos_tri[r3][path]>=pattern_freq_threshold:
            compos_num += 1
            total_cp_tri += compos_tri[r3][path]
            compos_rel_set.add(r3)
            r1, r2 = [int(r) for r in path.split('-')[1:3]]
            compos_rel_set.add(r1)
            compos_rel_set.add(r2)
            if r3 not in rel_tri:
                rel_tri[r3] = compos_tri[r3][path]

print('compos rel #: ', len(compos_rel_set)//2)
print('compos_num: ', compos_num)
print('avg. compos tri: ', total_cp_tri/compos_num)

cp_rel_tri = 0
for tri in train_examples:
    h,r,t = tri
    if r in compos_rel_set:
        cp_rel_tri+= 1

print('cp_rel_tri: ', cp_rel_tri)

pattern_freq_threshold:  81
compos rel #:  211
compos_num:  4105
avg. compos tri:  635.7144945188794
cp_rel_tri:  268684


In [12]:
compos_tri

{3: {'-9-20': 964,
  '-2-3': 415,
  '-13-3': 461,
  '-12-1': 976,
  '-12-12': 971,
  '-1-1': 999,
  '-1-12': 1033,
  '-5-3': 210,
  '-0-3': 28,
  '-16-3': 345,
  '-20-3': 10,
  '-16-5': 421,
  '-18-7': 52,
  '-13-2': 134,
  '-3-3': 37,
  '-10-3': 134,
  '-21-3': 150,
  '-2-4': 10,
  '-9-14': 28,
  '-14-11': 28,
  '-14-9': 14,
  '-13-5': 6,
  '-9-1': 103,
  '-9-12': 96,
  '-13-14': 6,
  '-2-9': 6,
  '-14-3': 28,
  '-10-11': 2,
  '-9-2': 18,
  '-3-5': 2,
  '-20-9': 10,
  '-3-14': 37,
  '-10-14': 4,
  '-10-0': 2,
  '-0-11': 10,
  '-0-14': 2,
  '-17-6': 23,
  '-15-4': 10,
  '-11-14': 28,
  '-11-10': 2,
  '-11-21': 2,
  '-11-3': 30,
  '-12-20': 2,
  '-2-13': 18,
  '-1-3': 6,
  '-12-3': 2,
  '-12-9': 26,
  '-9-3': 18,
  '-21-11': 2,
  '-3-13': 2,
  '-14-0': 2,
  '-14-20': 4,
  '-1-20': 4,
  '-11-0': 14,
  '-20-14': 4,
  '-21-21': 4,
  '-1-9': 30,
  '-9-9': 4,
  '-14-13': 6,
  '-17-7': 4,
  '-1-5': 2,
  '-16-12': 2,
  '-16-1': 2,
  '-13-13': 4,
  '-11-11': 4,
  '-3-16': 2,
  '-16-2': 8,
  '-1

In [None]:
# save graph
# stats_out = data_name + '_graph.pkl'

# with open(os.path.join(data_dir, stats_out), 'wb') as file:
#     pkl.dump(g, file) 

In [136]:
# load graph
# with open(os.path.join(data_dir, 'YAGO3-10_graph.pkl'), "rb") as in_file:
#         g = pkl.load(in_file)

with open(os.path.join(data_dir, 'FB237_new_graph_full.pkl'), "rb") as in_file:
        g = pkl.load(in_file)

In [137]:
print(len(train_examples)*2)
threshold1 = 3
threshold2 = 100
compos_rel = set()

well_rep_compos_num = 0
less_rep_compos_num = 0

compos_pattern = g.getComposPattern()
# compos_pattern=g.compos_pattern

wel_compos_rel = []
less_compos_rel = []
for item in compos_pattern.items():
    rel = item[0]
    patterns = item[1]
    for p in patterns.items():
        rel_chain = p[0].split('-')[1:]
        if len(rel_chain)>1:
            if p[1] > threshold2:
                wel_compos_rel.append([rel, p[0]])
                well_rep_compos_num += 1
            elif p[1] > threshold1:
                # print(rel, p[0], p[1])
                less_compos_rel.append([rel, p[0]])
                less_rep_compos_num += 1

            compos_rel.add(rel)
            for rel_str in rel_chain:
                compos_rel.add(int(rel_str))
print('compos pattern num: ', well_rep_compos_num+less_rep_compos_num)
print('well compos pattern num: ', well_rep_compos_num)
print('less compos pattern num: ', less_rep_compos_num)
print(len(compos_rel))


544230
compos pattern num:  12867
well compos pattern num:  2599
less compos pattern num:  10268
468


In [7]:
# find composition triples in training data for Exp5
compos_pattern_full = g.getComposPatternFull()

compos_pattern_full_flatt = {}
compos_pattern_num_flatt = {}

for item in compos_pattern_full.items():
    r3 = str(item[0]) + ':'
    for rel_chain in item[1]:
        compos = r3+rel_chain
        x_tuples = [tuple(sublist) for sublist in item[1][rel_chain]]
        unique_tuples = set(x_tuples)
        compos_pattern_full_flatt[compos] = [list(subtuple) for subtuple in unique_tuples]
        compos_pattern_num_flatt[compos] = len(unique_tuples)

# compos_pattern_full_flatt =dict(sorted(compos_pattern_full_flatt.items(), key=lambda item: len(item[1]), reverse=True))
compos_pattern_num_flatt = dict(sorted(compos_pattern_num_flatt.items(), key=lambda item: item[1], reverse=True))
print(compos_pattern_num_flatt)

# compos0 '10:-191-428' 10429
# compos1 '10:-9-246': 8911
# compos2  '96:-78-372': 5678
# compos3  '10:-247-10': 9364
# compos4  '96:-78-370': 5612 
# compos5  '92:-337-12': 25 
# compos6, '135:-144-132': 26
# '10:-188-425': 7181
# '10:-247-247': 9370
# '10:-9-246': 8911
# '12:-191-428': 5431
# '96:-78-371': 5132


{'10:-10': 15989, '10:-247': 12950, '96:-96': 12893, '9:-9': 12157, '191:-191': 10945, '10:-191-428': 10429, '68:-68': 9494, '3:-3': 9465, '10:-247-247': 9370, '10:-247-10': 9364, '10:-10-10': 9162, '10:-10-247': 9158, '10:-9-246': 8911, '12:-12': 8423, '96:-96-365': 7774, '96:-96-128': 7607, '85:-85': 7268, '10:-188-425': 7181, '12:-249': 6860, '10:-186-423': 6424, '96:-96-144': 6418, '96:-96-381': 6394, '11:-11': 6277, '12:-247': 6272, '10:-249': 6272, '12:-10': 6262, '10:-12': 6262, '149:-149': 5880, '96:-78-372': 5678, '4:-4': 5673, '96:-78-370': 5612, '12:-191-428': 5431, '89:-89': 5305, '129:-129': 5201, '96:-78-371': 5132, '68:-188-310': 4997, '10:-9-4': 4920, '96:-78-369': 4849, '10:-196-433': 4774, '10:-247-12': 4418, '10:-247-249': 4361, '3:-246-11': 4238, '188:-188': 4197, '10:-10-12': 4188, '10:-10-249': 4185, '12:-9-246': 4093, '12:-247-10': 4037, '10:-11-248': 4033, '12:-247-247': 4022, '9:-241': 3882, '4:-246': 3882, '158:-158': 3795, '68:-188-333': 3780, '12:-10-10': 37

In [259]:
# Generate data for Exp5 (only with a-r1-b-r2-c)
# '10:-191-428', '68:-188-310', '12:-9-4'
compos_combination = '68:-188-310'
exp5_dataname = 'compos_0'
new_train = []
new_test = []
for path in compos_pattern_full_flatt[compos_combination]:
    a, r1, b, r2, c = path
    tri1 = [a, r1, b]
    tri2 = [b, r2, c]
    new_train.append(tri1)
    new_train.append(tri2)
    new_test.append([a,r1,r2, c])

print('new test: ', len(new_test))
print('new train: ', len(new_train))

new_data_dir = os.path.join(data_dir, exp5_dataname)
if not os.path.exists(new_data_dir):
    os.makedirs(new_data_dir)
with open(os.path.join(new_data_dir, 'trainRaw.pickle'), 'wb') as file:
    pkl.dump(new_train, file) 

with open(os.path.join(new_data_dir, 'testRaw.pickle'), 'wb') as file:
    pkl.dump(new_test, file) 
data_raw = compos_pattern_full_flatt[compos_combination]

with open(os.path.join(new_data_dir, 'dataRaw.pickle'), 'wb') as file:
    pkl.dump(data_raw, file) 

new test:  4997
new train:  9994


In [110]:
# Generate data for Exp5 (only with a-r1-b-r2-c)
compos_combination_list = ['10:-191-428', '10:-9-246', '96:-78-372', '10:-247-10','92:-337-12', '135:-144-132', '12:-329-307']
# compos_combination_list = ['92:-337-12', '135:-144-132', '96:-78-372', '12:-329-307']
exp5_dataname = 'compos_101'
new_train = []
new_test = []
data_raw = []
for compos_combination in compos_combination_list:
    for path in compos_pattern_full_flatt[compos_combination]:
        a, r1, b, r2, c = path
        tri1 = [a, r1, b]
        tri2 = [b, r2, c]
        new_train.append(tri1)
        new_train.append(tri2)
        new_test.append([a,r1,r2, c])
    
    data_raw += compos_pattern_full_flatt[compos_combination]
    
print('new test: ', len(new_test))
print('new train: ', len(new_train))

new_data_dir = os.path.join(data_dir, exp5_dataname)
if not os.path.exists(new_data_dir):
    os.makedirs(new_data_dir)
with open(os.path.join(new_data_dir, 'trainRaw.pickle'), 'wb') as file:
    pkl.dump(new_train, file) 

with open(os.path.join(new_data_dir, 'testRaw.pickle'), 'wb') as file:
        pkl.dump(new_test, file) 

with open(os.path.join(new_data_dir, 'dataRaw.pickle'), 'wb') as file:
    pkl.dump(data_raw, file) 

new test:  34458
new train:  68916


In [209]:
# exp6: generate data from FB for rare r3 of compos pattern
# compos0 '10:-191-428' 10429
# compos1 '10:-9-246': 8911
# '158:-157-395': 2055
# compos2  '96:-78-372': 5678
# compos3  '10:-247-10': 9364
# compos4  '96:-78-370': 5612 
# compos5  '92:-337-12': 25 
# compos6, '135:-144-132': 26
# '68:-188-310': 4997
# '12:-9-4': 3310
compos_combination_list = ['10:-191-428', '68:-188-310', '12:-9-4']
exp6_dataname = 'compos_exp6_34_dedup'
new_train = []
new_test = []
data_raw = []
compos_combination = '68:-188-310'
length = len(compos_pattern_full_flatt[compos_combination])
r3_num = 0.1 * length
r1_tr_freq=0
r3_tr_freq=0
r3_tst_freq=0
for path in compos_pattern_full_flatt[compos_combination]:
    a, r1, b, r2, c = path
    tri1 = [a, r1, b]
    tri2 = [b, r2, c]
    tri3 = [a, r3, c]
    new_train.append(tri1)
    new_train.append(tri2)
    r1_tr_freq += 1
    if r3_num >0:
        new_train.append(tri3)
        r3_tr_freq += 1
        r3_num -= 1
    else:
        r3_tst_freq+=1
        new_test.append(tri3)
    
print('new test: ', len(new_test))
print('new train: ', len(new_train))
print(r1_tr_freq, r3_tr_freq, r3_tst_freq)

new_data_dir = os.path.join(data_dir, exp6_dataname)
if not os.path.exists(new_data_dir):
    os.makedirs(new_data_dir)
with open(os.path.join(new_data_dir, 'trainRaw.pickle'), 'wb') as file:
    pkl.dump(new_train, file) 

with open(os.path.join(new_data_dir, 'testRaw.pickle'), 'wb') as file:
        pkl.dump(new_test, file) 

new test:  4497
new train:  10494
4997 500 4497


In [197]:
# # H2: rare composition pattern, multiple
# compos_combination_list = ['10:-191-428', '68:-188-310', '12:-9-4']
# exp6_dataname = 'compos_exp6_200'
# new_train = []
# new_test = []
# data_raw = []
# for compos_combination in compos_combination_list:
#     length = len(compos_pattern_full_flatt[compos_combination])
#     r3_num = 0.9 * length
#     for path in compos_pattern_full_flatt[compos_combination]:
#         a, r1, b, r2, c = path
#         tri1 = [a, r1, b]
#         tri2 = [b, r2, c]
#         tri3 = [a, r3, c]
#         new_train.append(tri1)
#         new_train.append(tri2)
#         if r3_num >0:
#             new_train.append(tri3)
#             r3_num -= 1
#         else:
#             new_test.append(tri3)

        
# print('new test: ', len(new_test))
# print('new train: ', len(new_train))

# new_data_dir = os.path.join(data_dir, exp6_dataname)
# if not os.path.exists(new_data_dir):
#     os.makedirs(new_data_dir)
# with open(os.path.join(new_data_dir, 'trainRaw.pickle'), 'wb') as file:
#     pkl.dump(new_train, file) 

# with open(os.path.join(new_data_dir, 'testRaw.pickle'), 'wb') as file:
#         pkl.dump(new_test, file) 

new test:  1872
new train:  54336


In [242]:
# compos_combination_list = ['68:-188-310', '10:-191-428', '12:-9-4']
r1=188
r2=310
# aug=0
new_train = g1train
new_test = g1test
ent_set=set()
tr_set = set()
tr_r1_set = set()
tr_r2_set = set()
tr_r3_set = set()
tst_tr_set = set()
for tri in new_train:
    h,r,t=tri
    ent_set.add(h)
    ent_set.add(t)
    tri_set = (h,r,t)
    tr_set.add(tri_set)
    if r == r1:
        tr_r1_set.add(tri_set)
    elif r== r2:
        tr_r2_set.add(tri_set)
    else:
        tr_r3_set.add(tri_set)
for tri in new_test:
    h,r,t=tri
    tri_set=(h,r,t)
    tst_tr_set.add(tri_set)

print('ent: ', len(ent_set))
print('r1: ', len(tr_r1_set), ' r2: ', len(tr_r2_set), ' r3: ', len(tr_r3_set), ' aug: ', aug)
# print('train: ', len(tr_r1_set)+len(tr_r2_set)+len(tr_r3_set), ' test: ', len(tst_tr_set))
print('train: ', len(tr_set), ' test: ', len(tst_tr_set))


ent:  3301
r1:  1568  r2:  4957  r3:  500  aug:  0
train:  7025  test:  4497


In [251]:
# H2: rare composition pattern, multiple
import pickle as pkl

with open(os.path.join(data_dir, 'compos_exp6_14', 'trainRaw.pickle'), 'rb') as file:
    g1train = pkl.load(file) 

with open(os.path.join(data_dir, 'compos_exp6_14', 'testRaw.pickle'), 'rb') as file:
    g1test = pkl.load(file) 

with open(os.path.join(data_dir, 'compos_exp6_30', 'trainRaw.pickle'), 'rb') as file:
    g2train = pkl.load(file) 

with open(os.path.join(data_dir, 'compos_exp6_30', 'testRaw.pickle'), 'rb') as file:
    g2test = pkl.load(file) 

with open(os.path.join(data_dir, 'compos_exp6_23', 'trainRaw.pickle'), 'rb') as file:
    g3train = pkl.load(file) 

with open(os.path.join(data_dir, 'compos_exp6_23', 'testRaw.pickle'), 'rb') as file:
    g3test = pkl.load(file) 

g4train = g1train+g2train+g3train
g4test = g1test+g2test+g3test

# exp6_dataname = 'compos_exp6_400'

# new_data_dir = os.path.join(data_dir, exp6_dataname)
# if not os.path.exists(new_data_dir):
#     os.makedirs(new_data_dir)
# with open(os.path.join(new_data_dir, 'trainRaw.pickle'), 'wb') as file:
#     pkl.dump(g4train, file) 

# with open(os.path.join(new_data_dir, 'testRaw.pickle'), 'wb') as file:
#         pkl.dump(g4test, file) 

In [3]:
import pickle as pkl
import os 

with open(os.path.join('datasets', 'data', 'countries_S1', 'train.pickle'), 'rb') as file:
    tr_countries = pkl.load(file) 

with open(os.path.join('datasets', 'data', 'countries_S1', 'test.pickle'), 'rb') as file:
    tst_countries = pkl.load(file) 

sub_countries = graph(tr_countries, max_hop=2)

compos = sub_countries.getComposPattern()
print(compos)

# used_compos_pattern = {}
# for i in range(len(tst_countries)):
#     query = tst_countries[i]
#     h,r,t = query
#     if h in g.nodes and t in g.nodes:
#         if not g.searchComposPattern(query, hop=1, threshold = threshold1):
#             if g.searchComposPattern(query, hop=2, threshold = threshold2):
#                 well_compos_test_num += 1
#             elif g.searchComposPattern(query, hop=2, threshold = threshold1):
#                 less_compos_test_num += 1

{0: {'-0': 465, '-2-0': 207, '-0-0': 201, '-0-2': 203, '-3-0': 607, '-1-0': 161, '-2': 1}, 1: {'-1': 648, '-1-3': 179, '-1-1': 180, '-3': 640, '-0-2': 760, '-3-1': 862, '-3-3': 860}}


In [253]:
subgraph_train = g4train
subgraph_test = g4test
duplicate_num = 0
deduplicate = []
for tri in subgraph_train:
    tri_tuple=(tri[0], tri[1], tri[2])
    if tri_tuple in deduplicate:
        duplicate_num += 1
    else:
        deduplicate.append(tri_tuple)
print('train triples: ', len(deduplicate))
print('before deduplicate:', len(subgraph_train))
# print(len(subgraph_train)-len(deduplicate))
print('duplication in raw: ', duplicate_num)

# subgraph_train = deduplicate

exp6_dataname = 'compos_exp6_400_dedup'
new_data_dir = os.path.join(data_dir, exp6_dataname)
if not os.path.exists(new_data_dir):
    os.makedirs(new_data_dir)
with open(os.path.join(new_data_dir, 'trainRaw.pickle'), 'wb') as file:
    pkl.dump(deduplicate, file) 

with open(os.path.join(new_data_dir, 'testRaw.pickle'), 'wb') as file:
        pkl.dump(subgraph_test, file) 

train triples:  30827
before deduplicate: 46166
duplication in raw:  15339


In [250]:
# compos_combination_list = ['68:-188-310', '10:-191-428', '12:-9-4']
sub_g = graph(deduplicate, max_hop=2)
sub_composPattern = sub_g.getComposPattern()
print(sub_composPattern.keys())
sub_composPattern[10]

for rel3 in sub_composPattern:
    for path in sub_composPattern[rel3]:
        if len(path.split('-'))>2:
            print(rel3, ':', path, ' have ',sub_composPattern[rel3][path])

print(len(deduplicate))


dict_keys([9, 4, 10])
9 : -10-7  have  2252
9 : -9-7  have  176
9 : -9-4  have  179
9 : -10-9  have  56
9 : -10-10  have  339
9 : -10-4  have  188
9 : -7-4  have  420
9 : -7-10  have  116
9 : -13-10  have  211
9 : -13-4  have  39
9 : -13-9  have  33
9 : -13-7  have  414
9 : -9-13  have  52
9 : -10-13  have  470
9 : -12-7  have  60
9 : -12-4  have  18
9 : -12-10  have  38
9 : -4-4  have  110
9 : -4-13  have  440
9 : -4-10  have  162
9 : -4-7  have  381
9 : -7-12  have  111
9 : -9-10  have  49
9 : -10-12  have  40
9 : -7-7  have  218
9 : -9-9  have  8
9 : -7-9  have  27
9 : -9-12  have  8
9 : -7-13  have  107
9 : -13-13  have  95
9 : -12-13  have  15
9 : -4-9  have  25
9 : -4-12  have  22
9 : -12-12  have  6
9 : -12-9  have  7
9 : -13-12  have  4
4 : -12-10  have  2555
4 : -12-13  have  894
4 : -12-4  have  437
4 : -4-7  have  2774
4 : -4-4  have  2616
4 : -4-9  have  438
4 : -4-10  have  4503
4 : -7-10  have  692
4 : -7-4  have  2307
4 : -12-9  have  116
4 : -7-7  have  294
4 : -13-7  h

In [102]:
import torch
test_name = 'compos_5'
with open(os.path.join('datasets', 'data', 'FB237_new', test_name, 'trainRaw.pickle'), "rb") as f:
        train_examples = pkl.load(f)
train_examples = np.array(train_examples)
train_examples = torch.from_numpy(train_examples.astype("int64"))

total_filter_dir = os.path.join(data_dir, 'to_skip.pickle')
with open(total_filter_dir, "rb") as f:
        filters = pkl.load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [96]:
# exp5: generate data for unseen relations
import random
length = 1000
r3_num=10
ent_id=0
r1,r2=0,1

new_train = []
new_test = []
new_raw = []

for i in range(length):
    x = random.randint(0,9)
    if i==0 or x<7:
        a,b,c = ent_id, ent_id+1, ent_id+2 
        ent_id = ent_id + 3
    elif x<8:
        b,c=ent_id,ent_id+1
        ent_id+=2
    elif x<9:
        a,b=ent_id,ent_id+1
        ent_id+=2
    else:
        b=ent_id
        ent_id+=1
    new_raw.append([a,r1,b,r2,c])
    new_train.append([a,r1,b])
    new_train.append([b,r2,c])
    new_test.append([a,r1,r2,c])
    

    
exp5_dataname = 'compos_10'
new_data_dir = os.path.join(data_dir, exp5_dataname)
if not os.path.exists(new_data_dir):
    os.makedirs(new_data_dir)
with open(os.path.join(new_data_dir, 'trainRaw.pickle'), 'wb') as file:
    pkl.dump(new_train, file) 

with open(os.path.join(new_data_dir, 'testRaw.pickle'), 'wb') as file:
    pkl.dump(new_test, file) 
data_raw = new_raw
with open(os.path.join(new_data_dir, 'dataRaw.pickle'), 'wb') as file:
    pkl.dump(data_raw, file) 


In [146]:
# exp6: generate data for rare relations in composition pattern
import random
length = 1000
r3_num=100
ent_id=0
r1,r2,r3=0,1,2

new_train = []
new_test = []
new_raw = []

for i in range(length):
    x = random.randint(0,9)
    if i==0 or x<7:
        a,b,c = ent_id, ent_id+1, ent_id+2 
        ent_id = ent_id + 3
    elif x<8:
        b,c=ent_id,ent_id+1
        ent_id+=2
    elif x<9:
        a,b=ent_id,ent_id+1
        ent_id+=2
    else:
        b=ent_id
        ent_id+=1

    new_train.append([a,r1,b])
    new_train.append([b,r2,c])

    if r3_num>0:
        new_train.append([a,r3,c])
        r3_num -= 1
    else:   
        new_test.append([a,r3,c])
    

    
exp6_dataname = 'compos_exp6_2'
data_dir = os.path.join('datasets', 'data', exp6_dataname)
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# with open(os.path.join(data_dir, 'trainRaw.pickle'), 'wb') as file:
#     pkl.dump(new_train, file) 

# with open(os.path.join(data_dir, 'testRaw.pickle'), 'wb') as file:
#     pkl.dump(new_test, file) 



In [147]:
print(len(new_train),len(new_test))
n=[0,0,0]
for tri in new_train:
    r=tri[1]
    if r==1:
        n[0]+=1
    elif r==2:
        n[1]+=1
    else:
        n[2]+=1
print(n)


2100 900
[1000, 100, 1000]


In [85]:
class graph:
    def __init__(self, data, max_hop) -> None:
        self.data = data
        self.max_hop = max_hop
        self.nodes = set()
        self.rels = set()

        # {ent1: {relA: [ent10, ent11...]}}
        self.paths = {} 

        # {ent1: [ent10, ent11, ...]}, i.e. multi-relations are not considered
        self.connects = {}
        
        self.relNum = 1000

        self.addElement()
        self.nodeNum = len(self.nodes)
        self.relNum = len(self.rels)

        self.addPaths()
        self.addConnectivity()
        self.getCompositionPattern(max_hop)
      
    
    def addElement(self):
        for tri in self.data:         
            h,r,t = tri
            if h not in self.nodes:
               self.nodes.add(h)
            if t not in self.nodes:
               self.nodes.add(t)
            if r not in self.rels:
               self.rels.add(r)

    def addPaths(self):
        for tri in self.data:
            h,r,t = tri
            if h not in self.paths:
                self.paths[h] = {r:[t]}
            else:
                if r not in self.paths[h]:
                    self.paths[h][r] = [t]
                else:
                    self.paths[h][r].append(t)
            rev_r = r + self.relNum
            if t not in self.paths:
                self.paths[t] = {rev_r:[h]}
            else:
                if rev_r not in self.paths[t]:
                    self.paths[t][rev_r] = [h]
                else:
                    self.paths[t][rev_r].append(h)

    def addConnectivity(self):
        for tri in self.data:
            h,r,t = tri
            if h not in self.connects:
                self.connects[h] = [t]
            else:
                self.connects[h].append(t)
            if t not in self.connects:
                self.connects[t] = [h]
            else:
                self.connects[t].append(h)
            



    def findConnectivity(self, head, tail, hop):
        if hop > 0:
            candidates = self.connects[head]
            if tail in candidates:
                return True
            for ent in candidates:
                if self.findConnectivity(ent, tail, hop-1):
                    return True
        return False
    
    def getCompositionPattern(self, hop):
        self.compos_pattern = {}
        # {rel1: {[rel2,rel3]: num, [rel2, rel4],...]: num }, rel2: {}}

        for tri in self.data:
            h0,r0,t0 = tri
            if r0 not in self.compos_pattern:
                self.compos_pattern[r0] = {}
            collect_rel_paths = []
            tmp_rel_path = ''
            tmp_ent_in_path = set()
            self.compositionPath(collect_rel_paths, tmp_rel_path, tmp_ent_in_path, h0, t0, hop)

            for rel_path in collect_rel_paths:
                if rel_path not in self.compos_pattern[r0]:
                    self.compos_pattern[r0][rel_path] = 1
                else:
                    self.compos_pattern[r0][rel_path] += 1
        



    def compositionPath(self, collect_rel_paths, tmp_rel_path, tmp_ent_in_path, head, tail, hop):
        # print(tmp_ent_in_path)
        # print(type(tmp_ent_in_path))
        if hop < 1:
            return 
        tmp_ent_in_path.add(head)
        for rel in self.paths[head]:
            for ent in self.paths[head][rel]:
                if ent == tail:   
                    # rel_path = tmp_rel_path.copy()
                    # rel_path.append(rel)
                    rel_path = tmp_rel_path +'-'+str(rel)
                    collect_rel_paths.append(rel_path)
                else:    
                    if ent not in tmp_ent_in_path:
                    # new_tmp_rel_path = tmp_rel_path.copy()
                    # new_tmp_rel_path.append(rel)
                        new_tmp_rel_path = tmp_rel_path +'-'+str(rel)
                        self.compositionPath(collect_rel_paths, new_tmp_rel_path, tmp_ent_in_path, ent, tail, hop-1)
                    
            
        return
    
    def searchComposPattern(self, query, hop = 6, threshold = 5):
        h,r,t = query
        try:
            rel_paths = [pstr for pstr in self.compos_pattern[r].keys() if self.compos_pattern[r][pstr]>=threshold]
        except:
            print(self.compos_pattern[r])
        for rp in rel_paths:
            rel_list = [int(i) for i in rp.split('-')[1:]]
            if len(rel_list)<=hop and self.checkComposPatthern(rel_list, h, t, 0):
                return True
        return False

            
    def checkComposPatthern(self, rel_path, head, tail, i):
        if i >= len(rel_path):
            return False
        if rel_path[i] not in self.paths[head]:
            return False
        if tail in self.paths[head][rel_path[i]]:
            return True
        for ent in self.paths[head][rel_path[i]]:
            return self.checkComposPatthern(rel_path, ent, tail, i+1)
        
    def searchConnectivity(self, query, threshold = 5):
        h,r,t = query
        return self.findConnectivity(h,t,threshold)

In [40]:
print('total test: ', len(test_examples))

threshold1 = 3
threshold2 = 100

num_query_seen = 0

well_compos_test_num = 0
less_compos_test_num = 0

for i in range(len(test_examples)):
    query = test_examples[i]
    h,r,t = query
    if h in g.nodes and t in g.nodes:
        num_query_seen += 1
        if not g.searchComposPattern(query, hop=1, threshold = threshold1):
            if g.searchComposPattern(query, hop=2, threshold = threshold2):
                well_compos_test_num += 1
            elif g.searchComposPattern(query, hop=2, threshold = threshold1):
                less_compos_test_num += 1


print('seen test: ', num_query_seen)
print('well_compos lead: ', well_compos_test_num)
print('less_compos lead: ', less_compos_test_num)

total test:  3134
seen test:  2923
well_compos lead:  197
less_compos lead:  50


In [24]:
sum = 0
for tri in test_examples:
    h,r,t = tri
    r_rev = r
    if t in g.paths and r_rev in g.paths[t] and h in g.paths[t][r_rev]:
        sum+=1
print(sum)

1086


In [4]:
# check connectivity
num_query_seen = 0
connect_1hop = 0
connect_2hop = 0
connect_3hop = 0
connect_4hop = 0
connect_5hop = 0
for query in test_examples:
    h,r,t = query
    if h in g.nodes and t in g.nodes:
        num_query_seen += 1
        if g.findConnectivity(h,t, hop=1):
            connect_1hop += 1
        if g.findConnectivity(h,t, hop=2):
            connect_2hop += 1
        if g.findConnectivity(h,t, hop=3):
            connect_3hop += 1
        if g.findConnectivity(h,t, hop=4):
            connect_4hop += 1
        if g.findConnectivity(h,t, hop=5):
            connect_5hop += 1

connect_hop = [connect_1hop, connect_2hop, connect_3hop, connect_4hop, connect_5hop]
connect_ratio_seen = [i/num_query_seen for i in connect_hop]
connect_ratio_total = [i/len(test_examples) for i in connect_hop]
print(connect_hop)
print(connect_ratio_seen)
print(connect_ratio_total)

[1096, 1387, 2060, 2295, 2572]
[0.3749572357167294, 0.47451248717071504, 0.7047553882996921, 0.7851522408484434, 0.8799178925761204]
[0.3497128270580728, 0.4425654116145501, 0.6573069559668155, 0.7322910019144863, 0.8206764518187619]


In [5]:
# check composition
num_query_seen = 0
compos_1hop = 0
compos_2hop = 0
compos_3hop = 0
compos_4hop = 0
compos_5hop = 0
for i in range(len(test_examples)):
    query = test_examples[i]
    h,r,t = query
    if h in g.nodes and t in g.nodes:
        num_query_seen += 1
        if g.searchComposPattern(query, hop=1, threshold = 100):
            compos_1hop += 1
        if g.searchComposPattern(query, hop=2, threshold = 100):
            compos_2hop += 1
        if g.searchComposPattern(query, hop=3, threshold = 100):
            compos_3hop += 1
        if g.searchComposPattern(query, hop=4, threshold = 100):
            compos_4hop += 1
        if g.searchComposPattern(query, hop=5, threshold = 100):
            compos_5hop += 1

compos_hop = [compos_1hop, compos_2hop, compos_3hop, compos_4hop, compos_5hop]
compos_ratio_seen = [i/num_query_seen for i in compos_hop]

total_test_num = len(test_examples)
compos_ratio_total = [i/total_test_num for i in compos_hop]

print(compos_hop)
print(compos_ratio_seen)
print(compos_ratio_total)


[1083, 1281, 1281, 1281, 1281]
[0.3705097502565857, 0.43824837495723573, 0.43824837495723573, 0.43824837495723573, 0.43824837495723573]
[0.3455647734524569, 0.40874282067645185, 0.40874282067645185, 0.40874282067645185, 0.40874282067645185]


In [38]:
# model = 'HolmE_WN_64'

model = 'AttH_WN_32'

# model = 'HolmE_FB_64'

# model = 'AttH_FB_32'

result_path = os.path.join('AAAI_models', model,'results.pkl')

with open(result_path, 'rb') as f:
    ranks_total = pkl.load(f)

# print(ranks_total['rhs'])
mrr_list = [((1/ranks_total['rhs'][i] + 1/ranks_total['lhs'][i])/2).item() for i in range(len(ranks_total['rhs']))]


In [39]:
print('total mrr: ', np.mean(mrr_list))
threshold1 = 3
threshold2 = 100

num_query_seen = 0

# compos_test_mrr_list = []
compos_test_r_list = []
# well_compos_test_mrr_list = []
well_compos_test_r_list = []
# less_compos_test_mrr_list = []
less_compos_test_r_list = []

for i in range(len(test_examples)):
    query = test_examples[i]
    h,r,t = query
    if h in g.nodes and t in g.nodes:
        num_query_seen += 1
        if not g.searchComposPattern(query, hop=1, threshold = threshold2):
            if g.searchComposPattern(query, hop=2, threshold = threshold2):
                # well_compos_test_mrr_list.append(mrr_list[i])
                well_compos_test_r_list.append(ranks_total['rhs'][i])
                well_compos_test_r_list.append(ranks_total['lhs'][i])
            elif g.searchComposPattern(query, hop=2, threshold = threshold1):
                # less_compos_test_mrr_list.append(mrr_list[i])
                less_compos_test_r_list.append(ranks_total['rhs'][i])
                less_compos_test_r_list.append(ranks_total['lhs'][i])
        
# compos_test_mrr_list = well_compos_test_mrr_list + less_compos_test_mrr_list
compos_test_r_list = well_compos_test_r_list + less_compos_test_r_list

def get_hit(rank_list):
    tt = len(rank_list)
    h1=np.sum([1 if i==1 else 0 for i in rank_list])
    h3=np.sum([1 if i<=3 else 0 for i in rank_list])
    h10=np.sum([1 if i<=10 else 0 for i in rank_list])
    return h1/tt,h3/tt,h10/tt

print('mrr for test derived by composition: ', np.mean([1/i for i in compos_test_r_list]), ' hit1/3/10: ', get_hit(compos_test_r_list))

print('mrr for test derived by well composition: ', np.mean([1/i for i in well_compos_test_r_list]), ' hit1/3/10: ', get_hit(well_compos_test_r_list))
print('mrr for test derived by less composition: ', np.mean([1/i for i in less_compos_test_r_list]), ' hit1/3/10: ', get_hit(less_compos_test_r_list))


total mrr:  0.4620115704755502
mrr for test derived by composition:  0.34798574  hit1/3/10:  (0.2548262548262548, 0.3841698841698842, 0.5212355212355212)
mrr for test derived by well composition:  0.36135334  hit1/3/10:  (0.2676767676767677, 0.3939393939393939, 0.5454545454545454)
mrr for test derived by less composition:  0.30459577  hit1/3/10:  (0.21311475409836064, 0.3524590163934426, 0.4426229508196721)


In [96]:
num_query_seen = 0
compos_1hop = 0
compos_2hop = 0
compos_3hop = 0
compos_4hop = 0
compos_5hop = 0
mrr_dict = {1:[], 2:[], 3:[], 4:[], 5:[]}
for i in range(len(test_examples)):
    query = test_examples[i]
    h,r,t = query
    if h in g.nodes and t in g.nodes:
        num_query_seen += 1
        if g.searchComposPattern(query, hop=1, threshold = 100):
            compos_1hop += 1
            mrr_dict[1].append(mrr_list[i])
        elif g.searchComposPattern(query, hop=2, threshold = 100):
            compos_2hop += 1
            mrr_dict[2].append(mrr_list[i])
        # elif g.searchComposPattern(query, hop=3, threshold = 100):
        #     compos_3hop += 1
        #     mrr_dict[3].append(mrr_list[i])
        # elif g.searchComposPattern(query, hop=4, threshold = 100):
        #     compos_4hop += 1
        #     mrr_dict[4].append(mrr_list[i])
        # elif g.searchComposPattern(query, hop=5, threshold = 100):
        #     compos_5hop += 1
        #     mrr_dict[5].append(mrr_list[i])

compos_hop = [compos_1hop, compos_2hop, compos_3hop, compos_4hop, compos_5hop]
compos_ratio_seen = [i/num_query_seen for i in compos_hop]

total_test_num = len(test_examples)
compos_ratio_total = [i/total_test_num for i in compos_hop]

print(compos_hop)
print(compos_ratio_seen)
print(compos_ratio_total)

print('if we calculate MRR for queries in hop number')
for i in mrr_dict:
    mrr = np.mean(mrr_dict[i])
    print(i, ': ', mrr)

print('if we calculate MRR for queries within hop number')
tmp_list = []
for i in mrr_dict:
    tmp_list.extend(mrr_dict[i])
    mrr = np.mean(tmp_list)
    print(i, ': ', mrr)


[0, 10411, 0, 0, 0]
[0.0, 0.5093942655837166, 0.0, 0.0, 0.0]
[0.0, 0.5086973517052673, 0.0, 0.0, 0.0]
if we calculate MRR for queries in hop number
1 :  nan
2 :  0.35002595092072525
3 :  nan
4 :  nan
5 :  nan
if we calculate MRR for queries within hop number
1 :  nan
2 :  0.35002595092072525
3 :  0.35002595092072525
4 :  0.35002595092072525
5 :  0.35002595092072525


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [74]:
count = 0
except_num = 0
for test in test_examples:
    h,r,t = test
    query = [t,r,h]
    if query in train_examples.tolist():
        count += 1
        # if not g.searchComposPattern(query, hop=1, threshold = 100):
        #     print(query)
        #     except_num += 1
        
print(count)
print(except_num)

0
0


In [97]:
num_query_seen = 0
compos_1hop = 0
compos_2hop = 0
compos_3hop = 0
compos_4hop = 0
compos_5hop = 0
mrr_dict = {1:[], 2:[], 3:[], 4:[], 5:[]}
for i in range(len(test_examples)):
    query = test_examples[i]
    h,r,t = query
    if h in g.nodes and t in g.nodes:
        num_query_seen += 1
        if g.searchComposPattern(query, hop=1, threshold = 10):
            compos_1hop += 1
            mrr_dict[1].append(mrr_list[i])
        elif g.searchComposPattern(query, hop=2, threshold = 10):
            compos_2hop += 1
            mrr_dict[2].append(mrr_list[i])
        # elif g.searchComposPattern(query, hop=3, threshold = 10):
        #     compos_3hop += 1
        #     mrr_dict[3].append(mrr_list[i])
        # elif g.searchComposPattern(query, hop=4, threshold = 10):
        #     compos_4hop += 1
        #     mrr_dict[4].append(mrr_list[i])
        # elif g.searchComposPattern(query, hop=5, threshold = 10):
        #     compos_5hop += 1
        #     mrr_dict[5].append(mrr_list[i])

compos_hop = [compos_1hop, compos_2hop, compos_3hop, compos_4hop, compos_5hop]
compos_ratio_seen = [i/num_query_seen for i in compos_hop]

total_test_num = len(test_examples)
compos_ratio_total = [i/total_test_num for i in compos_hop]

print(compos_hop)
print(compos_ratio_seen)
print(compos_ratio_total)

print('if we calculate MRR for queries in hop number')
for i in mrr_dict:
    mrr = np.mean(mrr_dict[i])
    print(i, ': ', mrr)

print('if we calculate MRR for queries within hop number')
tmp_list = []
for i in mrr_dict:
    tmp_list.extend(mrr_dict[i])
    mrr = np.mean(tmp_list)
    print(i, ': ', mrr)


[0, 11289, 0, 0, 0]
[0.0, 0.5523534592425874, 0.0, 0.0, 0.0]
[0.0, 0.5515977719143946, 0.0, 0.0, 0.0]
if we calculate MRR for queries in hop number
1 :  nan
2 :  0.34748206912853885
3 :  nan
4 :  nan
5 :  nan
if we calculate MRR for queries within hop number
1 :  nan
2 :  0.34748206912853885
3 :  0.34748206912853885
4 :  0.34748206912853885
5 :  0.34748206912853885


In [80]:
dict_mrr_7class = {'rare-rare':[], 'rare-freq':[], 'rare-non':[], 'freq-rare':[],'freq-freq':[],'freq-non':[],'unseen':[]}
dict_num_8class = {'rare-rare':0, 'rare-freq':0, 'rare-non':0, 'freq-rare':0,'freq-freq':0,'freq-non':0,'rare-unseen':0,'freq-unseen':0}
wn_r_trNum = {1:29715, 3:34796, 5:7402, 2: 4816, 9:3116, 4:2921, 0:1299, 10:1138, 6:923, 7:629, 8:80}

num_rare_rel = 0
num_freq_rel = 0

num_rare_compos_rel = 0
num_freq_compos_rel = 0
num_non_compos_rel = 0
num_unseen_rel = 0

freq_rel_thh = 1000

freq_compos_rel_thh = 1000
rare_compos_rel_thh = 1

for i in range(len(test_examples)):
    query = test_examples[i]
    h,r,t = query
    if wn_r_trNum[r]>freq_rel_thh: # freq relation
        num_freq_rel += 1
        if h not in g.nodes or t not in g.nodes:
            dict_num_8class['freq-unseen'] += 1
            dict_mrr_7class['unseen'].append(mrr_list[i])
            num_unseen_rel += 1
        elif g.searchComposPattern(query, hop=5, threshold = freq_compos_rel_thh):
            dict_num_8class['freq-freq'] += 1
            dict_mrr_7class['freq-freq'].append(mrr_list[i])
            num_freq_compos_rel += 1
        elif g.searchComposPattern(query, hop=5, threshold = rare_compos_rel_thh):
            dict_num_8class['freq-rare'] += 1
            dict_mrr_7class['freq-rare'].append(mrr_list[i])
            num_rare_compos_rel += 1
        else:
            dict_num_8class['freq-non'] += 1
            dict_mrr_7class['freq-non'].append(mrr_list[i])
            num_non_compos_rel += 1
    else:
        num_rare_rel += 1
        if h not in g.nodes or t not in g.nodes:
            dict_num_8class['rare-unseen'] += 1
            dict_mrr_7class['unseen'].append(mrr_list[i])
            num_unseen_rel += 1
        if g.searchComposPattern(query, hop=5, threshold = freq_compos_rel_thh):
            dict_num_8class['rare-freq'] += 1
            dict_mrr_7class['rare-freq'].append(mrr_list[i])
            num_freq_compos_rel += 1
        elif g.searchComposPattern(query, hop=5, threshold = rare_compos_rel_thh):
            dict_num_8class['rare-rare'] += 1
            dict_mrr_7class['rare-rare'].append(mrr_list[i])
            num_rare_compos_rel += 1
        else:
            dict_num_8class['rare-non'] += 1
            dict_mrr_7class['rare-non'].append(mrr_list[i])
            num_non_compos_rel += 1

print(model)
print(dict_num_8class)
for i in dict_mrr_7class:
    mrr = np.mean(dict_mrr_7class[i])
    print(i, ' ', mrr)



print('num_rare_rel: ', num_rare_rel, 'num_freq_rel:', num_freq_rel)
print('num_rare_compos_rel: ', num_rare_compos_rel, ' num_freq_compos_rel:', num_freq_compos_rel, 
      ' num_non_compos_rel:', num_non_compos_rel, ' num_unseen_rel:', num_unseen_rel)


KeyError: 219

In [41]:
FB_r_num_mrr = {}

FB_tr_r_num = {}
for tri in train_examples:
    h,r,t = tri
    if r not in FB_tr_r_num:
        FB_tr_r_num[r] = 1
    else:
        FB_tr_r_num[r] += 1

for i in range(len(test_examples)):
    h,r,t = test_examples[i]
    if r not in FB_r_num_mrr:
        FB_r_num_mrr[r] = [FB_tr_r_num[r], 1, [mrr_list[i]]]
    else:
        FB_r_num_mrr[r][1] += 1
        FB_r_num_mrr[r][2].append(mrr_list[i])

FB_r_num_mrr = dict(sorted(FB_r_num_mrr.items(), key=lambda item: item[1][0], reverse=True))

print(model)
print('relID, NumInTr, NumInTst, MRR')
for item in FB_r_num_mrr.items():
    mrr = np.mean(item[1][2])
    print(item[0], item[1][0], item[1][1], mrr)

AttH_WN_32
relID, NumInTr, NumInTst, MRR
3 34796 1251 0.14907810927816595
1 29715 1074 0.9398705968401296
5 7402 253 0.20609871951764872
2 4816 172 0.17845494088305713
9 3116 114 0.33494927447968065
4 2921 122 0.3812662858899934
0 1299 56 0.6041683383352522
10 1138 39 0.8824876867972004
6 923 26 0.23163773830478582
7 629 24 0.3007320079244285
8 80 3 1.0


In [42]:
# test with freq/rare relations vs test derived by composition patterns
print('total test: ', len(test_examples))

threshold1 = 3
threshold2 = 100
rare_rel_threshold = 1000

num_query_seen = 0

well_compos_test_num_dict = {'freq_rel': 0, 'rare_rel': 0}
less_compos_test_num_dict = {'freq_rel': 0, 'rare_rel': 0}
cannot_compos_test_num_dict =  {'freq_rel': 0, 'rare_rel': 0}

for i in range(len(test_examples)):
    query = test_examples[i]
    h,r,t = query
    if h in g.nodes and t in g.nodes:
        num_query_seen += 1
        if g.searchComposPattern(query, hop=2, threshold = threshold2):
            if FB_r_num_mrr[r][0] > rare_rel_threshold:
                well_compos_test_num_dict['freq_rel'] += 1
            else:
                well_compos_test_num_dict['rare_rel'] += 1

        elif g.searchComposPattern(query, hop=2, threshold = threshold1):
            if FB_r_num_mrr[r][0] > rare_rel_threshold:
                less_compos_test_num_dict['freq_rel'] += 1
            else:
                less_compos_test_num_dict['rare_rel'] += 1
        else:
            if FB_r_num_mrr[r][0] > rare_rel_threshold:
                cannot_compos_test_num_dict['freq_rel'] += 1
            else:
                cannot_compos_test_num_dict['rare_rel'] += 1


print('well_compos_test_num_dict: ', well_compos_test_num_dict)
print('less_compos_test_num_dict: ', less_compos_test_num_dict)
print('cannot_compos_test_num_dict: ', cannot_compos_test_num_dict)

total test:  3134
well_compos_test_num_dict:  {'freq_rel': 1281, 'rare_rel': 0}
less_compos_test_num_dict:  {'freq_rel': 58, 'rare_rel': 3}
cannot_compos_test_num_dict:  {'freq_rel': 1533, 'rare_rel': 48}


In [16]:
FB_atth = FB_r_num_mrr

In [19]:
FB_holmE=FB_r_num_mrr

In [22]:
threshold = 1000

mrr_list_atth_freq = []
mrr_list_atth_rare = []
mrr_list_holme_freq = []
mrr_list_holme_rare = []
for i in FB_atth:
    if FB_atth[i][0] >= threshold:
        mrr_list_atth_freq += FB_atth[i][2]
    else:
        mrr_list_atth_rare += FB_atth[i][2]

for i in FB_holmE:
    if FB_holmE[i][0] >= threshold:
        mrr_list_holme_freq += FB_holmE[i][2]
    else:
        mrr_list_holme_rare += FB_holmE[i][2]


atth_freq = np.mean(mrr_list_atth_freq)
atth_rare = np.mean(mrr_list_atth_rare)

holme_freq = np.mean(mrr_list_holme_freq)
holme_rare = np.mean(mrr_list_holme_rare)

print('mrr_atth_freq: ', atth_freq, ', mrr_holmE_freq: ', holme_freq, ' diff= ', holme_freq-atth_freq)
print('mrr_atth_rare: ', atth_rare, ', mrr_holmE_rare: ', holme_rare, ' diff= ', holme_rare-atth_rare )



mrr_atth_freq:  0.29847904233019384 , mrr_holmE_freq:  0.3082343687956936  diff=  0.009755326465499758
mrr_atth_rare:  0.38809156562559666 , mrr_holmE_rare:  0.4022287941420288  diff=  0.014137228516432121


In [110]:
print(len(mrr_list_atth_freq))
print(len(mrr_list_atth_rare))
print(len(mrr_list_holme_freq))
print(len(mrr_list_holme_rare))

15919
4547
15919
4547


In [104]:
FB_holmE

{10: [15989,
  214,
  [0.08012820780277252,
   0.1666666716337204,
   0.2666666805744171,
   0.3499999940395355,
   0.2666666805744171,
   0.17499999701976776,
   0.25,
   0.006797706708312035,
   0.25,
   0.1147058829665184,
   0.0012083658948540688,
   0.375,
   0.18333333730697632,
   0.1964285671710968,
   0.4166666865348816,
   0.1339285671710968,
   0.5,
   0.4166666865348816,
   0.2611111104488373,
   0.2916666865348816,
   0.026455026119947433,
   0.06470588594675064,
   0.20000000298023224,
   0.1979166716337204,
   0.2083333432674408,
   0.02055381052196026,
   0.2083333432674408,
   0.014504504390060902,
   0.13851352035999298,
   0.006642861757427454,
   0.13846154510974884,
   0.2666666805744171,
   0.4166666865348816,
   0.3333333432674408,
   0.17045454680919647,
   0.375,
   0.019611075520515442,
   0.11666667461395264,
   0.17499999701976776,
   0.5,
   0.375,
   0.4166666865348816,
   0.25,
   0.5,
   0.11274509876966476,
   0.12777778506278992,
   0.4166666865348816,

In [105]:
FB_atth

{10: [15989,
  214,
  [0.22499999403953552,
   0.1666666716337204,
   0.2916666865348816,
   0.4166666865348816,
   0.17142857611179352,
   0.11688312143087387,
   0.1666666716337204,
   0.0046126991510391235,
   0.25,
   0.09552845358848572,
   0.0018463707529008389,
   0.3125,
   0.18333333730697632,
   0.0466531440615654,
   0.4166666865348816,
   0.04296066612005234,
   0.5,
   0.4166666865348816,
   0.1369047611951828,
   0.07720588147640228,
   0.07107843458652496,
   0.02862069010734558,
   0.15555556118488312,
   0.0875576063990593,
   0.2666666805744171,
   0.010125046595931053,
   0.3499999940395355,
   0.0572916679084301,
   0.10862068831920624,
   0.028679654002189636,
   0.21212121844291687,
   0.4166666865348816,
   0.4166666865348816,
   0.4166666865348816,
   0.22499999403953552,
   0.02814938686788082,
   0.006004140712320805,
   0.13333334028720856,
   0.17499999701976776,
   0.5,
   0.2083333432674408,
   0.13846154510974884,
   0.22499999403953552,
   0.5,
   0.0455

In [55]:
23067 in g.nodes

False

In [1]:
import argparse
import json
import os
import pickle as pkl
import numpy as np

import torch

import models
from datasets.kg_dataset import KGDataset
from utils.train import avg_both, format_metrics


data_dir = os.path.join('datasets', 'data', 'FB237-exp3-4','Train')

train_data_dir = os.path.join(data_dir, 'train.pickle')
valid_data_dir = os.path.join(data_dir, 'valid.pickle')
test_data_dir = os.path.join(data_dir, 'test.pickle')

with open(train_data_dir, "rb") as in_file:
        train_examples = pkl.load(in_file)

with open(valid_data_dir, "rb") as in_file:
        valid_examples = pkl.load(in_file)

with open(test_data_dir, "rb") as in_file:
        test_examples = pkl.load(in_file)


data_size = np.max(train_examples, axis= 0)
n_entity = int(max(data_size[0], data_size[2]) + 1)
n_relation = int(data_size[1] + 1)

print('#training triples: ', len(train_examples))
print('#valid triples: ', len(valid_examples))
print('#test triples: ', len(test_examples))
print('#entities: ', n_entity)
print('#relations: ', n_relation)

train_facts = {} # each query (h,r) has a list of tails [t]: {(h,r): [t]}
ent_rel = {} # each entity has a set of rel: {ent: {set}}
ent_ent = {} # each connecting entity pair (ordered) has a rel: {(ent1, ent2): rel}
for triple in train_examples:
    head, rel, tail = triple
    if (head, rel) not in train_facts:
        train_facts[(head, rel)] = [tail]
    else:
        train_facts[(head, rel)].append(tail)

    if (tail, rel+n_relation) not in train_facts:
        train_facts[(tail, rel+n_relation)] = [head]
    else:
        train_facts[(tail, rel+n_relation)].append(head)

    if head not in ent_rel:
        ent_rel[head] = {rel}
    else:
        ent_rel[head].add(rel)
    if tail not in ent_rel:
        ent_rel[tail] = {rel+n_relation}
    else:
        ent_rel[tail].add(rel+n_relation)

    ent_ent[(head, tail)] = rel

  from .autonotebook import tqdm as notebook_tqdm


#training triples:  254916
#valid triples:  17237
#test triples:  20101
#entities:  14541
#relations:  237


In [2]:
symmertry_pattern = {} # {rel: [{ent1, ent2}]}
inversion_pattern = {} # {(rel1, rel2): [{ent1, ent2}]}
composition_pattern = {} # {(rel1, rel2, rel3): [(ent1, ent2, ent3)]}
transitivity_pattern = {} # {rel: [{ent1, ent2, ent3}] }
for query in train_facts:
    h, r = query
    for t in train_facts[query]:
        if (t, r) in train_facts:
            for t2 in train_facts[(t, r)]:
                if h == t2:
                    if r not in symmertry_pattern:
                        symmertry_pattern[r] = [{h,t}]
                    elif {h,t} not in symmertry_pattern[r]:
                        symmertry_pattern[r].append({h,t})
                elif h != t and t != t2 and t2 in train_facts[query]:
                    if r not in transitivity_pattern:
                        transitivity_pattern[r] = [(h,t,t2)]
                    else:
                        transitivity_pattern[r].append((h,t,t2))
                

        for rel in ent_rel[t]:
            if rel != r:
                if abs(rel-r) !=n_relation:
                    if h in train_facts[(t, rel)]:
                        r_pair = (r,rel) if r<rel else (rel,r)
                        if r_pair not in inversion_pattern:
                            inversion_pattern[r_pair] = [{h,t}]
                        elif {h,t} not in inversion_pattern[r_pair]:
                            inversion_pattern[r_pair].append({h,t})

                for t2 in train_facts[(t, rel)]:
                    if (h, t2) in ent_ent:
                        r_tuple = (r, rel, ent_ent[(h, t2)])
                        if r_tuple not in composition_pattern:
                            composition_pattern[r_tuple] = [(h,t,t2)]
                        else:
                            composition_pattern[r_tuple].append((h,t,t2))

# for query in train_facts:
#     h, r = query
#     for t in train_facts[query]:
#         if t != h:
#             if (t,r) in train_facts:
#                 for t2 in train_facts[(t,r)]:
#                     if t2 != h and t2 != t and t2 in train_facts[query]:
#                         if r not in symmertry_pattern:
#                             transitivity_pattern[r] = [(h,t,t2)]
#                         else:
#                             transitivity_pattern[r].append((h,t,t2))
                
symmertry_pattern_num = {} # {rel: num}
transitivity_pattern_num = {} # {rel: num}
inversion_pattern_num = {} # {(rel1, rel2): num}
composition_pattern_num = {} # {(rel1, rel2, rel3): num}
for rel in symmertry_pattern:
    symmertry_pattern_num[rel] = len(symmertry_pattern[rel])
for rel in transitivity_pattern:
    transitivity_pattern_num[rel] = len(transitivity_pattern[rel])
for rel_pair in inversion_pattern:
    inversion_pattern_num[rel_pair] = len(inversion_pattern[rel_pair])
for rel_tuple in composition_pattern:
    composition_pattern_num[rel_tuple] = len(composition_pattern[rel_tuple])
                            

In [3]:
symmertry_pattern_test_num = {}
transitivity_pattern_test_num = {}
inversion_pattern_test_num = {} # {rel1:num, rel2: num}
composition_pattern_test_num = {} # {rel3: num}
rel_test_num = {}

for triple in test_examples:
    r = triple[1]
    if r not in rel_test_num:
        rel_test_num[r] = 1
        rel_test_num[r+n_relation] = 1
    else:
        rel_test_num[r] += 1
        rel_test_num[r+n_relation] += 1


for rel in symmertry_pattern:
    if rel not in symmertry_pattern_test_num and rel in rel_test_num:
        symmertry_pattern_test_num[rel] = rel_test_num[rel]


for rel in transitivity_pattern:
    if rel not in transitivity_pattern_test_num and rel in rel_test_num:
        transitivity_pattern_test_num[rel] = rel_test_num[rel]

for rel_pair in inversion_pattern:
    r1, r2 = rel_pair
    if r1 not in inversion_pattern_test_num and r1 in rel_test_num:
        inversion_pattern_test_num[r1] = rel_test_num[r1]
    if r2 not in inversion_pattern_test_num and r2 in rel_test_num:
        inversion_pattern_test_num[r2] = rel_test_num[r2]

for rel_tuple in composition_pattern:
    _, _, r = rel_tuple
    if r not in composition_pattern_test_num and r in rel_test_num:
        composition_pattern_test_num[r] = rel_test_num[r]

In [4]:
symmertric_num = len(symmertry_pattern_num.keys())
avg_symmertric_comb = np.sum(list(symmertry_pattern_num.values()))/symmertric_num

transitivity_num = len(transitivity_pattern_num.keys())
avg_transitivity_comb = np.sum(list(transitivity_pattern_num.values()))/transitivity_num

inversion_num = len(inversion_pattern_num.keys())
avg_inversion_comb = np.sum(list(inversion_pattern_num.values()))/inversion_num

composition_num = len(composition_pattern_num.keys())
avg_compos_comb = np.sum(list(composition_pattern_num.values()))/composition_num

print('symmertric_num: ', symmertric_num, 'avg_symmertric_comb: ', avg_symmertric_comb)
print('transitivity_num: ', transitivity_num, 'avg_transitivity_comb: ', avg_transitivity_comb)
print('inversion_num: ', inversion_num, 'avg_inversion_comb: ', avg_inversion_comb)
print('composition_num: ', composition_num, 'avg_compos_comb: ', avg_compos_comb)

symmertric_num:  86 avg_symmertric_comb:  282.2325581395349
transitivity_num:  70 avg_transitivity_comb:  5391.571428571428
inversion_num:  1158 avg_inversion_comb:  89.29188255613126
composition_num:  24535 avg_compos_comb:  140.9808844507846


In [5]:
symmertric_test_rel_num = len(symmertry_pattern_test_num.keys())
symmertric_test_num = np.sum(list(symmertry_pattern_test_num.values()))
avg_symmertric_test_query_num = symmertric_test_num/symmertric_test_rel_num

transitivity_test_rel_num = len(transitivity_pattern_test_num.keys())
transitivity_test_num = np.sum(list(transitivity_pattern_test_num.values()))
avg_transitivity_test_query_num = transitivity_test_num/transitivity_test_rel_num

inversion_test_rel_num = len(inversion_pattern_test_num.keys())
inversion_test_num = np.sum(list(inversion_pattern_test_num.values()))
avg_inversion_test_query_num = inversion_test_num/inversion_test_rel_num

composition_test_rel_num = len(composition_pattern_test_num.keys())
composition_test_num = np.sum(list(composition_pattern_test_num.values()))
avg_compos_test_query_num = composition_test_num/composition_test_rel_num



print('symmertric_test_rel_num: ', symmertric_test_rel_num, 
'avg_symmertric_test_query_num: ', avg_symmertric_test_query_num,
' total related query_num: ', symmertric_test_num)
print('transitivity_test_rel_num: ', transitivity_test_rel_num, 'avg_transitivity_test_query_num: ', avg_transitivity_test_query_num,
' total related query_num: ', transitivity_test_num)
print('inversion_test_rel_num: ', inversion_test_rel_num, 'avg_inversion_test_query_num: ', avg_inversion_test_query_num,
' total related query_num: ', inversion_test_num)
print('composition_test_rel_num: ', composition_test_rel_num, 'avg_compos_test_query_num: ', avg_compos_test_query_num,
' total related query_num: ', composition_test_num)

symmertric_test_rel_num:  76 avg_symmertric_test_query_num:  81.34210526315789  total related query_num:  6182
transitivity_test_rel_num:  64 avg_transitivity_test_query_num:  62.78125  total related query_num:  4018
inversion_test_rel_num:  332 avg_inversion_test_query_num:  91.855421686747  total related query_num:  30496
composition_test_rel_num:  221 avg_compos_test_query_num:  90.83257918552036  total related query_num:  20074


In [181]:
stats_file = os.path.join(data_dir, 'patternStats.txt')

if os.path.exists(stats_file):
    os.remove(stats_file)

with open(stats_file, "a") as f:
    f.write('#entiy: {}\t#relation: {}\n'.format(n_entity, n_relation))
    f.write('#training triples: {}\t#valid triples: {}\t#test triples: {}\n'.format(len(train_examples), len(valid_examples), len(test_examples)))
    f.write('In Training:\n')
    f.write('\t#Symmertric Pattern: {}\t\t#Average supported triple-combination: {}\n'.format(symmertric_num, avg_symmertric_comb))
    f.write('\t#Transitivity Pattern: {}\t\t#Average supported triple-combination: {}\n'.format(transitivity_num, avg_transitivity_comb))
    f.write('\t#Inversion Pattern: {}\t\t#Average supported triple-combination: {}\n'.format(inversion_num, avg_inversion_comb))
    f.write('\t#Composition Pattern: {}\t\t#Average supported triple-combination: {}\n\n'.format(composition_num, avg_compos_comb))

    f.write('In Test:\n')
    f.write('\t#Symmertric Relation: {}\t\t#Average symmertry-related query: {}\t\t#Total related query: {}\n'
    .format(symmertric_test_rel_num, avg_symmertric_test_query_num, symmertric_test_num))
    f.write('\t#Transitiv Relation: {}\t\t#Average transitiv-related query: {}\t\t#Total related query: {}\n'
    .format(transitivity_test_rel_num, avg_transitivity_test_query_num, transitivity_test_num))
    f.write('\t#Inversion Relation: {}\t\t#Average inversion-related query: {}\t\t#Total related query: {}\n'
    .format(inversion_test_rel_num, avg_inversion_test_query_num, inversion_test_num))
    f.write('\t#Composed Relations: {}\t\t#Average composition-related query: {}\t\t#Total related query: {}\n'
    .format(composition_test_rel_num, avg_compos_test_query_num, composition_test_num))

In [26]:
ranks = under_rep_rel_rnk['rhs'] + under_rep_rel_rnk['lhs']+other_rel_rnk['rhs']+other_rel_rnk['lhs']
print('#total test triples: ', len(ranks)//2)
print('#total query: ', len(ranks))
print('total MRR: ', np.mean([1/i for i in ranks]))

#total test triples:  20466
#total query:  40932
total MRR:  0.29181558


In [27]:
print('\t#rhs-query with under-represented rel: ', len(under_rep_rel_rnk['rhs']))
print('\t#lhs-query with under-represented rel: ', len(under_rep_rel_rnk['lhs']))
print('#query with under-represented rel: ', len(under_rep_rel_rnk['rhs'])+len(under_rep_rel_rnk['lhs']))

print('===============================')
print('\trhs MRR of under-represented rel: ', np.mean([1/i for i in under_rep_rel_rnk['rhs']]))
print('\tlhs MRR of under-represented rel: ', np.mean([1/i for i in under_rep_rel_rnk['lhs']]))
print('MRR of under-represented rel: ', np.mean([1/i for i in under_rep_rel_rnk['rhs']+under_rep_rel_rnk['lhs']]))

print('===============================')
print('\trhs MRR of other rel: ', np.mean([1/i for i in other_rel_rnk['rhs']]))
print('\tlhs MRR of other rel: ', np.mean([1/i for i in other_rel_rnk['lhs']]))
print('MRR of other rel: ', np.mean([1/i for i in other_rel_rnk['rhs']+other_rel_rnk['lhs']]))

	#rhs-query with under-represented rel:  1503
	#lhs-query with under-represented rel:  1503
#query with under-represented rel:  3006
	rhs MRR of under-represented rel:  0.34478772
	lhs MRR of under-represented rel:  0.020680781
MRR of under-represented rel:  0.18273427
	rhs MRR of other rel:  0.38739136
	lhs MRR of other rel:  0.21353129
MRR of other rel:  0.30046135


In [3]:
# to verify the splitted mrrs correspond to total mrr 
ranks = []
for freq in B_freq_rnk:
    ranks += B_freq_rnk[freq]['lhs']
    ranks += B_freq_rnk[freq]['rhs']
print(len(ranks))
rr = [1/i for i in ranks]
print(np.mean(rr))
print(len(test_examples))

6268
0.4859449
3134


In [4]:
# to get the freq and mrr lists, in order to make the scatter figure below (Indicate the relation between freq and mrr )
B_rhs_freq = []
B_rhs_mrr = []
B_lhs_freq = []
B_lhs_mrr = []

for freq in B_freq_rnk:
    rhs_rr = np.array([1/i for i in B_freq_rnk[freq]['rhs']])
    lhs_rr = np.array([1/i for i in B_freq_rnk[freq]['lhs']])
    if len(rhs_rr) != 0:
        B_rhs_freq.append(freq)
        B_rhs_mrr.append(np.mean(rhs_rr))
    if len(lhs_rr) != 0:
        B_lhs_freq.append(freq)
        B_lhs_mrr.append(np.mean(lhs_rr))  

In [13]:
with open(os.path.join(model_dir, 'B_freq_rnk.pickle'), 'wb') as save_file:
    pkl.dump(B_freq_rnk, save_file)

In [15]:
# set bias to 0 manually, if the model is trained with constant bias, then the following is not needed
ub_model = getattr(models, args.model)(args)
device = 'cuda'
ub_model.to(device)
ub_model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pt')))
bt_shape = model.bt.weight.shape
ub_model.bt.weight.data = torch.zeros(bt_shape).cuda()
ub_model.bt.weight.data


ub_freq_rnk = {}
rhs_ranks = ub_model.get_ranking(test_examples, filters['rhs'], 100)

lhs_test_examples = test_examples.clone()
tmp = torch.clone(lhs_test_examples[:, 0])
lhs_test_examples[:, 0] = lhs_test_examples[:, 2]
lhs_test_examples[:, 2] = tmp
lhs_test_examples[:, 1] += model.sizes[1] // 2

lhs_ranks = ub_model.get_ranking(lhs_test_examples, filters['lhs'], 100)


for t_f, h_f, rhs_r, lhs_r in zip(tail_freq, head_freq, rhs_ranks, lhs_ranks):
    if t_f not in ub_freq_rnk:
        ub_freq_rnk[t_f] = {'rhs':[], 'lhs':[]}
    if h_f not in ub_freq_rnk:
        ub_freq_rnk[h_f] = {'rhs':[], 'lhs':[]}
    ub_freq_rnk[t_f]['rhs'].append(rhs_r)
    ub_freq_rnk[h_f]['lhs'].append(lhs_r)

In [16]:
with open(os.path.join(model_dir, 'ub_freq_rnk.pickle'), 'wb') as save_file:
    pkl.dump(ub_freq_rnk, save_file)

In [17]:
# to get the freq and mrr lists, in order to make the scatter figure below (Indicate the relation between freq and mrr )
uB_rhs_freq = []
uB_rhs_mrr = []
uB_lhs_freq = []
uB_lhs_mrr = []

for freq in ub_freq_rnk:
    rhs_rr = np.array([1/i for i in ub_freq_rnk[freq]['rhs']])
    lhs_rr = np.array([1/i for i in ub_freq_rnk[freq]['lhs']])
    if len(rhs_rr) != 0:
        uB_rhs_freq.append(freq)
        uB_rhs_mrr.append(np.mean(rhs_rr))
    if len(lhs_rr) != 0:
        uB_lhs_freq.append(freq)
        uB_lhs_mrr.append(np.mean(lhs_rr))   

In [19]:
# to verify the splitted mrrs correspond to total mrr 
ranks = []
for freq in ub_freq_rnk:
    ranks += ub_freq_rnk[freq]['lhs']
    ranks += ub_freq_rnk[freq]['rhs']
print(len(ranks))
rr = [1/i for i in ranks]
print(np.mean(rr))
print(len(test_examples))

6268
0.46353686
3134


In [20]:
# get the mrr of splitted test data, the threshold and freq_rnk are changable 
# freq_rnk = B_freq_rnk
freq_rnk = ub_freq_rnk 
# B_freq_rnk

threshold = 2*mean

rare_entity_ratio = np.mean([1 if i<=threshold else 0 for i in ent_freq])

rare_rhs_r = []
rare_lhs_r = []
common_rhs_r = []
common_lhs_r = []

for freq in freq_rnk:
    if freq <= threshold:
        rare_rhs_r += freq_rnk[freq]['rhs']
        rare_lhs_r += freq_rnk[freq]['lhs']
    else:
        common_rhs_r += freq_rnk[freq]['rhs']
        common_lhs_r += freq_rnk[freq]['lhs']

rare_rhs_rr = np.array([1/i for i in rare_rhs_r])
rare_lhs_rr = np.array([1/i for i in rare_lhs_r])
common_rhs_rr = np.array([1/i for i in common_rhs_r])
common_lhs_rr = np.array([1/i for i in common_lhs_r])

rare_rhs_mrr = np.mean(rare_rhs_rr)
rare_lhs_mrr = np.mean(rare_lhs_rr)
common_rhs_mrr = np.mean(common_rhs_rr)
common_lhs_mrr = np.mean(common_lhs_rr)

print(len(rare_rhs_rr), 'of rare_rhs_queries, has rare_rhs_mrr: ', rare_rhs_mrr)
print(len(rare_lhs_rr), 'of rare_lhs_queries, has rare_lhs_mrr: ', rare_lhs_mrr)
print(len(common_rhs_rr), 'of common_rhs_queries, has common_rhs_mrr: ', common_rhs_mrr)
print(len(common_lhs_rr), 'of common_rhs_queries, has common_lhs_mrr: ', common_lhs_mrr)
print('rare_entity_ratio: ', rare_entity_ratio)
print('rare_query_ratio: ', (len(rare_rhs_rr) + len(rare_lhs_rr))/(len(rare_rhs_rr) + len(rare_lhs_rr) + len(common_rhs_rr) + len(common_lhs_rr)))


2040 of rare_rhs_queries, has rare_rhs_mrr:  0.5103688
2680 of rare_lhs_queries, has rare_lhs_mrr:  0.42658558
1094 of common_rhs_queries, has common_rhs_mrr:  0.42069843
454 of common_rhs_queries, has common_lhs_mrr:  0.57445663
rare_entity_ratio:  0.9444349835786401
rare_query_ratio:  0.7530312699425654


In [None]:
import matplotlib.pyplot as plt

plt.scatter(B_rhs_freq, B_rhs_mrr, c='b', s=2, label= 'B_rhs')
plt.scatter(B_lhs_freq, B_lhs_mrr, c='r', s=2, label= 'B_lhs')
plt.scatter(uB_rhs_freq, uB_rhs_mrr, c='k', s=1, label= 'nB_rhs')
plt.scatter(uB_lhs_freq, uB_lhs_mrr, c='y', s=1, label= 'nB_lhs')
plt.legend()
plt.savefig(os.path.join(model_dir, 'freq_mrr_w_woB.png'))




In [None]:

plt.scatter(B_rhs_freq, B_rhs_mrr, c='b', s=1, label= 'B_rhs')
plt.scatter(B_lhs_freq, B_lhs_mrr, c='r', s=1, label= 'B_lhs')
plt.scatter(uB_rhs_freq, uB_rhs_mrr, c='k', s=1, label= 'nB_rhs')
plt.scatter(uB_lhs_freq, uB_lhs_mrr, c='y', s=1, label= 'nB_lhs')
plt.xlim(0, threshold)
plt.legend()

In [None]:
# test_name = 'test_t0_h0'
# test_path = os.path.join(dataset_path, 'test', test_name+'.pickle')
# with open(test_path, "rb") as in_file:
#         test_split = pkl.load(in_file)

# test_split = []
# for triple in test_examples:
#     head, rel, tail = triple.tolist()
#     if ent_freq[tail] == 1 and ent_freq[head]  == 1:
#         test_split.append([head, rel, tail])
# test_split = np.array(test_split)

# test_split = torch.from_numpy(test_split.astype("int64"))

In [None]:
test_split_list = [test_t0_h0, test_t0_h1, test_t1_h0, test_t1_h1]
test_split_name_list = ['test_t0_h0', 'test_t0_h1', 'test_t1_h0', 'test_t1_h1']


for i in list(range(2,4)):
    test_split = test_split_list[i]
    test_split_name = test_split_name_list[i]
    mean_rank = {}
    mean_reciprocal_rank = {}
    hits_at = {}
    ranks_total = {}

    for m in ["rhs", "lhs"]:
        q = test_split.clone()
        if m == "lhs":
            tmp = torch.clone(q[:, 0])
            q[:, 0] = q[:, 2]
            q[:, 2] = tmp
            q[:, 1] += model.sizes[1] // 2
        ranks = model.get_ranking(q, filters[m],10)
        mean_rank[m] = torch.mean(ranks).item()
        mean_reciprocal_rank[m] = torch.mean(1. / ranks).item()
        hits_at[m] = torch.FloatTensor((list(map(
            lambda x: torch.mean((ranks <= x).float()).item(),
            (1, 3, 10)
        ))))
        ranks_total[m] = ranks

    # stats_file = os.path.join(model_dir, 'results_{}.txt'.format(test_name))
    # with open(stats_file, "a") as f:
    #     f.write('ranks (rhs): {}'.format(ranks_total['rhs'].tolist()))
    #     f.write('\n********************\n********************\n********************\n')
    #     f.write('mean_rank: {}, mean_reciprocal_rank: {}, hits_at[1/3/10]: {} \n\n'\
    #         .format(mean_rank['rhs'], mean_reciprocal_rank['rhs'], hits_at['rhs']))
    #     f.write('=======================================\n')
    #     f.write('=======================================\n')
    #     f.write('=======================================\n')
    #     f.write('=======================================\n')
    #     f.write('=======================================\n\n\n')
    #     f.write('reversed ranks (lhs): {}'.format(ranks_total['lhs'].tolist()))
    #     f.write('\n********************\n********************\n********************\n')
    #     f.write('mean_rank: {}, mean_reciprocal_rank: {}, hits_at[1/3/10]: {} \n'\
    #         .format(mean_rank['lhs'], mean_reciprocal_rank['lhs'], hits_at['lhs']))

    mr = (mean_rank['lhs'] + mean_rank['rhs']) / 2.
    mrr = (mean_reciprocal_rank['lhs'] + mean_reciprocal_rank['rhs']) / 2.
    h = (hits_at['lhs'] + hits_at['rhs']) / 2.
    test_metrics = {'MR': mr, 'MRR': mrr, 'hits@[1,3,10]': h}
    print('In test split: ', test_split_name)
    print('\t mrr rhs: ', mean_reciprocal_rank['rhs'], ' and mrr lhs: ', mean_reciprocal_rank['lhs'])
    print('\t hit at rhs: ', hits_at['rhs'], ' and hit at lhs: ', hits_at['lhs'])

    print(format_metrics(test_metrics, split='test'))

In [None]:
for i in range(4):
    test_split = test_split_list[i]
    test_split_name = test_split_name_list[i]
    mean_rank = {}
    mean_reciprocal_rank = {}
    hits_at = {}
    ranks_total = {}

    for m in ["rhs", "lhs"]:
        q = test_split.clone()
        if m == "lhs":
            tmp = torch.clone(q[:, 0])
            q[:, 0] = q[:, 2]
            q[:, 2] = tmp
            q[:, 1] += model_noB.sizes[1] // 2
        ranks = model_noB.get_ranking(q, filters[m],10)
        mean_rank[m] = torch.mean(ranks).item()
        mean_reciprocal_rank[m] = torch.mean(1. / ranks).item()
        hits_at[m] = torch.FloatTensor((list(map(
            lambda x: torch.mean((ranks <= x).float()).item(),
            (1, 3, 10)
        ))))
        ranks_total[m] = ranks

    mr = (mean_rank['lhs'] + mean_rank['rhs']) / 2.
    mrr = (mean_reciprocal_rank['lhs'] + mean_reciprocal_rank['rhs']) / 2.
    h = (hits_at['lhs'] + hits_at['rhs']) / 2.
    test_metrics = {'MR': mr, 'MRR': mrr, 'hits@[1,3,10]': h}
    print('In test split: ', test_split_name)
    print('\t mrr rhs: ', mean_reciprocal_rank['rhs'], ' and mrr lhs: ', mean_reciprocal_rank['lhs'])
    print('\t hit at rhs: ', hits_at['rhs'], ' and hit at lhs: ', hits_at['lhs'])

    print(format_metrics(test_metrics, split='test'))

In [None]:

test_split = []
for triple in test_examples:
    head, rel, tail = triple.tolist()
    if ent_freq[tail] == 1 and ent_freq[head]  == 1:
        test_split.append([head, rel, tail])
test_split = np.array(test_split)

test_split = torch.from_numpy(test_split.astype("int64"))

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.figure()
plt.scatter(ent_freq, bh,s=1)
# plt.scatter(x,y)
# plt.xlim(0,200)
# plt.savefig('bias head.png')
plt.show()

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.figure()
plt.scatter(ent_freq, bt,s=1)
# plt.scatter(x,y)
# plt.xlim(0,200)
# plt.savefig('bias tail.png')
plt.show()