In [1]:
def rdict_to_set(dict_r):
    set_r = set()
    for head in dict_r:
        for tail in dict_r[head]:
            set_r.add((head,tail))
    return set_r

def rdict_to_reverse_set(dict_r):
    set_reverse_r = set()
    for head in dict_r:
        for tail in dict_r[head]:
            set_reverse_r.add((tail,head))
    return set_reverse_r

def check_symmetric(r, all_relations,portion=0.9):
    '''
    input: r - a key in all_relations, representing a relation r
    all_relations: a dict of relations, each relation is a dict of the form {h_1: [t_1, t_2, ...], h_2:[t_3,t_4,...]}
    portion: similarity parameter, float
    output: True or False
    ---
    check_symmetric will return True if for at least portion of the triplets (h,r,t) in relation r satisfy that 
    the triplet (t,r,h) also appears in r. Otherwise it returns False.
    '''
    if r in all_relations:
        dict_r = all_relations[r]
        set_r = rdict_to_set(dict_r)
        set_reverse_r = rdict_to_reverse_set(dict_r)
        if len(set_reverse_r.intersection(set_r)) >=portion*len(set_r):
            return True
    return False

def check_inversion(r, all_relations,portion=0.9):
    '''
    input: r - a key in all_relations, representing a relation r
    all_relations: a dict of relations, each relation is a dict of the form {h_1: [t_1, t_2, ...], h_2:[t_3,t_4,...]}
    portion: similarity parameter, float
    output: None or an array of keys from all_relations
    ---
    check_inversion will return a key s if the relation s is an inverse of r for at least portion of the triplets in both.
    otherwise it returns None.
    '''
    if r in all_relations:
        r_inv=[]
        dict_r = all_relations[r]
        reversed_set_r = rdict_to_reverse_set(dict_r)
        count_trip_r = len(reversed_set_r)
        relations = all_relations.keys()
        for s in relations:
            dict_s = all_relations[s]
            set_s = rdict_to_set(dict_s)
            count_trip_s = len(set_s)
            count_trip_intersection = len(set_s.intersection(reversed_set_r))
            if count_trip_intersection >= portion*count_trip_s and count_trip_intersection>=portion*count_trip_r:
                r_inv.append(s)
        if len(r_inv)!=0:
            return r_inv
    return None

def check_anti_sym(r, all_relations):
    '''
    input: r - a key in all_relations, representing a relation r
    all_relations: a dict of relations, each relation is a dict of the form {h_1: [t_1, t_2, ...], h_2:[t_3,t_4,...]}
    output: True or False
    ---
    check_anti_sym will return True if the relation r is antisymmetric
    '''
    if r in all_relations:
        dict_r = all_relations[r]
        set_r = rdict_to_set(dict_r)
        reversed_set_r = rdict_to_reverse_set(dict_r)
        if len(set_r.intersection(reversed_set_r))==0:
            return True
    return False

def check_transitivity(r, all_relations, portion=0.9):
    '''
    input: r - a key in all_relations, representing a relation r
    all_relations: a dict of relations, each relation is a dict of the form {h_1: [t_1, t_2, ...], h_2:[t_3,t_4,...]}
    portion: float
    output: True or False
    ---
    check_transitivity will return True if the relation r is transitive (in a non empty way) on at least portion of
    the pairs of triplets (a,r,b) (b,r,c)
    '''
    if r in all_relations:
        set_2_chain_endpoints = find_all_2_chains(all_relations[r])
        set_r = rdict_to_set(all_relations[r])
        if (len(set_2_chain_endpoints.intersection(set_r)) > portion*len(set_2_chain_endpoints) 
        and len(set_2_chain_endpoints.intersection(set_r)) > portion*0.05*len(set_r)):
            return True
    return False

def find_all_2_chains(rdict):
    set_2_chain_endpoints = set()
    for a in rdict:
        for b in rdict[a]:
            if b in rdict:
                for c in rdict[b]:
                    set_2_chain_endpoints.add((a,c))
    return set_2_chain_endpoints

def check_composition(r, s, all_relations, chain_portion=0.8, portion=0.8):
    '''
    input:  r - a key in all_relations, representing a relation r
            s - a key in all_relations, representing a relation s
            all_relations: a dict of relations, each relation is a dict of the form 
            {h_1: [t_1, t_2, ...], h_2:[t_3,t_4,...]}
            chain_portion: float
            portion: float
    output: None or a list of relations
    ---
    check_composition will return relations [u_1, u_2, ...] if applying r and then s is equivalent to applying 
    u = u_i, when at least "chain_portion" of the edges in r are composable and "portion" of the chains of length 2 
    (a, r, b), (b, s, c) also appears in u as (a, u, c). If no such u is found, check_composition will return None.
    '''
    comp = []
    if r in all_relations and s in all_relations:
        dict_r = all_relations[r]
        dict_s = all_relations[s]
        two_chains = find_all_2_chains_two_rel(dict_r, dict_s)
        if len(two_chains)>=chain_portion*(len(rdict_to_set(dict_r))):
            for cand in all_relations:
                set_cand = rdict_to_set(all_relations[cand])
                if len(two_chains.intersection(set_cand)) >=portion*len(two_chains):
                    comp.append(cand)
        if len(comp)!=0:
            return comp
    return None

        
def find_all_2_chains_two_rel(rdict, sdict):
    set_2_chain_endpoints = set()
    for a in rdict:
        for b in rdict[a]:
            if b in sdict:
                for c in sdict[b]:
                    set_2_chain_endpoints.add((a,c))
    return set_2_chain_endpoints

In [2]:
dataset = "FB15K237"
train_file = "../benchmarks/" + dataset +"/"+ "train2id.txt"
valid_file = "../benchmarks/" + dataset +"/"+ "valid2id.txt"
test_file = "../benchmarks/" + dataset +"/"+ "test2id.txt"


file = train_file
triples = open(file, "r")

all_relations_FB = dict()
i=0

for line in triples:
    if i!=0:
        h,t,r = list(map(int, line.strip().split()))
        if r in all_relations_FB:
            if h in all_relations_FB[r]:
                all_relations_FB[r][h].append(t)
            else:
                all_relations_FB[r][h] = [t]
        else:
            all_relations_FB[r] = {h:[t]}
    i+=1

In [3]:
list_of_sym = []
list_of_inv = []
list_of_anti_sym = []
list_of_trans = []
list_of_comp = []

for r in all_relations_FB.keys():
    if check_symmetric(r, all_relations_FB):
        list_of_sym.append(r)
    if check_inversion(r, all_relations_FB) is not None:
        list_of_inv.append((r, check_inversion(r, all_relations_FB)))
    if check_anti_sym(r, all_relations_FB):
        list_of_anti_sym.append(r)    
    if check_transitivity(r, all_relations_FB):
        list_of_trans.append(r)
    for s in all_relations_FB.keys():
        check_comp_r_s=check_composition(r, s, all_relations_FB)
        if check_comp_r_s is not None:
            list_of_comp.append((r,s, check_comp_r_s))

In [4]:
print ("list_of_sym: ", list_of_sym)
print ("list_of_inv: ", list_of_inv)
print ("list_of_anti_sym: ", list_of_anti_sym)
print ("list_of_trans: ", list_of_trans)
print ("list_of_comp: ", list_of_comp)

list_of_sym:  [56, 81, 146]
list_of_inv:  [(56, [56]), (81, [81]), (146, [146])]
list_of_anti_sym:  [0, 1, 2, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 33, 34, 35, 36, 37, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 55, 60, 61, 62, 63, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 76, 78, 79, 80, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 112, 113, 115, 116, 117, 118, 120, 121, 122, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 136, 137, 138, 139, 140, 142, 143, 145, 147, 148, 149, 150, 151, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 176, 177, 179, 181, 183, 184, 185, 186, 187, 189, 191, 193, 194, 195, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 217, 219, 220, 221, 222, 223, 224, 225, 226, 227, 229, 232, 233, 234, 236]
list_of_trans:  [56, 

In [None]:
p = 0.5
list_of_trans = []
for r in all_relations_FB.keys():  
    if check_transitivity(r, all_relations_FB, p):
        list_of_trans.append(r)
        
print(list_of_trans)
len(list_of_trans)

In [12]:
#clean list_of_comp from id's
ids = {56, 81, 146}

cleaned_list_of_comp = []
for item in list_of_comp:
    r,s,_ = item
    if r not in ids and s not in ids:
        cleaned_list_of_comp.append(item)

In [6]:
cleaned_list_of_comp

[(16, 228, [16]),
 (108, 25, [25]),
 (111, 120, [120]),
 (128, 25, [25]),
 (166, 25, [25]),
 (175, 48, [48]),
 (217, 25, [25]),
 (225, 25, [25]),
 (228, 22, [22]),
 (228, 44, [44]),
 (228, 89, [89]),
 (228, 101, [101])]

In [13]:
def total_num_edges_comp(cleaned_list_of_comp, all_relations):
    total_edge_in_r = 0
    total_edge_in_s = 0
    total_two_chain = 0
    total_two_chain_in_intersection = 0
    total_edges_in_r_circ_s = 0
    i=0
    for item in cleaned_list_of_comp:
        r,s,circ = item
        circ = circ[0]
        rset = rdict_to_set(all_relations[r])
        total_edge_in_r+=len(rset)
        sset = rdict_to_set(all_relations[r])
        total_edge_in_s+=len(sset)
        two_chain = find_all_2_chains_two_rel(all_relations[r],all_relations[s])
        total_two_chain+=len(two_chain)
        circ_set = rdict_to_set(all_relations[circ])
        total_edges_in_r_circ_s+=len(circ_set)
        total_two_chain_in_intersection+=len(two_chain.intersection(circ_set))
        i+=1
    
    print("ave # of edges for r: ", total_edge_in_r/i)
    print("ave # of edges for s: ", total_edge_in_s/i)
    print("ave # of two chain: ", total_two_chain/i)
    print("ave # of two chain in intersection: ", total_two_chain_in_intersection/i)
    print("ave # of edges for r_circ_s: ", total_edges_in_r_circ_s/i)
    
    print("tot # of edges for r: ", total_edge_in_r)
    print("tot # of edges for s: ", total_edge_in_s)
    print("tot # of two chain: ", total_two_chain)
    print("tot # of two chain in intersection: ", total_two_chain_in_intersection)
    print("tot # of edges for r_circ_s: ", total_edges_in_r_circ_s)

total_num_edges_comp(cleaned_list_of_comp, all_relations_FB)

ave # of edges for r:  212.83333333333334
ave # of edges for s:  212.83333333333334
ave # of two chain:  215.33333333333334
ave # of two chain in intersection:  185.5
ave # of edges for r_circ_s:  2623.3333333333335
tot # of edges for r:  2554
tot # of edges for s:  2554
tot # of two chain:  2584
tot # of two chain in intersection:  2226
tot # of edges for r_circ_s:  31480


In [15]:
dataset = "WN18RR"
train_file = "../benchmarks/" + dataset +"/"+ "train2id.txt"
valid_file = "../benchmarks/" + dataset +"/"+ "valid2id.txt"
test_file = "../benchmarks/" + dataset +"/"+ "test2id.txt"


file = train_file
triples = open(file, "r")

all_relations = dict()
i=0

for line in triples:
    if i!=0:
        h,t,r = list(map(int, line.strip().split()))
        if r in all_relations:
            if h in all_relations[r]:
                all_relations[r][h].append(t)
            else:
                all_relations[r][h] = [t]
        else:
            all_relations[r] = {h:[t]}
    i+=1
    
    
list_of_sym_wn = []
list_of_inv_wn = []
list_of_anti_sym_wn = []
list_of_trans_wn = []
list_of_comp_wn = []

for r in all_relations.keys():
    if check_symmetric(r, all_relations):
        list_of_sym_wn.append(r)
    if check_inversion(r, all_relations) is not None:
        list_of_inv_wn.append((r, check_inversion(r, all_relations)))
    if check_anti_sym(r, all_relations):
        list_of_anti_sym_wn.append(r)    
    if check_transitivity(r, all_relations):
        list_of_trans_wn.append(r)
    for s in all_relations.keys():
        check_comp_r_s=check_composition(r, s, all_relations)
        if check_comp_r_s is not None:
            list_of_comp_wn.append((r,s, check_comp_r_s))

            
print ("list_of_sym: ", list_of_sym_wn)
print ("list_of_anti_sym: ", list_of_anti_sym_wn)
print ("list_of_inv: ", list_of_inv_wn)
print ("list_of_comp: ", list_of_comp_wn)
print ("list_of_trans: ", list_of_trans_wn)



print ("len list_of_sym: ", len(list_of_sym_wn))
print ("len list_of_anti_sym: ", len(list_of_anti_sym_wn))
print ("len list_of_inv: ", len(list_of_inv_wn))
print ("len list_of_comp: ", len(list_of_comp_wn))
print ("len list_of_trans: ",  len(list_of_trans_wn))

list_of_sym:  [1, 9, 10]
list_of_anti_sym:  [0, 2, 4, 6, 7, 8]
list_of_inv:  [(1, [1]), (9, [9]), (10, [10])]
list_of_comp:  []
list_of_trans:  []
len list_of_sym:  3
len list_of_anti_sym:  6
len list_of_inv:  3
len list_of_comp:  0
len list_of_trans:  0


In [104]:
def total_num_edges(list_of_rel, all_relations):
    total = 0
    for r in list_of_rel:
        total+=len(rdict_to_set(all_relations[r]))
    return total

In [50]:
def generate_test_set_for_rel_type(list_of_rel,test_file_loc, dst):
    #written in the same format as test2id.txt
    test_all_file = open(test_file_loc, 'r') 
    test_all = test_all_file.readlines()
    test_all_file.close()
    lines_to_remove = []
    for i in range(len(test_all)):
        if i!=0:
            h,t,r = list(map(int, test_all[i].strip().split()))
            if r not in list_of_rel:
                lines_to_remove.append(test_all[i])
    for line in lines_to_remove:
        test_all.remove(line)
    test_all[0] = str(len(test_all)-1) + "\n"
    with open(dst, 'w') as f:
        for item in test_all:
            f.write("%s" % item)

In [51]:
dataset = "WN18RR"
test_file = "../benchmarks/" + dataset +"/"+ "test2id.txt"
new_test_file = "../benchmarks/" + dataset +"/"+ "symtest2id.txt"
list_of_rel = list_of_sym_wn
generate_test_set_for_rel_type(list_of_rel, test_file, new_test_file)



In [55]:
new_test_file = "../benchmarks/" + dataset +"/"+ "antitest2id.txt"
list_of_rel = list_of_anti_sym_wn
generate_test_set_for_rel_type(list_of_rel, test_file, new_test_file)

In [56]:
def list_of_rest(list_of_sym, list_of_anti_sym, num_of_relation):
    rest = []
    for i in range(num_of_relation):
        if (i not in list_of_sym) and (i not in list_of_anti_sym):
            rest.append(i)
    return rest

In [57]:
rest_wn = list_of_rest(list_of_sym_wn, list_of_anti_sym_wn, 11)
rest_wn

[3, 5]

In [58]:
new_test_file = "../benchmarks/" + dataset +"/"+ "resttest2id.txt"
list_of_rel = rest_wn
generate_test_set_for_rel_type(list_of_rel, test_file, new_test_file)

In [63]:
dataset = "FB15K237"
test_file = "../benchmarks/" + dataset +"/"+ "test2id.txt"
new_test_file = "../benchmarks/" + dataset +"/"+ "symtest2id.txt"
list_of_rel = list_of_sym
generate_test_set_for_rel_type(list_of_rel, test_file, new_test_file)


In [66]:
new_test_file = "../benchmarks/" + dataset +"/"+ "antitest2id.txt"
list_of_rel = list_of_anti_sym
generate_test_set_for_rel_type(list_of_rel, test_file, new_test_file)

In [67]:
rest = list_of_rest(list_of_sym, list_of_anti_sym, 237)

new_test_file = "../benchmarks/" + dataset +"/"+ "resttest2id.txt"
list_of_rel = rest
generate_test_set_for_rel_type(list_of_rel, test_file, new_test_file)