In [1]:
import csv
import numpy as np
import nltk
import keras
import keras.backend as K

Using TensorFlow backend.


In [2]:
nltk.download('wordnet')
nltk.download('omw')

[nltk_data] Downloading package wordnet to /home/marco/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw to /home/marco/nltk_data...
[nltk_data]   Package omw is already up-to-date!


True

### generate wordnet relation pairs

In [108]:
with open('./utils/wordnet_pairs.csv', 'w') as f:
    writer = csv.writer(f)
    for s in nltk.corpus.wordnet.all_synsets():
        for v in s.also_sees():
            writer.writerow((f's.{s.name()}', 's.also_sees', f's.{v.name()}'))
        for v in s.attributes():
            writer.writerow((f's.{s.name()}', 's.attributes', f's.{v.name()}'))
        for v in s.causes():
            writer.writerow((f's.{s.name()}', 's.causes', f's.{v.name()}'))
        for v in s.entailments():
            writer.writerow((f's.{s.name()}', 's.entailments', f's.{v.name()}'))
        for v in s.hypernyms():
            writer.writerow((f's.{s.name()}', 's.hypernyms', f's.{v.name()}'))
        for v in s.hyponyms():
            writer.writerow((f's.{s.name()}', 's.hyponyms', f's.{v.name()}'))
        for v in s.instance_hypernyms():
            writer.writerow((f's.{s.name()}', 's.instance_hypernyms', f's.{v.name()}'))
        for v in s.instance_hyponyms():
            writer.writerow((f's.{s.name()}', 's.instance_hyponyms', f's.{v.name()}'))
        for v in s.lemmas():
            writer.writerow((f's.{s.name()}', 's.lemmas', f'l.{v.name()}'))
        for v in s.member_holonyms():
            writer.writerow((f's.{s.name()}', 's.member_holonyms', f's.{v.name()}'))
        for v in s.member_meronyms():
            writer.writerow((f's.{s.name()}', 's.member_meronyms', f's.{v.name()}'))
        for v in s.part_holonyms():
            writer.writerow((f's.{s.name()}', 's.part_holonyms', f's.{v.name()}'))
        writer.writerow((f's.{s.name()}', 's.pos', f'p.{s.pos()}'))
        for v in s.region_domains():
            writer.writerow((f's.{s.name()}', 's.region_domains', f's.{v.name()}'))
        for v in s.root_hypernyms():
            writer.writerow((f's.{s.name()}', 's.root_hypernyms', f's.{v.name()}'))
        for v in s.similar_tos():
            writer.writerow((f's.{s.name()}', 's.similar_tos', f's.{v.name()}'))
        for v in s.substance_holonyms():
            writer.writerow((f's.{s.name()}', 's.substance_holonyms', f's.{v.name()}'))
        for v in s.substance_meronyms():
            writer.writerow((f's.{s.name()}', 's.substance_meronyms', f's.{v.name()}'))
        for v in s.topic_domains():
            writer.writerow((f's.{s.name()}', 's.topic_domains', f's.{v.name()}'))
        for v in s.usage_domains():
            writer.writerow((f's.{s.name()}', 's.usage_domains', f's.{v.name()}'))
        for v in s.verb_groups():
            writer.writerow((f's.{s.name()}', 's.verb_groups', f's.{v.name()}'))
        for v in s.frame_ids():
            writer.writerow((f's.{s.name()}', 's.frame_ids', f'f.{v}'))
    seen_lemma_keys = set()
    for l_name in nltk.corpus.wordnet.all_lemma_names():
        ls = nltk.corpus.wordnet.lemmas(l_name)
        for l in ls:
            if l.key() in seen_lemma_keys:
                continue
            seen_lemma_keys.add(l.key())
            for v in l.also_sees():
                writer.writerow((f'l.{l.key()}', 'l.also_sees', f'l.{v.key()}'))
            for v in l.antonyms():
                writer.writerow((f'l.{l.key()}', 'l.antonyms', f'l.{v.key()}'))
            for v in l.derivationally_related_forms():
                writer.writerow((f'l.{l.key()}', 'l.derivationally_related_forms', f'l.{v.key()}'))
            for v in l.frame_ids():
                writer.writerow((f'l.{l.key()}', 'l.frame_ids', f'f.{v}'))
            for v in l.pertainyms():
                writer.writerow((f'l.{l.key()}', 'l.pertainyms', f'l.{v.key()}'))
            for v in l.region_domains():
                writer.writerow((f'l.{l.key()}', 'l.region_domains', f'l.{v.key()}'))
            writer.writerow((f'l.{l.key()}', 'l.synset', f's.{l.synset().name()}'))
            writer.writerow((f'l.{l.key()}', 'l.syntactic_marker', f'sm.{l.syntactic_marker()}'))
            for v in l.topic_domains():
                writer.writerow((f'l.{l.key()}', 'l.topic_domains', f'l.{v.key()}'))
            for v in l.usage_domains():
                writer.writerow((f'l.{l.key()}', 'l.usage_domains', f'l.{v.key()}'))
            for v in l.verb_groups():
                writer.writerow((f'l.{l.key()}', 'l.verb_groups', f'l.{v.key()}'))
    del seen_lemma_keys

### postprocess relation pairs

In [3]:
entity_set = set()
relation_set = set()
with open('./utils/wordnet_pairs.csv', 'r') as f:
    reader = csv.reader(f)
    for row in reader:
        entity_set.add(row[0])
        entity_set.add(row[2])
        relation_set.add(row[1])
entity_lookup = { k+1:v for (k,v) in enumerate(entity_set)}
entity_index = { v:k+1 for (k,v) in enumerate(entity_set)}
entity_max = len(entity_set)
relation_lookup = { k+1:v for (k,v) in enumerate(relation_set)}
relation_index = { v:k+1 for (k,v) in enumerate(relation_set)}
relation_max = len(relation_set)
del entity_set
del relation_set
len(entity_lookup), len(relation_lookup)

(473382, 33)

In [4]:
dataset = []
with open('./utils/wordnet_pairs.csv', 'r') as f:
    reader = csv.reader(f)
    for row in reader:
        dataset.append((entity_index[row[0]], relation_index[row[1]], entity_index[row[2]]))
dataset = np.array(dataset, dtype='int')
dataset.shape

(1278206, 3)

### trans f

In [63]:
class TransF(keras.layers.Layer):
    
    def __init__(self, **kwargs):
        super(TransF, self).__init__(**kwargs)
    
    def build(self, input_shape):
        entity_shape, relation_shape = input_shape
        self.kernel_width = entity_shape[1]
        self.kernel = self.add_weight(
            shape=(relation_shape[1], entity_shape[1], entity_shape[1]),
            initializer='glorot_uniform', name='kernel')
        super(TransF, self).build(input_shape)
    
    def call(self, inputs):
        entity, relation = inputs
        relation = K.reshape(relation, K.concatenate([K.shape(relation), (1,1)]))
        relation = relation * self.kernel
        relation = K.sum(relation, axis=1, keepdims=False)
        relation = relation + K.eye(self.kernel_width)
        return (K.expand_dims(entity, axis=1) @ relation)[:,0]
    
    def compute_output_shape(self, input_shape):
        entity_shape, _ = input_shape
        return entity_shape

In [69]:
X_h = X_input_h = keras.layers.Input((1,), dtype='int32')
X_t = X_input_t = keras.layers.Input((1,), dtype='int32')
X_r = X_input_r = keras.layers.Input((1,), dtype='int32')
X_e_embedding = keras.layers.Embedding(entity_max+1, 32)
X_r_embedding = keras.layers.Embedding(relation_max+1, 8)
X_trans_f = TransF()
X_h = X_e_embedding(X_h)
X_h = keras.layers.Flatten()(X_h)
X_t = X_e_embedding(X_t)
X_t = keras.layers.Flatten()(X_t)
X_r = X_r_embedding(X_r)
X_r = keras.layers.Flatten()(X_r)
X_h = X_trans_f([X_h, X_r])
X_t = X_trans_f([X_t, X_r])
X = keras.layers.Subtract()([X_h, X_t])
X = keras.layers.Dot(-1)([X, X])
M_sample = keras.Model([X_input_h, X_input_r, X_input_t], X)
M_sample.compile('adam', 'mse')
M_sample.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_88 (InputLayer)           (None, 1)            0                                            
__________________________________________________________________________________________________
input_90 (InputLayer)           (None, 1)            0                                            
__________________________________________________________________________________________________
input_89 (InputLayer)           (None, 1)            0                                            
__________________________________________________________________________________________________
embedding_59 (Embedding)        (None, 1, 32)        15148256    input_88[0][0]                   
                                                                 input_89[0][0]                   
__________

In [70]:
M_sample.fit([dataset[:,0],dataset[:,1],dataset[:,2]], np.zeros((dataset.shape[0],1)), batch_size=512, epochs=1)

Epoch 1/1
  49664/1278206 [>.............................] - ETA: 8:40 - loss: 0.0020

KeyboardInterrupt: 