In [1]:
import torch
from bertviz import head_view, model_view
from transformers import GPT2Tokenizer, GPT2Model, AutoTokenizer, AutoModelForCausalLM

In [2]:
sentence_a = "I had been working out every morning lately, but not anymore"

## 1. karpathy random init param

In [3]:
from train_gpt2 import GPT, GPTConfig

checkpoint_path = "./log/model_19072_20241014.pt"
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))['model']  # Load on CPU

In [5]:
model_karp = GPT(GPTConfig(vocab_size=50304))

model_karp.load_state_dict(state_dict)
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')

input_ids_karp = tokenizer_gpt2(sentence_a, return_tensors='pt')['input_ids']
print(f"input ids:\t{input_ids_karp}")
output_karp = model_karp(input_ids_karp, output_attentions=True)
print(f"output keys:\t{output_karp.keys()}")
attention_karp = output_karp['attentions']
print(f"attention:\t{attention_karp}")

tokens_karp = tokenizer_gpt2.convert_ids_to_tokens(input_ids_karp[0]) # Convert input ids to token strings, stil don't know why though
print(f"tokens:\t{tokens_karp}")

input ids:	tensor([[   40,   550,   587,  1762,   503,   790,  3329, 16537,    11,   475,
           407,  7471]])
output keys:	odict_keys(['logits', 'loss', 'kv_cache', 'attentions'])
attention:	[tensor([[[[0.0561, 0.0909, 0.1305,  ..., 0.0623, 0.0124, 0.1370],
          [0.1528, 0.1220, 0.0738,  ..., 0.0366, 0.0969, 0.0646],
          [0.1365, 0.1450, 0.0910,  ..., 0.0451, 0.0659, 0.0703],
          ...,
          [0.1569, 0.1242, 0.0908,  ..., 0.0440, 0.0433, 0.0679],
          [0.1497, 0.1325, 0.0997,  ..., 0.0321, 0.0540, 0.0632],
          [0.1771, 0.1343, 0.1057,  ..., 0.0442, 0.0290, 0.0675]],

         [[0.1659, 0.0682, 0.0287,  ..., 0.0473, 0.0415, 0.1939],
          [0.2405, 0.1428, 0.0942,  ..., 0.0525, 0.0367, 0.0603],
          [0.2097, 0.1511, 0.0783,  ..., 0.0543, 0.0425, 0.0604],
          ...,
          [0.2081, 0.0997, 0.0645,  ..., 0.0275, 0.1287, 0.0985],
          [0.1648, 0.0956, 0.0681,  ..., 0.0846, 0.0698, 0.1470],
          [0.1633, 0.0890, 0.0705,  ..., 0.04

In [6]:
model_view(attention_karp, tokens_karp)

<IPython.core.display.Javascript object>

In [7]:
head_view(attention_karp, tokens_karp)

<IPython.core.display.Javascript object>

## 2. Huggingface gpt2 pretrained model

In [None]:
model_gpt2 = GPT2Model.from_pretrained('gpt2', output_attentions=True)
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')

input_ids_gpt2 = tokenizer_gpt2(sentence_a, return_tensors='pt')['input_ids']
print(f"input ids:\t{input_ids_gpt2}")
output_gpt2 = model_gpt2(input_ids_gpt2)
print(f"output keys:\t{output_gpt2.keys()}")
attention_gpt2 = output_gpt2[-1]
print(f"attention:\t {attention_gpt2}")

tokens_gpt2 = tokenizer_gpt2.convert_ids_to_tokens(input_ids_gpt2[0]) # Convert input ids to token strings, stil don't know why though
print(f"tokens_gpt2:\t{tokens_gpt2}")

input ids:	tensor([[   40,   550,   587,  1762,   503,   790,  3329, 16537,    11,   475,
           407,  7471]])
output keys:	odict_keys(['last_hidden_state', 'past_key_values', 'attentions'])
attention:	 (tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [7.9486e-01, 2.0514e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [6.5988e-01, 2.5410e-01, 8.6025e-02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [2.6526e-01, 8.2453e-02, 6.6413e-02,  ..., 4.5397e-02,
           0.0000e+00, 0.0000e+00],
          [2.3890e-01, 7.0330e-02, 6.0750e-02,  ..., 6.0363e-02,
           4.9515e-02, 0.0000e+00],
          [2.3950e-01, 6.0439e-02, 4.5944e-02,  ..., 9.8064e-02,
           1.0153e-01, 3.3246e-02]],

         [[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [2.3471e-03, 9.9765e-01, 0.0000e+00,  ..., 0.0000e+00,
   

In [None]:
model_view(attention_gpt2, tokens_gpt2)

<IPython.core.display.Javascript object>

In [None]:
head_view(attention_gpt2, tokens_gpt2)

<IPython.core.display.Javascript object>

## 3. Huggingface  instruction model

In [None]:
model_instr = AutoModelForCausalLM.from_pretrained("vicgalle/gpt2-open-instruct-v1")
tokenizer_instr = AutoTokenizer.from_pretrained("vicgalle/gpt2-open-instruct-v1")

inputs_ids_instr = tokenizer_instr(sentence_a, return_tensors="pt")['input_ids']
attention_instr = model_instr(inputs_ids_instr, output_attentions=True)[-1]

tokens_instr = tokenizer_instr.convert_ids_to_tokens(inputs_ids_instr[0]) # Convert input ids to token strings, stil don't know why though
tokens_instr

['I',
 'Ġhad',
 'Ġbeen',
 'Ġworking',
 'Ġout',
 'Ġevery',
 'Ġmorning',
 'Ġlately',
 ',',
 'Ġbut',
 'Ġnot',
 'Ġanymore']

In [None]:
model_view(attention_instr, tokens_instr)

<IPython.core.display.Javascript object>

In [None]:
head_view(attention_instr, tokens_instr)

<IPython.core.display.Javascript object>

## 4. Attention weights used for KV cache
with Karpathy gpt2 model

In [13]:
kv_cache = None
output_karp_kv = model_karp(input_ids_karp, kv_cache=kv_cache, use_cache=True, output_attentions=True)
attention_karp_kv = output_karp['attentions']
print(f"attention:\t{attention_karp}")

tokens_karp = tokenizer_gpt2.convert_ids_to_tokens(input_ids_karp[0]) # Convert input ids to token strings, stil don't know why though
print(f"tokens:\t{tokens_karp}")

attention:	[tensor([[[[0.0561, 0.0909, 0.1305,  ..., 0.0623, 0.0124, 0.1370],
          [0.1528, 0.1220, 0.0738,  ..., 0.0366, 0.0969, 0.0646],
          [0.1365, 0.1450, 0.0910,  ..., 0.0451, 0.0659, 0.0703],
          ...,
          [0.1569, 0.1242, 0.0908,  ..., 0.0440, 0.0433, 0.0679],
          [0.1497, 0.1325, 0.0997,  ..., 0.0321, 0.0540, 0.0632],
          [0.1771, 0.1343, 0.1057,  ..., 0.0442, 0.0290, 0.0675]],

         [[0.1659, 0.0682, 0.0287,  ..., 0.0473, 0.0415, 0.1939],
          [0.2405, 0.1428, 0.0942,  ..., 0.0525, 0.0367, 0.0603],
          [0.2097, 0.1511, 0.0783,  ..., 0.0543, 0.0425, 0.0604],
          ...,
          [0.2081, 0.0997, 0.0645,  ..., 0.0275, 0.1287, 0.0985],
          [0.1648, 0.0956, 0.0681,  ..., 0.0846, 0.0698, 0.1470],
          [0.1633, 0.0890, 0.0705,  ..., 0.0479, 0.0614, 0.1198]],

         [[0.1550, 0.0584, 0.0643,  ..., 0.0934, 0.0473, 0.1131],
          [0.0912, 0.1058, 0.0681,  ..., 0.0916, 0.0517, 0.1028],
          [0.1013, 0.0793, 0.0

In [10]:
model_view(attention_karp, tokens_karp)

<IPython.core.display.Javascript object>

In [11]:
head_view(attention_karp, tokens_karp)

<IPython.core.display.Javascript object>