In [1]:
import sys
sys.path.append('../')
from transformers import BertTokenizer, BertModel
import torch
import itertools
from probe.load_data import WordInspectionDataset
from scipy.spatial.distance import cosine
from statistics import mean 



In [2]:
def main():
    tokenizer = BertTokenizer.from_pretrained("bert-large-uncased", do_lower_case=True)
    feature_extraction_model = BertModel.from_pretrained('bert-large-uncased')
    batch_size = 20  # totally arbitrarily chosen

    pdata = WordInspectionDataset('vec_sim_test.txt', tokenizer)
    dataset = pdata.get_data()
    embedding_outputs, encoded_inputs, indices = pdata.bert_word_embeddings(feature_extraction_model,
                                                                            pdata.get_encoded(), batch_size)
    sentence_embeddings = pdata.aggregate_sentence_embeddings(embedding_outputs, encoded_inputs, indices,
                                                              aggregation_metric=torch.mean)

    idiom_sentence_indexes = get_idiom_sentences(dataset)

    cosine_metrics = [calculate_similarity_metrics(idiom_sent_idx, tokenizer, dataset, embedding_outputs, encoded_inputs) 
                        for idiom_sent_idx in idiom_sentence_indexes]
    
    

In [3]:
def calculate_similarity_metrics(idiom_sent_index, tokenizer, dataset, embedding_outputs, encoded_inputs):
    idiom_ex = dataset[idiom_sent_index]
    idiom_word_embedding = get_word_embedding(tokenizer, dataset, embedding_outputs, encoded_inputs, idiom_sent_index)
    cosine_similarity_metrics = {}

    literal_usage_sents = [i for i, ex in enumerate(dataset) if ex.pair_id == idiom_ex.pair_id and 
                                                            ex.word == idiom_ex.word and not 
                                                            ex.sentence_id == idiom_ex.sentence_id ]
    paraphrase_sents = [i for i, ex in enumerate(dataset) if ex.pair_id == idiom_ex.pair_id 
                                                            and not ex.word == idiom_ex.word]

    literal_usage_embeddings = [get_word_embedding(tokenizer, dataset, embedding_outputs, encoded_inputs, lit_idx) for lit_idx in literal_usage_sents]
    paraphrase_embeddings = [get_word_embedding(tokenizer, dataset, embedding_outputs, encoded_inputs, para_idx) for para_idx in paraphrase_sents]

    cosine_similarity_metrics['fig_to_literal'] = calculate_cosine_similarity_average([idiom_word_embedding], literal_usage_embeddings)
    cosine_similarity_metrics['literal_to_literal'] = calculate_cosine_similarity_average(literal_usage_embeddings)
    cosine_similarity_metrics['fig_to_paraphrase'] = calculate_cosine_similarity_average([idiom_word_embedding], paraphrase_embeddings)
    cosine_similarity_metrics['literal_to_paraphrase'] = calculate_cosine_similarity_average(literal_usage_embeddings, paraphrase_embeddings)
    
    return {
        'sentence_id': idiom_ex.sentence_id,
        'sentence': idiom_ex.sentence,
        'word': idiom_ex.word,
        'paraphrase_word': dataset[paraphrase_sents[0]].word,
        'cosine_similarities': cosine_similarity_metrics,
    }

In [4]:
def get_idiom_sentences(dataset):
    return [i for i, ex in enumerate(dataset) if ex.figurative]

In [5]:
def get_sentences_for_idiom_sentence(dataset, idiom_sent):
    literal_usage_sents = [i for i, ex in enumerate(dataset) if ex.pair_id == idiom_sent.pair_id and 
                                                                ex.word == idiom_sent.word and not 
                                                                ex.sentence_id == idiom_sent.sentence_id ]
    paraphrase_sents = [i for i, ex in enumerate(dataset) if ex.pair_id == idiom_sent.pair_id and not 
                                                            ex.word == idiom_sent.word]
    return (literal_usage_sents, paraphrase_sents)

In [6]:
def get_word_embedding(tokenizer, dataset, embedding_outputs, encoded_inputs, dataset_index):
    ex = dataset[dataset_index]
    decoded_tokens = tokenizer.convert_ids_to_tokens(encoded_inputs[dataset_index].tolist())
    word_index = decoded_tokens.index(ex.word[0])
    return embedding_outputs[dataset_index][word_index]

In [7]:
def calculate_cosine_similarity_average(embeddings_1, embeddings_2=None):
    if embeddings_2:
        embedding_pairs = list(itertools.product(embeddings_1, embeddings_2))
    else:
        embedding_pairs = list(itertools.combinations(embeddings_1, 2))

    cosine_similarities = [1 - cosine(embedding_1, embedding_2) for embedding_1, embedding_2 in embedding_pairs]
    return mean(cosine_similarities)

In [8]:
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased", do_lower_case=True)
feature_extraction_model = BertModel.from_pretrained('bert-large-uncased')
batch_size = 20  # totally arbitrarily chosen

In [9]:
pdata = WordInspectionDataset('vec_sim_test.txt', tokenizer)
dataset = pdata.get_data()
embedding_outputs, encoded_inputs, indices = pdata.bert_word_embeddings(feature_extraction_model,
                                                                            pdata.get_encoded(), batch_size)
sentence_embeddings = pdata.aggregate_sentence_embeddings(embedding_outputs, encoded_inputs, indices,
                                                              aggregation_metric=torch.mean)

idiom_sentence_indexes = get_idiom_sentences(dataset)

HBox(children=(IntProgress(value=1, bar_style='info', description='Feature extraction', max=1, style=ProgressS…

processed 20/60 sentences, current max sentence length 17
processed 40/60 sentences, current max sentence length 17
processed 60/60 sentences, current max sentence length 21



In [10]:
cosine_metrics = [calculate_similarity_metrics(idiom_sent_idx, tokenizer, dataset, embedding_outputs, encoded_inputs) 
                        for idiom_sent_idx in idiom_sentence_indexes]

In [11]:
cosine_metrics

[{'sentence_id': 1,
  'sentence': ['the', 'cat', 'is', 'out', 'of', 'the', 'bag', '.'],
  'word': ['cat'],
  'paraphrase_word': ['secret'],
  'cosine_similarities': {'fig_to_literal': 0.8095400002267625,
   'literal_to_literal': 0.881085894174046,
   'fig_to_paraphrase': 0.5255098730325699,
   'literal_to_paraphrase': 0.5296435475349426}},
 {'sentence_id': 21,
  'sentence': ['soon', 'we', "'", 're', 'going', 'to', 'hit', 'the', 'sack'],
  'word': ['sack'],
  'paraphrase_word': ['bed'],
  'cosine_similarities': {'fig_to_literal': 0.3896622094843123,
   'literal_to_literal': 0.6366828166776233,
   'fig_to_paraphrase': 0.3006304562091827,
   'literal_to_paraphrase': 0.47102577686309816}},
 {'sentence_id': 41,
  'sentence': ['it', 'is', 'time', 'to', 'bite', 'the', 'bullet'],
  'word': ['bullet'],
  'paraphrase_word': ['situation'],
  'cosine_similarities': {'fig_to_literal': 0.6265564627117581,
   'literal_to_literal': 0.8031015197436014,
   'fig_to_paraphrase': 0.37428863942623136,
   'l

In [13]:
sentence_embeddings.shape

torch.Size([60, 1024])

In [14]:
encoded_inputs.shape

torch.Size([60, 21])

In [17]:
for row in encoded_inputs:
    print(tokenizer.decode(row.tolist()))
    print()

[CLS] the cat is out of the bag. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

[CLS] we saw a cat yesterday. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

[CLS] the cat is a afraid of dogs. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

[CLS] there was a cat on the roof. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

[CLS] how many hours a day will a cat sleep. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

[CLS] she thought she heard a cat crying. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

[CLS] nobody wanted to keep the cat. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

[CLS] the cat had bright blue eyes. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

[CLS] he asked for a cat for his birthday. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

[