## Setup

In [1]:
# Niceties for local dev
%load_ext autoreload
%autoreload 2

In [2]:
from interpogate import Interpogate

In [4]:
# Use transformers to load tinyllama (1.1b)
import transformers
import torch

pipe = transformers.pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                             torch_dtype=torch.bfloat16)
model = pipe.model
tokenizer = pipe.tokenizer

## Using Interpogate

In [7]:
# We can wrap the model and tokenizer with interpogate and visualize it
interp = Interpogate(model, tokenizer)
interp.visualize()

In [13]:
# Let's run a forward pass and visualize it
interp.forward_text("The most fascinating thing is the")
interp.visualize()

In [18]:
# Click on the chat bubble icon in the visualization above
# to see what the most probable following tokens are.
#   ["_fact", "_way", "_ability", "_possibility"...]

# We're going to try out a technique called "logit lens" to
# visualize the most likely next tokens according to each of
# the 22 layers

# To do that, we first need to grab the layer responsible for
# unembedding the response from dim 2048 -> dim 32000
lm_head = interp.node('lm_head')
lm_head

Linear(in_features=2048, out_features=32000, bias=False)

In [31]:
# Now, we're going to hook each of the 22 layers output.
# We can make use of the predictable names to register
# hooks for each layer

layer_logits = []

with interp.hook() as hook:
    def post_hook(model, input, output):
        # output shape: [<1×N×2048>,...]
        # Run the lm head to unembed and get logits
        logits = lm_head(output[0])[0]
        layer_logits.append(logits)
        pass
    
    # Register hooks as needed
    for n in range(22):
        hook.post(f"model.layers.{n}", post_hook)

    # Run forward pass
    interp.forward_text("The most fascinating thing is the")

# Stack the layer logits into a unified tensor
layer_logits = torch.stack(layer_logits)

# Print collected logits
print(layer_logits.shape)
print(layer_logits[0])

torch.Size([22, 9, 32000])
tensor([[-0.0072,  0.0259, -0.0125,  ..., -0.0280, -0.0063, -0.0035],
        [ 0.0089,  0.0054,  0.0028,  ..., -0.0024,  0.0400, -0.0012],
        [-0.0126, -0.0109,  0.0082,  ..., -0.0272, -0.0027, -0.0310],
        ...,
        [-0.0388, -0.0459, -0.0080,  ..., -0.0479, -0.0244, -0.0304],
        [-0.0095, -0.0137, -0.0148,  ..., -0.0134, -0.0028, -0.0172],
        [-0.0013, -0.0042,  0.0094,  ..., -0.0114,  0.0037, -0.0062]],
       dtype=torch.bfloat16, grad_fn=<SelectBackward0>)


In [48]:
# Apply softmax on the final dimension
softmax = layer_logits.softmax(dim=2)
# Grab softmax on only the final logits
softmax = softmax[:,-1,:]
# Get argmax
top_indices = softmax.argsort(dim=1, descending=True)
# Print results
for i in range(top_indices.shape[0]):
    print(f"Layer {i}:")
    # Print the top 6 tokens
    for j in range(6):
        idx = top_indices[i][j]
        print(f"  #{j + 1}: {tokenizer.decode(idx)} (prob: {softmax[i][j]})")
    print("")

Layer 0:
  #1: ..., (prob: 3.123283386230469e-05)
  #2: ...) (prob: 3.123283386230469e-05)
  #3: !, (prob: 3.170967102050781e-05)
  #4: ,[ (prob: 3.123283386230469e-05)
  #5: andis (prob: 3.075599670410156e-05)
  #6: important (prob: 3.075599670410156e-05)

Layer 1:
  #1: ..., (prob: 3.123283386230469e-05)
  #2: __( (prob: 3.0994415283203125e-05)
  #3: andis (prob: 3.123283386230469e-05)
  #4: dade (prob: 3.170967102050781e-05)
  #5: euw (prob: 3.0994415283203125e-05)
  #6: Tout (prob: 3.0517578125e-05)

Layer 2:
  #1: ..., (prob: 3.123283386230469e-05)
  #2: ater (prob: 3.0994415283203125e-05)
  #3: osoph (prob: 3.123283386230469e-05)
  #4: CAA (prob: 3.147125244140625e-05)
  #5: ppo (prob: 3.123283386230469e-05)
  #6: jna (prob: 3.0994415283203125e-05)

Layer 3:
  #1: ater (prob: 3.0994415283203125e-05)
  #2: CAA (prob: 3.075599670410156e-05)
  #3: osoph (prob: 3.0994415283203125e-05)
  #4: :\ (prob: 3.1948089599609375e-05)
  #5: hm (prob: 3.0994415283203125e-05)
  #6: tk (prob: 3.07