In [1]:
import torch; torch.set_grad_enabled(False)
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "roneneldan/TinyStories-1M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [2]:
from utils import load_orig_ds_txt, tokenize
ds_txt = load_orig_ds_txt("validation[:100]")
ds_tok = [tokenize(tokenizer, txt) for txt in ds_txt]
sample_tok = ds_tok[0]



  0%|          | 0/100 [00:00<?, ?it/s]

In [6]:
def get_logits(model, sample_tok):
    sample_tok = sample_tok.unsqueeze(0)
    return model(sample_tok).logits[0]

def get_correct_probs(logits, sample_tok):
    # pos, d_vocab
    probs = torch.softmax(logits, dim=-1)
    # drop the value for the last position, as we don't know
    # what is the correct next token there
    probs = probs[:-1]
    # out of d_vocab values, take the one that corresponds to the correct next token
    return probs[range(len(probs)), sample_tok[1:]]

In [7]:
logits = get_logits(model, sample_tok)
correct_probs = get_correct_probs(logits, sample_tok)
correct_probs

tensor([3.1341e-07, 3.8516e-03, 3.8292e-03, 1.2702e-03, 2.6244e-02, 5.7214e-03,
        6.9342e-03, 2.1320e-01, 7.5029e-03, 7.4739e-01, 9.8252e-01, 5.3241e-02,
        9.0845e-01, 1.9862e-04, 1.2865e-01, 1.9049e-02, 3.7672e-01, 8.3394e-01,
        5.8995e-01, 7.8704e-03, 9.3007e-02, 1.0252e-03, 1.6984e-01, 6.3082e-02,
        7.9264e-02, 9.3585e-01, 3.5446e-03, 9.8924e-01, 9.9815e-01, 1.6970e-01,
        9.9656e-01, 8.5609e-01, 6.6405e-01, 1.9771e-01, 1.8745e-01, 3.4450e-06,
        6.2227e-01, 1.1502e-01, 8.9421e-01, 7.8436e-01, 4.2450e-01, 9.5226e-01,
        8.0645e-03, 3.1316e-01, 2.5546e-01, 2.1434e-01, 3.2364e-01, 9.7244e-01,
        3.6453e-01, 2.6221e-01, 3.2694e-01, 3.1549e-03, 5.4743e-04, 7.2140e-01,
        7.3100e-01, 7.1458e-03, 6.7133e-01, 1.7639e-02, 2.8133e-03, 2.0983e-01,
        4.1505e-04, 5.5878e-01, 3.2920e-01, 2.2728e-01, 2.7677e-02, 1.8162e-01,
        8.5580e-01, 8.0620e-01, 2.9107e-02, 5.5225e-02, 4.0758e-01, 8.9057e-01,
        1.8427e-01, 2.4617e-01, 1.4175e-

In [8]:
for i in range(correct_probs.shape[0]):
    tok_str = tokenizer.decode(sample_tok[i+1])
    tok_str = tok_str.replace("\n", r"\n")
    tok_str_pipes = f"|{tok_str}|"
    prob = correct_probs[i].item()
    print(f"{tok_str_pipes:>14} {prob:.2%}")

        |Spot| 0.00%
           |.| 0.39%
       | Spot| 0.38%
        | saw| 0.13%
        | the| 2.62%
      | shiny| 0.57%
        | car| 0.69%
        | and| 21.32%
       | said| 0.75%
           |,| 74.74%
          | "| 98.25%
         |Wow| 5.32%
           |,| 90.84%
      | Kitty| 0.02%
           |,| 12.87%
       | your| 1.90%
        | car| 37.67%
         | is| 83.39%
         | so| 59.00%
     | bright| 0.79%
        | and| 9.30%
      | clean| 0.10%
          |!"| 16.98%
      | Kitty| 6.31%
     | smiled| 7.93%
        | and| 93.59%
    | replied| 0.35%
           |,| 98.92%
          | "| 99.81%
       |Thank| 16.97%
        | you| 99.66%
           |,| 85.61%
       | Spot| 66.40%
           |.| 19.77%
          | I| 18.75%
     | polish| 0.00%
         | it| 62.23%
      | every| 11.50%
        | day| 89.42%
          |."| 78.44%
          |\n| 42.45%
          |\n| 95.23%
       |After| 0.81%
    | playing| 31.32%
       | with| 25.55%
        | the| 21.43%
       