In [1]:
from main import CrossLingualPipeline
from embedding_transform import EmbeddingTransform
from transformers import BertForMaskedLM, BertTokenizer
from polyglot.text import Text, Word


import pickle
import os
import torch
import numpy as np

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
clpl = CrossLingualPipeline()

epoch: 0 loss: 21.933106154203415




In [2]:
def split_dict(ratio=0.5):
    with open('data/filtered_en_nl_dict.txt', 'r', encoding="utf-8") as f:
        all_sents = f.read().splitlines()
        n_sents = len(all_sents)
        
        n_train_sents = int(n_sents * ratio)
        training_sents = all_sents[:n_train_sents]
        eval_sents = all_sents[n_train_sents:]
        return training_sents, eval_sents

def train_transform_matrix(transform_type="SGD", save_fn="./models/stored_embeddings-{}.pkl", transform_subset=None, sgd_epochs=1):
    tokens = "[MASK]"
    emb_tokens = bert_model.bert.embeddings.word_embeddings(torch.tensor([bert_tokenizer.encode(tokens, add_special_tokens=True)], device='cpu'))
    srt_tok = emb_tokens[0, 0, :]
    msk_tok = emb_tokens[0, 1, :]
    end_tok = emb_tokens[0, 2, :]
    if transform_subset:
        save_fn = save_fn.format(transform_subset)
    else:
        save_fn = save_fn.format("all")


    traning_sents, _ = split_dict()
    en_words_to_embed = [words.split('\t')[1] for words in traning_sents]
    nl_words_to_embed = [words.split('\t')[0] for words in traning_sents]

    if transform_subset:
        # Only use a fraction of the most common words when calculating the
        # transformation

        tokens = []
        for line in traning_sents:
            tokens += [token.lower() for token in line.split()]
        token_counts = Counter(tokens)
        filt_token_counts = Counter({k: token_counts.get(k, 0) for k in nl_words_to_embed}).most_common(int(len(nl_words_to_embed) * transform_subset))
        most_common_subset = [w[0] for w in filt_token_counts]

        most_common_ind = []
        for word in most_common_subset:
            most_common_ind.append(nl_words_to_embed.index(word))

        new_nl_words_to_embed = [nl_words_to_embed[i] for i in most_common_ind]
        new_en_words_to_embed = [en_words_to_embed[i] for i in most_common_ind]

        nl_words_to_embed = new_nl_words_to_embed
        en_words_to_embed = new_en_words_to_embed

    english_embeddings = clpl.get_english_embeddings(en_words_to_embed)
    dutch_embeddings = clpl.get_dutch_embeddings(nl_words_to_embed)

    transform = EmbeddingTransform(transform_type, dutch_embeddings, english_embeddings, str_tok=srt_tok, end_tok=end_tok, sgd_epochs=sgd_epochs)
    return transform

In [9]:
def compare_embeddings():
    
        transform = train_transform_matrix(sgd_epochs=50)
        english_word_embeddings = {}
        dutch_word_embeddings = {}
        
        _, eval_sents = split_dict()
        en_words_to_embed = [words.split('\t')[1] for words in eval_sents]
        nl_words_to_embed = [words.split('\t')[0] for words in eval_sents]

        for word in nl_words_to_embed:
            poly_word = Word(word, language="nl")
            dutch_word_embeddings[word] = poly_word.vector
    
        word_indices = torch.tensor(bert_tokenizer.encode(en_words_to_embed, add_special_tokens=True)[1:-1]).unsqueeze(0)
        all_word_embeddings = bert_model.bert.embeddings.word_embeddings
        english_embeddings = all_word_embeddings(word_indices)
        
        for index, english_word in enumerate(en_words_to_embed):
            english_word_embeddings[english_word] = english_embeddings.squeeze()[index]

        correct_words = 0
        for dutch_word in dutch_word_embeddings:
            
            dutch_transformed_word = dutch_word_embeddings[dutch_word] @ transform.transform.detach().numpy()
            best_word = (None, None, 0)
            for index, english_word in enumerate(english_word_embeddings):
                english_word_embedding = english_word_embeddings[english_word]
                similarity = np.dot(english_word_embedding.detach().numpy(), dutch_transformed_word)/(np.linalg.norm(english_word_embedding.detach().numpy())*np.linalg.norm(dutch_transformed_word))
                if similarity > best_word[2]:
                    best_word = (english_word, list(dutch_word_embeddings.keys())[index], similarity)
            if best_word[1] == dutch_word:
                correct_words += 1
#             print('{} -> {} -> {}'.format(dutch_word, best_word[0], best_word[1]))
        print('{}/{}'.format(correct_words, len(dutch_word_embeddings)))


In [11]:
compare_embeddings()

number of dutch embeddings: 384
epoch: 0 loss: 15.596981406211853
epoch: 1 loss: 9.251447260379791
epoch: 2 loss: 6.679577589035034
epoch: 3 loss: 5.1524423360824585
epoch: 4 loss: 4.097177952528
epoch: 5 loss: 3.3872907757759094
epoch: 6 loss: 2.835872545838356
epoch: 7 loss: 2.4036504477262497
epoch: 8 loss: 2.0670585930347443
epoch: 9 loss: 1.7946406453847885
epoch: 10 loss: 1.569934755563736
epoch: 11 loss: 1.384336806833744
epoch: 12 loss: 1.2289033457636833
epoch: 13 loss: 1.0948253273963928
epoch: 14 loss: 0.9826774969696999
epoch: 15 loss: 0.8869808912277222
epoch: 16 loss: 0.802455946803093
epoch: 17 loss: 0.7292896956205368
epoch: 18 loss: 0.6653983779251575
epoch: 19 loss: 0.6087809056043625
epoch: 20 loss: 0.5596024468541145
epoch: 21 loss: 0.5160383246839046
epoch: 22 loss: 0.47702304646372795
epoch: 23 loss: 0.44134124368429184
epoch: 24 loss: 0.409800972789526
epoch: 25 loss: 0.38152410089969635
epoch: 26 loss: 0.35578494891524315
epoch: 27 loss: 0.3324795011430979
epoch