In [27]:
import sys
if '..' not in sys.path:
    sys.path.append('..')

In [28]:
import tensorflow as tf
sess = tf.InteractiveSession()

In [29]:
path = '../output/hdn-large.2018-05-21-b1d1867-best-model'

In [30]:
saver = tf.train.import_meta_graph(path + '.meta', clear_devices=True)
saver.restore(sess, path)

INFO:tensorflow:Restoring parameters from ../output/hdn-large.2018-05-21-b1d1867-best-model


In [31]:
def load_tensors(sess):
    x = sess.graph.get_tensor_by_name('Model/x:0')
    logits = sess.graph.get_tensor_by_name('Model/Max:0') # should have had a name
    lens = sess.graph.get_tensor_by_name('Model/lens:0')
    candidates = sess.graph.get_tensor_by_name('Model/candidate_list:0')
    
    return x, logits, lens, candidates

x, logits, lens, candidates = load_tensors(sess)

In [32]:
import generate_hdn_datasets
from block_timer.timer import Timer
import pickle

word_vocab_path = '../output/vocab.2018-05-10-7d764e7.pkl'
word2id = pickle.load(open(word_vocab_path, 'rb'))
hdn_vocab_path = '../output/hdn-vocab.2018-05-18-f48a06c.pkl'
hdn2id = pickle.load(open(hdn_vocab_path, 'rb'))
hdn_list_vocab_path = '../output/hdn-list-vocab.2018-05-18-f48a06c.pkl'
hdn_list2id = pickle.load(open(hdn_list_vocab_path, 'rb'))

In [33]:
from evaluate.wn_utils import synset2identifier
from nltk.corpus import wordnet as wn

id2synset = {synset2identifier(ss, '30'): ss for ss in  wn.all_synsets('n')}

In [39]:
from nltk.corpus import wordnet as wn

def synsets_graph_info(wn_instance, wn_version, lemma, pos):
    """
    extract:
    1. hyponym under lowest least common subsumer

    :param nltk.corpus.reader.wordnet.WordNetCorpusReader wn_instance: instance
    of nltk.corpus.reader.wordnet.WordNetCorpusReader
    :param str wn_version: supported: '171' | '21' | '30'
    :param str lemma: a lemma
    :param str pos: a pos

    :rtype: dict
    :return: mapping synset_id 
        -> 'under_lcs' -> under_lcs identifier
        -> 'path_to_under_lcs' -> [sy1_iden, sy2_iden, sy3_iden, ...]
    """
    sy_id2under_lcs_info = dict()

    synsets = wn_instance.synsets(lemma, pos=pos)

    if len(synsets) == 1:
        target_sy_iden = synset2identifier(synsets[0], wn_version)
        sy_id2under_lcs_info[target_sy_iden] = {'under_lcs': None,
                                                'path_to_under_lcs': []}
        return sy_id2under_lcs_info

    for sy1 in synsets:

        target_sy_iden = synset2identifier(sy1, wn_version)

        min_path_distance = 100
        closest_lcs = None

        for sy2 in synsets:
            if sy1 != sy2:
                lcs_s = sy1.lowest_common_hypernyms(sy2)
                lcs = lcs_s[0]

                path_distance = sy1.shortest_path_distance(lcs)

                if path_distance < min_path_distance:
                    closest_lcs = lcs
                    min_path_distance = path_distance

        under_lcs = None
        for hypernym_path in sy1.hypernym_paths():
            for first, second in zip(hypernym_path, hypernym_path[1:]):
                if first == closest_lcs:
                    under_lcs = second

                    index_under_lcs = hypernym_path.index(under_lcs)
                    path_to_under_lcs = hypernym_path[index_under_lcs + 1:-1]

                    under_lcs_iden = synset2identifier(under_lcs, wn_version)
                    path_to_under_lcs_idens = [synset2identifier(synset, wn_version)
                                               for synset in path_to_under_lcs]

                    sy_id2under_lcs_info[target_sy_iden] = {'under_lcs': under_lcs_iden,
                                                            'under_lcs_obj' : under_lcs,
                                                            'path_to_under_lcs': path_to_under_lcs_idens}

    return sy_id2under_lcs_info

def find_hdns(lemma):
    graph_info = synsets_graph_info(wn_instance=wn,
                                wn_version='30',
                                lemma=lemma,
                                pos='n')
    hdn2synset = {info['under_lcs']: synset for synset, info in graph_info.items()}
    hdn_list = tuple(sorted(info['under_lcs'] # sorted to avoid arbitrary order
                        for info in graph_info.values() 
                        if info['under_lcs']))
    return hdn_list, hdn2synset

def find_path_to_hdns(lemma):
    graph_info = synsets_graph_info(wn_instance=wn,
                                wn_version='30',
                                lemma=lemma,
                                pos='n')
    return {synset: info['path_to_under_lcs'] for synset, info in graph_info.items()}

hdn_list, hdn2synset = find_hdns('study')
print([id2synset[i].name() for i in hdn_list])
print({id2synset[k].name(): id2synset[v].name() for k, v in hdn2synset.items()})
print(find_path_to_hdns('study'))

['physical_entity.n.01', 'event.n.01', 'creation.n.02', 'structure.n.01', 'ability.n.02', 'basic_cognitive_process.n.01', 'higher_cognitive_process.n.01', 'content.n.05', 'written_communication.n.01', 'auditory_communication.n.01']
{'event.n.01': 'survey.n.01', 'basic_cognitive_process.n.01': 'study.n.02', 'written_communication.n.01': 'report.n.01', 'ability.n.02': 'study.n.04', 'structure.n.01': 'study.n.05', 'content.n.05': 'discipline.n.01', 'creation.n.02': 'sketch.n.01', 'higher_cognitive_process.n.01': 'cogitation.n.02', 'physical_entity.n.01': 'study.n.09', 'auditory_communication.n.01': 'study.n.10'}
{'eng-30-00644503-n': ['eng-30-00030358-n', 'eng-30-00407535-n', 'eng-30-00575741-n', 'eng-30-00633864-n', 'eng-30-00635850-n'], 'eng-30-05755883-n': ['eng-30-05752544-n'], 'eng-30-07218470-n': ['eng-30-06362953-n', 'eng-30-06470073-n'], 'eng-30-05705355-n': ['eng-30-05650329-n', 'eng-30-05650579-n', 'eng-30-05704266-n'], 'eng-30-04345028-n': ['eng-30-02735688-n', 'eng-30-04105893

In [35]:
target_id = word2id['<target>']

In [36]:
sentence_tokens = 'In one of these , an exploding wire device to study systems thermodynamically up to 6000 * * f and 100 atmospheres pressure , a major goal was achieved .'.lower().split()
sentence_as_ids = [word2id.get(w) or word2id['<unkn>'] for w in sentence_tokens]

target_index = len(sentence_tokens)-4
lemma = sentence_tokens[target_index]
sentence_as_ids[target_index] = target_id
hdn_list, hdn2synset = find_hdns(lemma)
feed_dict = {x: [sentence_as_ids],
             lens: [len(sentence_as_ids)],
             candidates: [hdn_list2id[hdn_list]]}
target_embeddings = sess.run(logits, feed_dict=feed_dict)
scores = [target_embeddings[0,hdn2id[hdn]] for hdn in hdn_list]
_, best_hdn = max(zip(scores, hdn_list))
print(lemma)
print('\t' + '/'.join(id2synset[hdn].name() for hdn in hdn_list))
print('\t--> %s (%s)' %(id2synset[best_hdn].name(), 
                        id2synset[hdn2synset[best_hdn]].name()))

goal
goal
	whole.n.02/cognition.n.01/location.n.01/event.n.01
	--> whole.n.02 (goal.n.03)


In [37]:
find_hdns('obstacle')

(('eng-30-00001930-n', 'eng-30-00002137-n'),
 {'eng-30-00002137-n': 'eng-30-05690269-n',
  'eng-30-00001930-n': 'eng-30-03839795-n'})