## Imports

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import functools
import itertools

In [4]:
from src.utils.logits import TokenizerDebugger

## Consts

In [5]:
model_name = "distilgpt2"
tokenizer_name = model_name
init_sentence = """Q: What is the name of the president of the United States and when was he born?
A:"""

generation_inputs = {
    "max_length": 50,
    "num_beams": 5,
    "early_stopping": True,
    "no_repeat_ngram_size": 2,
    "num_return_sequences": 3
}

In [6]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)

In [23]:
tokenizer_debugger = TokenizerDebugger(tokenizer)
input_ids = tokenizer.encode(init_sentence, return_tensors='pt')

In [21]:
model_outputs = model.generate(
    input_ids,
    **generation_inputs,
    # output_attentions=True
)
selected_output = model_outputs[0]
generated_model = model(selected_output, output_attentions=True)

In [25]:
res = tokenizer_debugger.get_sequence_logit_top_n_tokens(selected_output, generated_model.logits, 9)
pd.DataFrame(res).fillna('')

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,Q,,,,,,,,,
1,: (0.00),The (0.03),. (0.02),A (0.01),(0.01),A (0.01),The (0.01),I (0.01),", (0.01)",S (0.01)
2,What (0.01),(0.08),The (0.06),A (0.02),I (0.02),What (0.01),How (0.01),� (0.01),""" (0.01)",We (0.01)
3,is (0.13),is (0.13),are (0.10),do (0.09),� (0.06),'s (0.06),does (0.03),was (0.02),the (0.02),happens (0.02)
4,the (0.33),the (0.33),your (0.08),it (0.06),a (0.04),this (0.02),that (0.01),an (0.01),his (0.01),going (0.00)
5,name (0.00),difference (0.03),best (0.02),current (0.01),future (0.01),most (0.01),reason (0.01),purpose (0.01),new (0.01),impact (0.01)
6,of (0.80),of (0.80),? (0.04),and (0.01),""" (0.01)",", (0.01)",for (0.01),that (0.01),you (0.01),' (0.00)
7,the (0.47),the (0.47),a (0.12),your (0.03),an (0.02),this (0.02),one (0.01),that (0.01),our (0.01),my (0.00)
8,president (0.00),new (0.02),company (0.02),game (0.01),first (0.01),group (0.01),""" (0.01)",team (0.01),project (0.01),organization (0.01)
9,of (0.27),? (0.35),of (0.27),", (0.04)",and (0.04),'s (0.04),who (0.03),in (0.02),that (0.02),� (0.02)
