In [4]:
import numpy as np
import os
import torch

In [5]:
data_path = '../data/wn18rr'
def read_triple(file_path, entity2id, relation2id):
    '''
    Read triples and map them into ids.
    '''
    triples = []
    with open(file_path) as fin:
        for line in fin:
            h, r, t = line.strip().split('\t')
            triples.append((entity2id[h], relation2id[r], entity2id[t]))
    return triples

with open(os.path.join(data_path, 'entities.dict')) as fin:
    entity2id = dict()
    for line in fin:
        eid, entity = line.strip().split('\t')
        entity2id[entity] = int(eid)

with open(os.path.join(data_path, 'relations.dict')) as fin:
    relation2id = dict()
    for line in fin:
        rid, relation = line.strip().split('\t')
        relation2id[relation] = int(rid)
            
train_triples = read_triple(os.path.join(data_path, 'train.txt'), entity2id, relation2id)
valid_triples = read_triple(os.path.join(data_path, 'valid.txt'), entity2id, relation2id)
test_triples = read_triple(os.path.join(data_path, 'test.txt'), entity2id, relation2id)

# triples = train_triples + valid_triples + test_triples
triples = train_triples # test_triples
triples = torch.LongTensor(triples)


# Categorize relations

In [6]:
num_relations = len(relation2id)

one_many, one_one, many_one, many_many = 0., 0., 0., 0.
one_many_num, one_one_num, many_one_num, many_many_num = 0., 0., 0., 0.

many_thresh = 1.1

relation_dict = {}

for i in range(num_relations):
    relation_mask = (triples[:, 1] == i)
    if torch.sum(relation_mask) == 0:
        relation_dict[list(relation2id.keys())[i]] = 'None'
        continue
    head = triples[relation_mask, 0].data.tolist()
    tail = triples[relation_mask, 2].data.tolist()
    head = set(head)
    tail = set(tail)
    
    pairs = triples[relation_mask, :]
    pairs_tail = pairs[:, 2].unsqueeze(1).expand(-1, len(tail))
    tensor_tail = torch.Tensor(list(tail)).view(1, len(tail))
    n_heads = (tensor_tail == pairs_tail).sum(dim=0)
    avg_head = torch.mean(n_heads.float())
    
    pairs_head = pairs[:, 0].unsqueeze(1).expand(-1, len(head))
    tensor_head = torch.Tensor(list(head)).view(1, len(head))
    n_tails = (tensor_head == pairs_head).sum(dim=0)
    avg_tail = torch.mean(n_tails.float())
    
    n = torch.sum(relation_mask).item()
    if avg_head > many_thresh:
        if avg_tail > many_thresh:
            cat = 'M-M'
            many_many += 1
            many_many_num += n
        else:
            cat = 'M-1'
            many_one += 1
            many_one_num += n
    else:
        if avg_tail > many_thresh:
            cat = '1-M'
            one_many += 1
            one_many_num += n
        else:
            cat = '1-1'
            one_one += 1
            one_one_num += n
    
    relation_dict[list(relation2id.keys())[i]] = cat
    print(i, list(relation2id.keys())[i], cat, n, avg_head.item(), avg_tail.item())
    
        
        
    
    

0 _hypernym M-1 34796 3.6627368927001953 1.0224194526672363
1 _derivationally_related_form M-M 29715 1.8446210622787476 1.8454229831695557
2 _instance_hypernym M-M 2921 7.230197906494141 1.18450927734375
3 _also_see M-M 1299 1.6505718231201172 1.8373408317565918
4 _member_meronym 1-M 7402 1.0084468126296997 2.391599416732788
5 _synset_domain_topic_of M-1 3116 10.084142684936523 1.0484522581100464
6 _has_part M-M 4816 1.2070175409317017 2.4347825050354004
7 _member_of_domain_usage 1-M 629 1.058922529220581 25.15999984741211
8 _member_of_domain_region 1-M 923 1.0572737455368042 8.096490859985352
9 _verb_group M-M 1138 1.1612244844436646 1.1635991334915161
10 _similar_to 1-1 80 1.0526316165924072 1.0389610528945923


In [13]:
relation_dict
print(one_many, one_one, many_one, many_many)
print(one_many_num, one_one_num, many_one_num, many_many_num)
print(one_many_num + one_one_num + many_one_num + many_many_num)
print(len(test_triples))

4.0 4.0 3.0 0.0
475.0 1172.0 1487.0 0.0
3134.0
3134


In [14]:
with open(os.path.join(data_path, 'relation_category.txt'), 'w') as f:
    for key, value in relation_dict.items():
        f.write('%s\t%s\n' % (key, value))