In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')
import torch
import transformers
import baukit

torch.__version__, transformers.__version__

('2.1.0+cu121', '4.34.0')

In [3]:
def recursive_copy(x, clone=None, detach=None, retain_grad=None):
    """
    Copies a reference to a tensor, or an object that contains tensors,
    optionally detaching and cloning the tensor(s).  If retain_grad is
    true, the original tensors are marked to have grads retained.
    """
    if not clone and not detach and not retain_grad:
        return x
    if isinstance(x, torch.Tensor):
        if retain_grad:
            if not x.requires_grad:
                x.requires_grad = True
            x.retain_grad()
        elif detach:
            x = x.detach()
        if clone:
            x = x.clone()
        return x
    # Only dicts, lists, and tuples (and subclasses) can be copied.
    if isinstance(x, dict):
        return type(x)({k: recursive_copy(v) for k, v in x.items()})
    elif isinstance(x, (list, tuple)):
        return type(x)([recursive_copy(v) for v in x])
    else:
        assert False, f"Unknown type {type(x)} cannot be broken into tensors."

In [4]:
# model_name = "EleutherAI/gpt-j-6b"
# model_name = "gpt2"
# model_name = "mistralai/Mistral-7B-v0.1"
model_name = "/home/local_arnab/Codes/Weights/mistral-7B"

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    torch_dtype = torch.float16
).to("cuda")
tokenizer.pad_token = tokenizer.eos_token

model.eval()
model.config

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

MistralConfig {
  "_name_or_path": "/home/local_arnab/Codes/Weights/mistral-7B",
  "architectures": [
    "MistralForCausalLM"
  ],
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-05,
  "rope_theta": 10000.0,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.34.0",
  "use_cache": true,
  "vocab_size": 32000
}

In [5]:
model

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRM

In [6]:
if "gpt" in model.config._name_or_path:
    layer_name_format = "transformer.h.{}"
elif "mistral" in model.config._name_or_path:
    layer_name_format = "model.layers.{}"
else:
    raise AssertionError("Untested model")

n_embd = (
    model.config.n_embd 
    if "n_embd" in model.config.to_dict() 
    else model.config.hidden_size 
)

n_layer = (
    model.config.n_layer 
    if "n_layer" in model.config.to_dict() 
    else model.config.num_hidden_layers 
)

n_embd, n_layer

(4096, 32)

In [7]:
prompt = "A quick brown fox"
tokenized = tokenizer(prompt, return_tensors="pt", padding="longest").to(model.device)

original_output = model(**tokenized)

In [8]:
def untuple(x):
    if isinstance(x, tuple):
        return x[0]
    return x

def intervention(
        intervention_layer,
        intervene_at, 
        replace_out_with
    ):
    def edit_output(layer, output):
        if layer != intervention_layer:
            return output
        untuple(output)[:, intervene_at] = replace_out_with
        return output
    return edit_output

layer = layer_name_format.format(8)
intervene_at = -1

with torch.no_grad():
    with baukit.Trace(
        model,
        layer = layer_name_format.format(8),
        retain_input=True, 
    ) as trace:
        output_without_intervention = model(**tokenized)
replace_out_with = untuple(trace.output)[:, intervene_at].clone().detach()
# replace_out_with = torch.randn(n_embd).to(model.dtype).to(model.device)

with torch.no_grad():
    with baukit.Trace(
        model,
        layer = layer_name_format.format(8),
        retain_input=True, 
        edit_output=intervention(
            intervention_layer=layer,
            intervene_at=intervene_at,
            replace_out_with=replace_out_with
        )
    ) as trace:
        output = model(**tokenized)


# check output after edit (basically the same)
torch.allclose(
    original_output.logits,
    output.logits
)

True

In [9]:
# intervention successful?
torch.allclose(
    untuple(trace.output)[:, intervene_at].squeeze(),
    replace_out_with,
)

True

In [10]:
if "gpt-j" in model.config._name_or_path:
    trace_inp = trace.input_kw["hidden_states"]
else:
    trace_inp = trace.input
trace_inp.shape

torch.Size([1, 6, 4096])

In [11]:
untuple(trace.output).shape

torch.Size([1, 6, 4096])

## Testing TraceDict

In [12]:
with torch.no_grad():
    with baukit.TraceDict(
        model,
        layers = [layer_name_format.format(i) for i in range(n_layer)],
        retain_input=True, 
        edit_output=intervention(
            intervention_layer=layer,
            intervene_at=intervene_at,
            replace_out_with=replace_out_with
        )
    ) as traces:
        output = model(**tokenized)

In [13]:
# trace dict input ok?

if "gpt-j" in model.config._name_or_path:
    trace_dict_inp = traces[layer].input_kw["hidden_states"]
else:
    trace_dict_inp = traces[layer].input

torch.allclose(
    trace_inp,
    trace_dict_inp
)

True

In [14]:
# trace dict output ok?
torch.allclose(
    untuple(trace.output),
    untuple(traces[layer].output)
)

True

In [15]:
# intervention successful?
torch.allclose(
    untuple(traces[layer].output)[:, intervene_at].squeeze(),
    replace_out_with,
)

True