In [1]:
import torch.nn as nn
import torch
import numpy as np
import time
from collections import defaultdict
import pandas as pd
from torch.utils import data as torch_data
#from ..models import TransE, TransH, DistMult
import copy
import pickle
from numpy import genfromtxt
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [2]:
train = torch.as_tensor(genfromtxt('train2id.txt', delimiter='\t'), dtype=torch.int32)
valid = torch.as_tensor(genfromtxt('valid2id.txt', delimiter='\t'), dtype=torch.int32)
test = torch.as_tensor(genfromtxt('test2id.txt', delimiter='\t'), dtype=torch.int32)
all_triples = torch.cat((train,valid,test))

In [3]:
all_triples

tensor([[    0,     0,     1],
        [    2,     1,     3],
        [    4,     2,     5],
        ...,
        [ 1677,   177,   589],
        [12770,   131, 14420],
        [  452,    26,  1098]], dtype=torch.int32)

In [7]:
X = pd.DataFrame(all_triples.numpy(), columns=['s','p','o'])

In [8]:
X

Unnamed: 0,s,p,o
0,0,0,1
1,2,1,3
2,4,2,5
3,6,3,7
4,8,4,9
...,...,...,...
309460,8005,17,862
309461,7682,48,32
309462,1677,177,589
309463,12770,131,14420


In [10]:
X.groupby(by=["p"]).count()

Unnamed: 0_level_0,s,o
p,Unnamed: 1_level_1,Unnamed: 2_level_1
0,369,369
1,1426,1426
2,2619,2619
3,8491,8491
4,2279,2279
...,...,...
232,139,139
233,127,127
234,149,149
235,109,109


In [16]:
X[X.p==1].nunique()[0]

271

In [15]:
X[X.p==17]

Unnamed: 0,s,p,o
21,41,17,42
31,59,17,60
70,130,17,131
117,214,17,215
126,231,17,215
...,...,...,...
309313,5869,17,1730
309319,8159,17,967
309340,2287,17,131
309384,3431,17,871


In [18]:
unique_ent = {}
for p in X.p.unique():
    tmp = X[X.p==p].nunique()
    unique_ent[p] = {'s': tmp[0], 'o': tmp[2]}
    

In [19]:
unique_ent

{0: {'s': 192, 'o': 6},
 1: {'s': 271, 'o': 968},
 2: {'s': 101, 'o': 1434},
 3: {'s': 1612, 'o': 1618},
 4: {'s': 745, 'o': 7},
 5: {'s': 744, 'o': 6},
 6: {'s': 2119, 'o': 1791},
 7: {'s': 241, 'o': 1622},
 8: {'s': 2336, 'o': 2344},
 9: {'s': 120, 'o': 121},
 10: {'s': 260, 'o': 140},
 11: {'s': 1624, 'o': 5},
 12: {'s': 1137, 'o': 1651},
 13: {'s': 1353, 'o': 180},
 14: {'s': 1867, 'o': 61},
 15: {'s': 908, 'o': 3009},
 16: {'s': 199, 'o': 4},
 17: {'s': 4382, 'o': 152},
 18: {'s': 113, 'o': 37},
 19: {'s': 3448, 'o': 405},
 20: {'s': 318, 'o': 1089},
 21: {'s': 524, 'o': 461},
 22: {'s': 540, 'o': 2238},
 23: {'s': 330, 'o': 24},
 24: {'s': 2783, 'o': 724},
 25: {'s': 4363, 'o': 1},
 26: {'s': 1287, 'o': 38},
 27: {'s': 141, 'o': 1626},
 28: {'s': 1754, 'o': 64},
 29: {'s': 58, 'o': 209},
 30: {'s': 1032, 'o': 235},
 31: {'s': 1883, 'o': 123},
 32: {'s': 76, 'o': 1644},
 33: {'s': 458, 'o': 1},
 34: {'s': 1469, 'o': 7},
 35: {'s': 59, 'o': 66},
 36: {'s': 89, 'o': 20},
 37: {'s': 

In [21]:
rels_del = []
for key, value in unique_ent.items():
    if unique_ent[key]['s'] < 10 or unique_ent[key]['o'] < 10:
        rels_del.append(key)

In [22]:
rels_del

[0,
 4,
 5,
 11,
 16,
 25,
 33,
 34,
 43,
 55,
 60,
 73,
 78,
 80,
 85,
 90,
 91,
 108,
 110,
 113,
 115,
 118,
 120,
 126,
 128,
 129,
 132,
 139,
 142,
 156,
 161,
 166,
 167,
 168,
 171,
 172,
 195,
 197,
 200,
 204,
 206,
 208,
 209,
 210,
 211,
 212,
 213,
 215,
 217,
 219,
 220,
 225,
 232,
 233,
 234]

In [24]:
test.shape

torch.Size([20407, 3])

In [26]:
test_del = test.detach().clone()

In [29]:
test_del = pd.DataFrame(test.numpy(), columns=['s','p','o'])

In [38]:
new_test = test_del[~test_del.iloc[:,1].isin(rels_del)]

In [40]:
np.savetxt('test2id.txt', new_test[['s', 'p', 'o']].values, fmt='%d', delimiter='\t')

In [47]:
valid_del = pd.DataFrame(valid.numpy(), columns=['s','p','o'])

In [48]:
new_valid = valid_del[~valid_del.iloc[:,1].isin(rels_del)]

In [49]:
np.savetxt('valid2id.txt', new_valid[['s', 'p', 'o']].values, fmt='%d', delimiter='\t')

In [50]:
new_valid

Unnamed: 0,s,p,o
0,6404,31,1211
1,4089,154,1437
3,8379,47,3498
4,2364,13,235
6,8252,51,160
...,...,...,...
17477,7344,17,862
17478,2248,7,1079
17479,6881,26,990
17481,6116,13,2091


In [51]:
valid.shape

torch.Size([17483, 3])