In [None]:
import transformers
device = "cuda:5"
model = transformers.AutoModelForCausalLM.from_pretrained("gpt2-xl")
model.eval().to(device)
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2-xl")
tokenizer.pad_token = tokenizer.eos_token

# Slightly Nicer Implementation

In [None]:
import baukit
import torch
from torch import nn

lm_head = nn.Sequential(
    model.transformer.ln_f,
    model.lm_head,
)

def logit_lens(prompt, entities, layers=range(40, 48), model=model, tokenizer=tokenizer):
    batch_idx = torch.arange(len(entities))
    prompts = [
        prompt.format(entity)
        for entity in entities
    ]
    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
    prompt_last_idx = inputs.attention_mask.sum(dim=-1).sub(1).tolist()
    with baukit.TraceDict(
        model,
        [f"transformer.h.{layer}" for layer in layers],
        stop=True,
    ) as ret:
        model(**inputs)
    for layer in layers:
        layername = f"transformer.h.{layer}"
        logits = lm_head(ret[layername].output[0][batch_idx, prompt_last_idx])
        predictions = logits.topk(dim=-1, k=5).indices.tolist()
        print(f"---- layer {layer} ----")
        for entity, ids in zip(entities, predictions):
            tokens = tokenizer.convert_ids_to_tokens(ids)
            preds = [t.replace("Ġ", " ") for t in tokens]
            print(f"{entity}: {preds}")

logit_lens(
    "{} is located in the country of",
    [
        "The Eiffel Tower",
        "Niagara Falls",
        "The Leaning Tower of Pisa",
    ],
    layers=range(0, 40)
)

In [None]:
import baukit
import torch
from torch import nn
from torch.autograd import functional as F
from tqdm.auto import tqdm


lm_head = nn.Sequential(
    model.transformer.ln_f,
    model.lm_head,
)

def attribute_lens(prompt,
                   entities,
                   layers=range(40, 48),
                   estimate_jacobian_from=0,
                   model=model,
                   tokenizer=tokenizer):

    batch_idx = torch.arange(len(entities))
    entity_last_idx = tokenizer(
        entities,
        padding=True,
        return_tensors="pt",
    ).attention_mask.sum(dim=-1).sub(1).tolist()
#     entity_last_idx[-1] -= 1
#     entity_last_idx[-2] -= 1

    prompts = [
        prompt.format(entity)
        for entity in entities
    ]
    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
    prompt_last_idx = inputs.attention_mask.sum(dim=-1).sub(1).tolist()

    print("---- DEBUG ----")
    print("prompts", prompts)
    print("batch_idx", batch_idx)
    print("entity_last_idx", entity_last_idx)
    print("prompt_last_idx", prompt_last_idx)
    
    for entity_layer in layers:
        entity_layername = f"transformer.h.{entity_layer}"
#         last_layername = f"transformer.h.{entity_layer + 1}"
        last_layername = "transformer.h.47"
        with baukit.TraceDict(model,
                              (entity_layername, last_layername),
                              stop=True) as ret:
            model(**inputs)
        entity_hs = ret[entity_layername].output[0][batch_idx, entity_last_idx]
        assert entity_hs.shape == (len(entities), model.config.hidden_size)

        last_hs = ret[last_layername].output[0][batch_idx, prompt_last_idx]
        assert last_hs.shape == (len(entities), model.config.hidden_size)

        def replaced_entity_fn(h):
            def replace_entity_h(output, layer):
                if str(entity_layer) in layer:
                    output[0][
                        estimate_jacobian_from,
                        entity_last_idx[estimate_jacobian_from]
                    ] = h
                return output
            with baukit.TraceDict(model,
                                  (entity_layername, last_layername),
                                  edit_output=replace_entity_h) as ret:
                model(**inputs)
            return ret[last_layername].output[0][
                estimate_jacobian_from,
                prompt_last_idx[estimate_jacobian_from]
            ]

        jacobian = F.jacobian(
            replaced_entity_fn,
            entity_hs[estimate_jacobian_from],
            vectorize=True)
        bias = last_hs[None, estimate_jacobian_from] - entity_hs[None, estimate_jacobian_from].mm(jacobian.t())
        print(jacobian.norm(), bias.norm())
        head_inputs = entity_hs.mm(jacobian.t()) + bias
        logits = lm_head(head_inputs)
        predictions = logits.topk(dim=-1, k=8).indices.tolist()

        print(f"---- layer {entity_layer} ----")
        for entity, ids in zip(entities, predictions):
            tokens = tokenizer.convert_ids_to_tokens(ids)
            preds = [t.replace("Ġ", " ") for t in tokens]
            print(f"{entity}: {preds}")

# attribute_lens(
#     "{} is located in the country of",
#     [
#         "The Space Needle",
#         "The Great Wall",
#         "The Louvre",
#         "Niagara Falls",
#         "The Eiffel Tower",
#     ],
#     estimate_jacobian_from=0,
#     layers=range(20, 31)
# )

# attribute_lens(
#     "{} plays the sport of",
#     [
#         "Larry Bird",
#         "John McEnroe",
#         "Oksana Baiul",
#         "Megan Rapinoe",
#         "Tom Brady",
#         "Babe Ruth"
#     ],
#     estimate_jacobian_from=5,
#     layers=range(20, 31)
# )

# attribute_lens(
#     "{} is a song by the band",
#     [
#         "Smells Like Teen Spirit",
#         "Stairway to Heaven",
#         "Don't Stop Believing",
#         "Bohemian Rhapsody",
#         "Creep",
#         "Shake It Off",
#     ],
#     estimate_jacobian_from=0,
#     layers=range(35, 45)
# )

# attribute_lens(
#     "{} is the lead singer of the band",
#     [
#         "Kurt Cobain",
#         "Eddie Vedder",
# #         "Hayley Williams",
#         "Stevie Nicks",
#         "Chris Cornell",
#         "Freddie Mercury",
#     ],
#     estimate_jacobian_from=2,
#     layers=range(25, 31)
# )

# attribute_lens(
#     "{} is CEO of",
#     [
#         "Indra Nooyi",
#         "Sundar Pichai",
#         "Elon Musk",
#         "Mark Zuckerberg",
#         "Satya Nadella",
#         "Jeff Bezos",
#         "Tim Cook",
#     ],
#     estimate_jacobian_from=0,
#     layers=range(25, 31)
# )

attribute_lens(
    "{} are usually the color of",
    [
        "bananas",
        "apples",
        "strawberries",
        "tangerines",
        "kiwis",
    ],
    estimate_jacobian_from=0,
    layers=range(25, 31)
)

In [None]:
inputs = tokenizer('Peaches are usually colored', return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=5)
tokenizer.batch_decode(outputs)

Version where Jacobian estimated across multiple examples.

In [None]:
from functools import partial

import baukit
import torch
from torch.autograd import functional as F
from tqdm.auto import tqdm

def batch_attribute_lens(prompt,
                         entities,
                         layers=range(40, 48),
                         model=model,
                         tokenizer=tokenizer):

    batch_idx = torch.arange(len(entities))
    entity_last_idx = tokenizer(
        entities,
        padding=True,
        return_tensors="pt",
    ).attention_mask.sum(dim=-1).sub(1).tolist()

    prompts = [
        prompt.format(entity)
        for entity in entities
    ]
    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
    prompt_last_idx = inputs.attention_mask.sum(dim=-1).sub(1).tolist()

    print("---- DEBUG ----")
    print("prompts", prompts)
    print("batch_idx", batch_idx)
    print("entity_last_idx", entity_last_idx)
    print("prompt_last_idx", prompt_last_idx)

    for entity_layer in layers:
        entity_layername = f"transformer.h.{entity_layer}"
#         last_layername = f"transformer.h.{entity_layer + 1}"
        last_layername = "transformer.h.47"
        with baukit.TraceDict(model, (entity_layername, last_layername), stop=True) as ret:
            model(**inputs)
        entity_hs = ret[entity_layername].output[0][batch_idx, entity_last_idx]
        assert entity_hs.shape == (len(entities), model.config.hidden_size)

        last_hs = ret[last_layername].output[0][batch_idx, prompt_last_idx]
        assert last_hs.shape == (len(entities), model.config.hidden_size)

        def replaced_entity_fn(i, h):
            def replace_entity_h(output, layer):
                if str(entity_layer) in layer:
                    output[0][i, entity_last_idx[i]] = h
                return output
            with baukit.TraceDict(model,
                                  (entity_layername, last_layername),
                                  edit_output=replace_entity_h) as ret:
                model(**inputs)
            return ret[last_layername].output[0][i, prompt_last_idx[i]] 

        jacs = []
        biases = []
        for i in [1, 2, 3]: #range(len(entities)):
            jac = F.jacobian(
                partial(replaced_entity_fn, i),
                entity_hs[i],
                vectorize=True)
            bias = last_hs[None, i] - entity_hs[None, i].mm(jac.t())
            jacs.append(jac.t())
            biases.append(bias)

        j = torch.stack(jacs).mean(dim=0)
        b = torch.stack(biases).mean(dim=0)
        rep = entity_hs.mm(j.t()) + b
        logits = model.lm_head(rep)
        predictions = logits.topk(dim=-1, k=5).indices.tolist()

        print(f"---- layer {entity_layer} ----")
        for entity, ids in zip(entities, predictions):
            tokens = tokenizer.convert_ids_to_tokens(ids)
            preds = [t.replace("Ġ", " ") for t in tokens]
            print(f"{entity}: {preds}")

# batch_attribute_lens(
#     "{} is located in the country of",
#     [
#         "The Space Needle",
#         "The Great Wall",
#         "Niagara Falls",
#     ],
#     layers=range(20, 40)
# )

# batch_attribute_lens(
#     "{} plays the sport of",
#     [
#         "Larry Bird",
#         "John McEnroe",
#         "Oksana Baiul",
# #         "Megan Rapinoe",
#     ],
#     layers=range(19, 31)
# )

batch_attribute_lens(
    "{} is the lead singer of the band",
    [
        "Kurt Cobain",
        "Eddie Vedder",
        "Stevie Nicks",
        "Chris Cornell",
        "Freddie Mercury",
    ],
    layers=range(25, 31)
)

# What is going on with Bias Norm?

In [None]:
import baukit

prompt_template = "{} is located in the country of"
entity_a = "Space Needle"
entity_b = "Great Wall"
layer = 25

entities = [entity_a, entity_b]
edit_layername = f"transformer.h.{layer}"
last_layername = f"transformer.h.47"

batch_idx = torch.arange(2)
entity_last_idx = tokenizer(
    entities,
    padding=True,
    return_tensors="pt",
).attention_mask.sum(dim=-1).sub(1).tolist()

prompts = [prompt_template.format(entity) for entity in [entity_a, entity_b]]
inputs = tokenizer(prompts, return_tensors="pt", padding="longest").to(device)

with baukit.Trace(model, edit_layername) as ret:
    model(**inputs)
h_a, h_b = ret.output[0][batch_idx, entity_last_idx]
for h in (h_a, h_b):
    assert h.shape == (model.config.hidden_size,)

# Demo from David

In [None]:
import torch
from baukit import Trace, set_requires_grad
from collections import OrderedDict

net = torch.nn.Sequential(OrderedDict([
    ('linear1', torch.nn.Linear(2,3)),
    ('tanh1',   torch.nn.Tanh()),
    ('linear2', torch.nn.Linear(3,3)),
    ('tanh2',   torch.nn.Tanh()),
    ('linear3',  torch.nn.Linear(3,2)),
]))

# jacobian impl doesn't seem to use any of these mechanisms
set_requires_grad(False, net)


def get_jacobian_wrt_hidden_layer_at(x, net, hidden_layername):

    # First get the value of the hidden state
    with Trace(net, hidden_layername, stop=True) as tr:
        net(x)
    z = tr.output.detach()

    # Then create a function that runs the net with the hidden
    # state as a free variable, by patching it into itself.
    def my_hidden_fn(z):
        def insert_z(output):
            output[...] = z
        with Trace(net, hidden_layername, edit_output=insert_z):
            return net(x)

    # Now get the jacobian
    return torch.autograd.functional.jacobian(my_hidden_fn, z)

x = torch.randn(2)
print('Jacobian of output wrt linear1:')
print(get_jacobian_wrt_hidden_layer_at(x, net, 'linear1'))

print('Jacobian of output wrt tanh2:')
print(get_jacobian_wrt_hidden_layer_at(x, net, 'tanh2'))

print('Weights of last layer, should match:')
print(net.linear3.weight)