In [8]:
import torch
from bertviz import head_view
from transformers import MT5TokenizerFast, MT5ForConditionalGeneration
from datasets import load_dataset

from overfit_attention import prepare_text_input


In [None]:
# Some features are shared among all models
shortest_article_ids = [260, 1301, 2088, 665, 1572, 436, 1887, 1422, 1506, 474]

dataset = load_dataset("dennlinger/klexikon")
tokenizer = MT5TokenizerFast.from_pretrained("google/mt5-small")

for idx in shortest_article_ids:
    # Load index-specific model
    model = MT5ForConditionalGeneration.from_pretrained(f"./{idx}")
    
    sample = dataset["train"][idx]
    # Prepare with sensible border tokens. Decoder needs to start with <pad>
    wiki_text = f"<extra_id_0> {prepare_text_input(sample['wiki_sentences'])}"
    klexikon_text = f"<pad> {prepare_text_input(sample['klexikon_sentences'])}"

    # Prepare forward pass
    model_inputs = tokenizer(wiki_text, return_tensors="pt")
    decoder_inputs = tokenizer(klexikon_text, return_tensors="pt")
    model_inputs["decoder_input_ids"] = decoder_inputs["input_ids"]
    
    result = model(input_ids=model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"],
                   decoder_input_ids=model_inputs["decoder_input_ids"], output_attentions=True,
                   labels=model_inputs["decoder_input_ids"])
    
    # Check predicted tokens for sanity check
    predicted_ids = torch.argmax(result.logits.detach().to("cpu"), dim=-1)
    print(tokenizer.decode(predicted_ids[0]))


Using custom data configuration dennlinger--klexikon-33d8b47837d0742e
Reusing dataset json (/home/dennis/.cache/huggingface/datasets/json/dennlinger--klexikon-33d8b47837d0742e/0.0.0/c90812beea906fcffe0d5e3bb9eba909a80a998b5f88e9f8acbd320aa91acfde)


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:

head_view(cross_attention=result.cross_attentions,
          encoder_tokens=tokenizer.convert_ids_to_tokens(model_inputs["input_ids"][0]),
          decoder_tokens=tokenizer.convert_ids_to_tokens(model_inputs["decoder_input_ids"][0]))