In [1]:
from bertviz import head_view
#from transformers import BertTokenizer, BertModel

In [2]:
# Adding module to sys path
import sys
sys.path.append("../MEDeA/")
# RNN imports
import medea
from medea import torch, numpy as np

In [3]:
from medea.models.composite_model import CompositeModel
from medea.training.model_trainer import ModelTrainer
from medea.inputs.data.build_and_embed.data_builder import DatasetBuilderEmbedder
from medea.inputs.data.read import MedeaDatasetReader, get_all_phones, get_all_phones_to_frequency
from medea.inputs.parameters.embeddings import EmbeddingParams, EmbeddingParamsOneLang
from medea.utils.shuffling import ShuffleType
from medea.utils import BatchInfo

In [4]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

## Data

In [15]:
path_params = "../MEDeA/datasets/file_parameters/default/"
path_data = "../MEDeA/datasets/romance_languages/cog/"
test_data = MedeaDatasetReader(folder_path=path_data, langs=["es", "it", "la"], name="Triple", phonetized=True)

lang_pairs = [["es", "it", "la"],
              ["it", "la"], ["es", "it"], ["es", "la"],
             ]
tokens = ["EOW", "es", "it", "la"]
phones = [p for p in set(get_all_phones(path_data, lang_pairs)) if p not in tokens]

all_words = sorted(phones)
vocab_dim = len(all_words) + len(tokens)

embed_1_params = EmbeddingParamsOneLang.from_json_file(f"{path_params}embed_1_parameters.json")
embed_params = {}
for lang in ["es", "it", "la"]:
    embed_params[lang] = embed_1_params
    embed_params[lang].phone_embedding.phone_categories = list(all_words)
    embed_params[lang].pseudotoken_embedding.list = tokens
embed_params = EmbeddingParams(embed_params)

# Used for finetuning - we keep all the pairs
train_set = DatasetBuilderEmbedder.build(data_list=[test_data], parameters=embed_params)

[INFO] By setting the padding length to -1, you mean that you do not want padding.
[INFO] 3 duplicate words present in set 'Triple' were removed.


In [16]:
batches = train_set.to_batch(ShuffleType.DESCENDING, batch_size=1, 
                             langs= ["es", "it", "la"], lang_of_reference_for_shuffling="es")
batch_info = train_set.get_batch_info()

In [114]:
batch = batches[10]

## Model

In [139]:
num_head = 4
epoch = 20
model_path = "/Users/cfourrie/Desktop/TACL_experiments/Experiment2/" + \
            f"Transformer{num_head}head_16_18_1/results/experiment/models/epoch_{epoch}/"
model = CompositeModel.load(model_path)


In [140]:

predictions, attn = model.predict(batch, batch_info)

In [141]:
in_lang = "la"
out_lang = "la"

local_prediction = [
    train_set[out_lang].ix_to_item[int(c)]
    for c in predictions["decoder"][f"{in_lang}_{out_lang}"][0][0]
]
local_target = [
    train_set[out_lang].ix_to_item[int(c)]
    for c in batch[out_lang][0]
]
local_input = [
    train_set[in_lang].ix_to_item[int(c)]
    for c in batch[in_lang][0]
]

local_attn = torch.stack(
    [attn[f"{in_lang}_{out_lang}"][0][f"decoder_attn/layer0_head{h}"]
     for h in range(num_head)]).unsqueeze(0)
local_self_attn = torch.stack(
    [attn[f"{in_lang}_{out_lang}"][0][f"decoder_self_attn/layer0_head{h}"]
     for h in range(num_head)]).unsqueeze(0)


In [144]:
print(local_self_attn)

tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [5.2877e-01, 4.7123e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [5.2276e-01, 4.7723e-01, 3.3665e-06, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [6.1307e-01, 3.8692e-01, 2.5238e-07, 1.1609e-05, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [5.8275e-01, 4.1653e-01, 5.1633e-05, 6.6152e-04, 9.2256e-06,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0

In [143]:
self = True
if self:
    print("Decoder self attention")
    head_view([local_self_attn], local_prediction, local_prediction, prettify_tokens=False)
else:
    print("Decoder attention")
    head_view([local_attn], local_prediction, local_input, prettify_tokens=False)



Decoder self attention


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>