In [1]:
import sys
sys.path.append('..')
sys.path.append('../src')
sys.path.append('../data')

In [2]:
import torch
from transformer_lens import HookedTransformer
import json
from src.model import WrapHookedTransformer
import src.nanda_plot
from src.nanda_plot import imshow_reversed, imshow

import transformer_lens.utils as utils
from transformer_lens.utils import get_act_name
from functools import partial
from transformer_lens import patching
import plotly.express as px


%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Generate dataset

In [3]:

class Dataset:
    def __init__(self, target_dataset, orthogonal_dataset, model:WrapHookedTransformer):
        self.target_dataset = target_dataset
        self.orthogonal_dataset = orthogonal_dataset
        self.target_dataset_per_length, self.orthogonal_dataset_per_length = self.split_for_lenght()

    def random_sample(self, n):
        possible_lengths = []
        for length in self.target_dataset_per_length.keys():
            if len(self.target_dataset_per_length[length]) >= n and len(self.orthogonal_dataset_per_length[length]) > 0:
                possible_lengths.append(length)
                
        length = random.choice(possible_lengths)
        self.dataset_per_length = {length: random.sample(self.target_dataset_per_length[length], n) + random.sample(self.orthogonal_dataset_per_length[length], n)}

    def split_for_lenght(self):
        target_dataset_per_length = {}
        for d in self.target_dataset:
            length = d["length"]
            if length not in target_dataset_per_length:
                target_dataset_per_length[length] = []
            target_dataset_per_length[length].append(d)
            
        orthogonal_dataset_per_length = {}
        for d in self.orthogonal_dataset:
            length = d["length"]
            if length not in orthogonal_dataset_per_length:
                orthogonal_dataset_per_length[length] = []
            orthogonal_dataset_per_length[length].append(d)
        return target_dataset_per_length, orthogonal_dataset_per_length
    
    def logits(self, model:WrapHookedTransformer):
        logits_per_length = {}
        for length, dataset in self.dataset_per_length.items():
            input_ids = model.to_tokens([d["premise"] for d in dataset])
            logits_per_length[length] = model(input_ids)
        return logits_per_length
  
    def get_tensor_token(self,model):
        tensor_token_per_length = {}
        for length, dataset in self.dataset_per_length.items():
            if length not in tensor_token_per_length:
                tensor_token_per_length[length] = {}
            tensor_token_per_length[length]["target"] = model.to_tokens([d["target"] for d in dataset], prepend_bos=False)
            tensor_token_per_length[length]["orthogonal_token"] = model.to_tokens([d["orthogonal_token"] for d in dataset], prepend_bos=False)
        
        for length, tensor in tensor_token_per_length.items():
            tensor_token_per_length[length]["target"] = tensor_token_per_length[length]["target"].squeeze(1)
            tensor_token_per_length[length]["orthogonal_token"] = tensor_token_per_length[length]["orthogonal_token"].squeeze(1)
        return tensor_token_per_length
    
MODEL_NAME = "gpt2small"
    
model = WrapHookedTransformer.from_pretrained("gpt2", device="cpu")
import random
target_data = json.load(open("../data/target_win_dataset_{}.json".format(MODEL_NAME)))
#suffhle 
orthogonal_data = json.load(open("../data/orthogonal_win_dataset_{}.json".format(MODEL_NAME)))
orthogonal_data = random.sample(orthogonal_data, len(target_data))
dataset = Dataset(target_data,orthogonal_data, model)
    

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [11]:
dataset.random_sample(10)
logits_per_length = dataset.logits(model)
tokens_dict_per_length = dataset.get_tensor_token(model)

delta_per_length = {}
for max_len in list(tokens_dict_per_length.keys()):
    if max_len not in delta_per_length:
        delta_per_length[max_len] = {}
    # probs = torch.softmax(logits_per_length[max_len], dim=-1)
    probs = logits_per_length[max_len]
    batch_index = torch.arange(probs.shape[0])
    delta_per_length[max_len] = probs[batch_index,-1, tokens_dict_per_length[max_len]["target"]] - probs[batch_index,-1, tokens_dict_per_length[max_len]["orthogonal_token"]]
delta_per_length

{14.0: tensor([  1.6992,   2.5517,   2.8656,   3.9811,   3.9914,   1.1745,   5.7033,
           0.0552,   0.9692,   8.9047,  -6.7607,  -4.9450, -19.7999, -13.7948,
         -14.2011, -11.2658, -14.8811, -14.1361,  -7.0904, -12.8447])}

In [28]:
for item in dataset.dataset_per_length[14]:
    print(model.to_str_tokens(item["premise"]))

['<|endoftext|>', 'Su', 'z', 'uki', ' Splash', ',', ' developed', ' by', ' discovered', ' Suzuki', ' Splash', ',', ' developed', ' by']
['<|endoftext|>', 'Mos', ' Def', ' follows', ' the', ' religion', ' of', 'do', ' Mos', ' Def', ' follows', ' the', ' religion', ' of']
['<|endoftext|>', 'Windows', ' 2000', ' was', ' a', ' product', ' of', 'pace', ' Windows', ' 2000', ' was', ' a', ' product', ' of']
['<|endoftext|>', 'Virtual', ' Boy', ' is', ' a', ' product', ' of', ' pine', ' Virtual', ' Boy', ' is', ' a', ' product', ' of']
['<|endoftext|>', 'Apple', ' Thunderbolt', ' Display', ' was', ' created', ' by', 'omb', ' Apple', ' Thunderbolt', ' Display', ' was', ' created', ' by']
['<|endoftext|>', 'CBS', ' News', ' is', ' to', ' debut', ' on', ' Toad', ' CBS', ' News', ' is', ' to', ' debut', ' on']
['<|endoftext|>', 'Apple', ' Wireless', ' Mouse', ' was', ' developed', ' by', 'entin', ' Apple', ' Wireless', ' Mouse', ' was', ' developed', ' by']
['<|endoftext|>', 'Apple', ' Wireless', 

# Study just a single length

In [13]:
max_len = 14
def embs_to_tokens_ids(noisy_embs, model:WrapHookedTransformer):

    input_embedding_norm = torch.functional.F.normalize(noisy_embs, p=2, dim=2)
    embedding_matrix_norm = torch.functional.F.normalize(model.W_E, p=2, dim=1)
    similarity = torch.matmul(input_embedding_norm, embedding_matrix_norm.T)
    corrupted_tokens = torch.argmax(similarity, dim=2)
    return corrupted_tokens

In [130]:
from src.patching import get_act_patch_block_every, get_act_patch_resid_pre, get_act_patch_attn_head_out_all_pos


def patch_resid_pre(model, input_ids, input_ids_corrupted, clean_cache, metric):
    resid_pre_act_patch_results = get_act_patch_resid_pre(
        model, input_ids_corrupted, clean_cache, metric
    )
    return resid_pre_act_patch_results


def patch_attn_head_out_all_pos(
    model, input_ids, input_ids_corrupted, clean_cache, metric, input_embeddings
):
    attn_head_out_all_pos_act_patch_results = (
        get_act_patch_attn_head_out_all_pos(
            model, input_ids_corrupted, clean_cache, metric, corrupted_embeddings=input_embeddings
        )
    )
    return attn_head_out_all_pos_act_patch_results


def patch_attn_head_by_pos(model, input_ids, input_ids_corrupted, clean_cache, metric):
    ALL_HEAD_LABELS = [
        f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)
    ]
    import einops

    attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(
        model, input_ids_corrupted, clean_cache, metric
    )
    attn_head_out_act_patch_results = einops.rearrange(
        attn_head_out_act_patch_results, "layer pos head -> (layer head) pos"
    )
    return attn_head_out_act_patch_results


def patch_per_block_all_poss(
    model, input_ids_corrupted, input_ids, clean_cache, metric, interval, corrupted_embeddings
):
    every_block_result = get_act_patch_block_every(
        model, input_ids_corrupted, clean_cache, metric, patch_interval=interval, corrupted_embeddings=corrupted_embeddings
    )
    return every_block_result


def patch_attn_head_all_pos_every(
    model, input_ids_corrupted, input_ids, clean_cache, metric
):
    every_head_all_pos_act_patch_result = (
        patching.get_act_patch_attn_head_all_pos_every(
            model, input_ids_corrupted, clean_cache, metric
        )
    )
    return every_head_all_pos_act_patch_result

In [149]:
delta = delta_per_length[max_len]
data = dataset.dataset_per_length[max_len]
logits = logits_per_length[max_len]
input_ids = model.to_tokens([d["premise"] for d in data])
embs_corrupted = model.add_noise([d["premise"] for d in data], noise_index = torch.tensor([1,2,3,6]), target_win=7)
target_ids = tokens_dict_per_length[max_len]["target"]
orthogonal_ids = tokens_dict_per_length[max_len]["orthogonal_token"]
input_ids_corrupted = embs_to_tokens_ids(embs_corrupted, model)
# print(model.to_str_tokens(input_ids_corrupted))
delta_corrupted = model(input_ids_corrupted)[torch.arange(input_ids_corrupted.shape[0]),-1, tokens_dict_per_length[max_len]["target"]] - model(input_ids_corrupted)[torch.arange(input_ids_corrupted.shape[0]),-1, tokens_dict_per_length[max_len]["orthogonal_token"]]


## Positive examples

In [None]:
corrupted_i

In [150]:
#select just the positive deltas
positive_delta = delta[delta > 0]
emb_corrupted = embs_corrupted[delta > 0]
input_ids = input_ids[delta > 0]
target_ids = target_ids[delta > 0]
orthogonal_ids = orthogonal_ids[delta > 0]

corrupted_logit, corrupted_cache = model.run_with_cache_from_embed(embs_corrupted)
clean_logit, clean_cache = model.run_with_cache(input_ids)
corrupted_logit = torch.softmax(corrupted_logit, dim=-1)

def delta_target(logits):
    logits = torch.softmax(logits, dim=-1)
    # print(logits[:, -1, target_ids], corrupted_logit[:, -1, target_ids])
    batch_index = torch.arange(logits.shape[0])
    target_delta = logits[batch_index, -1, target_ids] - corrupted_logit[batch_index, -1, target_ids]
    return target_delta.mean()

CLEAN_BASELINE_TARGET = delta_target(clean_logit)

def delta_orthogonal(logits):
    logits = torch.softmax(logits, dim=-1)
    batch_index = torch.arange(logits.shape[0])
    orthogonal_delta = logits[batch_index, -1, orthogonal_ids]  - corrupted_logit[batch_index, -1, orthogonal_ids]
    return orthogonal_delta.mean()

CLEAN_BASELINE_ORTHOGONAL = delta_orthogonal(clean_logit)


def metric(logits, return_type="both"):
    if return_type == "both":
        return (delta_target(logits)/CLEAN_BASELINE_TARGET - delta_orthogonal(logits)/CLEAN_BASELINE_ORTHOGONAL)
    if return_type == "target":
        return delta_target(logits) /CLEAN_BASELINE_TARGET
    if return_type == "orthogonal":
        return delta_orthogonal(logits) /CLEAN_BASELINE_ORTHOGONAL

print(CLEAN_BASELINE_TARGET)
print(CLEAN_BASELINE_ORTHOGONAL)


tensor(0.0385)
tensor(-0.1478)


### Logit Lens

In [133]:
def logit_lens(target_ids, clean_cache, input_ids, model):
    target_residual_direction = model.tokens_to_residual_directions(target_ids["target"])
    strange_residual_direction = model.tokens_to_residual_directions(target_ids["orthogonal_token"])
    batch_index = torch.arange(target_residual_direction.shape[0])
    pos_delta_resid_direction = target_residual_direction[batch_index] - strange_residual_direction[batch_index]


    def residual_stack_to_logit_diff(residual_stack, cache):
        scaled_residual_stack = clean_cache.apply_ln_to_stack(residual_stack, layer=-1)
        return einops.einsum( scaled_residual_stack, pos_delta_resid_direction, "n_comp batch pos d_model, batch d_model -> n_comp pos")/pos_input_ids.shape[0]


    accumulater_residual, labels = clean_cache.accumulated_resid(layer=-1, incl_mid=True, return_labels=True, pos_slice=-1)
    accumulater_residual = clean_cache.apply_ln_to_stack(accumulater_residual, layer=-1, pos_slice=-1)
    unembed_accumulated_residual = einops.einsum(accumulater_residual, model.W_U, "n_comp batch d_model, d_model vocab -> n_comp batch vocab")
    batch_index = torch.arange(input_ids.shape[0])
    delta_accumulated_residual = (unembed_accumulated_residual[:, batch_index, target_ids["target"]] - unembed_accumulated_residual[:, batch_index, target_ids["orthogonal_token"]]).mean(dim=1)


    x=np.arange(model.cfg.n_layers*2+1)/2
    plt.plot(x, delta_accumulated_residual.detach().numpy(), color="blue")
    plt.plot(x, unembed_accumulated_residual[:,batch_index, target_ids["target"]].mean(-1).detach().numpy(), color="red")
    plt.plot(x, unembed_accumulated_residual[:,batch_index, target_ids["orthogonal_token"]].mean(-1).detach().numpy(), color="green")
    plt.legend(["delta", "target", "strange"])
    plt.show()


In [134]:
import einops
import numpy as np

target_residual_direction = model.tokens_to_residual_directions(target_ids)
strange_residual_direction = model.tokens_to_residual_directions(orthogonal_ids)
batch_index = torch.arange(target_residual_direction.shape[0])
delta_resid_direction = (
    target_residual_direction[batch_index] - strange_residual_direction[batch_index]
)


def residual_stack_to_logit_diff(residual_stack, cache):
    scaled_residual_stack = clean_cache.apply_ln_to_stack(residual_stack, layer=-1)
    return (
        einops.einsum(
            scaled_residual_stack,
            delta_resid_direction,
            "n_comp batch pos d_model, batch d_model -> n_comp pos",
        )
        / input_ids.shape[0]
    )


accumulater_residual, labels = clean_cache.accumulated_resid(
    layer=-1, incl_mid=True, return_labels=True, pos_slice=-1
)
accumulater_residual = clean_cache.apply_ln_to_stack(
    accumulater_residual, layer=-1, pos_slice=-1
)
unembed_accumulated_residual = einops.einsum(
    accumulater_residual,
    model.W_U,
    "n_comp batch d_model, d_model vocab -> n_comp batch vocab",
)
batch_index = torch.arange(input_ids.shape[0])
delta_accumulated_residual = (
    unembed_accumulated_residual[:, batch_index, target_ids]
    - unembed_accumulated_residual[:, batch_index, orthogonal_ids]
).mean(dim=1)


# x = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(pos_input_ids[0]))]
# plt.plot(delta_accumulated_residual.detach().numpy(), color="blue")
# plt.plot(unembed_accumulated_residual[:,batch_index,pos_target_ids["target"]].mean(-1).detach().numpy(), color="red")
# plt.plot(unembed_accumulated_residual[:,batch_index,pos_target_ids["orthogonal_token"]].mean(-1).detach().numpy(), color="green")
# plt.legend(["delta", "target", "strange"])
# plt.show()

import pandas as pd

# Assuming you have the following data prepared
x = np.arange(model.cfg.n_layers * 2 + 1) / 2
y1 = delta_accumulated_residual.detach().numpy()
y2 = unembed_accumulated_residual[:, batch_index, target_ids].mean(-1).detach().numpy()
y3 = (
    unembed_accumulated_residual[:, batch_index, orthogonal_ids]
    .mean(-1)
    .detach()
    .numpy()
)

# Convert data to a DataFrame
df = pd.DataFrame({"x": x, "delta": y1, "target": y2, "strange": y3})

# Plot using plotly.express
fig = px.line(
    df,
    x="x",
    y=["delta", "target", "strange"],
    labels={"value": "Y-axis", "variable": "Legend"},
    title="Your Plot Title",
)

fig.show()

In [135]:
per_head_residual, labels = clean_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)

per_head_logit_diffs = clean_cache.apply_ln_to_stack(per_head_residual, layer=-1, pos_slice=-1)

per_head_logit_diffs = einops.einsum(per_head_logit_diffs, model.W_U, "n_comp batch d_model, d_model vocab -> n_comp batch vocab")

batch_index = torch.arange(pos_input_ids.shape[0])
delta_per_head_residual = (per_head_logit_diffs[:, batch_index, pos_target_ids["target"]] - per_head_logit_diffs[:, batch_index, pos_target_ids["orthogonal_token"]]).mean(dim=1)
per_head_logit_diffs = einops.rearrange(delta_per_head_residual, "(layer head_index) ->  layer head_index ", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)

imshow_reversed(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")

NameError: name 'pos_input_ids' is not defined

### Activation patching

In [136]:
pos_resid_pre_act_patch_results = patch_resid_pre(model, input_ids, input_ids_corrupted, clean_cache, metric)
imshow_reversed(pos_resid_pre_act_patch_results, 
    yaxis="Layer", 
    xaxis="Position", 
    x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(input_ids[0]))],
    title="resid_pre Activation Patching")

 19%|█▉        | 32/168 [00:09<00:41,  3.28it/s]


KeyboardInterrupt: 

In [None]:
pos_attn_head_out_all_pos_act_patch_results_target = patch_attn_head_out_all_pos(
    model,
    input_ids,
    input_ids_corrupted,
    clean_cache,
    partial(metric, return_type="target"),
    input_embeddings=embs_corrupted,
)
imshow_reversed(
    pos_attn_head_out_all_pos_act_patch_results_target,
    yaxis="Layer",
    xaxis="Head",
    title="attn_head_out Activation Patching (All Pos)",
)

pos_attn_head_out_allpos_act_patch_results_orthogonal = patch_attn_head_out_all_pos(
    model,
    input_ids,
    input_ids_corrupted,
    clean_cache,
    partial(metric, return_type="orthogonal"),
    input_embeddings=embs_corrupted,
)
imshow(
    pos_attn_head_out_allpos_act_patch_results_orthogonal,
    yaxis="Layer",
    xaxis="Head",
    title="attn_head_out Activation Patching (All Pos)",
)

  0%|          | 0/144 [00:00<?, ?it/s]

100%|██████████| 144/144 [00:41<00:00,  3.50it/s]


100%|██████████| 144/144 [00:42<00:00,  3.35it/s]


In [None]:
pos_attn_head_out_act_patch_results = patch_attn_head_by_pos(model, pos_input_ids, pos_input_ids_corrupted, clean_cache, metric)    
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]
imshow_reversed(pos_attn_head_out_act_patch_results, 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(input_ids[0]))],
        y=ALL_HEAD_LABELS,
        title="attn_head_out Activation Patching By Pos")

In [154]:
pos_every_block_result_target = patch_per_block_all_poss(
    model,
    input_ids_corrupted[delta > 0],
    input_ids,
    clean_cache,
    partial(metric, return_type="target"),
    interval=1,
    corrupted_embeddings=embs_corrupted,
)
imshow_reversed(
    pos_every_block_result_target,
    facet_col=0,
    facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
    title="Activation Patching Per Block",
    xaxis="Position",
    yaxis="Layer",
    zmax=1,
    zmin=-1,
    x=[f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(input_ids[2]))],
)

  0%|          | 0/168 [00:00<?, ?it/s]


RuntimeError: The expanded size of the tensor (10) must match the existing size (20) at non-singleton dimension 0.  Target sizes: [10, 14, 768].  Tensor sizes: [20, 14, 768]

In [153]:
pos_every_block_result_orthogonal = patch_per_block_all_poss(
    model,
    input_ids_corrupted[:10],
    input_ids,
    clean_cache,
    partial(metric, return_type="orthogonal"),
    interval=1,
    corrupted_embeddings=embs_corrupted,
)
imshow(
    pos_every_block_result_orthogonal,
    facet_col=0,
    facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
    title="Activation Patching Per Block",
    xaxis="Position",
    yaxis="Layer",
    zmax=1,
    zmin=-1,
    x=[f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(input_ids[2]))],
)

  0%|          | 0/168 [00:00<?, ?it/s]


RuntimeError: The expanded size of the tensor (10) must match the existing size (20) at non-singleton dimension 0.  Target sizes: [10, 14, 768].  Tensor sizes: [20, 14, 768]

In [109]:
pos_every_block_result_orthogonal.shape

torch.Size([3, 12, 14])

In [75]:
# compute the avg restored target (mean over the last two dimensions)
avg_restored_target = pos_every_block_result_target.mean(dim=(-1, -2))
std_restored_target = pos_every_block_result_target.std(dim=(-1, -2))
print(avg_restored_target, std_restored_target, avg_restored_target + std_restored_target)
avg_restored_orthogonal = pos_every_block_result_orthogonal.mean(dim=(-1, -2))
std_restored_orthogonal = pos_every_block_result_orthogonal.std(dim=(-1, -2))
print(avg_restored_orthogonal, std_restored_orthogonal, avg_restored_orthogonal + std_restored_orthogonal)
# check what component are both important for the two mechanism
print("Component, Layer, Position")
for component in range(3):
    target_threshold = avg_restored_target[component] + std_restored_target[component]
    orthogonal_threshold = avg_restored_orthogonal[component] + std_restored_orthogonal[component]
    for layer in range(12):
        for pos in range(14):
            if pos_every_block_result_orthogonal[component, layer, pos] > orthogonal_threshold and pos_every_block_result_target[component, layer, pos] > target_threshold:
                print(component,layer, pos)

tensor([0.1093, 0.0143, 0.0228]) tensor([0.2524, 0.0883, 0.1115]) tensor([0.3618, 0.1026, 0.1343])
tensor([ 0.0158,  0.0403, -0.0137]) tensor([0.5027, 0.2020, 0.3611]) tensor([0.5185, 0.2422, 0.3474])


In [76]:
# check what component are both important for the two mechanism
print("Component, Layer, Position")
for component in range(3):
    target_threshold = avg_restored_target[component] + std_restored_target[component]
    orthogonal_threshold = avg_restored_orthogonal[component] + std_restored_orthogonal[component]
    for layer in range(12):
        for pos in range(14):
            if pos_every_block_result_orthogonal[component, layer, pos] > orthogonal_threshold and pos_every_block_result_target[component, layer, pos] > target_threshold:
                print(component,layer, pos)

Component, Layer, Position
0 0 8
0 1 8
0 2 8
0 3 8
0 9 13
0 10 13
0 11 13
1 0 8
1 2 8
1 6 13
2 0 8
2 2 10
2 4 4
2 10 13


In [113]:
pos_attn_head_out_all_pos_act_patch_results_target.shape

torch.Size([12, 12])

In [81]:
avg_restored_target = pos_attn_head_out_all_pos_act_patch_results_target.mean(dim=(-1,-2))
avg_restored_orthogonal = pos_attn_head_out_allpos_act_patch_results_orthogonal.mean(dim=(-1,-2))
std_restored_target = pos_attn_head_out_all_pos_act_patch_results_target.std(dim=(-1,-2))
std_restored_orthogonal = pos_attn_head_out_allpos_act_patch_results_orthogonal.std(dim=(-1,-2))
print(avg_restored_target, std_restored_target, avg_restored_target + std_restored_target)
print(avg_restored_orthogonal, std_restored_orthogonal, avg_restored_orthogonal + std_restored_orthogonal)


LAYER = 2
for head in range(12):
    target_threshold = avg_restored_target + std_restored_target
    orthogonal_threshold = avg_restored_orthogonal + std_restored_orthogonal
    if pos_attn_head_out_all_pos_act_patch_results_target[LAYER, head] > target_threshold and pos_attn_head_out_allpos_act_patch_results_orthogonal[LAYER, head] > orthogonal_threshold:
        print(head)

tensor(-0.0047) tensor(0.0816) tensor(0.0769)
tensor(0.0157) tensor(0.0437) tensor(0.0594)
2


In [92]:
import circuitsvis as cv

cv.attention.attention_patterns(tokens=model.to_str_tokens(input_ids[3]), attention=clean_cache["pattern", 6, "attn"][3])