In [1]:
import torch
import os

In [16]:
data_path = '../data/wn18rr_small'

In [46]:
def read_triple(file_path, entity2id, relation2id):
    '''
    Read triples and map them into ids.
    '''
    triples = []
    with open(file_path) as fin:
        for line in fin:
            h, r, t = line.strip().split('\t')
            triples.append((entity2id[h], relation2id[r], entity2id[t]))
    return triples

with open(os.path.join(data_path, 'entities.dict')) as fin:
    entity2id = dict()
    for line in fin:
        eid, entity = line.strip().split('\t')
        entity2id[entity] = int(eid)

with open(os.path.join(data_path, 'relations.dict')) as fin:
    relation2id = dict()
    for line in fin:
        rid, relation = line.strip().split('\t')
        relation2id[relation] = int(rid)

train_triples = read_triple(os.path.join(data_path, 'train.txt'), entity2id, relation2id)
train_triples = torch.LongTensor(train_triples)

valid_triples = read_triple(os.path.join(data_path, 'valid.txt'), entity2id, relation2id)
valid_triples = torch.LongTensor(valid_triples)

test_triples = read_triple(os.path.join(data_path, 'test.txt'), entity2id, relation2id)
test_triples = torch.LongTensor(test_triples)
triples = train_triples

rid2cid = dict()
category2id = {'1-1': 0, '1-M': 1, 'M-1': 2, 'M-M': 3, 'None': -1}

with open(os.path.join(data_path, 'relation_category.txt')) as fin:
    for line in fin:
        relation, category = line.strip().split('\t')
        rid2cid[relation2id[relation]] = category2id[category]


In [47]:
# Create a column for storing relation category
category = torch.zeros(triples.shape[0], 1).long()
for i in range(triples.shape[0]):
    category[i] = rid2cid[triples[i, 1].item()]
triples = torch.cat([triples, category], dim=-1)
print(triples.shape)

torch.Size([1247, 4])


In [48]:
replaced_selected_valid_triples = torch.load('replaced_selected_valid_triples.pt')
replaced_selected_triples = torch.load('replaced_selected_triples.pt')
replaced_selected_test_triples = torch.load('replaced_selected_test_triples.pt')
print(torch.equal(replaced_selected_test_triples, triples[:, 0:3]))
print(replaced_selected_valid_triples)
print(triples)

True
tensor([[23654,     0, 11345],
        [13957,     0, 22153],
        [ 4873,     0,  7817],
        ...,
        [12018,     0,   715],
        [ 9552,     0,  9551],
        [16557,     0,  6297]])
tensor([[13267,     4, 23498,     0],
        [ 3239,     0,  5350,     0],
        [17637,     0, 17843,     0],
        ...,
        [23467,     2, 22667,     1],
        [18872,     0,  9297,     0],
        [ 5176,     0,  5175,     0]])


# Search for 1-to-N

In [20]:
mode = '1-to-N'

maximum = 0
for relation_id in range(len(relation2id)):
    mask = torch.nonzero(triples[:, 1] == relation_id).squeeze(1)
    print(relation_id, rid2cid[relation_id], torch.nonzero(mask).shape[0])


0 2 34795
1 0 29715
2 2 2921
3 0 1299
4 1 7402
5 2 3116
6 1 4816
7 1 629
8 1 923
9 0 1138
10 0 80


# Select train triples

In [5]:
# relation_id = 

relation_mask = (triples[:, 1] == 4) | (triples[:, 1] == 7)
one_one_mask = triples[:, 3] == 0
mask = relation_mask | one_one_mask

selected_triples = triples[torch.nonzero(mask).squeeze(1), ...]
print(selected_triples.shape)

# print(triples[nonzero_relation_mask[0:100], :])

# selected_indices = nonzero_relation_mask.tolist()

# for i in range(nonzero_relation_mask.shape[0]):
    
#     head_id = triples[nonzero_relation_mask[i], 0]
#     # print(i, nonzero_relation_mask[i], head_id)

#     mask = (triples[:, 0] == head_id) & relation_mask
#     nonzero = torch.nonzero(mask)
#     tails_id = triples[nonzero, 2]

#     for j in range(tails_id.shape[0]):
#         tail_mask = (triples[:, 2] == tails_id[j]) & (triples[:, 3] == 0)
#         selected_indices += torch.nonzero(tail_mask).tolist()

# print(len(selected_indices))
# selected_indices = torch.LongTensor(selected_indices).squeeze(1)
# selected_triples = triples[selected_indices, :]

torch.Size([40263, 4])


In [6]:
print(selected_triples.shape)
head_set = selected_triples[:, 0].tolist()
tail_set = selected_triples[:, 2].tolist()
entity_set_train = set(head_set + tail_set)
relation_set_train = set(selected_triples[:, 1].tolist())
print(len(entity_set_train), len(relation_set_train))
print(relation_set_train)

torch.Size([40263, 4])
25172 6
{1, 3, 4, 7, 9, 10}


# Select valid and test triples

In [7]:
def select_triples(triples, entity_set, relation_set):
    select_indices = []
    for i in range(triples.shape[0]):
        h, r, t = triples[i, :]
        if (h.item() in entity_set) and (r.item() in relation_set) and (t.item() in entity_set):
            select_indices.append(i)
    select_indices = torch.LongTensor(select_indices)
    selected_triples = triples[select_indices, :]
    return selected_triples

selected_test_triples = select_triples(test_triples, entity_set_train, relation_set_train)
selected_valid_triples = select_triples(valid_triples, entity_set_train, relation_set_train)

print(selected_test_triples.shape, selected_valid_triples.shape)

torch.Size([1247, 3]) torch.Size([1225, 3])


In [8]:
head_set = selected_test_triples[:, 0].tolist()
tail_set = selected_test_triples[:, 2].tolist()
entity_set = set(head_set + tail_set)
relation_set = set(selected_test_triples[:, 1].tolist())
print(len(entity_set), len(relation_set))
print(relation_set)
for i in entity_set:
    assert i in entity_set_train


2303 6
{1, 3, 4, 7, 9, 10}


# Generate Dataset

In [9]:
new_dataset = 'wn18rr_small'

## switch indexes to zero-based

In [15]:
switch_relation_id = {}
switch_entity_id = {}

relation_set_train = list(relation_set_train)
entity_set_train = list(entity_set_train)

for i in range(len(relation_set_train)):
    switch_relation_id[relation_set_train[i]] = i
    print('relation id %d is switch to %d' % (relation_set_train[i], i))
print(switch_relation_id)

for i in range(len(entity_set_train)):
    switch_entity_id[entity_set_train[i]] = i
    # print('entity id %d is switch to %d' % (entity_set_train[i], i))
    
def replace_id(triples, switch_entity_id, switch_relation_id):
    temp = triples.clone()
    for i in range(triples.shape[0]):
        temp[i, 0] = switch_entity_id[triples[i, 0].item()]
        temp[i, 2] = switch_entity_id[triples[i, 2].item()]
        
        temp[i, 1] = switch_relation_id[triples[i, 1].item()]

    return temp

replaced_selected_triples = replace_id(selected_triples, switch_entity_id, switch_relation_id)
replaced_selected_test_triples = replace_id(selected_test_triples, switch_entity_id, switch_relation_id)
replaced_selected_valid_triples = replace_id(selected_valid_triples, switch_entity_id, switch_relation_id)

torch.save(replaced_selected_triples, 'replaced_selected_triples.pt')
torch.save(replaced_selected_test_triples, 'replaced_selected_test_triples.pt')
torch.save(replaced_selected_valid_triples, 'replaced_selected_valid_triples.pt')

relation id 1 is switch to 0
relation id 3 is switch to 1
relation id 4 is switch to 2
relation id 7 is switch to 3
relation id 9 is switch to 4
relation id 10 is switch to 5
{1: 0, 3: 1, 4: 2, 7: 3, 9: 4, 10: 5}


## Generate entities.dict

In [11]:
id2entity = dict()
for key, value in entity2id.items():
    id2entity[value] = key

new_entity2id = dict()
new_id2entity = dict()

count = 0
for i in entity_set_train:
    entity = id2entity[i]
    new_entity2id[entity] = count
    new_id2entity[count] = entity
    count += 1


In [12]:
with open(os.path.join(new_dataset, 'entities.dict'), 'w') as f:
    for key, value in new_id2entity.items():
        f.write('%s\t%s\n' % (key, value))
f.close()

## Generate .txt files

In [13]:
new_id2relation = {
    0: '_derivationally_related_form',
    1: '_also_see',
    2: '_member_meronym',
    3: '_member_of_domain_usage',
    4: '_verb_group',
    5: '_similar_to'
}


def write_txt(triples, id2entity, id2relation, filename='train.txt'):
    with open(filename, 'w') as f:
        for i in range(triples.shape[0]):
            head_id, relation_id, tail_id = triples[i, 0:3]
            # print(head_id.item(), relation_id.item())
            head = id2entity[head_id.item()]
            tail = id2entity[tail_id.item()]
            relation = id2relation[relation_id.item()]
            f.write('%s\t%s\t%s\n' % (head, relation, tail))
    f.close()

write_txt(replaced_selected_triples, new_id2entity, new_id2relation, os.path.join(new_dataset, 'train.txt'))
write_txt(replaced_selected_test_triples, new_id2entity, new_id2relation, os.path.join(new_dataset, 'test.txt'))
write_txt(replaced_selected_valid_triples, new_id2entity, new_id2relation, os.path.join(new_dataset, 'valid.txt'))
            