In [1]:
from easynmt import EasyNMT

In [2]:
SRC_LANG = "en"
TRG_LANG = "ru"

MODEL_TYPE = "opus-mt"
MODEL_NAME = f"Helsinki-NLP/opus-mt-{SRC_LANG}-{TRG_LANG}"

In [3]:
model = EasyNMT(MODEL_TYPE)
translator = model.translator.load_model(MODEL_NAME)

In [4]:
sentences = [
    "A caravan of fifty two camels slowly made its way through the desert.",
    "Wood may remain ten years in the water, but it will never become a crocodile.",
    "The horse raced past the barn fell.",
    "We haven't really spoken much since your return.",
    "I'll bring the wine glasses.",
    "This is the best tasting pear I've ever eaten.",
    "Dentists recommend to change toothbrushes every three months, because over time their bristles become worse at getting rid of plague, as well as accumulate microbes.",
    "Your boss was so impressed with your skills, she gave you a raise.",
]

model.translate(sentences, source_lang=SRC_LANG, target_lang=TRG_LANG, beam_size=1)

['Караван из пятидесяти двух верблюдов медленно проехал свой путь по пустыне.',
 'Древесина может оставаться в воде десять лет, но она никогда не станет крокодилом.',
 'Лошадь бежала мимо сарая.',
 'Мы не особо много говорили с тех пор, как вы вернулись.',
 'Я принесу винные очки.',
 'Это лучший вкус груши, который я когда-либо ел.',
 'Дантисты рекомендуют менять зубные щетки каждые три месяца, потому что со временем их щетки становятся все хуже, когда они избавляются от чумы, а также накапливают микробы.',
 'Твой босс был так впечатлен твоими навыками, что она дала тебе повышение.']

## Извлечение Encoder Attention

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
from utils.attn_extraction import get_attn_scores

In [7]:
tokenizer = model.translator.models["Helsinki-NLP/opus-mt-en-ru"]["tokenizer"]
translator = model.translator.models["Helsinki-NLP/opus-mt-en-ru"]["model"]

In [8]:
import inspect

inspect.signature(translator.model.decoder.layers[0].encoder_attn.forward)

<Signature (hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]>

Hooks do not capture kwargs in pytorch ([issue](https://github.com/pytorch/pytorch/issues/35643)). Seems like input is passed to attention as kwarg, hence the hook does not capture any input.

* What to do with beam search? Depending on `beam_size` amount of passes through decoder layer changes

    | beam size       | 1  | 2  | 3  | 4  | 5  | 6  | ... | 10 |
    | --------------- | -- | -- | -- | -- | -- | -- | --- | -- |
    | output "tokens" | 20 | 21 | 21 | 23 | 23 | 23 | ... | 23 |


In [9]:
all_attns = get_attn_scores(sentences[1], model, MODEL_NAME, SRC_LANG, TRG_LANG)

# Visualized examples

In [13]:
from bertviz import head_view
import torch

In [14]:
def tokensInputOutput(sentence, tokenizer, translator):
    tokenization = tokenizer([sentence], truncation=True, padding=True, return_tensors="pt")
    input_tokens = tokenizer.convert_ids_to_tokens(tokenization["input_ids"][0])
    with torch.no_grad():
        translated = translator.generate(**tokenization, num_beams=1)
        output_tokens = tokenizer.convert_ids_to_tokens(translated[0])[1:]  # first token is always <pad>
    
    return input_tokens, output_tokens

### Example of good translation

In [15]:
sentence = sentences[2]

input_tokens, output_tokens = tokensInputOutput(sentence, tokenizer, translator)
attn = get_attn_scores(sentence, model, MODEL_NAME, SRC_LANG, TRG_LANG)

In [16]:
head_view(cross_attention=attn.values(), encoder_tokens=input_tokens, decoder_tokens=output_tokens, layer=4)

<IPython.core.display.Javascript object>

### Example of bad translation

In [17]:
sentence = sentences[4]

input_tokens, output_tokens = tokensInputOutput(sentence, tokenizer, translator)
attn = get_attn_scores(sentence, model, MODEL_NAME, SRC_LANG, TRG_LANG)

In [18]:
head_view(cross_attention=attn.values(), encoder_tokens=input_tokens, decoder_tokens=output_tokens, layer=4)

<IPython.core.display.Javascript object>