In [1]:
from helper_classes import Data
import util as ut

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Embedding, Flatten,Dropout
from tensorflow.keras.regularizers import l2


import warnings
import logging
from scipy import stats
import numpy as np
import random
import torch
from sklearn.preprocessing import MultiLabelBinarizer
random_state = 1
np.random.seed(random_state)
random.seed(random_state)
warnings.filterwarnings('ignore')

In [2]:
dataset='FB15K'
kg_root = 'KGs/'+dataset+'/'
dataset=Data(kg_root)

In [3]:
y=[]
x=[]

entitiy_idx=dict()

sub_obj_pairs= dataset.get_entity_pairs_with_predicates(dataset.train_data)
for s_o_pair, predicates in sub_obj_pairs.items():
    s,o=s_o_pair
    entitiy_idx.setdefault(s, len(entitiy_idx))
    entitiy_idx.setdefault(o, len(entitiy_idx))
    x.append([entitiy_idx[s],entitiy_idx[o]])
    y.append(list(predicates))
x=np.array(x)

In [4]:
x=np.array(x)
binarizer=MultiLabelBinarizer()
y=binarizer.fit_transform(y)

In [5]:
embedding_dim=50
model = Sequential()
model.add(Embedding(len(entitiy_idx), embedding_dim, input_length=2,activity_regularizer=l2(0.1)))
model.add(Flatten())
model.add(Dropout(.4))
model.add(Dense(embedding_dim*5, activation='relu',activity_regularizer=l2(0.1)))
model.add(Dropout(.4))
model.add(Dense(y.shape[1], activation='sigmoid'))
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (None, 2, 50)             747550    
_________________________________________________________________
flatten (Flatten)            (None, 100)               0         
_________________________________________________________________
dropout (Dropout)            (None, 100)               0         
_________________________________________________________________
dense (Dense)                (None, 250)               25250     
_________________________________________________________________
dropout_1 (Dropout)          (None, 250)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 1345)              337595    
Total params: 1,110,395
Trainable params: 1,110,395
Non-trainable params: 0
______________________________________________

In [6]:
history = model.fit(x, y, batch_size=1000, epochs= 30, use_multiprocessing=True,verbose=1,shuffle=True)

Train on 394804 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


In [7]:
del x
del y

In [8]:
def evaluation(model,binarizer,dataset,triples):
    x_=[]
    y_=[]
    
    hits = []
    ranks=[]
    for i in range(10):
        hits.append([])

    rank_per_relation=dict()
        
    for i in triples:
        s,p,o=i
        x_.append((entitiy_idx[s],entitiy_idx[o]))
        y_.append(p)
    
    
    tensor_pred=torch.from_numpy(model.predict(np.array(x_)))
    
    _, ranked_predictions =tensor_pred.topk(k=len(binarizer.classes_))
    
    ranked_predictions=ranked_predictions.numpy()
    
    
    assert len(ranked_predictions)==len(y_)
    
    classes_=binarizer.classes_.tolist()
        
    for i in range(len(y_)):
        true_relation=y_[i]
        ith_class=classes_.index(true_relation)
        
        rank = np.where(ranked_predictions[i]==ith_class)[0]
        

        rank_per_relation.setdefault(true_relation, []).append(rank+1)
        
        ranks.append(rank+1)
        
        for hits_level in range(10):
            if rank <= hits_level:
                hits[hits_level].append(1.0)
    

    hits=np.array(hits)
    ranks=np.array(ranks)
    print('########## Relation Prediction Results ##########')

    print('Mean Hits @5: {0}'.format(sum(hits[4]) / (float(len(y_)))))
    print('Mean Hits @3: {0}'.format(sum(hits[2]) / (float(len(y_)))))
    print('Mean Hits @1: {0}'.format(sum(hits[0]) / (float(len(y_)))))
    print('Mean rank: {0}'.format(np.mean(ranks)))
    print('Mean reciprocal rank: {0}'.format(np.mean(1. / ranks)))
    
    print('########## Relation Prediction Analysis ##########')
    
    for pred, ranks in rank_per_relation.items():
        ranks=np.array(ranks)
        
        average_hit_at_1=np.sum(ranks==1)/len(ranks)
        average_hit_at_3=np.sum(ranks<=3)/len(ranks)
        average_hit_at_5=np.sum(ranks<=5)/len(ranks)
        
        print('{0}:\t Hits@1:\t{1:.3f}'.format(pred,average_hit_at_1))
        print('{0}:\t Hits@3:\t{1:.3f}'.format(pred,average_hit_at_3))
        print('{0}:\t Hits@5:\t{1:.3f}'.format(pred,average_hit_at_5))
        print('{0}:\t MRR:\t{1:.3f}\t number of occurrence {2}'.format(pred,np.mean(1. / ranks),len(ranks)))
        print('################################')
    

In [9]:
del dataset.train_data # free memory.
del dataset.valid_data

In [10]:
evaluation(model,binarizer,dataset,dataset.test_data)

########## Relation Prediction Results ##########
Mean Hits @5: 0.9871679842900916
Mean Hits @3: 0.9492475157014441
Mean Hits @1: 0.7343705032926479
Mean rank: 1.5986693978432733
Mean reciprocal rank: 0.8438899278081052
########## Relation Prediction Analysis ##########
/award/award_nominee/award_nominations./award/award_nomination/award:	 Hits@1:	0.994
/award/award_nominee/award_nominations./award/award_nomination/award:	 Hits@3:	1.000
/award/award_nominee/award_nominations./award/award_nomination/award:	 Hits@5:	1.000
/award/award_nominee/award_nominations./award/award_nomination/award:	 MRR:	0.997	 number of occurrence 1555
################################
/base/activism/activist/area_of_activism:	 Hits@1:	0.800
/base/activism/activist/area_of_activism:	 Hits@3:	1.000
/base/activism/activist/area_of_activism:	 Hits@5:	1.000
/base/activism/activist/area_of_activism:	 MRR:	0.889	 number of occurrence 15
################################
/travel/travel_destination/climate./travel/travel

/government/form_of_government/countries:	 Hits@1:	1.000
/government/form_of_government/countries:	 Hits@3:	1.000
/government/form_of_government/countries:	 Hits@5:	1.000
/government/form_of_government/countries:	 MRR:	1.000	 number of occurrence 44
################################
/american_football/football_team/current_roster./american_football/football_roster_position/position:	 Hits@1:	0.597
/american_football/football_team/current_roster./american_football/football_roster_position/position:	 Hits@3:	0.710
/american_football/football_team/current_roster./american_football/football_roster_position/position:	 Hits@5:	0.887
/american_football/football_team/current_roster./american_football/football_roster_position/position:	 MRR:	0.711	 number of occurrence 62
################################
/organization/organization/child./organization/organization_relationship/child:	 Hits@1:	0.810
/organization/organization/child./organization/organization_relationship/child:	 Hits@3:	0.952
/org

/education/educational_institution/sexes_accepted./education/gender_enrollment/sex:	 Hits@1:	1.000
/education/educational_institution/sexes_accepted./education/gender_enrollment/sex:	 Hits@3:	1.000
/education/educational_institution/sexes_accepted./education/gender_enrollment/sex:	 Hits@5:	1.000
/education/educational_institution/sexes_accepted./education/gender_enrollment/sex:	 MRR:	1.000	 number of occurrence 4
################################
/user/robert/us_congress/congressional_district/state:	 Hits@1:	0.000
/user/robert/us_congress/congressional_district/state:	 Hits@3:	0.000
/user/robert/us_congress/congressional_district/state:	 Hits@5:	0.000
/user/robert/us_congress/congressional_district/state:	 MRR:	0.059	 number of occurrence 1
################################
/user/mt/default_domain/metabolite/associated_disorder:	 Hits@1:	0.000
/user/mt/default_domain/metabolite/associated_disorder:	 Hits@3:	0.000
/user/mt/default_domain/metabolite/associated_disorder:	 Hits@5:	0.000
/us