In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')
import torch
import transformers
import baukit

torch.__version__, transformers.__version__

('2.1.0+cu121', '4.34.0')

In [3]:
model_name = "EleutherAI/gpt-j-6b"
# model_name = "gpt2"
# model_name = "mistralai/Mistral-7B-v0.1"

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    torch_dtype = torch.float16
).to("cuda")
tokenizer.pad_token = tokenizer.eos_token

model.eval()
model.config

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


GPTJConfig {
  "_name_or_path": "EleutherAI/gpt-j-6b",
  "activation_function": "gelu_new",
  "architectures": [
    "GPTJForCausalLM"
  ],
  "attn_pdrop": 0.0,
  "bos_token_id": 50256,
  "embd_pdrop": 0.0,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gptj",
  "n_embd": 4096,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 28,
  "n_positions": 2048,
  "resid_pdrop": 0.0,
  "rotary": true,
  "rotary_dim": 64,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50,
      "temperature": 1.0
    }
  },
  "tie_word_embeddings": false,
  "tokenizer_class": "GPT2Tokenizer",
  "torch_dtype": "float16",
  "transformers_version": "4.34.0",
  "use_cache": true,
  "vocab_size": 50

In [4]:
model

GPTJForCausalLM(
  (transformer): GPTJModel(
    (wte): Embedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-27): 28 x GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): GPTJMLP(
          (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
          (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f)

In [5]:
if "gpt" in model.config._name_or_path:
    layer_name_format = "transformer.h.{}"
elif "mistral" in model.config._name_or_path:
    layer_name_format = "model.layers.{}"
else:
    raise AssertionError("Untested model")

n_embd = (
    model.config.n_embd 
    if "n_embd" in model.config.to_dict() 
    else model.config.hidden_size 
)

n_layer = (
    model.config.n_layer 
    if "n_layer" in model.config.to_dict() 
    else model.config.num_hidden_layers 
)

n_embd, n_layer

(4096, 28)

In [6]:
prompt = "A quick brown fox"
tokenized = tokenizer(prompt, return_tensors="pt", padding="longest").to(model.device)

original_output = model(**tokenized)

In [8]:
def untuple(x):
    if isinstance(x, tuple):
        return x[0]
    return x

def intervention(
        intervention_layer,
        intervene_at, 
        replace_out_with
    ):
    def edit_output(layer, output):
        if layer != intervention_layer:
            return output
        untuple(output)[:, intervene_at] = replace_out_with
        return output
    return edit_output

layer = layer_name_format.format(8)
intervene_at = -1
replace_out_with = torch.randn(n_embd).to(model.dtype).to(model.device)

with torch.no_grad():
    with baukit.Trace(
        model,
        layer = layer_name_format.format(8),
        retain_input=True, 
        # input_capture_all_keys=True,
        input_key = "hidden_states", # for GPT-J
        edit_output=intervention(
            intervention_layer=layer,
            intervene_at=intervene_at,
            replace_out_with=replace_out_with
        )
    ) as trace:
        output = model(**tokenized)

In [9]:
# check output
# ! make sure that edit_output is None in the trace

torch.allclose(
    original_output.logits,
    output.logits
)

False

In [10]:
# intervention successful?
torch.allclose(
    untuple(trace.output)[:, intervene_at].squeeze(),
    replace_out_with,
)

True

In [11]:
trace.input.shape

torch.Size([1, 4, 4096])

In [12]:
untuple(trace.output).shape

torch.Size([1, 4, 4096])

In [14]:
with torch.no_grad():
    with baukit.TraceDict(
        model,
        layers = [layer_name_format.format(i) for i in range(n_layer)],
        retain_input=True, 
        # input_capture_all_keys=True,
        input_key = "hidden_states",
        edit_output=intervention(
            intervention_layer=layer,
            intervene_at=intervene_at,
            replace_out_with=replace_out_with
        )
    ) as traces:
        output = model(**tokenized)

In [15]:
# intervention successful?
torch.allclose(
    untuple(traces[layer].output)[:, intervene_at].squeeze(),
    replace_out_with,
)

True

In [16]:
# trace dict input ok?
torch.allclose(
    trace.input,
    traces[layer].input
)

True

In [17]:
# trace dict output ok?
torch.allclose(
    untuple(trace.output),
    untuple(traces[layer].output)
)

True