In [1]:
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import einops
import numpy as np
import pandas as pd
import torch
import transformer_lens
import transformer_lens.utils as utils
from fancy_einsum import einsum
from fastapi import FastAPI
from transformer_lens import ActivationCache, HookedTransformer
from dataclasses import dataclass

In [2]:
# Setup

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Looking at GPT-2 Small first
model = HookedTransformer.from_pretrained("gpt2", device=device)
model_cfg = model.cfg

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [None]:
@dataclass
class Prediction:
    next_token: str
    top5_tokens: List[str]
    confidence_in_guess: float
    confidence_in_top_5: float
    attention: List[float]

In [24]:
a = torch.tensor([1,2,3])
a.sort(-1, descending=True)[0]

tensor([3, 2, 1])

Idea: 
You give me a sequence of tokens, for each subsequence I give you back: 
- The best prediction for next token
- How confident I am
- My top 5 guesses
- The attention pattern I used to get this guess

In [40]:
def get_prediction(sequence: str, model: HookedTransformer = model, topk: int = 5) -> Tuple[str, float, List[str]]:
    "Return the next token, confidence in the guess and top 5 guesses"
    logits = model(sequence) # batch (1), sequence length, vocab size
    
    logits_most_to_least_likely, tokens_most_to_least_likely = logits.sort(-1, descending =True) # batch (1), sequence length, vocab size
    print(logits_most_to_least_likely.shape)
    
    probs_most_to_least_likely = torch.softmax(logits_most_to_least_likely, dim=-1)[0]
    confidence_in_top_guess = probs_most_to_least_likely[-1, 0].item()
    
    top_predictions = tokens_most_to_least_likely[0, :, :topk] # sequence length, topk
    final_top_predictions = top_predictions[-1, :]

    print(top_predictions.shape)
    print(top_predictions)
    # str_predictions = [model.to_str_tokens(top_predictions[:, i]) for i in range(topk)]
    topk_predictions = model.to_str_tokens(final_top_predictions) 
    # print(len(str_predictions))
    # print(len(str_predictions[0]))
    
    # prediction = Prediction(next_token=str_predictions[0][-1],
    #                         top5_tokens=None,
    #                         confidence_in_guess=confidence_in_top_guess,
    #                         confidence_in_top_5=confidence_in_topk,
    #                         attention=None)
    return topk_predictions[0], confidence_in_top_guess, topk_predictions
    # TODO: Also return the probability of the next token and maybe top 3 guesses
    return model.to_string(next_token)

In [None]:
def get_all_predictions(sequence: str, model: HookedTransformer = model, topk: int = 5) -> pd.DataFrame

In [None]:
# model.to_str_tokens(top_predictions[:, 0])

In [41]:
get_prediction("1,2,3,4,")

torch.Size([1, 9, 50257])
torch.Size([9, 5])
tensor([[ 198,  464,    1,   32,   40],
        [  13,    8, 5985,  198,  362],
        [ 830, 4059, 2167,   23,   24],
        [  11,   12,   13,    8,  198],
        [  18,   19,   20,   17,   16],
        [  11,   12,  198,    8,   13],
        [  19,   20,   18,   21,   16],
        [  11,  198,   12,    8,   13],
        [  20,   21,   19,   16,   22]])


('5', 0.9244884848594666, ['5', '6', '4', '1', '7'])

In [4]:
def attention_pattern(sequence: str, model: HookedTransformer = model) -> list:
    logits, cache = model.run_with_cache(sequence, remove_batch_dim=True)
    attention_patterns = [cache["pattern", layer, "attn"] for layer in range(12)]
    attention_patterns = torch.stack(attention_patterns, dim=0)  # layer, head, seq, seq
    reduced_attention = einops.reduce(
        attention_patterns, "layer head i j -> i j", "mean"
    )
    final_token_attention = reduced_attention[-1]
    final_token_attention = final_token_attention[1:]
    # rescale to sum to 1
    final_token_attention = final_token_attention / final_token_attention.sum(
        -1, keepdims=True
    )
    return final_token_attention.numpy().tolist()


In [5]:
def predict_with_attention(sequence="1,2,3,4,", model=model):
    next_token = get_prediction(sequence, model)

    attention = attention_pattern(sequence, model)
    print("next token: ", next_token)
    print("attention: ", attention)
    return next_token, attention

In [None]:





if __name__ == "__main__":
    print(predict_with_attention("1,2,3,4,"))

# next_token_logits, cache = model.run_with_cache(to_4_seq)
# possible_next_tokens = next_token_logits.argsort(-1, descending=True)[0][-1][:10]
# print(possible_next_tokens.shape)
# model.to_str_tokens(possible_next_tokens.unsqueeze(0))

# to_4_seq = "1,2,3,4,"
# to_4_tokens = model.to_tokens(to_4_seq)
# logits, cache = model.run_with_cache(to_4_tokens, remove_batch_dim=True)

# attention_patterns = [cache["pattern", layer, "attn"] for layer in range(12)]
# attention_patterns = torch.stack(attention_patterns, dim=0)  # layer, head, seq, seq
# reduced_attention = einops.reduce(attention_patterns, "layer head i j -> i j", "mean")
# final_token_attention = reduced_attention[-1]
# final_token_attention = final_token_attention[1:]
# # rescale to sum to 1
# final_token_attention = final_token_attention / final_token_attention.sum(
#     -1, keepdims=True
# )


# # Also give how confident model was in it's prediction.
# next_token_probs = logits.softmax(-1)[0][-1].sort(descending=True)[0][:10]
