## `Tracr` clean

## SGD compress a `tracr` model

In [166]:
tl_model.forward??

[0;31mSignature:[0m
[0mtl_model[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0minput[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m,[0m [0mjaxtyping[0m[0;34m.[0m[0mInt[0m[0;34m[[0m[0mTensor[0m[0;34m,[0m [0;34m'batch pos'[0m[0;34m][0m[0;34m,[0m [0mjaxtyping[0m[0;34m.[0m[0mFloat[0m[0;34m[[0m[0mTensor[0m[0;34m,[0m [0;34m'batch pos d_model'[0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreturn_type[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mstr[0m[0;34m][0m [0;34m=[0m [0;34m'logits'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mloss_per_token[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mbool[0m[0;34m][0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mprepend_bos[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mbool[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m  

## Train a `tracr` model 

In [162]:
import sys
sys.path.append('../tracr')

from tracr.rasp import rasp

device = 'cpu'

def make_length():
  all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
  return rasp.SelectorWidth(all_true_selector)

length = make_length()  # `length` is not a primitive in our implementation.
opp_index = length - rasp.indices - 1
flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ)
reverse = rasp.Aggregate(flip, rasp.tokens)

from tracr.compiler import compiling

bos = "BOS"
model = compiling.compile_rasp_to_model(
    reverse,
    vocab={1, 2, 3, 4, 5, 6},
    max_seq_len=5,
    compiler_bos=bos,
)

from transformer_lens import HookedTransformer, HookedTransformerConfig
import einops
import torch
import numpy as np

def tracr_to_tl(model):

    n_heads = model.model_config.num_heads
    n_layers = model.model_config.num_layers
    d_head = model.model_config.key_size
    d_mlp = model.model_config.mlp_hidden_size
    act_fn = "relu"
    normalization_type = "LN"  if model.model_config.layer_norm else None
    attention_type = "causal"  if model.model_config.causal else "bidirectional"


    n_ctx = model.params["pos_embed"]['embeddings'].shape[0]
    # Equivalent to length of vocab, with BOS and PAD at the end
    d_vocab = model.params["token_embed"]['embeddings'].shape[0]
    # Residual stream width, I don't know of an easy way to infer it from the above config.
    d_model = model.params["token_embed"]['embeddings'].shape[1]

    # Equivalent to length of vocab, WITHOUT BOS and PAD at the end because we never care about these outputs
    # In practice, we always feed the logits into an argmax
    d_vocab_out = model.params["token_embed"]['embeddings'].shape[0] - 2

    cfg = HookedTransformerConfig(
        n_layers=n_layers,
        d_model=d_model,
        d_head=d_head,
        n_ctx=n_ctx,
        d_vocab=d_vocab,
        d_vocab_out=d_vocab_out,
        d_mlp=d_mlp,
        n_heads=n_heads,
        act_fn=act_fn,
        attention_dir=attention_type,
        normalization_type=normalization_type,
    )
    tl_model = HookedTransformer(cfg)


    # %%
    sd = {}
    sd["pos_embed.W_pos"] = model.params["pos_embed"]['embeddings']
    sd["embed.W_E"] = model.params["token_embed"]['embeddings']
    # Equivalent to max_seq_len plus one, for the BOS

    # The unembed is just a projection onto the first few elements of the residual stream, these store output tokens
    # This is a NumPy array, the rest are Jax Arrays, but w/e it's fine.
    sd["unembed.W_U"] = np.eye(d_model, d_vocab_out)

    for l in range(n_layers):
        sd[f"blocks.{l}.attn.W_K"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/key"]["w"],
            "d_model (n_heads d_head) -> n_heads d_model d_head",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.b_K"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/key"]["b"],
            "(n_heads d_head) -> n_heads d_head",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.W_Q"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/query"]["w"],
            "d_model (n_heads d_head) -> n_heads d_model d_head",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.b_Q"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/query"]["b"],
            "(n_heads d_head) -> n_heads d_head",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.W_V"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/value"]["w"],
            "d_model (n_heads d_head) -> n_heads d_model d_head",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.b_V"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/value"]["b"],
            "(n_heads d_head) -> n_heads d_head",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.W_O"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/linear"]["w"],
            "(n_heads d_head) d_model -> n_heads d_head d_model",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.b_O"] = model.params[f"transformer/layer_{l}/attn/linear"]["b"]

        sd[f"blocks.{l}.mlp.W_in"] = model.params[f"transformer/layer_{l}/mlp/linear_1"]["w"]
        sd[f"blocks.{l}.mlp.b_in"] = model.params[f"transformer/layer_{l}/mlp/linear_1"]["b"]
        sd[f"blocks.{l}.mlp.W_out"] = model.params[f"transformer/layer_{l}/mlp/linear_2"]["w"]
        sd[f"blocks.{l}.mlp.b_out"] = model.params[f"transformer/layer_{l}/mlp/linear_2"]["b"]
    print(sd.keys())


    for k, v in sd.items():
        # I cannot figure out a neater way to go from a Jax array to a numpy array lol
        sd[k] = torch.tensor(np.array(v))

    tl_model.load_state_dict(sd, strict=False)

    return tl_model

tl_model = tracr_to_tl(model).to(device)

INPUT_ENCODER = model.input_encoder
OUTPUT_ENCODER = model.output_encoder

def create_model_input(input, input_encoder=INPUT_ENCODER):
    encoding = input_encoder.encode(input)
    return torch.tensor(encoding).unsqueeze(dim=0)

def decode_model_output(logits, output_encoder=OUTPUT_ENCODER, bos_token=INPUT_ENCODER.bos_token):
    max_output_indices = logits.squeeze(dim=0).argmax(dim=-1)
    decoded_output = output_encoder.decode(max_output_indices.tolist())
    decoded_output_with_bos = [bos_token] + decoded_output[1:]
    return decoded_output_with_bos

input = [bos, 1, 2, 3, 2]
out = model.apply(input)
print("Original Decoding:", out.decoded)

input_tokens_tensor = create_model_input(input)
logits = tl_model(input_tokens_tensor)
print(logits)
decoded_output = decode_model_output(logits)
print("TransformerLens Replicated Decoding:", decoded_output)

# Randomly initialise all weights in the model
#tl_model.init_weights()

logits = tl_model(input_tokens_tensor)
logits.squeeze(dim=0).max(dim=-1)

# Target is reverse input tokens tensor (shape 1x5)
idx = [i for i in range(input_tokens_tensor.size(1)-1, -1, -1)]
idx = torch.LongTensor(idx)
target = input_tokens_tensor[:, idx].type(torch.float32)
target

# Calculate L2 loss between logits and target
criterion = torch.nn.MSELoss()
loss = criterion(logits.squeeze(dim=0).argmax(dim=-1).cpu().type(torch.float32), target.squeeze().cpu())
loss

import torch
import itertools

def permutations_with_replacement_to_tensor(s):
    # Convert the set to a list to allow indexed access
    elements = list(s)
    n = len(elements)
    
    # Generate all permutations with replacement using itertools.product
    # This creates an iterator for all n-length combinations of the elements
    all_permutations = itertools.product(elements, repeat=n)
    
    # Convert iterator to list of lists
    permutations_list = [list(perm) for perm in all_permutations]
    
    # Convert list of lists to a PyTorch tensor
    # Each permutation is a row in the tensor
    permutations_tensor = torch.tensor(permutations_list)

    # Shuffle tensor
    perm_indices = torch.randperm(permutations_tensor.size(0))
    permutations_tensor = permutations_tensor[perm_indices]
    
    return permutations_tensor

# Example usage
s = {1, 2, 3, 4, 5, 6}
tensor = permutations_with_replacement_to_tensor(s)
print(tensor)
print(tensor.shape)

dict_keys(['pos_embed.W_pos', 'embed.W_E', 'unembed.W_U', 'blocks.0.attn.W_K', 'blocks.0.attn.b_K', 'blocks.0.attn.W_Q', 'blocks.0.attn.b_Q', 'blocks.0.attn.W_V', 'blocks.0.attn.b_V', 'blocks.0.attn.W_O', 'blocks.0.attn.b_O', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.attn.W_K', 'blocks.1.attn.b_K', 'blocks.1.attn.W_Q', 'blocks.1.attn.b_Q', 'blocks.1.attn.W_V', 'blocks.1.attn.b_V', 'blocks.1.attn.W_O', 'blocks.1.attn.b_O', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.2.attn.W_K', 'blocks.2.attn.b_K', 'blocks.2.attn.W_Q', 'blocks.2.attn.b_Q', 'blocks.2.attn.W_V', 'blocks.2.attn.b_V', 'blocks.2.attn.W_O', 'blocks.2.attn.b_O', 'blocks.2.mlp.W_in', 'blocks.2.mlp.b_in', 'blocks.2.mlp.W_out', 'blocks.2.mlp.b_out', 'blocks.3.attn.W_K', 'blocks.3.attn.b_K', 'blocks.3.attn.W_Q', 'blocks.3.attn.b_Q', 'blocks.3.attn.W_V', 'blocks.3.attn.b_V', 'blocks.3.attn.W_O', 'blocks.3.attn.b_O', 'blocks.3.ml

In [163]:
input = [bos, 1, 2, 3, 2]
out = model.apply(input)
print("Original Decoding:", out.decoded)

input_tokens_tensor = create_model_input(input)
logits = tl_model(input_tokens_tensor)
print(logits)
decoded_output = decode_model_output(logits)
print("TransformerLens Replicated Decoding:", decoded_output)

logits = tl_model(input_tokens_tensor)
max_output_indices = logits.squeeze(dim=0).argmax(dim=-1)
OUTPUT_ENCODER.decode(max_output_indices.tolist())

Original Decoding: ['BOS', 2, 3, 2, 1]
tensor([[[5.3889e-07, 1.0778e-06, 5.3889e-07, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.9040e-13, 1.0000e+00, 2.9040e-13, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.9040e-13, 5.8080e-13, 1.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.9040e-13, 1.0000e+00, 2.9040e-13, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [1.0000e+00, 5.8080e-13, 2.9040e-13, 0.0000e+00, 0.0000e+00,
          0.0000e+00]]], grad_fn=<AddBackward0>)
TransformerLens Replicated Decoding: ['BOS', 2, 3, 2, 1]


[2, 2, 3, 2, 1]

In [61]:
# Target is reverse input tokens tensor (shape 1x5)
idx = [i for i in range(input_tokens_tensor.size(1)-1, -1, -1)]
idx = torch.LongTensor(idx)
target = input_tokens_tensor[:, idx].type(torch.float32)
target

tensor([[1., 2., 1., 0., 6.]])

In [62]:
logits.squeeze(dim=0).argmax(dim=-1).cpu().type(torch.float32)

tensor([1., 2., 2., 2., 5.])

In [63]:
target.squeeze().cpu().type()

'torch.FloatTensor'

In [64]:
# Calculate L2 loss between logits and target
criterion = torch.nn.MSELoss()
loss = criterion(logits.squeeze(dim=0).argmax(dim=-1).cpu().type(torch.float32), target.squeeze().cpu())
loss

tensor(1.2000)

In [56]:
logits.squeeze(dim=0).argmax(dim=-1).cpu().type(torch.float32).shape, target.squeeze().cpu().shape

(torch.Size([7]), torch.Size([7]))

In [58]:
import sys
sys.path.append('../tracr')

from tracr.rasp import rasp
from tracr.compiler import compiling
from transformer_lens import HookedTransformer, HookedTransformerConfig
import einops
import torch
import numpy as np
import itertools

device = 'cpu'

def make_length():
    all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
    return rasp.SelectorWidth(all_true_selector)

length = make_length()
opp_index = length - rasp.indices - 1
flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ)
reverse = rasp.Aggregate(flip, rasp.tokens)

bos = "BOS"
model = compiling.compile_rasp_to_model(
    reverse,
    vocab={1, 2, 3, 4, 5, 6},
    max_seq_len=6,
    compiler_bos=bos,
)

tl_model = tracr_to_tl(model).to(device)

INPUT_ENCODER = model.input_encoder
OUTPUT_ENCODER = model.output_encoder

def create_model_input(input, input_encoder=INPUT_ENCODER):
    encoding = input_encoder.encode(input)
    return torch.tensor(encoding).unsqueeze(dim=0)

def decode_model_output(logits, output_encoder=OUTPUT_ENCODER, bos_token=INPUT_ENCODER.bos_token):
    max_output_indices = logits.squeeze(dim=0).argmax(dim=-1)
    decoded_output = output_encoder.decode(max_output_indices.tolist())
    decoded_output_with_bos = [bos_token] + decoded_output[1:]
    return decoded_output_with_bos

def permutations_with_replacement_to_tensor(s):
    elements = list(s)
    n = len(elements)
    all_permutations = itertools.product(elements, repeat=n)
    permutations_list = [list(perm) for perm in all_permutations]
    permutations_tensor = torch.tensor(permutations_list)
    perm_indices = torch.randperm(permutations_tensor.size(0))
    permutations_tensor = permutations_tensor[perm_indices]
    return permutations_tensor

# Randomly initialize all weights in the model
tl_model.init_weights()

# Generate all permutations of the input
input_permutations = permutations_with_replacement_to_tensor({1, 2, 3, 4, 5, 6})

# Define loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(tl_model.parameters(), lr=0.001)

# Training loop
num_epochs = 100
batch_size = 32

for epoch in range(num_epochs):
    for i in range(0, input_permutations.size(0), batch_size):
        batch = input_permutations[i:i+batch_size]
        input_tokens_tensor = create_model_input([bos] + batch.tolist()[0])
        logits = tl_model(input_tokens_tensor)
        
        idx = [i for i in range(input_tokens_tensor.size(1)-1, -1, -1)]
        idx = torch.LongTensor(idx)
        target = input_tokens_tensor[:, idx].type(torch.float32)
        
        # Convert logits to float32 and move to CPU
        logits = logits.squeeze(dim=0).cpu().type(torch.float32)
        
        # Convert target to float32 tensor and move to CPU
        target = target.squeeze().cpu().type(torch.float32)
        
        loss = criterion(logits, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

dict_keys(['pos_embed.W_pos', 'embed.W_E', 'unembed.W_U', 'blocks.0.attn.W_K', 'blocks.0.attn.b_K', 'blocks.0.attn.W_Q', 'blocks.0.attn.b_Q', 'blocks.0.attn.W_V', 'blocks.0.attn.b_V', 'blocks.0.attn.W_O', 'blocks.0.attn.b_O', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.attn.W_K', 'blocks.1.attn.b_K', 'blocks.1.attn.W_Q', 'blocks.1.attn.b_Q', 'blocks.1.attn.W_V', 'blocks.1.attn.b_V', 'blocks.1.attn.W_O', 'blocks.1.attn.b_O', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.2.attn.W_K', 'blocks.2.attn.b_K', 'blocks.2.attn.W_Q', 'blocks.2.attn.b_Q', 'blocks.2.attn.W_V', 'blocks.2.attn.b_V', 'blocks.2.attn.W_O', 'blocks.2.attn.b_O', 'blocks.2.mlp.W_in', 'blocks.2.mlp.b_in', 'blocks.2.mlp.W_out', 'blocks.2.mlp.b_out', 'blocks.3.attn.W_K', 'blocks.3.attn.b_K', 'blocks.3.attn.W_Q', 'blocks.3.attn.b_Q', 'blocks.3.attn.W_V', 'blocks.3.attn.b_V', 'blocks.3.attn.W_O', 'blocks.3.attn.b_O', 'blocks.3.ml

  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: The size of tensor a (6) must match the size of tensor b (7) at non-singleton dimension 1

### `tracr-reverse`

In [102]:
import sys
sys.path.append('../tracr')

from tracr.rasp import rasp

def make_length():
  all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
  return rasp.SelectorWidth(all_true_selector)

length = make_length()  # `length` is not a primitive in our implementation.
opp_index = length - rasp.indices - 1
flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ)
reverse = rasp.Aggregate(flip, rasp.tokens)

In [103]:
from tracr.compiler import compiling

bos = "BOS"
model = compiling.compile_rasp_to_model(
    reverse,
    vocab={1, 2, 3, 4, 5, 6},
    max_seq_len=5,
    compiler_bos=bos,
)

In [104]:
from transformer_lens import HookedTransformer, HookedTransformerConfig
import einops
import torch
import numpy as np

def tracr_to_tl(model):

    n_heads = model.model_config.num_heads
    n_layers = model.model_config.num_layers
    d_head = model.model_config.key_size
    d_mlp = model.model_config.mlp_hidden_size
    act_fn = "relu"
    normalization_type = "LN"  if model.model_config.layer_norm else None
    attention_type = "causal"  if model.model_config.causal else "bidirectional"


    n_ctx = model.params["pos_embed"]['embeddings'].shape[0]
    # Equivalent to length of vocab, with BOS and PAD at the end
    d_vocab = model.params["token_embed"]['embeddings'].shape[0]
    # Residual stream width, I don't know of an easy way to infer it from the above config.
    d_model = model.params["token_embed"]['embeddings'].shape[1]

    # Equivalent to length of vocab, WITHOUT BOS and PAD at the end because we never care about these outputs
    # In practice, we always feed the logits into an argmax
    d_vocab_out = model.params["token_embed"]['embeddings'].shape[0] - 2

    cfg = HookedTransformerConfig(
        n_layers=n_layers,
        d_model=d_model,
        d_head=d_head,
        n_ctx=n_ctx,
        d_vocab=d_vocab,
        d_vocab_out=d_vocab_out,
        d_mlp=d_mlp,
        n_heads=n_heads,
        act_fn=act_fn,
        attention_dir=attention_type,
        normalization_type=normalization_type,
    )
    tl_model = HookedTransformer(cfg)


    # %%
    sd = {}
    sd["pos_embed.W_pos"] = model.params["pos_embed"]['embeddings']
    sd["embed.W_E"] = model.params["token_embed"]['embeddings']
    # Equivalent to max_seq_len plus one, for the BOS

    # The unembed is just a projection onto the first few elements of the residual stream, these store output tokens
    # This is a NumPy array, the rest are Jax Arrays, but w/e it's fine.
    sd["unembed.W_U"] = np.eye(d_model, d_vocab_out)

    for l in range(n_layers):
        sd[f"blocks.{l}.attn.W_K"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/key"]["w"],
            "d_model (n_heads d_head) -> n_heads d_model d_head",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.b_K"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/key"]["b"],
            "(n_heads d_head) -> n_heads d_head",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.W_Q"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/query"]["w"],
            "d_model (n_heads d_head) -> n_heads d_model d_head",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.b_Q"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/query"]["b"],
            "(n_heads d_head) -> n_heads d_head",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.W_V"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/value"]["w"],
            "d_model (n_heads d_head) -> n_heads d_model d_head",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.b_V"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/value"]["b"],
            "(n_heads d_head) -> n_heads d_head",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.W_O"] = einops.rearrange(
            model.params[f"transformer/layer_{l}/attn/linear"]["w"],
            "(n_heads d_head) d_model -> n_heads d_head d_model",
            d_head = d_head,
            n_heads = n_heads
        )
        sd[f"blocks.{l}.attn.b_O"] = model.params[f"transformer/layer_{l}/attn/linear"]["b"]

        sd[f"blocks.{l}.mlp.W_in"] = model.params[f"transformer/layer_{l}/mlp/linear_1"]["w"]
        sd[f"blocks.{l}.mlp.b_in"] = model.params[f"transformer/layer_{l}/mlp/linear_1"]["b"]
        sd[f"blocks.{l}.mlp.W_out"] = model.params[f"transformer/layer_{l}/mlp/linear_2"]["w"]
        sd[f"blocks.{l}.mlp.b_out"] = model.params[f"transformer/layer_{l}/mlp/linear_2"]["b"]
    print(sd.keys())


    for k, v in sd.items():
        # I cannot figure out a neater way to go from a Jax array to a numpy array lol
        sd[k] = torch.tensor(np.array(v))

    tl_model.load_state_dict(sd, strict=False)

    return tl_model

tl_model = tracr_to_tl(model)

INPUT_ENCODER = model.input_encoder
OUTPUT_ENCODER = model.output_encoder

def create_model_input(input, input_encoder=INPUT_ENCODER):
    encoding = input_encoder.encode(input)
    return torch.tensor(encoding).unsqueeze(dim=0)

def decode_model_output(logits, output_encoder=OUTPUT_ENCODER, bos_token=INPUT_ENCODER.bos_token):
    max_output_indices = logits.squeeze(dim=0).argmax(dim=-1)
    decoded_output = output_encoder.decode(max_output_indices.tolist())
    decoded_output_with_bos = [bos_token] + decoded_output[1:]
    return decoded_output_with_bos

input = [bos, 1, 2, 3]
out = model.apply(input)
print("Original Decoding:", out.decoded)

input_tokens_tensor = create_model_input(input)
logits = tl_model(input_tokens_tensor)
decoded_output = decode_model_output(logits)
print("TransformerLens Replicated Decoding:", decoded_output)

dict_keys(['pos_embed.W_pos', 'embed.W_E', 'unembed.W_U', 'blocks.0.attn.W_K', 'blocks.0.attn.b_K', 'blocks.0.attn.W_Q', 'blocks.0.attn.b_Q', 'blocks.0.attn.W_V', 'blocks.0.attn.b_V', 'blocks.0.attn.W_O', 'blocks.0.attn.b_O', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.attn.W_K', 'blocks.1.attn.b_K', 'blocks.1.attn.W_Q', 'blocks.1.attn.b_Q', 'blocks.1.attn.W_V', 'blocks.1.attn.b_V', 'blocks.1.attn.W_O', 'blocks.1.attn.b_O', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.2.attn.W_K', 'blocks.2.attn.b_K', 'blocks.2.attn.W_Q', 'blocks.2.attn.b_Q', 'blocks.2.attn.W_V', 'blocks.2.attn.b_V', 'blocks.2.attn.W_O', 'blocks.2.attn.b_O', 'blocks.2.mlp.W_in', 'blocks.2.mlp.b_in', 'blocks.2.mlp.W_out', 'blocks.2.mlp.b_out', 'blocks.3.attn.W_K', 'blocks.3.attn.b_K', 'blocks.3.attn.W_Q', 'blocks.3.attn.b_Q', 'blocks.3.attn.W_V', 'blocks.3.attn.b_V', 'blocks.3.attn.W_O', 'blocks.3.attn.b_O', 'blocks.3.ml

In [105]:
import torch
import itertools

def permutations_with_replacement_to_tensor(s):
    # Convert the set to a list to allow indexed access
    elements = list(s)
    n = len(elements)
    
    # Generate all permutations with replacement using itertools.product
    # This creates an iterator for all n-length combinations of the elements
    all_permutations = itertools.product(elements, repeat=n)
    
    # Convert iterator to list of lists
    permutations_list = [list(perm) for perm in all_permutations]
    
    # Convert list of lists to a PyTorch tensor
    # Each permutation is a row in the tensor
    permutations_tensor = torch.tensor(permutations_list)

    # Shuffle tensor
    perm_indices = torch.randperm(permutations_tensor.size(0))
    permutations_tensor = permutations_tensor[perm_indices]
    
    return permutations_tensor

# Example usage
s = {1, 2, 3, 4, 5, 6}
tensor = permutations_with_replacement_to_tensor(s)
print(tensor)
print(tensor.shape)

tensor([[6, 4, 1, 2, 1, 6],
        [2, 1, 5, 6, 5, 6],
        [6, 6, 1, 2, 6, 6],
        ...,
        [6, 2, 6, 6, 5, 6],
        [4, 1, 2, 3, 3, 1],
        [2, 5, 2, 4, 2, 1]])
torch.Size([46656, 6])


In [106]:
def groundtruth_tl_tracr_model(model):
    circuit_components = []
    
    for i in range(model.cfg.n_layers):
        # Check if W_Q is not all zero
        if not torch.allclose(model.W_Q[i], torch.zeros_like(model.W_Q[i])):
            circuit_components.append(f"attn_q_{i}")
        # Check if W_K is not all zero
        if not torch.allclose(model.W_K[i], torch.zeros_like(model.W_K[i])):
            circuit_components.append(f"attn_k_{i}")
        # Check if W_V is not all zero
        if not torch.allclose(model.W_V[i], torch.zeros_like(model.W_V[i])):
            circuit_components.append(f"attn_v_{i}")

    return circuit_components

ground_truth = groundtruth_tl_tracr_model(tl_model)
ground_truth

['attn_q_0', 'attn_k_0', 'attn_v_0', 'attn_q_3', 'attn_k_3', 'attn_v_3']

In [107]:
def resid_cache_from_tl_tracr_model(model, input_tokens_tensor):
    _, cache = model.run_with_cache(input_tokens_tensor)
    labels = []

    # For each layer, get the K, Q and V tensors and MLP in and out
    attn_hook_k = []
    attn_hook_q = []
    attn_hook_v = []
    for i in range(model.cfg.n_layers):
        hook_k = cache[f'blocks.{i}.attn.hook_k']
        #hook_k = torch.nn.functional.pad(hook_k, (0, model.cfg.d_model - hook_k.shape[-1]))
        attn_hook_k.append(hook_k.squeeze().mean(dim=1) + torch.randn_like(hook_k.squeeze().mean(dim=1)))
        labels.append(f"attn_k_{i}")

        hook_q = cache[f'blocks.{i}.attn.hook_q']
        #hook_q = torch.nn.functional.pad(hook_q, (0, model.cfg.d_model - hook_q.shape[-1]))
        attn_hook_q.append(hook_q.squeeze().mean(dim=1) + torch.randn_like(hook_q.squeeze().mean(dim=1)))
        labels.append(f"attn_q_{i}")

        hook_v = cache[f'blocks.{i}.attn.hook_v']
        #hook_v = torch.nn.functional.pad(hook_v, (0, model.cfg.d_model - hook_v.shape[-1]))
        attn_hook_v.append(hook_v.squeeze().mean(dim=1) + torch.randn_like(hook_v.squeeze().mean(dim=1)))
        labels.append(f"attn_v_{i}")

    # Stack all tensors together
    attn_hook_k = torch.stack(attn_hook_k)
    attn_hook_q = torch.stack(attn_hook_q)
    attn_hook_v = torch.stack(attn_hook_v)

    # Stack everything
    final_resid = torch.cat([attn_hook_k, attn_hook_q, attn_hook_v], dim=0)
    return einops.rearrange(final_resid, 'c n d -> n c d'), labels
    
head_resid, head_labels = resid_cache_from_tl_tracr_model(tl_model, tensor[:250, :])
print(head_resid.shape)
print(head_labels)

torch.Size([250, 12, 12])
['attn_k_0', 'attn_q_0', 'attn_v_0', 'attn_k_1', 'attn_q_1', 'attn_v_1', 'attn_k_2', 'attn_q_2', 'attn_v_2', 'attn_k_3', 'attn_q_3', 'attn_v_3']


In [108]:
# Corrupted head resid is going through ground-truth circuit
corrupted_head_resid = head_resid.clone() #torch.randn_like(head_resid)
# If component is in ground-truth, set its resid to 0 everywhere, else add Gaussian noise
corrupted_head_resid = head_resid.clone()
for i, component in enumerate(head_labels):
    if component in ground_truth:
        print(f"Zeroing out {component}")
        corrupted_head_resid[:, i] = 0
        print(corrupted_head_resid[:, i])
    else:
        print(f"Corrupting {component}")
        corrupted_head_resid[:, i] += torch.randn_like(corrupted_head_resid[:, i])
        print(corrupted_head_resid[:, i])

print(corrupted_head_resid.shape)

Zeroing out attn_k_0
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0')
Zeroing out attn_q_0
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0')
Zeroing out attn_v_0
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0')
Corrupting attn_k_1
tensor([[ 0.9897,  0.7680,  0.2630,  ...,  0.0629,  1.0821,  1.5981],
        [ 0.0575,  0.6795, -0.8268,  

In [109]:
# Stack them together and save
resid_streams = torch.cat([head_resid, corrupted_head_resid], dim=0)
path = "../data/tracr-reverse/"
torch.save(resid_streams, path + "resid_heads_mean.pt")
# Save the head labels
torch.save(head_labels, path + "labels_heads_mean.pt")
# Save the ground truth
torch.save(ground_truth, path + "ground_truth.pt")

### `tracr-xproportion`

In [110]:
from tracr.compiler.lib import make_frac_prevs

model = compiling.compile_rasp_to_model(
      make_frac_prevs(rasp.tokens == "x"),
      vocab={"w", "x", "y", "z"},
      max_seq_len=6,
      compiler_bos="BOS",
      )

out = model.apply(["BOS", "w", "x", "y", "z"])
out.decoded

['BOS', 4.4194817525719976e-16, 0.5, 0.3333333432674408, 0.25]

In [111]:
tl_model = tracr_to_tl(model)

INPUT_ENCODER = model.input_encoder
OUTPUT_ENCODER = model.output_encoder

def create_model_input(input, input_encoder=INPUT_ENCODER):
    encoding = input_encoder.encode(input)
    return torch.tensor(encoding).unsqueeze(dim=0)

def decode_model_output(logits, output_encoder=OUTPUT_ENCODER, bos_token=INPUT_ENCODER.bos_token):
    max_output_indices = logits.squeeze(dim=0).argmax(dim=-1)
    decoded_output = output_encoder.decode(max_output_indices.tolist())
    decoded_output_with_bos = [bos_token] + decoded_output[1:]
    return decoded_output_with_bos

dict_keys(['pos_embed.W_pos', 'embed.W_E', 'unembed.W_U', 'blocks.0.attn.W_K', 'blocks.0.attn.b_K', 'blocks.0.attn.W_Q', 'blocks.0.attn.b_Q', 'blocks.0.attn.W_V', 'blocks.0.attn.b_V', 'blocks.0.attn.W_O', 'blocks.0.attn.b_O', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.attn.W_K', 'blocks.1.attn.b_K', 'blocks.1.attn.W_Q', 'blocks.1.attn.b_Q', 'blocks.1.attn.W_V', 'blocks.1.attn.b_V', 'blocks.1.attn.W_O', 'blocks.1.attn.b_O', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out'])


In [112]:
input = [bos, "x", "w", "w", "x"]
out = model.apply(input)
print("Original Decoding:", out.decoded)

input_tokens_tensor = create_model_input(input)
logits = tl_model(input_tokens_tensor)

logits

Original Decoding: ['BOS', 1.0, 0.5, 0.3333333432674408, 0.5]


tensor([[[4.2045e-08, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
         [5.0000e-01, 0.0000e+00, 1.0000e+00, 0.0000e+00],
         [3.3333e-01, 0.0000e+00, 0.0000e+00, 1.0000e+00],
         [5.0000e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00]]], device='mps:0',
       grad_fn=<AddBackward0>)

In [113]:
import torch
import itertools

def permutations_with_string_to_tensor(input_string):
    # Map each unique character to a unique integer
    unique_chars = set(input_string)
    char_to_index = {char: i for i, char in enumerate(unique_chars)}
    
    # Convert input_string to indices
    indices = [char_to_index[char] for char in input_string]
    n = len(indices)
    
    # Generate all permutations with replacement using itertools.product
    all_permutations = itertools.product(indices, repeat=n)
    
    # Convert iterator to list of lists
    permutations_list = [list(perm) for perm in all_permutations]
    
    # Convert list of lists to a PyTorch tensor of type long
    permutations_tensor = torch.tensor(permutations_list, dtype=torch.long)
    
    # Shuffle tensor
    perm_indices = torch.randperm(permutations_tensor.size(0))
    permutations_tensor = permutations_tensor[perm_indices]
    
    return permutations_tensor

# Example usage
s = {"w", "x", "y", "z"}
tensor = permutations_with_string_to_tensor(s)
print(tensor)
print(tensor.shape)

tensor([[3, 0, 3, 1],
        [0, 2, 1, 1],
        [1, 0, 0, 1],
        ...,
        [2, 1, 0, 2],
        [3, 2, 0, 1],
        [0, 2, 3, 2]])
torch.Size([256, 4])


In [114]:
ground_truth = groundtruth_tl_tracr_model(tl_model)
ground_truth

['attn_q_1', 'attn_k_1', 'attn_v_1']

In [115]:
head_resid, head_labels = resid_cache_from_tl_tracr_model(tl_model, tensor[:250, :])
print(head_resid.shape)
print(head_labels)

torch.Size([250, 6, 8])
['attn_k_0', 'attn_q_0', 'attn_v_0', 'attn_k_1', 'attn_q_1', 'attn_v_1']


In [116]:
# Corrupted head resid is going through ground-truth circuit
# If component is in ground-truth, set its resid to 0 everywhere, else add Gaussian noise
corrupted_head_resid = head_resid.clone()
for i, component in enumerate(head_labels):
    if component in ground_truth:
        print(f"Zeroing out {component}")
        corrupted_head_resid[:, i] = 0
        print(corrupted_head_resid[:, i])
    else:
        print(f"Corrupting {component}")
        corrupted_head_resid[:, i] += torch.randn_like(corrupted_head_resid[:, i])
        print(corrupted_head_resid[:, i])

print(corrupted_head_resid.shape)

Corrupting attn_k_0
tensor([[-0.4801, -1.2460, -0.9162,  ..., -3.5015, -0.8105, -0.6292],
        [-0.7809,  0.5113, -1.2693,  ...,  1.1346,  0.7321, -0.5356],
        [-1.6931, -1.8289,  0.5035,  ..., -2.7695,  0.0945, -0.5288],
        ...,
        [ 0.5398, -0.4671, -0.3707,  ..., -1.1552,  0.4300,  0.8700],
        [-2.1709,  2.3685, -0.0539,  ..., -0.5963, -1.0306, -0.0327],
        [ 0.5795,  3.8303,  1.3946,  ...,  1.9283, -1.0473,  0.5952]],
       device='mps:0')
Corrupting attn_q_0
tensor([[-1.1757,  1.4945,  1.1378,  ..., -0.8040, -0.5693, -0.7115],
        [ 0.2037,  2.7882, -0.2064,  ..., -0.0908,  0.6268,  2.5802],
        [ 1.0321,  0.0652,  0.1803,  ..., -0.3469, -0.4476, -3.3302],
        ...,
        [ 1.6057,  3.2537, -1.4321,  ...,  0.4479,  2.0291,  0.1824],
        [ 0.6302,  1.1082,  0.5714,  ..., -1.3209,  3.1089, -0.4411],
        [-0.8440, -0.4489,  1.0039,  ...,  1.2577,  1.6996,  0.4194]],
       device='mps:0')
Corrupting attn_v_0
tensor([[ 0.1370,  2.5469,

In [117]:
# Stack them together and save
resid_streams = torch.cat([head_resid, corrupted_head_resid], dim=0)
path = "../data/tracr-fracprev/"
torch.save(resid_streams, path + "resid_heads_mean.pt")
# Save the head labels
torch.save(head_labels, path + "labels_heads_mean.pt")
# Save the ground truth
torch.save(ground_truth, path + "ground_truth.pt")

### `tracr-count`

In [140]:
from tracr.compiler.lib import make_sort
import tracr.compiler.lib as lib

vocab = {1, 2, 3, 4, 5, 6}
max_seq_len = 6
program = lib.make_sort(rasp.tokens, rasp.tokens, max_seq_len=max_seq_len, min_key=1)

model = compiling.compile_rasp_to_model(
    program,
    vocab=vocab,
    max_seq_len=max_seq_len,
    compiler_bos="BOS",
)

out = model.apply(["BOS", 3, 2, 1])
print("Original Decoding:", out.decoded)

Original Decoding: ['BOS', 1, 2, 3]


In [141]:
tl_model = tracr_to_tl(model)

INPUT_ENCODER = model.input_encoder
OUTPUT_ENCODER = model.output_encoder

def create_model_input(input, input_encoder=INPUT_ENCODER):
    encoding = input_encoder.encode(input)
    return torch.tensor(encoding).unsqueeze(dim=0)

def decode_model_output(logits, output_encoder=OUTPUT_ENCODER, bos_token=INPUT_ENCODER.bos_token):
    max_output_indices = logits.squeeze(dim=0).argmax(dim=-1)
    decoded_output = output_encoder.decode(max_output_indices.tolist())
    decoded_output_with_bos = [bos_token] + decoded_output[1:]
    return decoded_output_with_bos

dict_keys(['pos_embed.W_pos', 'embed.W_E', 'unembed.W_U', 'blocks.0.attn.W_K', 'blocks.0.attn.b_K', 'blocks.0.attn.W_Q', 'blocks.0.attn.b_Q', 'blocks.0.attn.W_V', 'blocks.0.attn.b_V', 'blocks.0.attn.W_O', 'blocks.0.attn.b_O', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.attn.W_K', 'blocks.1.attn.b_K', 'blocks.1.attn.W_Q', 'blocks.1.attn.b_Q', 'blocks.1.attn.W_V', 'blocks.1.attn.b_V', 'blocks.1.attn.W_O', 'blocks.1.attn.b_O', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.2.attn.W_K', 'blocks.2.attn.b_K', 'blocks.2.attn.W_Q', 'blocks.2.attn.b_Q', 'blocks.2.attn.W_V', 'blocks.2.attn.b_V', 'blocks.2.attn.W_O', 'blocks.2.attn.b_O', 'blocks.2.mlp.W_in', 'blocks.2.mlp.b_in', 'blocks.2.mlp.W_out', 'blocks.2.mlp.b_out'])


In [142]:
# Example usage
s = {1, 2, 3, 4, 5, 6}
tensor = permutations_with_replacement_to_tensor(s)
print(tensor)
print(tensor.shape)

tensor([[1, 3, 1, 5, 4, 5],
        [2, 3, 5, 6, 1, 2],
        [6, 6, 1, 4, 3, 6],
        ...,
        [3, 1, 3, 2, 3, 6],
        [3, 1, 6, 2, 4, 1],
        [5, 3, 6, 4, 5, 2]])
torch.Size([46656, 6])


In [143]:
ground_truth = groundtruth_tl_tracr_model(tl_model)
ground_truth

['attn_q_1', 'attn_k_1', 'attn_v_1', 'attn_q_2', 'attn_k_2', 'attn_v_2']

In [144]:
head_resid, head_labels = resid_cache_from_tl_tracr_model(tl_model, tensor[:250, :])
print(head_resid.shape)
print(head_labels)

torch.Size([250, 9, 38])
['attn_k_0', 'attn_q_0', 'attn_v_0', 'attn_k_1', 'attn_q_1', 'attn_v_1', 'attn_k_2', 'attn_q_2', 'attn_v_2']


In [145]:
# Corrupted head resid is going through ground-truth circuit
# If component is in ground-truth, set its resid to 0 everywhere, else add Gaussian noise
corrupted_head_resid = head_resid.clone()
for i, component in enumerate(head_labels):
    if component in ground_truth:
        print(f"Zeroing out {component}")
        corrupted_head_resid[:, i] = 0
        print(corrupted_head_resid[:, i])
    else:
        print(f"Corrupting {component}")
        corrupted_head_resid[:, i] += torch.randn_like(corrupted_head_resid[:, i])
        print(corrupted_head_resid[:, i])

print(corrupted_head_resid.shape)

Corrupting attn_k_0
tensor([[ 2.3146, -1.5225,  0.4979,  ...,  0.3677, -1.4383, -1.5451],
        [ 3.4963, -0.3185, -0.6330,  ..., -1.3146,  0.9055,  0.9480],
        [ 0.7630,  0.3394, -0.1548,  ..., -2.7482,  0.9594,  1.7201],
        ...,
        [-0.0616,  1.0126,  0.2954,  ...,  2.9095, -1.6609, -0.0842],
        [ 0.4575, -0.9681, -0.0613,  ...,  4.3568, -1.9714, -0.1920],
        [ 1.4183,  1.2655, -1.9858,  ..., -0.1248,  1.5747,  0.1844]],
       device='mps:0')
Corrupting attn_q_0
tensor([[-0.9063,  0.4985, -2.2621,  ...,  0.8974, -0.4606,  2.1218],
        [-1.4335, -0.4144,  1.5515,  ..., -0.9970,  2.8793, -1.8183],
        [-0.6547,  0.0983,  0.1909,  ...,  0.0919,  1.7877, -0.9679],
        ...,
        [-2.7284,  0.8806, -1.4835,  ..., -0.5953,  0.0109, -1.0188],
        [-0.0322,  0.5039,  0.1045,  ...,  0.1217,  1.5475, -0.3798],
        [-2.3106, -1.0567, -0.7819,  ..., -0.6577,  2.0268,  2.8858]],
       device='mps:0')
Corrupting attn_v_0
tensor([[ 0.1504,  0.1643,

In [146]:
# Stack them together and save
resid_streams = torch.cat([head_resid, corrupted_head_resid], dim=0)
path = "../data/tracr-sort/"
torch.save(resid_streams, path + "resid_heads_mean.pt")
# Save the head labels
torch.save(head_labels, path + "labels_heads_mean.pt")
# Save the ground truth
torch.save(ground_truth, path + "ground_truth.pt")

### `tracr-sortfreq`

In [149]:
from tracr.compiler.lib import make_sort
import tracr.compiler.lib as lib

vocab = {'a', 'b', 'c', 'd', 'e', 'f'}
max_seq_len = 6
program = lib.make_sort_freq(max_seq_len=max_seq_len)

model = compiling.compile_rasp_to_model(
    program,
    vocab=vocab,
    max_seq_len=max_seq_len,
    compiler_bos="BOS",
)

out = model.apply(["BOS", 'a', 'a', 'b', 'a', 'b', 'c'])
print("Original Decoding:", out.decoded)

Original Decoding: ['BOS', 'a', 'a', 'a', 'b', 'b', 'c']


In [150]:
tl_model = tracr_to_tl(model)

INPUT_ENCODER = model.input_encoder
OUTPUT_ENCODER = model.output_encoder

def create_model_input(input, input_encoder=INPUT_ENCODER):
    encoding = input_encoder.encode(input)
    return torch.tensor(encoding).unsqueeze(dim=0)

def decode_model_output(logits, output_encoder=OUTPUT_ENCODER, bos_token=INPUT_ENCODER.bos_token):
    max_output_indices = logits.squeeze(dim=0).argmax(dim=-1)
    decoded_output = output_encoder.decode(max_output_indices.tolist())
    decoded_output_with_bos = [bos_token] + decoded_output[1:]
    return decoded_output_with_bos

dict_keys(['pos_embed.W_pos', 'embed.W_E', 'unembed.W_U', 'blocks.0.attn.W_K', 'blocks.0.attn.b_K', 'blocks.0.attn.W_Q', 'blocks.0.attn.b_Q', 'blocks.0.attn.W_V', 'blocks.0.attn.b_V', 'blocks.0.attn.W_O', 'blocks.0.attn.b_O', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.attn.W_K', 'blocks.1.attn.b_K', 'blocks.1.attn.W_Q', 'blocks.1.attn.b_Q', 'blocks.1.attn.W_V', 'blocks.1.attn.b_V', 'blocks.1.attn.W_O', 'blocks.1.attn.b_O', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.2.attn.W_K', 'blocks.2.attn.b_K', 'blocks.2.attn.W_Q', 'blocks.2.attn.b_Q', 'blocks.2.attn.W_V', 'blocks.2.attn.b_V', 'blocks.2.attn.W_O', 'blocks.2.attn.b_O', 'blocks.2.mlp.W_in', 'blocks.2.mlp.b_in', 'blocks.2.mlp.W_out', 'blocks.2.mlp.b_out', 'blocks.3.attn.W_K', 'blocks.3.attn.b_K', 'blocks.3.attn.W_Q', 'blocks.3.attn.b_Q', 'blocks.3.attn.W_V', 'blocks.3.attn.b_V', 'blocks.3.attn.W_O', 'blocks.3.attn.b_O', 'blocks.3.ml

In [151]:
# Example usage
s = {'a', 'b', 'c', 'd', 'e', 'f'}
tensor = permutations_with_string_to_tensor(s)
print(tensor)
print(tensor.shape)

tensor([[5, 3, 1, 2, 4, 2],
        [0, 2, 5, 0, 4, 4],
        [1, 3, 2, 5, 5, 2],
        ...,
        [1, 0, 0, 4, 1, 1],
        [3, 5, 0, 4, 5, 0],
        [5, 2, 2, 5, 5, 0]])
torch.Size([46656, 6])


In [152]:
ground_truth = groundtruth_tl_tracr_model(tl_model)
ground_truth

['attn_q_0',
 'attn_k_0',
 'attn_v_0',
 'attn_q_3',
 'attn_k_3',
 'attn_v_3',
 'attn_q_4',
 'attn_k_4',
 'attn_v_4']

In [153]:
head_resid, head_labels = resid_cache_from_tl_tracr_model(tl_model, tensor[:250, :])
print(head_resid.shape)
print(head_labels)

torch.Size([250, 15, 44])
['attn_k_0', 'attn_q_0', 'attn_v_0', 'attn_k_1', 'attn_q_1', 'attn_v_1', 'attn_k_2', 'attn_q_2', 'attn_v_2', 'attn_k_3', 'attn_q_3', 'attn_v_3', 'attn_k_4', 'attn_q_4', 'attn_v_4']


In [154]:
# Corrupted head resid is going through ground-truth circuit
corrupted_head_resid = head_resid.clone() #torch.randn_like(head_resid)
# If component is in ground-truth, set its resid to 0 everywhere, else add Gaussian noise
corrupted_head_resid = head_resid.clone()
for i, component in enumerate(head_labels):
    if component in ground_truth:
        print(f"Zeroing out {component}")
        corrupted_head_resid[:, i] = 0
        print(corrupted_head_resid[:, i])
    else:
        print(f"Corrupting {component}")
        corrupted_head_resid[:, i] += torch.randn_like(corrupted_head_resid[:, i])
        print(corrupted_head_resid[:, i])

print(corrupted_head_resid.shape)

Zeroing out attn_k_0
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0')
Zeroing out attn_q_0
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0')
Zeroing out attn_v_0
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0')
Corrupting attn_k_1
tensor([[-0.1315, -0.2047,  1.4654,  ..., -0.7236, -1.1208,  1.7551],
        [ 0.6224, -0.9399,  0.9719,  

In [155]:
# Stack them together and save
resid_streams = torch.cat([head_resid, corrupted_head_resid], dim=0)
path = "../data/tracr-sortfreq/"
torch.save(resid_streams, path + "resid_heads_mean.pt")
# Save the head labels
torch.save(head_labels, path + "labels_heads_mean.pt")
# Save the ground truth
torch.save(ground_truth, path + "ground_truth.pt")