In [1]:
from bertviz import head_view
#from transformers import BertTokenizer, BertModel

In [2]:
import os

In [3]:
# Adding module to sys path
import sys
sys.path.append("../MEDeA/")
# RNN imports
import medea
from medea import torch, numpy as np

In [4]:
from medea.models.composite_model import CompositeModel
from medea.training.model_trainer import ModelTrainer
from medea.inputs.data.build_and_embed.data_builder import DatasetBuilderEmbedder
from medea.inputs.data.read import MedeaDatasetReader, get_all_phones, get_all_phones_to_frequency
from medea.inputs.parameters.embeddings import EmbeddingParams, EmbeddingParamsOneLang
from medea.utils.shuffling import ShuffleType
from medea.utils import BatchInfo

In [5]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

## Data

In [6]:
path_params = "../MEDeA/datasets/file_parameters/default/"
path_data = "../MEDeA/datasets/romance_languages/cog/"
test_data = MedeaDatasetReader(folder_path=path_data, langs=["es", "it", "la"], name="Triple", phonetized=True)

lang_pairs = [["es", "it", "la"],
              ["it", "la"], ["es", "it"], ["es", "la"],
             ]
tokens = ["EOW", "es", "it", "la"]
phones = [p for p in set(get_all_phones(path_data, lang_pairs)) if p not in tokens]

all_words = sorted(phones)
vocab_dim = len(all_words) + len(tokens)

embed_1_params = EmbeddingParamsOneLang.from_json_file(f"{path_params}embed_1_parameters.json")
embed_params = {}
for lang in ["es", "it", "la"]:
    embed_params[lang] = embed_1_params
    embed_params[lang].phone_embedding.phone_categories = list(all_words)
    embed_params[lang].pseudotoken_embedding.list = tokens
embed_params = EmbeddingParams(embed_params)

# Used for finetuning - we keep all the pairs
train_set = DatasetBuilderEmbedder.build(data_list=[test_data], parameters=embed_params)

[INFO] By setting the padding length to -1, you mean that you do not want padding.
[INFO] 3 duplicate words present in set 'Triple' were removed.


In [79]:
batches = train_set.to_batch(ShuffleType.NONE, batch_size=1, 
                             langs= ["es", "it", "la"], lang_of_reference_for_shuffling="es")
batch_info = train_set.get_batch_info()

In [115]:
train_set.data["es"].data_origin.index(['enfermar']) #'ira' 797

759

In [117]:
batch = batches[759]

In [118]:
print([
    train_set[out_lang].ix_to_item[int(c)]
    for c in batch[in_lang][0]
])
print([
    train_set[out_lang].ix_to_item[int(c)]
    for c in batch[out_lang][0]
])

['es', 'ɛ', 'm', 'f', 'ɛ', 'ɾ', 'm', 'a', 'ɾ', 'EOW']
['it', 'i', 'n', 'f', 'i', 'r', 'm', 'a:', 'r', 'e', 'EOW']


## Model

In [147]:
num_head = 3
epoch = 8

transformer = True
train = False

if transformer:
    #model_path = os.path.expanduser("~/Desktop/TACL_experiments/Experiment2/" + \
    #            f"Transformer{num_head}head_16_18_1/results/experiment/models/epoch_{epoch}/")
    
    if num_head == 3:
        model_path = os.path.expanduser("~/Desktop/MEDeA/runs/" + \
                f"Transformer3head_2020-09-23 16:27:46.844470/results/experiment/models/epoch_{epoch}/")
    else:
        model_path = os.path.expanduser("~/Desktop/MEDeA/runs/" + \
                f"Transformer1head_2020-09-23 16:23:00.862310/results/experiment/models/epoch_{epoch}/")
else:
    model_path = os.path.expanduser("~/Desktop/TACL_experiments/Experiment1/" + \
            f"EncDecBidir_16_36_0/results/experiment/models/epoch_{epoch}/")

model = CompositeModel.load(model_path)


In [148]:
if train:
    predictions, _ = model.predict(batch, batch_info)
    _, attn = model(batch, batch_info, gold_target={"decoder":batch})
else:
    predictions, attn = model.predict(batch, batch_info)

In [149]:
in_lang = "es"
out_lang = "it"

local_prediction = [
    train_set[out_lang].ix_to_item[int(c)]
    for c in predictions["decoder"][f"{in_lang}_{out_lang}"][0][0]
]
local_target = [
    train_set[out_lang].ix_to_item[int(c)]
    for c in batch[out_lang][0]
]
local_input = [
    train_set[in_lang].ix_to_item[int(c)]
    for c in batch[in_lang][0]
]

if train:
    if "Transformer" in model_path:
        local_self_attn = torch.stack(
            [attn[f"{in_lang}_{out_lang}"][f"decoder_self_attn/layer0_head{h}"]
             for h in range(num_head)]).transpose(0, 1)

        local_enc_attn = torch.stack(
            [attn[f"{in_lang}_{out_lang}"][f"encoder_attn/layer0_head{h}"]
             for h in range(num_head)]).transpose(0, 1)
        local_attn = torch.stack(
            [attn[f"{in_lang}_{out_lang}"][f"decoder_attn/layer0_head{h}"]
             for h in range(num_head)]).transpose(0, 1)
    else:
        local_attn = attn[f"{in_lang}_{out_lang}"][f"decoder_attn"].unsqueeze(0)
else:
    if "Transformer" in model_path:
        local_self_attn = torch.stack(
            [attn[f"{in_lang}_{out_lang}"][0][f"decoder_self_attn/layer0_head{h}"]
             for h in range(num_head)]).unsqueeze(0)

        local_enc_attn = torch.stack(
            [attn[f"{in_lang}_{out_lang}"][0][f"encoder_attn/layer0_head{h}"]
             for h in range(num_head)], dim=0).transpose(0, 1)
        local_attn = torch.stack(
            [attn[f"{in_lang}_{out_lang}"][0][f"decoder_attn/layer0_head{h}"]
             for h in range(num_head)]).unsqueeze(0)
    else:
        local_attn = attn[f"{in_lang}_{out_lang}"][0][f"decoder_attn"].unsqueeze(0)


In [150]:
enc = False

self = True
if enc:
    print("Encoder attention")
    head_view([local_enc_attn], local_input, local_input, prettify_tokens=False) 
else:
    if self:
        print("Decoder self attention")
        head_view([local_self_attn], local_prediction, local_prediction, prettify_tokens=False)
    else:
        print("Decoder attention")
        head_view([local_attn], local_prediction, local_input, prettify_tokens=False)



Decoder self attention


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>