In [1]:
from easynmt import EasyNMT

In [2]:
SRC_LANG = "en"
TRG_LANG = "ru"

MODEL_TYPE = "opus-mt"
MODEL_NAME = f"Helsinki-NLP/opus-mt-{SRC_LANG}-{TRG_LANG}"

In [3]:
model = EasyNMT(MODEL_TYPE)
translator = model.translator.load_model(MODEL_NAME)

In [4]:
sentences = [
    "A caravan of fifty two camels slowly made its way through the desert.",
    "Wood may remain ten years in the water, but it will never become a crocodile.",
    "The horse raced past the barn fell.",
    "We haven't really spoken much since your return.",
    "I'll bring the wine glasses.",
    "This is the best tasting pear I've ever eaten.",
    "Dentists recommend to change toothbrushes every three months, because over time their bristles become worse at getting rid of plague, as well as accumulate microbes.",
    "Your boss was so impressed with your skills, she gave you a raise.",
]

model.translate(sentences, source_lang=SRC_LANG, target_lang=TRG_LANG)

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


['Караван из пятидесяти двух верблюдов медленно прошел через пустыню.',
 'Древесина может оставаться в воде десять лет, но она никогда не станет крокодилом.',
 'Лошадь бежала мимо амбара.',
 'Мы почти ничего не говорили с тех пор, как вы вернулись.',
 'Я принесу винные очки.',
 'Это лучший вкус груши, который я когда-либо ел.',
 'Дантисты рекомендуют менять зубные щетки каждые три месяца, потому что со временем их щетки становятся еще хуже, когда они избавляются от чумы, а также накапливают микробы.',
 'Твой босс был так впечатлен твоими навыками, что она дала тебе повышение.']

## Извлечение Encoder Attention

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
from utils.attn_extraction import get_attn_scores

In [7]:
tokenizer = model.translator.models["Helsinki-NLP/opus-mt-en-ru"]["tokenizer"]
translator = model.translator.models["Helsinki-NLP/opus-mt-en-ru"]["model"]

In [8]:
import inspect

inspect.signature(translator.model.decoder.layers[0].encoder_attn.forward)

<Signature (hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]>

Hooks do not capture kwargs in pytorch ([issue](https://github.com/pytorch/pytorch/issues/35643)). Seems like input is passed to attention as kwarg, hence the hook does not capture any input.

* What to do with beam search? Depending on `beam_size` amount of passes through decoder layer changes

    | beam size       | 1  | 2  | 3  | 4  | 5  | 6  | ... | 10 |
    | --------------- | -- | -- | -- | -- | -- | -- | --- | -- |
    | output "tokens" | 20 | 21 | 21 | 23 | 23 | 23 | ... | 23 |


In [62]:
a = get_attn_scores(sentences[1], model, MODEL_NAME, SRC_LANG, TRG_LANG)

In [65]:
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

In [66]:
for idx, sentence in enumerate(sentences[:3]):
    tokenization = tokenizer([sentence], truncation=True, padding=True, return_tensors="pt")
    input_tokens = tokenizer.convert_ids_to_tokens(tokenization["input_ids"][0])
    with torch.no_grad():
        translated = translator.generate(**tokenization, num_beams=5)
        output_tokens = tokenizer.convert_ids_to_tokens(translated[0])[1:]  # first token is always <pad>
    
    print(f"Input ({len(input_tokens)} tokens):", *input_tokens)
    print(f"Output ({len(output_tokens)} tokens):", *output_tokens)
    print()

Input (16 tokens): ▁A ▁caravan ▁of ▁fifty ▁two ▁came ls ▁slowly ▁made ▁its ▁way ▁through ▁the ▁desert . </s>
Output (18 tokens): ▁Кар аван ▁из ▁пятидесят и ▁двух ▁в ер блюд ов ▁медленно ▁прошел ▁через ▁пуст ы ню . </s>

Input (18 tokens): ▁Wood ▁may ▁remain ▁ten ▁years ▁in ▁the ▁water , ▁but ▁it ▁will ▁never ▁become ▁a ▁crocodile . </s>
Output (23 tokens): ▁Д ре вес ина ▁может ▁оставаться ▁в ▁воде ▁десять ▁лет , ▁но ▁она ▁никогда ▁не ▁станет ▁к рок од ил ом . </s>

Input (11 tokens): ▁The ▁horse ▁race d ▁past ▁the ▁bar n ▁fell . </s>
Output (15 tokens): ▁Ло ша д ь ▁ бежал а ▁ мимо ▁а м бар а . </s>



In [11]:
def prep_graph(inc_mat, input_tokens, output_tokens):
    input_tokens = ["src " + tok for tok in input_tokens]
    output_tokens = ["tgt " + tok for tok in output_tokens]

    G = nx.Graph()
    G.add_nodes_from(input_tokens, bipartite=0)
    G.add_nodes_from(output_tokens, bipartite=1)

    G.add_edges_from([(output_tokens[i], input_tokens[j]) for i, j in np.argwhere(inc_mat).T.numpy()])

    pos = {}
    pos.update((node, (1, index)) for index, node in enumerate(input_tokens[::-1]))
    pos.update((node, (2, index)) for index, node in enumerate(output_tokens[::-1]))

    return G, pos

In [112]:
thresh = 0.25

layer_of_interest = 5

hook_name = f"decoder_l{layer_of_interest}"
head_attn = a[hook_name]
for idx, sentence in enumerate(sentences):
    tokenization = tokenizer([sentence], truncation=True, padding=True, return_tensors="pt")
    input_tokens = tokenizer.convert_ids_to_tokens(tokenization["input_ids"][0])
    with torch.no_grad():
        translated = translator.generate(**tokenization, num_beams=5)
        output_tokens = tokenizer.convert_ids_to_tokens(translated[0])
    print(get_attn_scores([sentence], model, MODEL_NAME, SRC_LANG, TRG_LANG)["decoder_l0"].shape)
    print(input_tokens, len(input_tokens))
    print(output_tokens, len(output_tokens))
    attn = F.normalize(head_attn)
    inc_mat = (attn > thresh).int()

torch.Size([1, 20, 512])
['▁A', '▁caravan', '▁of', '▁fifty', '▁two', '▁came', 'ls', '▁slowly', '▁made', '▁its', '▁way', '▁through', '▁the', '▁desert', '.', '</s>'] 16
['<pad>', '▁Кар', 'аван', '▁из', '▁пятидесят', 'и', '▁двух', '▁в', 'ер', 'блюд', 'ов', '▁медленно', '▁прошел', '▁через', '▁пуст', 'ы', 'ню', '.', '</s>'] 19
torch.Size([1, 23, 512])
['▁Wood', '▁may', '▁remain', '▁ten', '▁years', '▁in', '▁the', '▁water', ',', '▁but', '▁it', '▁will', '▁never', '▁become', '▁a', '▁crocodile', '.', '</s>'] 18
['<pad>', '▁Д', 'ре', 'вес', 'ина', '▁может', '▁оставаться', '▁в', '▁воде', '▁десять', '▁лет', ',', '▁но', '▁она', '▁никогда', '▁не', '▁станет', '▁к', 'рок', 'од', 'ил', 'ом', '.', '</s>'] 24
torch.Size([1, 14, 512])
['▁The', '▁horse', '▁race', 'd', '▁past', '▁the', '▁bar', 'n', '▁fell', '.', '</s>'] 11
['<pad>', '▁Ло', 'ша', 'д', 'ь', '▁', 'бежал', 'а', '▁', 'мимо', '▁а', 'м', 'бар', 'а', '.', '</s>'] 16
torch.Size([1, 14, 512])
['▁We', '▁haven', "'", 't', '▁really', '▁spoken', '▁much', 