In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')
import torch
import transformers
import baukit

torch.__version__, transformers.__version__

('2.1.0+cu121', '4.34.0')

In [3]:
def recursive_copy(x, clone=None, detach=None, retain_grad=None):
    """
    Copies a reference to a tensor, or an object that contains tensors,
    optionally detaching and cloning the tensor(s).  If retain_grad is
    true, the original tensors are marked to have grads retained.
    """
    if not clone and not detach and not retain_grad:
        return x
    if isinstance(x, torch.Tensor):
        if retain_grad:
            if not x.requires_grad:
                x.requires_grad = True
            x.retain_grad()
        elif detach:
            x = x.detach()
        if clone:
            x = x.clone()
        return x
    # Only dicts, lists, and tuples (and subclasses) can be copied.
    if isinstance(x, dict):
        return type(x)({k: recursive_copy(v) for k, v in x.items()})
    elif isinstance(x, (list, tuple)):
        return type(x)([recursive_copy(v) for v in x])
    else:
        assert False, f"Unknown type {type(x)} cannot be broken into tensors."

In [4]:
# model_name = "EleutherAI/gpt-j-6b"
model_name = "gpt2"
# model_name = "mistralai/Mistral-7B-v0.1"

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    torch_dtype = torch.float16
).to("cuda")
tokenizer.pad_token = tokenizer.eos_token

model.eval()
model.config

GPT2Config {
  "_name_or_path": "gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "torch_dtype": "float16",
  "transformers_version": "4.34.0",
  "use_cache": true,
  "vocab_size": 50257
}

In [5]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [6]:
if "gpt" in model.config._name_or_path:
    layer_name_format = "transformer.h.{}"
elif "mistral" in model.config._name_or_path:
    layer_name_format = "model.layers.{}"
else:
    raise AssertionError("Untested model")

n_embd = (
    model.config.n_embd 
    if "n_embd" in model.config.to_dict() 
    else model.config.hidden_size 
)

n_layer = (
    model.config.n_layer 
    if "n_layer" in model.config.to_dict() 
    else model.config.num_hidden_layers 
)

n_embd, n_layer

(768, 12)

In [7]:
prompt = "A quick brown fox"
tokenized = tokenizer(prompt, return_tensors="pt", padding="longest").to(model.device)

original_output = model(**tokenized)

In [8]:
def untuple(x):
    if isinstance(x, tuple):
        return x[0]
    return x

def intervention(
        intervention_layer,
        intervene_at, 
        replace_out_with
    ):
    def edit_output(layer, output):
        if layer != intervention_layer:
            return output
        untuple(output)[:, intervene_at] = replace_out_with
        return output
    return edit_output

layer = layer_name_format.format(8)
intervene_at = -1

with torch.no_grad():
    with baukit.Trace(
        model,
        layer = layer_name_format.format(8),
        retain_input=True, 
    ) as trace:
        output_without_intervention = model(**tokenized)
replace_out_with = untuple(trace.output)[:, intervene_at].clone().detach()
# replace_out_with = torch.randn(n_embd).to(model.dtype).to(model.device)

with torch.no_grad():
    with baukit.Trace(
        model,
        layer = layer_name_format.format(8),
        retain_input=True, 
        edit_output=intervention(
            intervention_layer=layer,
            intervene_at=intervene_at,
            replace_out_with=replace_out_with
        )
    ) as trace:
        output = model(**tokenized)


# check output after edit (basically the same)
torch.allclose(
    original_output.logits,
    output.logits
)

True

In [9]:
# intervention successful?
torch.allclose(
    untuple(trace.output)[:, intervene_at].squeeze(),
    replace_out_with,
)

True

In [10]:
if "gpt-j" in model.config._name_or_path:
    trace_inp = trace.input_kw["hidden_states"]
else:
    trace_inp = trace.input
trace_inp.shape

torch.Size([1, 4, 768])

In [11]:
untuple(trace.output).shape

torch.Size([1, 4, 768])

## Testing TraceDict

In [12]:
with torch.no_grad():
    with baukit.TraceDict(
        model,
        layers = [layer_name_format.format(i) for i in range(n_layer)],
        retain_input=True, 
        edit_output=intervention(
            intervention_layer=layer,
            intervene_at=intervene_at,
            replace_out_with=replace_out_with
        )
    ) as traces:
        output = model(**tokenized)

In [13]:
# trace dict input ok?

if "gpt-j" in model.config._name_or_path:
    trace_dict_inp = traces[layer].input_kw["hidden_states"]
else:
    trace_dict_inp = traces[layer].input

torch.allclose(
    trace_inp,
    trace_dict_inp
)

True

In [14]:
# trace dict output ok?
torch.allclose(
    untuple(trace.output),
    untuple(traces[layer].output)
)

True

In [15]:
# intervention successful?
torch.allclose(
    untuple(traces[layer].output)[:, intervene_at].squeeze(),
    replace_out_with,
)

True

## Testing `stopForward()`

In [29]:
with torch.no_grad():
    with baukit.Trace(
        model,
        layer = layer_name_format.format(8),
        retain_input=True, 
        stop=True
    ) as trace:
        model(**tokenized)

In [30]:
trace.output

(tensor([[[-0.9663, -1.9414,  0.8794,  ..., -0.9805, -0.7729,  0.2275],
          [ 6.2500, -0.9541, -7.9297,  ..., -2.5938,  1.1143, -1.2744],
          [ 5.5352,  3.2715, -1.4131,  ...,  1.6621,  0.5195, -2.6074],
          [ 0.8164,  3.0156, -2.4590,  ..., -0.2100, -2.4102, -0.6504]]],
        device='cuda:0', dtype=torch.float16),
 (tensor([[[[-0.0238, -2.3457,  0.1718,  ..., -0.2378, -0.1830,  0.0659],
            [-0.5811,  4.2266,  0.6685,  ...,  0.0909,  0.5693,  1.0293],
            [-0.0599,  5.0430,  0.2365,  ..., -0.3806, -1.8955,  0.6025],
            [-0.8174,  5.1133, -1.4238,  ..., -0.2462, -1.1680,  0.9092]],
  
           [[-0.8120,  0.2219,  0.4714,  ..., -0.5205,  1.0703,  1.1289],
            [ 0.4053,  0.8271,  0.5723,  ..., -0.8081,  1.3564,  0.0526],
            [ 1.0059,  1.0059,  1.0947,  ...,  1.3086,  1.5850,  0.6411],
            [ 0.8579,  1.4004,  0.2930,  ...,  1.6436,  2.5059, -0.1443]],
  
           [[-0.8535,  0.4805,  0.0164,  ...,  0.4956, -0.2296,