# Setup

In [1]:
from transformer_lens.cautils.notebook import *
# from transformer_lens import FactoredMatrix

from transformer_lens.rs.callum2.utils import (
    get_effective_embedding,
    create_title_and_subtitles,
    process_webtext,
    parse_str,
)

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device,
    # refactor_factored_attn_matrices=True,
)
model.set_use_attn_result(False)

clear_output()

In [3]:
effective_embedding_dict = get_effective_embedding(model)

W_EE = effective_embedding_dict["W_E (including MLPs)"]
W_EE0 = effective_embedding_dict["W_E (only MLPs)"]
W_U = model.W_U

# Section 3.1 - OV circuit

## Static

In [24]:
W_V = model.W_V[10, 7]
W_O = model.W_O[10, 7]

full_OV_circuit = FactoredMatrix(W_EE @ W_V, W_O @ W_U)

diag_elem_negative_ranks = []

for i in tqdm(range(model.cfg.d_vocab)):
    col = full_OV_circuit.A[i, :] @ full_OV_circuit.B
    diag_elem = col[i]
    diag_elem_negative_rank = (diag_elem > col).sum().item() # This is zero when diag elem is minimal
    diag_elem_negative_ranks.append(diag_elem_negative_rank)

diag_elem_negative_ranks = t.tensor(diag_elem_negative_ranks)

print(f"Mean rank = {diag_elem_negative_ranks.float().mean().item():.2f}")
print(f"Median rank = {diag_elem_negative_ranks.float().median().item():.0f}\n")

print(f"Proportion with rank zero = {(diag_elem_negative_ranks == 0).float().mean():.2%}")
print(f"Proportion with rank less than 10 = {(diag_elem_negative_ranks < 10).float().mean():.2%}\n")

print(f"Quantity with rank more than 10% = {(diag_elem_negative_ranks > 0.1 * model.cfg.d_vocab).sum()}")
print(f"Quantity with rank more than 5% = {(diag_elem_negative_ranks > 0.05 * model.cfg.d_vocab).sum()}")

# Histogram, with all values above 100 cropped so the pattern is more visible
hist(
    diag_elem_negative_ranks[diag_elem_negative_ranks < 100], 
    title=create_title_and_subtitles("Dynamic analysis: suppression ranks of source tokens", ["Rank=0 means source token is the most suppressed"]),
    template="simple_white",
    labels={"x": "Rank"},
)

  0%|          | 0/50257 [00:00<?, ?it/s]

Mean rank = 116.82
Median rank = 0

Proportion with rank zero = 57.26%
Proportion with rank less than 10 = 84.70%

Quantity with rank more than 10% = 301
Quantity with rank more than 5% = 571


## Function words

In [32]:
tokens_not_in_top_10pct = t.arange(model.cfg.d_vocab)[diag_elem_negative_ranks > 0.1 * model.cfg.d_vocab]
print(f"Number of tokens not in top 5% = {len(tokens_not_in_top_10pct)}")

# Remove ASCII-256 tokens
tokens_not_in_top_10pct_filtered = tokens_not_in_top_10pct[tokens_not_in_top_10pct > 255]
print(f"Number of these which not ASCII-256 = {len(tokens_not_in_top_10pct_filtered)}")

# Inspect the remaining words (print the least copied out)
str_toks_not_in_top_10pct_filtered = model.to_str_tokens(tokens_not_in_top_10pct_filtered)

ranks = diag_elem_negative_ranks[tokens_not_in_top_10pct_filtered]
ranks_ordered = t.argsort(-ranks)
ranks = ranks[ranks_ordered]
str_toks_not_in_top_10pct_filtered = model.to_str_tokens(tokens_not_in_top_10pct_filtered[ranks_ordered])

table = Table("Token", "Rank", title="OV circuit")
for str_tok, rank in zip(str_toks_not_in_top_10pct_filtered, ranks[:30]):
    table.add_row(str_tok, str(rank.item()))
rprint(table)

Number of tokens not in top 5% = 301
Number of these which not ASCII-256 = 255


# Dynamic evidence

In [None]:
BATCH_SIZE = 400
MINIBATCH_SIZE = 10 # because of memory constraints
SEQ_LEN = 100

DATA_TOKS, DATA_STR_TOKS = process_webtext(model, BATCH_SIZE, SEQ_LEN)

model = model.cuda()
DATA_TOKS = DATA_TOKS.cuda()

DATA_TOKS_LIST = [DATA_TOKS[i: i+MINIBATCH_SIZE] for i in range(0, BATCH_SIZE, MINIBATCH_SIZE)]
# DATA_STR_LIST = [DATA_STR[i: i+MINIBATCH_SIZE] for i in range(0, BATCH_SIZE, MINIBATCH_SIZE)]
DATA_STR_TOKS_LIST = [DATA_STR_TOKS[i: i+MINIBATCH_SIZE] for i in range(0, BATCH_SIZE, MINIBATCH_SIZE)]

Moving model to device:  cuda


### First, get all attention probs, and filter for the ones which are above 10%

In [None]:
# Get all the attention probs
ATTN = t.empty((0, SEQ_LEN, SEQ_LEN), device=device, dtype=t.float)

def cache_attn_from_head_10_07(pattern: Float[Tensor, "batch head seqQ seqK"], hook: HookPoint):
    assert hook.layer() == 10
    global ATTN
    ATTN = t.concat([ATTN, pattern[:, 7]])

model = cast(HookedTransformer, model)
model.reset_hooks()
for toks in DATA_TOKS_LIST:
    model.run_with_hooks(
        toks,
        return_type=None,
        fwd_hooks=[(utils.get_act_name("pattern", 10), cache_attn_from_head_10_07)]
    )
assert ATTN.shape == (BATCH_SIZE, SEQ_LEN, SEQ_LEN)

# Set attention to zeroth token to be zero (we don't want to count this)
ATTN[..., 0] = 0

# Get all cases where attention prob is more than 20%
ATTN_over_10pct = t.nonzero(ATTN > 0.1)
print(f"Cases where attention prob is more than 10%: {len(ATTN_over_10pct)}")

Cases where attention prob is more than 10%: 19268


### Second, get the direct logit attribution for all the (src, dest) pairs we found above

In [None]:
# Get all direct logit attributions, from each (source dest) pair, for each model
DLA = t.empty((0, model.cfg.d_vocab), dtype=t.float, device=device)

model.reset_hooks()

for i, toks in zip(range(0, BATCH_SIZE, MINIBATCH_SIZE), DATA_TOKS_LIST):
    # Cache the value vectors, pre-final-LN scale, and attention patterns
    t.cuda.empty_cache()
    _, cache = model.run_with_cache(
        toks,
        return_type = None,
        names_filter = lambda name: name in [
            utils.get_act_name("v", 10), 
            utils.get_act_name("pattern", 10), 
            utils.get_act_name("scale"),
        ]
    )
    v_10_07 = cache["v", 10][:, :, 7] # (batch, seqK, d_head)
    pattern_10_07 = cache["pattern", 10][:, 7] # (batch, seqQ, seqK)
    scale = cache["scale"] # (batch, seqQ, d_model, 1)
    del cache

    # Compute the vectors which get moved from each src -> dest
    result_pre_attn_10_07 = v_10_07 @ model.W_O[10, 7] # (batch, seqK, d_model)
    # Weight these vectors by attention probs (not really necessary for the final result I think)
    result_post_attn_10_07 = einops.einsum(
        result_pre_attn_10_07, pattern_10_07,
        "batch seqK d_model, batch seqQ seqK -> batch seqQ seqK d_model"
    )
    # Scale these by final LN
    scale_rep = einops.repeat(scale, "batch seqQ d_model -> batch seqQ seqK d_model", seqK=1)
    result_post_attn_10_07_scaled = result_post_attn_10_07 / scale_rep

    # Filter for all the "ATTN over 10pct" cases (saves a lot of memory to do this before DLA)
    indices = ATTN_over_10pct[(ATTN_over_10pct[:, 0] >= i) & (ATTN_over_10pct[:, 0] < i + MINIBATCH_SIZE)]
    batch_indices, seqQ_indices, seqK_indices = indices.unbind(dim=-1)
    batch_indices = batch_indices - i
    result_filtered = result_post_attn_10_07_scaled[batch_indices, seqQ_indices, seqK_indices] # (num_filtered, d_model)

    dla = result_filtered @ model.W_U # (num_filtered, d_vocab)
    # Add to main tensor
    DLA = t.concat([DLA, dla], dim=0)

### Third, find the average neg rank of the source tokens in the DLA in these cases

(also try with function words filtered out)

In [None]:
batch_indices, seqQ_indices, seqK_indices = ATTN_over_10pct.unbind(dim=-1)

src_toks = DATA_TOKS[batch_indices, seqK_indices]

DLA_src_toks = DLA[range(len(batch_indices)), src_toks]

DLA_src_toks_ranks = (DLA_src_toks.unsqueeze(1) > DLA).sum(dim=-1)

In [None]:
print(f"Mean rank = {DLA_src_toks_ranks.float().mean().item():.2f}")
print(f"Median rank = {DLA_src_toks_ranks.float().median().item():.0f}\n")

print(f"Proportion with rank zero = {(DLA_src_toks_ranks == 0).float().mean():.2%}")
print(f"Proportion with rank less than 10 = {(DLA_src_toks_ranks < 10).float().mean():.2%}\n")
print(f"Proportion with rank between 1 and 9 inclusive = {((DLA_src_toks_ranks < 10) & (DLA_src_toks_ranks > 0)).float().mean():.2%}\n")

# Histogram, with all values above 100 cropped so the pattern is more visible
hist(
    DLA_src_toks_ranks[DLA_src_toks_ranks < 100], 
    title=create_title_and_subtitles("Dynamic analysis: suppression ranks of source tokens", ["Rank=0 means source token is the most suppressed"]),
    template="simple_white",
    labels={"x": "Rank"},
)

Mean rank = 771.08
Median rank = 1

Proportion with rank zero = 36.24%
Proportion with rank less than 10 = 78.24%

Proportion with rank between 1 and 9 inclusive = 42.00%



### Get a random sample of the residual 42% - how many are semantically related? What are the cosine sims?

In [None]:
def concat_lists(list_of_lists):
    return [item for sublist in list_of_lists for item in sublist]


def is_equivalent(t1, t2):
    '''
    Tests whether tokens t1 and t2 (both strings) are equivalent. 

    Our equivalence relation t1 ~ t2 is true in any of the following cases (or a composition
    of them):

        (A) capitalization, i.e. t1 == t2.upper() or t1 == t2.capitalize()
        (B) prepended spaces, i.e. t1 == " " + t2
        (C) pluralization, i.e. t1 = t2 + "s"
        (D) tokenization, i.e. replace any conditions of the form "t1 == X" in cases (A) - (C) with "X is tokenized
            into more than one token, and t1 is one of them."
    '''
    t1_set = get_equivalency_set(t1)
    t2_set = get_equivalency_set(t2)
    return len(t1_set & t2_set) > 0


def is_token(t1):
    '''For some reason this works, and "t1 is in model.tokenizer.vocab" doesn't.'''
    return len(model.to_str_tokens(t1, prepend_bos=False)) == 1
    

def get_equivalency_set(t1):
    # Make sure we start with nonspace versions
    t1_stripped = t1.strip()
    t1_set = [t1_stripped]
    # (C) add plural versions, if applicable (this is a bit messy)
    if t1_stripped.endswith("s"):
        for t2 in [t1_stripped[:-1]]:
            if is_token(t2) or is_token(" " + t2):
                t1_set.append(t2)
    elif not(t1_stripped.endswith("s")):
        for t2 in [t1_stripped + "s"]:
            if is_token(t2) or is_token(" " + t2):
                t1_set.append(t2)
    # (A) add capitalized versions
    t1_set = concat_lists([t1_set, [t1.capitalize() for t1 in t1_set]]) # [t1.upper() for t1 in t1_set]
    # (B) add versions with prepended spaces
    t1_set = t1_set + [" " + t1 for t1 in t1_set]
    # (D) replace all elements in t1_set with their tokenized versions
    # (this is also a bit hacky)
    t1_set = concat_lists([model.to_str_tokens(t1, prepend_bos=False) for t1 in t1_set])
    t1_set = [t1 for t1 in t1_set if t1.strip()]
    return set(t1_set)


# TODO - don't tokenize as many of the words, to make this more sparse? esp. not the CAPITAL VERSION

print(get_equivalency_set(" Berkeley"))
print(get_equivalency_set(" device"))
print(get_equivalency_set(" pier"))
print(get_equivalency_set(" 1925"))

{'keley', 'Ber', ' Berkeley'}
{'device', ' Device', ' device', ' devices', ' Devices', 'ices', 'Dev', 'Device', 'devices'}
{'P', 'ier', 'p', ' Pier', ' pier'}
{' 1925', '25', '19'}


In [None]:
in_top10_not_top = (DLA_src_toks_ranks < 10) & (DLA_src_toks_ranks > 0)

src_toks_in_top10_not_top = src_toks[in_top10_not_top]
DLA_in_top10_not_top = DLA[in_top10_not_top]

DLA_src_strtoks_in_top10_not_top = model.to_str_tokens(src_toks_in_top10_not_top)
DLA_most_neg_strtoks_in_top10_not_top = model.to_str_tokens(DLA_in_top10_not_top.argmin(-1))

# import rich
# from rich import print as rprint
# from rich.table import Table

random_indices = t.randint(0, len(DLA_src_strtoks_in_top10_not_top), (50,))

table = Table("Source token", "Most suppressed token", "Semantically related?", title="DLA most negative tokens")
for i in random_indices:
    s1 = DLA_src_strtoks_in_top10_not_top[i]
    s2 = DLA_most_neg_strtoks_in_top10_not_top[i]
    table.add_row(repr(s1), repr(s2), "" if is_equivalent(s1, s2) else "No")
rprint(table)

In [None]:
is_semantically_related = []

for s1, s2 in tqdm(list(zip(DLA_src_strtoks_in_top10_not_top, DLA_most_neg_strtoks_in_top10_not_top))):
    is_semantically_related.append(is_equivalent(s1, s2))

  0%|          | 0/8093 [00:00<?, ?it/s]

In [None]:
print(f"Pct which are semantically related = {t.tensor(is_semantically_related).float().mean():.2%}")

In [None]:
DLA_src_strtoks_in_top10_not_top.shape

AttributeError: 'list' object has no attribute 'shape'

In [None]:
# Get avg cos sim for the semantically related ones

src_sem = src_toks_in_top10_not_top[is_semantically_related]
most_neg_sem = DLA_in_top10_not_top[is_semantically_related].argmin(-1)

src_unembeds = model.W_U.T[src_sem]
most_neg_unembeds = model.W_U.T[most_neg_sem]

src_unembeds_normed = src_unembeds / src_unembeds.norm(dim=-1, keepdim=True)
most_neg_unembeds_normed = most_neg_unembeds / most_neg_unembeds.norm(dim=-1, keepdim=True)

avg_cos_sim = (src_unembeds_normed * most_neg_unembeds_normed).sum(-1).mean()

print(f"Avg cos sim = {avg_cos_sim:.3f}")

### Try filtering out function words

In [None]:
toks_not_in_top_10pct_filtered = model.to_tokens(list(str_toks_not_in_top_10pct_filtered), prepend_bos=False).squeeze()

src_tok_is_nonfn = (DATA_TOKS[batch_indices, seqK_indices] != toks_not_in_top_10pct_filtered[:, None]).all(dim=0)

batch_indices_nonfn = batch_indices[src_tok_is_nonfn]
seqK_indices_nonfn = seqK_indices[src_tok_is_nonfn]

DLA_src_toks_nonfn = DLA[src_tok_is_nonfn, src_toks[src_tok_is_nonfn]]

DLA_src_toks_ranks_nonfn = (DLA_src_toks_nonfn.unsqueeze(1) > DLA[src_tok_is_nonfn]).sum(dim=-1)

In [None]:
print(f"Mean rank = {DLA_src_toks_ranks_nonfn.float().mean().item():.2f}")
print(f"Median rank = {DLA_src_toks_ranks_nonfn.float().median().item():.0f}\n")

print(f"Proportion with rank zero = {(DLA_src_toks_ranks_nonfn == 0).float().mean():.2%}")
print(f"Proportion with rank less than 10 = {(DLA_src_toks_ranks_nonfn < 10).float().mean():.2%}\n")

# Histogram, with all values above 100 cropped so the pattern is more visible
hist(
    DLA_src_toks_ranks_nonfn[DLA_src_toks_ranks_nonfn < 100], 
    title=create_title_and_subtitles("Dynamic analysis: suppression ranks of source tokens", ["Rank=0 means source token is the most suppressed"]),
    template="simple_white",
    labels={"x": "Rank"},
)

Mean rank = 247.01
Median rank = 1

Proportion with rank zero = 37.35%
Proportion with rank less than 10 = 80.64%



# Section 3.2 - QK circuit

In [57]:
W_Q = model.W_Q[10, 7]
W_K = model.W_K[10, 7]

full_QK_circuit = FactoredMatrix(W_U.T @ W_Q, W_K.T @ W_EE.T)

diag_elem_positive_ranks = []

for i in tqdm(range(model.cfg.d_vocab)):
    row = full_QK_circuit[i, :].AB.squeeze()
    diag_elem = row[i]
    diag_elem_positive_rank = (diag_elem < row).sum().item() # This is zero when diag elem is maximal
    diag_elem_positive_ranks.append(diag_elem_positive_rank)

diag_elem_positive_ranks = t.tensor(diag_elem_positive_ranks)

print(f"Mean rank = {diag_elem_positive_ranks.float().mean().item():.2f}")
print(f"Median rank = {diag_elem_positive_ranks.float().median().item():.0f}\n")

print(f"Proportion with rank zero = {(diag_elem_positive_ranks == 0).float().mean():.2%}")
print(f"Proportion with rank less than 10 = {(diag_elem_positive_ranks < 10).float().mean():.2%}\n")

print(f"Quantity with rank more than 10% = {(diag_elem_positive_ranks > 0.1 * model.cfg.d_vocab).sum()}")
print(f"Quantity with rank more than 5% = {(diag_elem_positive_ranks > 0.05 * model.cfg.d_vocab).sum()}")

# Histogram, with all values above 100 cropped so the pattern is more visible
hist(
    diag_elem_positive_ranks[diag_elem_positive_ranks < 100], 
    title=create_title_and_subtitles("Dynamic analysis: suppression ranks of source tokens", ["Rank=0 means source token is the most suppressed"]),
    template="simple_white",
    labels={"x": "Rank"},
)

  0%|          | 0/50257 [00:00<?, ?it/s]

Mean rank = 31.35
Median rank = 0

Proportion with rank zero = 71.24%
Proportion with rank less than 10 = 95.72%

Quantity with rank more than 10% = 67
Quantity with rank more than 5% = 78


## Function words

In [43]:
tokens_not_in_top_5pct = t.arange(model.cfg.d_vocab)[diag_elem_positive_ranks > 0.05 * model.cfg.d_vocab]
print(f"Number of tokens not in top 5% = {len(tokens_not_in_top_5pct)}")

# Remove ASCII-256 tokens
tokens_not_in_top_5pct_filtered = tokens_not_in_top_5pct[tokens_not_in_top_5pct > 255]
print(f"Number of these which not ASCII-256 = {len(tokens_not_in_top_5pct_filtered)}")

# Inspect the remaining words (print the least copied out)
str_toks_not_in_top_5pct_filtered = model.to_str_tokens(tokens_not_in_top_5pct_filtered)

ranks = diag_elem_positive_ranks[tokens_not_in_top_5pct_filtered]
ranks_ordered = t.argsort(-ranks)
ranks = ranks[ranks_ordered]
str_toks_not_in_top_5pct_filtered = model.to_str_tokens(tokens_not_in_top_5pct_filtered[ranks_ordered])

table = Table("Token", "Rank", title="OV circuit")
for str_tok, rank in zip(str_toks_not_in_top_5pct_filtered, ranks[:30]):
    table.add_row(str_tok, str(rank.item()))
rprint(table)

Number of tokens not in top 5% = 78
Number of these which not ASCII-256 = 31


## Make the figure

In [59]:
full_QK_circuit.shape

torch.Size([3, 3, 50257, 50257])

In [61]:
full_QK_circuit[:, :, 0, :].shape

torch.Size([3, 3, 1, 50257])

In [65]:
(diag_elem < row).shape

torch.Size([3, 3, 50257])

In [67]:
diag_elem.shape, rows.shape

(torch.Size([3, 3, 1]), torch.Size([3, 3, 50257]))

In [68]:
W_Q = model.W_Q[10, 7]
W_K = model.W_K[10, 7]

W_EE0 = effective_embedding_dict["W_E (only MLPs)"]
W_E = effective_embedding_dict["W_E (no MLPs)"]

keyside_matrices = {"W<sub>EE</sub>": W_EE, "MLP0": W_EE0, "W<sub>E</sub>": W_E}
queryside_matrices = {"W<sub>EE</sub>": W_EE.T, "W<sub>E</sub>": W_E.T, "W<sub>U</sub>": W_U}

keyside_matrices_stacked = t.stack(list(keyside_matrices.values()))[None, :]
queryside_matrices_stacked = t.stack(list(queryside_matrices.values()))[:, None]

full_QK_circuit = FactoredMatrix(
    queryside_matrices_stacked.transpose(-1, -2) @ W_Q,
    W_K.T @ keyside_matrices_stacked.transpose(-1, -2),
)

diag_elem_positive_ranks = []

for i in tqdm(range(model.cfg.d_vocab)):
    rows = full_QK_circuit[:, :, i, :].AB.squeeze()
    diag_elem = rows[..., [i]]
    diag_elem_positive_rank = (diag_elem < rows).sum(-1) # This is 3x3 tensor, equals zero when diag elem is maximal
    diag_elem_positive_ranks.append(diag_elem_positive_rank)

diag_elem_positive_ranks = t.stack(diag_elem_positive_ranks, dim=-1)

  0%|          | 0/50257 [00:00<?, ?it/s]

In [82]:
median_ranks = diag_elem_positive_ranks.float().median(dim=-1).values

imshow(
    median_ranks + 1,
    width = 700,
    height = 600,
    title = "Median rank of tokens in QK circuit",
    aspect = "equal",
    text_auto = ".0f",
    x = list(keyside_matrices),
    y = list(queryside_matrices),
    font_size = 15,
)

# Logit lens

In [None]:
from collections import OrderedDict

path = Path("/root/SERI-MATS-2023-Streamlit-pages/transformer_lens/rs/callum2/ov_qk_circuits/params.pt")
params: Dict[str, Tensor] = t.load(path)

print(params["10.weight"].shape)

torch.Size([768, 768])


In [None]:
W_Q = model.W_Q[10, 7]
W_K = model.W_K[10, 7]

id = t.eye(model.cfg.d_model)
tuned_lens = params["10.weight"]
W_U_tuned = (id + tuned_lens).to(device) @ W_U 

full_QK_circuit = FactoredMatrix(W_EE @ W_K, W_Q.T @ W_U_tuned)

diag_elem_negative_ranks = []

for i in tqdm(range(model.cfg.d_vocab)):
    row = full_QK_circuit.A @ full_QK_circuit.B[:, i]
    diag_elem = row[i]
    diag_elem_negative_rank = (diag_elem < row).sum().item() # This is zero when diag elem is maximal
    diag_elem_negative_ranks.append(diag_elem_negative_rank)

diag_elem_negative_ranks = t.tensor(diag_elem_negative_ranks)

print(f"Mean rank = {diag_elem_negative_ranks.float().mean().item():.2f}")
print(f"Median rank = {diag_elem_negative_ranks.float().median().item():.0f}\n")

print(f"Proportion with rank zero = {(diag_elem_negative_ranks == 0).float().mean():.2%}")
print(f"Proportion with rank less than 10 = {(diag_elem_negative_ranks < 10).float().mean():.2%}\n")

print(f"Quantity with rank more than 10% (= 5026) = {(diag_elem_negative_ranks > 5026).sum()}")

# Histogram, with all values above 100 cropped so the pattern is more visible
hist(
    diag_elem_negative_ranks[diag_elem_negative_ranks < 100], 
    title=create_title_and_subtitles("Dynamic analysis: suppression ranks of source tokens", ["Rank=0 means source token is the most suppressed"]),
    template="simple_white",
    labels={"x": "Rank"},
)

  0%|          | 0/50257 [00:00<?, ?it/s]

100%|██████████| 50257/50257 [00:03<00:00, 15600.55it/s]


Mean rank = 30.32
Median rank = 0

Proportion with rank zero = 70.74%
Proportion with rank less than 10 = 95.45%

Quantity with rank more than 10% (= 5026) = 70


In [12]:
toks_not_prediction_attention = t.where(diag_elem_negative_ranks > 2000)[0]
str_toks = model.to_str_tokens(toks_not_prediction_attention)
str_toks

['�',
 '�',
 '�',
 '�',
 '�',
 '�',
 '�',
 '�',
 '�',
 '�',
 '�',
 '�',
 '�',
 '\x00',
 '\x01',
 '\x02',
 '\x03',
 '\x04',
 '\x05',
 '\x06',
 '\x07',
 '\x08',
 '\t',
 '\x0b',
 '\x0c',
 '\r',
 '\x0e',
 '\x0f',
 '\x10',
 '\x11',
 '\x12',
 '\x13',
 '\x14',
 '\x15',
 '\x16',
 '\x17',
 '\x18',
 '\x19',
 '\x1a',
 '\x1b',
 '\x1c',
 '\x1d',
 '\x1e',
 '\x1f',
 '\x7f',
 ' to',
 ' of',
 ' that',
 ' it',
 ' as',
 ' have',
 '."',
 ' only',
 ' As',
 ' That',
 'ised',
 ' having',
 ' takes',
 ' whose',
 ' Of',
 ' THE',
 ' aren',
 ' perhaps',
 ' whatever',
 'Of',
 ' Those',
 ' Its',
 ' Such',
 'Their',
 'ÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂ',
 ' meanwhile',
 ' nevertheless',
 ')."',
 'ÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂ',
 'Around',
 '),"',
 ' externalTo',
 ' externalToEVA',
 'reportprint',
 'embedreportprint',
 'rawdownload',
 'ActionCode',
 ' RandomRedditor',
 'SPONSORED',
 'StreamerBot',
 'quickShip',
 '龍�',
 'oreAndOnline',
 'InstoreAndOnline',
 ' foregoing',
 ' TheNitrome',
 ' サーティ',
 ' THEIR',
 '<|endoftext|>']

: 