# Logit lens visualization for GPT-2-XL

First load the model and tokenizer

In [None]:
import torch, baukit
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "gpt2-xl"  # gpt2-xl or EleutherAI/gpt-j-6B
model, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=False).to("cuda"),
    AutoTokenizer.from_pretrained(MODEL_NAME)
)

We're not training the model, so set the requires_grad=False flav on all the model parameters, so that pytorch autograd isn't automatically invoked every time we use the parameters.  (The alternative is to `torch.set_grad_enabled(False)` to turn off autograd globally.)

In [None]:
import baukit
baukit.set_requires_grad(False, model)

Here is an example of the API for the model.  It takes as input `{'input_ids': tensor, 'attention_mask': tensor}`, and it returns a dictionary where the main output is `full_out['logits']`, and the main next prediction token is given by the very last logit.

If we want to do greedy generation, we can take the next predicted token, append it to the input ids, and iterate.  That is what this function does.  (There are much more tricky ways to generate beyond this that lead to more natural output text; this is just a demonstration of the simplest approach.)

In [None]:
def generate(model, tok, prefix, n=10):
    inp = {k: torch.tensor(v)[None].cuda() for k, v in tok(prefix).items()}
    initial_length = len(inp['input_ids'][0])
    pkv = None
    for _ in range(n):
        full_out = model(**inp)
        out = full_out['logits']
        pred = out[0, -1].argmax()
        inp['input_ids'] = torch.cat((inp['input_ids'], torch.tensor([pred])[None].cuda()), dim=1)
        inp['attention_mask'] = torch.cat((inp['attention_mask'], torch.ones(1, 1).cuda()), dim=1)
    return tok.decode(inp['input_ids'][0, initial_length:])
generate(model, tok, 'In his NBA career, KC Jones played', n=100)
    

The layer ouput activations that we're interested in tracing are called `transformer.h.0`, `transformer.h.1` etc.

This function demonstrates how to use `baukit.TraceDict` to gather those activations while we are running `model` in inference.  Note that `tr[layername].output` has the exact output of the layer, which is a tuple of tensors.  The first tensor in this tuple happens to be the main hidden-state tensor that we are looking for.

We stack all the `tr[layername].output[0]` tensors into one big tensor, 48 long, for all layers' hidden states gathered in one place.

In [None]:
def get_hidden_states(model, tok, prefix):
    import re
    from baukit import TraceDict
    inp = {k: torch.tensor(v)[None].cuda() for k, v in tok(prefix).items()}
    layer_names = [n for n, _ in model.named_modules()
                   if re.match(r'^transformer.h.\d+$', n)]
    with TraceDict(model, layer_names) as tr:
        logits = model(**inp)['logits']
    return torch.stack([tr[layername].output[0] for layername in layer_names])

prompt = 'Hello, my name is also'
hs = get_hidden_states(model, tok, prompt)
hs.shape

Here is the basic logit lens visualization.  Comments inline.

In [None]:
def show_logit_lens(model, tok, prefix, topk=5, color=None, hs=None):
    from baukit import show

    # You can pass in a function to compute the hidden states, or just the tensor of hidden states.
    if hs is None:
        hs = get_hidden_states
    if callable(hs):
        hs = hs(model, tok, prefix)

    # The full decoder head normalizes hidden state and applies softmax at the end.
    decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head, torch.nn.Softmax(dim=-1))
    probs = decoder(hs) # Apply the decoder head to every hidden state
    favorite_probs, favorite_tokens = probs.topk(k=topk, dim=-1)
    # Let's also plot hidden state magnitudes
    magnitudes = hs.norm(dim=-1)
    # For some reason the 0th token always has huge magnitudes, so normalize based on subsequent token max.
    magnitudes = magnitudes / magnitudes[:,:,1:].max()
    
    # All the input tokens.
    prompt_tokens = [tok.decode(t) for t in tok.encode(prefix)]

    # Foreground color shows token probability, and background color shows hs magnitude
    if color is None:
        color = [0, 0, 255]
    def color_fn(m, p):
        a = [int(255 * (1-m) + c * m) for c in color]
        b = [int(255 * (1-p) + 0 * p)] * 3
        return show.style(background=f'rgb({a[0]}, {a[1]}, {a[2]})',
                          color=f'rgb({b[0]}, {b[1]}, {b[2]})' )

    # In the hover popup, show topk probabilities beyond the 0th.
    def hover(tok, prob, toks, m):
        lines = [f'mag: {m:.2f}']
        for p, t in zip(prob, toks):
            lines.append(f'{tok.decode(t)}: prob {p:.2f}')
        return show.attr(title='\n'.join(lines))
    
    # Construct the HTML output using show.
    header_line = [ # header line
             [[show.style(fontWeight='bold'), 'Layer']] +
             [
                 [show.style(background='yellow'), show.attr(title=f'Token {i}'), t]
                 for i, t in enumerate(prompt_tokens)
             ]
         ]
    layer_logits = [
             # first column
             [[show.style(fontWeight='bold'), layer]] +
             [
                 # subsequent columns
                 [color_fn(m, p[0]), hover(tok, p, t, m), show.style(overflowX='hide'), tok.decode(t[0])]
                 for m, p, t in zip(wordmags, wordprobs, words)
             ]
        for layer, wordmags, wordprobs, words in
                zip(range(len(magnitudes)), magnitudes[:, 0], favorite_probs[:, 0], favorite_tokens[:,0])]
    
    # If you want to get the html without showing it, use show.html(...)
    show(header_line + layer_logits + header_line)


An example.

In [None]:
show_logit_lens(model, tok, '. The Space Needle is located in')

Now instead of directly decoding each hidden state, let's decocde the deltas between layers

In [None]:
def get_hidden_state_deltas(model, tok, prefix):
    import re
    from baukit import TraceDict
    inp = {k: torch.tensor(v)[None].cuda() for k, v in tok(prefix).items()}
    layer_names = [n for n, _ in model.named_modules()
                   if re.match(r'^transformer.h.\d+$', n)]
    with TraceDict(model, ['transformer.drop'] + layer_names) as tr:
        logits = model(**inp)['logits']
    first_h = tr['transformer.drop'].output[None]
    other_h = torch.stack([tr[ln].output[0] for ln in layer_names])
    all_h = torch.cat([first_h, other_h])
    delta_h = all_h[1:] - all_h[:-1]
    return delta_h


Test the delta function.

In [None]:
hs = get_hidden_state_deltas(model, tok, 'The Space Needle is located in')
hs.shape

Now plot it.

Notice!  "Seattle" shows up at layer 29 at token 4!!!

In [None]:
show_logit_lens(model, tok, '. The Space Needle is located in', hs=get_hidden_state_deltas)

In [None]:
help(model.forward)