In [1]:
import torch
import warnings
warnings.filterwarnings("ignore", message='source.*has changed')
import numpy as np
from transformers import *
from language_models_utils import load_languagemodel, get_logprobability_word
from vectors_utils import nearest_neighbors, cosine_similarity
from interpretation_model import get_expected_and_lexical_vectors, combine_expected_and_lexical

cuda = False

### Load language model

In [2]:
# model_type : LSTM, or BERT-large/BERT-base
model_type = 'LSTM'
if model_type.startswith('BERT'):
    from transformers import *
language_model, word_emb_matrix, vocab = load_languagemodel(model_type, model_dir = '../LSTM-LM/', cuda = cuda)

### Specify context and target word

In [3]:
context = ' Last night I was bored , but I found a very nice show to watch on TV '
target = 'show'
index_target = context.split().index(target)

In [4]:
expected, lexical = get_expected_and_lexical_vectors(language_model, vocab, context, index_target, model_type = model_type)

### Specify parameters for combination and combine

In [13]:
combination_type = 'delta'
alpha = 0.5
interpretation_combined = combine_expected_and_lexical(expected, lexical, combination_type =combination_type, alpha_param = alpha)

### Check outputs

In [15]:
expected_nns = nearest_neighbors(expected, word_emb_matrix, vocab, model_type = model_type, n=10, into_words = True)
lexical_nns = nearest_neighbors(lexical, word_emb_matrix, vocab, model_type = model_type, n=10, into_words = True)
combined_nns = nearest_neighbors(combined, word_emb_matrix, vocab, model_type = model_type, n=10, into_words = True)

cosine_exp_lex = cosine_similarity(expected, lexical)
cosine_exp_interpretation = cosine_similarity(expected, interpretation_combined)
surprisal_word = - get_logprobability_word(language_model, context, index_target, vocab, cuda =cuda, model_type = model_type)
print('Closest words to expectations:', expected_nns)
print('Closest words to lexical information:', lexical_nns)
print('Closest words to interpretation:', expected_nns)

print('Surprisal of the word given the context:', surprisal_word)
print('Cosine between expected and lexical representation:', cosine_exp_lex)
print('Cosine between expectations and interpretation:', cosine_exp_interpretation)

Closest words to expectations: ['picture', 'man', 'movie', 'film', 'place', 'girl', 'book', 'game', 'time', 'bird']
Closest words to lexical information: ['show', 'shows', 'showed', 'showing', 'demonstrate', 'Show', 'display', 'exhibit', 'reveal', 'indicate']
Closest words to interpretation: ['picture', 'man', 'movie', 'film', 'place', 'girl', 'book', 'game', 'time', 'bird']
Surprisal of the word given the context: 6.0083666
Cosine between expected and lexical representation: 0.18985112011432648
Cosine between expectations and interpretation: 0.8976683020591736
