# Attribute Lens Experiment for GPT-2-XL

This notebook contains the "attribute lens" initial experiment based on this idea from discord.

Say `D(h)` is the decoder readout head.  Logit lens had the idea of visualizing every hidden state `h[layer,token]` in a transformer looking at `D(h[layer,token])`.

Well instead of using D(h), why don't we look at hidden states using F(h) where F is a longer computation through the transformer rather than just the decoder head?  This is directly connected to our causal traces:

We noticed that (surprisingly!) single-hidden states early in the network can be causal for a much later token prediction.    So for example if you have a generic sentence like "*** is located in the city of", where the subject has been noised out, there is a specific early causal layer and token state `h[Lc,Tc]` where you can jam in a particular vector and it will say "Seattle" or "Paris" or "Rome" many tokens later.....

So let's fix the corrupted sentence form, and fix the causal layer and token `Lc` and `Tc` at the early site where the individual hidden state is causal for the object-attribute word.  And let's define `F(h)` as the decoding of what we get if you set `h[Lc, Tc] := h`   and run the computation in this sentence form.  At the end of the corrupted-and-intervened form, GPT will make a prediction, and this prediction is the output of `F(h)`.


OK, we begin by loading a model.

In [None]:
device = "cuda:1"

In [None]:
import torch, baukit
from transformers import AutoModelForCausalLM, AutoTokenizer
#MODEL_NAME = "gpt2-xl"  # gpt2-xl or EleutherAI/gpt-j-6B
MODEL_NAME = "EleutherAI/gpt-j-6B"
model, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=False).to(device),
    AutoTokenizer.from_pretrained(MODEL_NAME)
)
baukit.set_requires_grad(False, model)

So the innovation here is to make a custom readout function that is different from the standard readout.


The `make_custom_readout` function below is supposed to do that (but it seems to have some bugs).

Basically it does this:
  * As input it takes the custom readout `prompt` such as ". * is located in the city of", where the user has provided a star in the location of the "unknown" subject.
  * And it also takes a `layer` at which we're going to start driving the hidden state to test.
  * Then it creates a function `custom_readout(h)` that does the following:
  
`custom readout` does this:
  * if you pass it many hidden vectors `h`, it flattens the tensor into a big batch of `input_size` via vector-dimensions.
  * For each `h` (input_size times) it does this:
    * Runs the `prompt` 10 times
    * each time subsituting random noise in embeddings for the words at the `*` symbols.
    * each time it will also insert the vector `h` into the last `*` token at the given `layer`.
    * Then it reads out the prediction logits of the 10 times, and stores the logits
  * Finally after everything is done, it converts the logits into probabilities using softmax, and averages batches of 10 runs to get the results.
  
Then `make_custom_readout` returns this customized  `custom_readout` function, which maps `h -> token predictions` so that it can be used as a logit lens decoder.
  


In [None]:
import re

def get_embedding_scale(model, tok):
    prompt = 'Jane John Elizabeth Gerald Washington Seattle London Paris Boston'
    inp = {k: torch.tensor(v)[None].to(device) for k, v in tok(prompt).items()}
    embed_layer = [n for n, _ in model.named_modules() if 'wte' in n or 'embed' in n][0]
    with baukit.Trace(model, embed_layer) as t:
        model(**inp)
        return t.output.std()

# This function just returns random noise
noise_cache = None
def fixed_noise(shape):
    import numpy
    global noise_cache
    amount = numpy.prod(shape)
    if noise_cache is None or len(noise_cache) < amount:
        noise_cache = torch.Tensor(numpy.random.RandomState(1).randn(amount)).to(device)
    return noise_cache[:amount].reshape(shape)

def make_custom_readout(model, tok, prompt, subject='*', layer=None, noise_samples=10):
    # Layer should be a layer name
    target_layer = layer

    # These are all the token used in the subject name.  We encode it twice to get
    # both the non-space-prefixed and space-prefixed ones if needed.
    star_tokens = tok.encode(f'{subject} {subject}')

    # Tokenize the prompt and put it on cuda
    inp = {k: torch.tensor(v)[None].to(device) for k, v in tok(prompt).items()}

    # Find the index locations where the star/subject tokens appear.
    # This line used to have a bug.
    star_indexes = sum(inp['input_ids'] == t for t in star_tokens)[0].nonzero()[:,0].tolist()
    index = star_indexes[-1]

    # Get the embedding layer name
    embed_layers = [n for n, _ in model.named_modules() if 'wte' in n or 'embed' in n]
    noise_level = get_embedding_scale(model, tok)
    gt = lambda x: noise_level * x
    # The following rule was studied in the ROME rebuttal but doesn't seem to help for us.
    # from lens.stats import collect_embedding_gaussian
    # gt = collect_embedding_gaussian(model, tok).to('cuda', torch.float)
    
    if layer is None:
        layers = [n for n, _ in model.named_modules() if re.match(r'^transformer.h.\d+$', n)]
    else:
        layers = [layer]

    def custom_readout(h):
        import numpy
        cuda_h = h.to(device)
        # h may be a tensor of state vectors.  Flatten it to a batch.
        input_size = int(numpy.prod(cuda_h.shape[:-1]))
        flat_h = cuda_h.reshape(input_size, cuda_h.shape[-1])
        # We will physically batch by noise samples, and then just loop over input batch for now
        batch_size = noise_samples
        batch_inp = {k: v.expand((batch_size,) + v.shape[1:]) for k, v in inp.items()}
        
        the_noise = None
        results = []
        for input_index in range(0, input_size):
            target_h = flat_h[input_index][None]
            def insert_state(x, layer):
                hs = x[0] if isinstance(x, tuple) else x
                if layer in embed_layers:
                    hs[:,star_indexes,:] = gt(fixed_noise(hs[:,star_indexes,:].shape))
                else: # layer == target_layer:
                    hs[:,index,:] = target_h
                return x
            with baukit.TraceDict(model, layers + embed_layers, edit_output=insert_state) as t:
                batch_results = model(**batch_inp)['logits'][:,-1,:]
            results.append(batch_results)
        raw_results = torch.stack(results).reshape(
            *(cuda_h.shape[:-1] + (noise_samples,) + batch_results.shape[-1:]))
        return torch.nn.functional.softmax(raw_results, dim=-1).mean(dim=-2)
    return custom_readout



Now we can test things.  If we pass a tensor of state vectors, we should get a tensor of predictions out instead.

In [None]:
f = make_custom_readout(model, tok, 'The * is located in the city of')


# Try running f on some zero vectors
probs = f(torch.zeros(1, 5, 1, 3, 2, 4096))
print(probs.shape)
probs.sum(dim=-1).flatten()  # Veriy that Probabilities add up to 1.0

This function gathers the hidden state.

In [None]:
def get_hidden_states(model, tok, prefix):
    import re
    from baukit import TraceDict
    inp = {k: torch.tensor(v)[None].to(device) for k, v in tok(prefix).items()}
    layer_names = [n for n, _ in model.named_modules()
                   if re.match(r'^transformer.h.\d+$', n)]
    with TraceDict(model, layer_names) as tr:
        logits = model(**inp)['logits']
    return torch.stack([tr[layername].output[0] for layername in layer_names])

prompt = 'Hello, my name is also'
hs = get_hidden_states(model, tok, prompt)
hs.shape

Here is the basic logit lens visualization.  Comments inline.

In [None]:
def show_logit_lens(model, tok, prefix, topk=5, color=None, hs=None, decoder=None):
    from baukit import show

    # You can pass in a function to compute the hidden states, or just the tensor of hidden states.
    if hs is None:
        hs = get_hidden_states
    if callable(hs):
        hs = hs(model, tok, prefix)

    # The full decoder head normalizes hidden state and applies softmax at the end.
    if decoder is None:
        decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head, torch.nn.Softmax(dim=-1))

    probs = decoder(hs) # Apply the decoder head to every hidden state
    favorite_probs, favorite_tokens = probs.topk(k=topk, dim=-1)
    # Let's also plot hidden state magnitudes
    magnitudes = hs.norm(dim=-1)
    # For some reason the 0th token always has huge magnitudes, so normalize based on subsequent token max.
    # Added if statement to handle one token input
    if (len(magnitudes[0][0]) > 1):
        magnitudes = magnitudes / magnitudes[:,:,1:].max()
    
    # All the input tokens.
    prompt_tokens = [tok.decode(t) for t in tok.encode(prefix)]

    # Foreground color shows token probability, and background color shows hs magnitude
    if color is None:
        color = [0, 0, 255]
    def color_fn(m, p):
        a = [int(255 * (1-m) + c * m) for c in color]
        b = [int(196 * (1-p) + 0 * p)] * 2 + [0]
        return show.style(background=f'rgb({a[0]}, {a[1]}, {a[2]})',
                          color=f'rgb({b[0]}, {b[1]}, {b[2]})' )

    # In the hover popup, show topk probabilities beyond the 0th.
    def hover(tok, prob, toks, m):
        lines = [f'mag: {m:.2f}']
        for p, t in zip(prob, toks):
            lines.append(f'{tok.decode(t)}: prob {p:.2f}')
        return show.attr(title='\n'.join(lines))
    
    # Construct the HTML output using show.
    header_line = [ # header line
             [[show.style(fontWeight='bold'), 'Layer']] +
             [
                 [show.style(background='yellow'), show.attr(title=f'Token {i}'), t]
                 for i, t in enumerate(prompt_tokens)
             ]
         ]
    layer_logits = [
             # first column
             [[show.style(fontWeight='bold'), layer]] +
             [
                 # subsequent columns
                 [color_fn(m, p[0]), hover(tok, p, t, m), show.style(overflowX='hide'), tok.decode(t[0])]
                 for m, p, t in zip(wordmags, wordprobs, words)
             ]
        for layer, wordmags, wordprobs, words in
                zip(range(len(magnitudes)), magnitudes[:, 0], favorite_probs[:, 0], favorite_tokens[:,0])]
    
    # If you want to get the html without showing it, use show.html(...)
    show(header_line + layer_logits + header_line)


An example.

In [None]:
show_logit_lens(model, tok, 'The exhibits at the National Air and Space Museum were', decoder=f)

In [None]:
#f = make_custom_readout(model, tok, '. Arthur Rubinstein plays the piano. Miles Davis plays the trumpet. Yo-Yo Ma plays the cello. * plays the', layer='transformer.h.17')
#f = make_custom_readout(model, tok, '* plays the', layer='transformer.h.17')
f = make_custom_readout(model, tok, '. * is famous in the field of')


show_logit_lens(model, tok, 'Professors Zoubin Ghahramani and Michael Jordan', decoder=f)

In [None]:
def generate(model, tok, prefix, n=10):
    inp = {k: torch.tensor(v)[None].to(device) for k, v in tok(prefix).items()}
    initial_length = len(inp['input_ids'][0])
    pkv = None
    for _ in range(n):
        full_out = model(**inp)
        out = full_out['logits']
        pred = out[0, -1].argmax()
        inp['input_ids'] = torch.cat((inp['input_ids'], torch.tensor([pred])[None].to(device)), dim=1)
        inp['attention_mask'] = torch.cat((inp['attention_mask'], torch.ones(1, 1).to(device)), dim=1)
    return tok.decode(inp['input_ids'][0, initial_length:])
#generate(model, tok, '. Arthur Rubinstein plays the piano. Miles Davis plays the trumpet. Yo-Yo Ma plays the cello. Itzhak Perlman plays the', n=3)

generate(model, tok, 'Yo-yo Ma is a famous cellist. Arthur Rubenstein is a famous', n=3)
