In [1]:
# some ipython magic to automatically reload any imports if they change
# (useful when iterating locally)

from IPython import get_ipython

# do this so that formatter not messed up
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

# Load Model

In [2]:
import transformer_lens
import tqdm
import wandb

# load the device we'll use (GPU or MPS)
device = transformer_lens.utils.get_device()

print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [3]:
# load our model
model_name = "gpt2-small"
model = transformer_lens.HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)



Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
# sanity check with an example
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
transformer_lens.utils.test_prompt(
    example_prompt,
    example_answer,
    model,
    prepend_bos=True,
)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|
Top 1th token. Logit: 15.38 Prob:  4.67% Token: | the|
Top 2th token. Logit: 15.35 Prob:  4.54% Token: | John|
Top 3th token. Logit: 15.25 Prob:  4.11% Token: | them|
Top 4th token. Logit: 14.84 Prob:  2.73% Token: | his|
Top 5th token. Logit: 14.06 Prob:  1.24% Token: | her|
Top 6th token. Logit: 13.54 Prob:  0.74% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.73% Token: | their|
Top 8th token. Logit: 13.13 Prob:  0.49% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.42% Token: | him|


# Define Transcoder Config

In [5]:
import dataclasses
import torch
from jaxtyping import Float

from min_transcoder.transcoder import (
    TranscoderResults,
    TranscoderConfig,
    Transcoder,
)


@dataclasses.dataclass
class TranscoderTrainingConfig:

    # Name of the layer to hook into for feature extraction
    hook_point: str
    out_hook_point: str

    num_epochs: int = 100

    # both from https://arxiv.org/html/2406.11944v1#S3 appendix E
    learning_rate: float = 2 * 10e-5
    l1_coefficient: float = 1e-4

    @property
    def hook_point_layer(self) -> int:
        "Parse out the hook point layer as int ex: 'blocks.8.ln2.hook_normalized' -> 8"
        return int(self.hook_point.split(".")[1])


@dataclasses.dataclass
class TranscoderLoss:
    mse_loss: Float[torch.Tensor, ""]
    l1_loss: Float[torch.Tensor, ""]


def compute_loss(
    cfg: TranscoderTrainingConfig,
    mlp_out: Float[torch.Tensor, "..."],
    results: TranscoderResults,
) -> TranscoderLoss:

    mse_loss_per_batch: Float[torch.Tensor, "..."] = (
        torch.pow((results.transcoder_out - mlp_out.float()), 2)
        / (mlp_out**2).sum(dim=-1, keepdim=True).sqrt()
    )

    mse_loss = mse_loss_per_batch.mean()

    sparsity = torch.abs(results.hidden_activations).sum(dim=1).mean(dim=(0,))

    # TODO(bschoen): Do we sum here?
    l1_loss = cfg.l1_coefficient * sparsity.mean()

    return TranscoderLoss(mse_loss=mse_loss, l1_loss=l1_loss)


# from https://arxiv.org/html/2406.11944v1#S3 appendix E
transcoder_expansion_factor = 32

transcoder_cfg = TranscoderConfig(
    d_in=model.cfg.d_model,
    d_out=model.cfg.d_model,
    # our transcoder has a hidden dimension of d_mlp * expansion factor
    d_hidden=model.cfg.d_mlp * transcoder_expansion_factor,
    dtype=model.cfg.dtype,
    device=device,
)

In [6]:
print(f"{model.cfg.n_layers=}")
print(f"{model.cfg.d_mlp=}")

model.cfg.n_layers=12
model.cfg.d_mlp=3072


# Load Data

In [7]:
import datasets
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader


def create_tokenized_dataloader(
    max_length: int = 128,
    batch_size: int = 128,
    num_samples: int = 10000,
) -> DataLoader:

    print("Loading dataset...")
    dataset = datasets.load_dataset(
        path="NeelNanda/pile-10k",
        split="train",
        streaming=False,
    )

    print("Tokenizing dataset...")
    token_dataset = transformer_lens.utils.tokenize_and_concatenate(
        dataset=dataset,
        tokenizer=model.tokenizer,
        streaming=True,
        max_length=max_length,
        add_bos_token=model.cfg.default_prepend_bos,
    )

    # token_dataset['tokens'].shape=torch.Size([136625, 128])
    # print(f"{token_dataset['tokens'].shape=}")

    # shuffle, and arbitrarily cap at around 10,000 / 130,000 (original caps at ~24k)
    token_dataset = token_dataset.shuffle(42)
    token_dataset = token_dataset.take(num_samples)

    token_dataset_torch = torch.from_numpy(
        np.stack([x["tokens"] for x in token_dataset])
    ).cuda()

    # torch.Size([100, 1024])
    print(token_dataset_torch.shape)

    # Create a DataLoader for batching
    #
    # for batch in dataloader:
    #     print(batch.shape) # torch.Size([32, 1024])
    #     break
    #
    print(f"Creating dataloader for dataset...")
    dataloader = torch.utils.data.DataLoader(
        token_dataset_torch,
        batch_size=batch_size,
        shuffle=True,
    )

    print(f"Num batches: {token_dataset_torch.shape[0] / batch_size}")

    return dataloader

# Collect Activations

Here we'll create hooks to store the MLP activations only

In [8]:
# Define training parameters
training_cfg = TranscoderTrainingConfig(
    num_epochs=5,
    hook_point="blocks.8.ln2.hook_normalized",
    out_hook_point="blocks.8.hook_mlp_out",
)

In [9]:
# store the MLP activations
mlp_inputs: list[Float[torch.Tensor, "batch seq d_mlp_in"]] = []
mlp_outputs: list[Float[torch.Tensor, "batch seq d_mlp_out"]] = []


# TODO(bschoen): Could make this general
def store_mlp_inputs(
    mlp_input: Float[torch.Tensor, "... d_in"],
    hook: transformer_lens.hook_points.HookPoint,
) -> None:

    # Detach and move to CPU to save memory
    mlp_inputs.append(mlp_input.detach().cpu())


def store_mlp_output(
    mlp_output: Float[torch.Tensor, "... d_out"],
    hook: transformer_lens.hook_points.HookPoint,
) -> None:

    # Detach and move to CPU to save memory
    mlp_outputs.append(mlp_output.detach().cpu())

In [10]:
dataloader = create_tokenized_dataloader()

# put model itself into eval mode so doesn't change
model.eval()

for batch_index, batch in tqdm.tqdm(
    enumerate(dataloader),
    desc="Collecting MLP activations",
):
    # move batch to device
    batch = batch.to(device)

    # Get MLP input and output activations
    model.run_with_hooks(
        batch,
        fwd_hooks=[
            (training_cfg.hook_point, store_mlp_inputs),
            (training_cfg.out_hook_point, store_mlp_output),
        ],
        return_type=None,
    )

Loading dataset...


Tokenizing dataset...
torch.Size([10000, 128])
Creating dataloader for dataset...
Num batches: 78.125


Collecting MLP activations: 79it [00:22,  3.50it/s]


In [11]:
# now we can unload gpu
torch.cuda.empty_cache()

In [12]:
print(f"{len(mlp_inputs)=}, {mlp_inputs[0].shape=}")
print(f"{len(mlp_outputs)=}, {mlp_outputs[0].shape=}")

len(mlp_inputs)=79, mlp_inputs[0].shape=torch.Size([128, 128, 768])
len(mlp_outputs)=79, mlp_outputs[0].shape=torch.Size([128, 128, 768])


In [13]:
# Custom Dataset
class MLPActivationsDataset(Dataset):
    def __init__(
        self,
        mlp_inputs: list[Float[torch.Tensor, "batch seq d_mlp_in"]],
        mlp_outputs: list[Float[torch.Tensor, "batch seq d_mlp_out"]],
    ) -> None:
        self.mlp_inputs = mlp_inputs
        self.mlp_outputs = mlp_outputs
        assert len(self.mlp_inputs) == len(
            self.mlp_outputs
        ), "Inputs and outputs must be the same length."

    def __len__(self) -> int:
        return len(self.mlp_inputs)

    def __getitem__(self, idx: int) -> tuple[
        Float[torch.Tensor, "batch seq d_mlp_in"],
        Float[torch.Tensor, "batch seq d_mlp_out"],
    ]:
        x = self.mlp_inputs[idx]  # Shape: [128, 128, 768]
        y = self.mlp_outputs[idx]  # Shape: [128, 128, 768]
        return x, y


# Create Dataset and DataLoader
activations_dataset = MLPActivationsDataset(mlp_inputs, mlp_outputs)
activations_dataloader = DataLoader(
    activations_dataset,
    shuffle=True,
)

# Train Transcoder

In [14]:
# Initialize wandb
wandb.init(
    project="transcoder_training_v2",
    config=dataclasses.asdict(training_cfg),
)

transcoder = Transcoder(cfg=transcoder_cfg)

transcoder = transcoder.to(device)

# Initialize optimizer
optimizer = torch.optim.AdamW(transcoder.parameters(), lr=training_cfg.learning_rate)

num_steps = 0

# Training loop
for epoch in range(training_cfg.num_epochs):

    for batch_index, batch in tqdm.tqdm(
        enumerate(activations_dataloader),
        desc=f"Epoch {epoch+1}/{training_cfg.num_epochs}",
    ):

        # Do a training step.
        transcoder.train()

        # Make sure the W_dec is still zero-norm
        transcoder.set_decoder_norm_to_unit_norm()

        optimizer.zero_grad()

        # move batch to device
        batch_x, batch_y = batch

        mlp_in = batch_x[0].to(device)
        mlp_out = batch_y[0].to(device)

        transcoder_results = transcoder(mlp_in)

        # Compute loss
        loss_result = compute_loss(training_cfg, mlp_out, transcoder_results)

        loss = loss_result.mse_loss + loss_result.l1_loss

        # Backward pass and optimization
        loss.backward()

        optimizer.step()

        num_steps += 1

        # Print loss statistics every 10 batches
        if batch_index % 10 == 0:
            print(
                f"Epoch {epoch+1}/{training_cfg.num_epochs}, "
                f"Batch {batch_index}/{len(activations_dataloader)}, "
                f"Loss: {loss.item():.6f}, "
                f"MSE Loss: {loss_result.mse_loss.item():.6f}, "
                f"L1 Loss: {loss_result.l1_loss.item():.6f}"
            )

            # Log metrics to wandb
            wandb.log(
                {
                    "epoch": epoch + 1,
                    "loss": loss.item(),
                    "mse_loss": loss_result.mse_loss.item(),
                    "l1_loss": loss_result.l1_loss.item(),
                },
                step=num_steps,
            )

    # Log model parameters and gradients
    # wandb.watch(transcoder)

print("Training completed!")

# Finish the wandb run
wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbronsonschoen[0m ([33mbronsonschoen-personal-use[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/5: 1it [00:00,  1.06it/s]

Epoch 1/5, Batch 0/79, Loss: 0.074227, MSE Loss: 0.073587, L1 Loss: 0.000640


Epoch 1/5: 11it [00:10,  1.17s/it]

Epoch 1/5, Batch 10/79, Loss: 0.038926, MSE Loss: 0.038475, L1 Loss: 0.000451


Epoch 1/5: 21it [00:19,  1.17s/it]

Epoch 1/5, Batch 20/79, Loss: 0.028708, MSE Loss: 0.028345, L1 Loss: 0.000362


Epoch 1/5: 31it [00:28,  1.17s/it]

Epoch 1/5, Batch 30/79, Loss: 0.025376, MSE Loss: 0.025038, L1 Loss: 0.000338


Epoch 1/5: 41it [00:37,  1.17s/it]

Epoch 1/5, Batch 40/79, Loss: 0.023608, MSE Loss: 0.023279, L1 Loss: 0.000329


Epoch 1/5: 51it [00:46,  1.17s/it]

Epoch 1/5, Batch 50/79, Loss: 0.022781, MSE Loss: 0.022444, L1 Loss: 0.000337


Epoch 1/5: 61it [00:55,  1.17s/it]

Epoch 1/5, Batch 60/79, Loss: 0.021832, MSE Loss: 0.021480, L1 Loss: 0.000351


Epoch 1/5: 73it [01:04,  1.50it/s]

Epoch 1/5, Batch 70/79, Loss: 0.020207, MSE Loss: 0.019856, L1 Loss: 0.000351


Epoch 1/5: 79it [01:10,  1.12it/s]
Epoch 2/5: 1it [00:01,  1.81s/it]

Epoch 2/5, Batch 0/79, Loss: 0.019571, MSE Loss: 0.019208, L1 Loss: 0.000363


Epoch 2/5: 11it [00:10,  1.18s/it]

Epoch 2/5, Batch 10/79, Loss: 0.018902, MSE Loss: 0.018528, L1 Loss: 0.000374


Epoch 2/5: 21it [00:20,  1.17s/it]

Epoch 2/5, Batch 20/79, Loss: 0.017751, MSE Loss: 0.017377, L1 Loss: 0.000374


Epoch 2/5: 31it [00:29,  1.17s/it]

Epoch 2/5, Batch 30/79, Loss: 0.017271, MSE Loss: 0.016889, L1 Loss: 0.000382


Epoch 2/5: 41it [00:37,  1.01s/it]

Epoch 2/5, Batch 40/79, Loss: 0.017180, MSE Loss: 0.016787, L1 Loss: 0.000393


Epoch 2/5: 51it [00:46,  1.16s/it]

Epoch 2/5, Batch 50/79, Loss: 0.015852, MSE Loss: 0.015461, L1 Loss: 0.000390


Epoch 2/5: 61it [00:55,  1.17s/it]

Epoch 2/5, Batch 60/79, Loss: 0.015757, MSE Loss: 0.015359, L1 Loss: 0.000398


Epoch 2/5: 71it [01:04,  1.17s/it]

Epoch 2/5, Batch 70/79, Loss: 0.014554, MSE Loss: 0.014161, L1 Loss: 0.000393


Epoch 2/5: 79it [01:11,  1.11it/s]
Epoch 3/5: 1it [00:01,  1.81s/it]

Epoch 3/5, Batch 0/79, Loss: 0.015146, MSE Loss: 0.014730, L1 Loss: 0.000417


Epoch 3/5: 11it [00:10,  1.18s/it]

Epoch 3/5, Batch 10/79, Loss: 0.014783, MSE Loss: 0.014366, L1 Loss: 0.000417


Epoch 3/5: 21it [00:19,  1.13s/it]

Epoch 3/5, Batch 20/79, Loss: 0.014747, MSE Loss: 0.014327, L1 Loss: 0.000420


Epoch 3/5: 31it [00:28,  1.17s/it]

Epoch 3/5, Batch 30/79, Loss: 0.013826, MSE Loss: 0.013412, L1 Loss: 0.000414


Epoch 3/5: 41it [00:37,  1.17s/it]

Epoch 3/5, Batch 40/79, Loss: 0.013388, MSE Loss: 0.012972, L1 Loss: 0.000416


Epoch 3/5: 51it [00:46,  1.17s/it]

Epoch 3/5, Batch 50/79, Loss: 0.013521, MSE Loss: 0.013098, L1 Loss: 0.000424


Epoch 3/5: 61it [00:55,  1.17s/it]

Epoch 3/5, Batch 60/79, Loss: 0.013531, MSE Loss: 0.013103, L1 Loss: 0.000428


Epoch 3/5: 71it [01:04,  1.17s/it]

Epoch 3/5, Batch 70/79, Loss: 0.012677, MSE Loss: 0.012258, L1 Loss: 0.000420


Epoch 3/5: 79it [01:11,  1.11it/s]
Epoch 4/5: 1it [00:01,  1.81s/it]

Epoch 4/5, Batch 0/79, Loss: 0.012594, MSE Loss: 0.012166, L1 Loss: 0.000428


Epoch 4/5: 11it [00:10,  1.18s/it]

Epoch 4/5, Batch 10/79, Loss: 0.012698, MSE Loss: 0.012267, L1 Loss: 0.000431


Epoch 4/5: 21it [00:20,  1.17s/it]

Epoch 4/5, Batch 20/79, Loss: 0.012645, MSE Loss: 0.012209, L1 Loss: 0.000436


Epoch 4/5: 31it [00:29,  1.17s/it]

Epoch 4/5, Batch 30/79, Loss: 0.011967, MSE Loss: 0.011544, L1 Loss: 0.000424


Epoch 4/5: 41it [00:37,  1.13s/it]

Epoch 4/5, Batch 40/79, Loss: 0.011701, MSE Loss: 0.011281, L1 Loss: 0.000419


Epoch 4/5: 51it [00:46,  1.17s/it]

Epoch 4/5, Batch 50/79, Loss: 0.011682, MSE Loss: 0.011253, L1 Loss: 0.000429


Epoch 4/5: 61it [00:55,  1.17s/it]

Epoch 4/5, Batch 60/79, Loss: 0.011613, MSE Loss: 0.011191, L1 Loss: 0.000422


Epoch 4/5: 71it [01:04,  1.17s/it]

Epoch 4/5, Batch 70/79, Loss: 0.011133, MSE Loss: 0.010719, L1 Loss: 0.000414


Epoch 4/5: 79it [01:11,  1.11it/s]
Epoch 5/5: 1it [00:01,  1.81s/it]

Epoch 5/5, Batch 0/79, Loss: 0.011404, MSE Loss: 0.010965, L1 Loss: 0.000439


Epoch 5/5: 11it [00:10,  1.18s/it]

Epoch 5/5, Batch 10/79, Loss: 0.011410, MSE Loss: 0.010979, L1 Loss: 0.000432


Epoch 5/5: 21it [00:19,  1.09s/it]

Epoch 5/5, Batch 20/79, Loss: 0.010928, MSE Loss: 0.010501, L1 Loss: 0.000428


Epoch 5/5: 31it [00:28,  1.17s/it]

Epoch 5/5, Batch 30/79, Loss: 0.011062, MSE Loss: 0.010627, L1 Loss: 0.000435


Epoch 5/5: 41it [00:37,  1.17s/it]

Epoch 5/5, Batch 40/79, Loss: 0.010954, MSE Loss: 0.010519, L1 Loss: 0.000435


Epoch 5/5: 51it [00:46,  1.17s/it]

Epoch 5/5, Batch 50/79, Loss: 0.010440, MSE Loss: 0.010016, L1 Loss: 0.000424


Epoch 5/5: 61it [00:55,  1.17s/it]

Epoch 5/5, Batch 60/79, Loss: 0.010704, MSE Loss: 0.010278, L1 Loss: 0.000426


Epoch 5/5: 71it [01:04,  1.17s/it]

Epoch 5/5, Batch 70/79, Loss: 0.010563, MSE Loss: 0.010137, L1 Loss: 0.000427


Epoch 5/5: 79it [01:11,  1.11it/s]


Training completed!


0,1
epoch,▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆████████
l1_loss,█▄▂▁▁▁▁▁▂▂▂▂▂▂▃▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
loss,█▄▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mse_loss,█▄▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,5.0
l1_loss,0.00043
loss,0.01056
mse_loss,0.01014


In [15]:
# Save the trained transcoder model to a file
import torch

# Define the path where you want to save the model
model_save_path = "full_transcoder_model.pth"

print(f"Transcoder model saved to {model_save_path}")

torch.save(transcoder, model_save_path)

Transcoder model saved to full_transcoder_model.pth


In [16]:
# Define the path where you want to save the model
model_save_path = "full_transcoder_model.pth"

# Load the full transcoder model
loaded_transcoder = torch.load(model_save_path)

loaded_transcoder.to(device)

print("Loaded transcoder")

# Set the loaded model to evaluation mode
loaded_transcoder.eval()

print(loaded_transcoder)  # Print the loaded model architecture

# Optionally, you can verify the model's parameters
for name, param in loaded_transcoder.named_parameters():
    print(f"Parameter: {name}, Shape: {param.shape}")

  loaded_transcoder = torch.load(model_save_path)


Loaded transcoder
Transcoder()
Parameter: W_enc, Shape: torch.Size([768, 98304])
Parameter: b_enc, Shape: torch.Size([98304])
Parameter: W_dec, Shape: torch.Size([98304, 768])
Parameter: b_dec, Shape: torch.Size([768])
Parameter: b_dec_out, Shape: torch.Size([768])


# Compute Loss When Substituting MLP with Transcoder

In [17]:
class _TranscoderWrapper(torch.nn.Module):
    def __init__(self, transcoder: Transcoder):
        super().__init__()
        self.transcoder = transcoder

    def forward(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_out"]:
        transcoder_result = self.transcoder(x)
        return transcoder_result.transcoder_out


@torch.no_grad()
def get_test_loss_when_replacing_mlp_with_transcoder(
    batch_tokens: Float[torch.Tensor, "batch seq"],
    transcoder: Transcoder,
    model: transformer_lens.HookedTransformer,
    hook_point_layer: str,
) -> Float[torch.Tensor, ""]:
    """
    A method for running the model with the SAE activations in order to return the
    loss returns per token loss when activations are substituted in.

    """
    old_mlp = model.blocks[hook_point_layer]

    model.blocks[hook_point_layer].mlp = _TranscoderWrapper(transcoder)

    ce_loss_with_recons = model.run_with_hooks(batch_tokens, return_type="loss")

    model.blocks[hook_point_layer] = old_mlp

    model.reset_hooks()

    return ce_loss_with_recons

In [18]:
# compute how much worse this makes the loss
#
# note: normally compare to ablated
#
transcoder = loaded_transcoder

transcoder.eval()

num_batches = 10

dataloader = create_tokenized_dataloader(num_samples=num_batches)

avg_loss_original = 0
avg_loss_when_replaced_mlp = 0

for batch_index, batch in enumerate(dataloader):

    if batch_index > num_batches:
        break

    batch = batch.to(device)

    loss_original = model.run_with_hooks(batch, return_type="loss")

    loss_when_replaced_mlp = get_test_loss_when_replacing_mlp_with_transcoder(
        batch_tokens=batch,
        transcoder=transcoder,
        model=model,
        hook_point_layer=training_cfg.hook_point_layer,
    )

    avg_loss_original += loss_original.item()
    avg_loss_when_replaced_mlp += loss_when_replaced_mlp.item()

avg_loss_original /= num_batches
avg_loss_when_replaced_mlp /= num_batches

print(f"{avg_loss_original=}")
print(f"{avg_loss_when_replaced_mlp=}")

Loading dataset...
Tokenizing dataset...
torch.Size([10, 128])
Creating dataloader for dataset...
Num batches: 0.078125
avg_loss_original=0.3654099225997925
avg_loss_when_replaced_mlp=0.3675143003463745


# Sanity Check - Indirect Object Identification

We quickly check that IOI isn't impacted (it shouldn't be, since we know it
doesn't depend much on MLP, but it's good to check against a known result).

In [19]:
import transformer_lens

# sanity check with an example
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
transformer_lens.utils.test_prompt(
    example_prompt,
    example_answer,
    model,
    prepend_bos=True,
)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.27 Prob: 74.62% Token: | Mary|
Top 1th token. Logit: 15.84 Prob:  6.59% Token: | John|
Top 2th token. Logit: 15.28 Prob:  3.73% Token: | the|
Top 3th token. Logit: 14.84 Prob:  2.41% Token: | his|
Top 4th token. Logit: 14.52 Prob:  1.75% Token: | them|
Top 5th token. Logit: 13.46 Prob:  0.61% Token: | her|
Top 6th token. Logit: 13.22 Prob:  0.48% Token: | a|
Top 7th token. Logit: 13.21 Prob:  0.47% Token: | Jesus|
Top 8th token. Logit: 13.04 Prob:  0.40% Token: | him|
Top 9th token. Logit: 12.96 Prob:  0.37% Token: | their|


In [20]:
old_mlp = model.blocks[training_cfg.hook_point_layer]

model.blocks[training_cfg.hook_point_layer].mlp = _TranscoderWrapper(transcoder)

transformer_lens.utils.test_prompt(
    example_prompt,
    example_answer,
    model,
    prepend_bos=True,
)

model.blocks[training_cfg.hook_point_layer] = old_mlp

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.27 Prob: 74.62% Token: | Mary|
Top 1th token. Logit: 15.84 Prob:  6.59% Token: | John|
Top 2th token. Logit: 15.28 Prob:  3.73% Token: | the|
Top 3th token. Logit: 14.84 Prob:  2.41% Token: | his|
Top 4th token. Logit: 14.52 Prob:  1.75% Token: | them|
Top 5th token. Logit: 13.46 Prob:  0.61% Token: | her|
Top 6th token. Logit: 13.22 Prob:  0.48% Token: | a|
Top 7th token. Logit: 13.21 Prob:  0.47% Token: | Jesus|
Top 8th token. Logit: 13.04 Prob:  0.40% Token: | him|
Top 9th token. Logit: 12.96 Prob:  0.37% Token: | their|


# Differences In Generated Text

In [21]:
prompt = "The speech is about"

generated_text = model.generate(
    prompt,
    max_new_tokens=100,
    temperature=0,
    stop_at_eos=True,
)

print(generated_text)

100%|██████████| 100/100 [00:02<00:00, 43.94it/s]

The speech is about the future of the United States.

The speech is about the future of the United States.

The speech is about the future of the United States.

The speech is about the future of the United States.

The speech is about the future of the United States.

The speech is about the future of the United States.

The speech is about the future of the United States.

The speech is about the future of the United States.







In [22]:
prompt = "The speech is about"

old_mlp = model.blocks[training_cfg.hook_point_layer]

model.blocks[training_cfg.hook_point_layer].mlp = _TranscoderWrapper(transcoder)

generated_text = model.generate(
    prompt,
    max_new_tokens=100,
    temperature=0,
    stop_at_eos=True,
)

model.blocks[training_cfg.hook_point_layer] = old_mlp

print(generated_text)

100%|██████████| 100/100 [00:01<00:00, 53.33it/s]

The speech is about the future of the United States.

The speech is about the future of the United States.

The speech is about the future of the United States.

The speech is about the future of the United States.

The speech is about the future of the United States.

The speech is about the future of the United States.

The speech is about the future of the United States.

The speech is about the future of the United States.





