In [1]:
import pandas as pd
import torch as t
from pathlib import Path
from transformers import AutoTokenizer,AutoModel, GPTNeoXForCausalLM
from tqdm.auto import tqdm
import numpy as np
from pandarallel import pandarallel
import os

pandarallel.initialize(nb_workers=16,progress_bar=True)

INFO: Pandarallel will run on 16 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


In [2]:
pile_df = pd.read_json("data/pile_samples.jsonl",lines=True)
pile_df.head()

Unnamed: 0,text,meta,pile_set_name
0,---\naddress: '$^{1}$ Department of Computer E...,{'pile_set_name': 'ArXiv'},ArXiv
1,---\nabstract: 'We prove that the law of a ran...,{'pile_set_name': 'ArXiv'},ArXiv
2,---\nabstract: 'A nonperturbative numerical ev...,{'pile_set_name': 'ArXiv'},ArXiv
3,---\nabstract: 'Vacancy-induced magnetization ...,{'pile_set_name': 'ArXiv'},ArXiv
4,"---\nauthor:\n- |\n Robert G. Endres$^{1,2}...",{'pile_set_name': 'ArXiv'},ArXiv


In [3]:
device = t.device("cuda" if t.cuda.is_available() else "mps")

In [4]:
MODEL_NAME = "EleutherAI/pythia-70m-deduped"

In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

## Load The Pile

In [6]:
dfs = [pd.read_parquet(f"data/mem/memorization_70m-deduped-v0_143000_rank{rank}.parquet") for rank in range(4)]
df = pd.concat(dfs, ignore_index=True)

In [7]:
df["decoded_context"] = df.apply(lambda x: tokenizer.decode(x["context"]),axis=1)
df["decoded_generation"] = df.apply(lambda x: tokenizer.decode(x["generation"]),axis=1)
df["decoded_true_continuation"] = df.apply(lambda x: tokenizer.decode(x["true_continuation"]),axis=1)

In [42]:
# Sweep over different numbers of matching tokens
max_tokens_to_check = 32  # Maximum number of tokens to check for matching
results = []

for n in range(0, max_tokens_to_check + 1):    
    # all(row['generation'][:n] == row['true_continuation'][:n]) and 
    n_token_match = df.apply(
        lambda row:((n == max_tokens_to_check) or (row['generation'][n] != row['true_continuation'][n])) and
                   (n == 0 or sum(g != t for g, t in zip(row['generation'][:n], row['true_continuation'][:n])) == 1),
        axis=1
    )
    
    # Count matches and calculate percentage
    matching_count = n_token_match.sum()
    matching_percentage = matching_count / len(df) * 100    
    results.append((df[n_token_match], matching_count, matching_percentage))
    print(f"First {n} tokens match with 1 differing token: {matching_count} samples ({matching_percentage:.2f}%)")

First 0 tokens match with 1 differing token: 5784 samples (28.59%)
First 1 tokens match with 1 differing token: 4726 samples (23.36%)
First 2 tokens match with 1 differing token: 777 samples (3.84%)
First 3 tokens match with 1 differing token: 749 samples (3.70%)
First 4 tokens match with 1 differing token: 726 samples (3.59%)
First 5 tokens match with 1 differing token: 670 samples (3.31%)
First 6 tokens match with 1 differing token: 712 samples (3.52%)
First 7 tokens match with 1 differing token: 779 samples (3.85%)
First 8 tokens match with 1 differing token: 1106 samples (5.47%)
First 9 tokens match with 1 differing token: 742 samples (3.67%)
First 10 tokens match with 1 differing token: 570 samples (2.82%)
First 11 tokens match with 1 differing token: 223 samples (1.10%)
First 12 tokens match with 1 differing token: 190 samples (0.94%)
First 13 tokens match with 1 differing token: 179 samples (0.88%)
First 14 tokens match with 1 differing token: 213 samples (1.05%)
First 15 tokens

In [43]:
num_matching_tokens = 10
explore_df = results[num_matching_tokens][0]

In [44]:
explore_df['decoded_context'].values

array([' (630 nm and 700 nm) used for LLLT \\[[@B16-brainsci-09-00179]\\] to investigate the mechanisms underlying neuromodulation \\[[@',
       'nodes}} \\sum _{i=1}^{N_\\mathrm{nodes}} |\\varvec{x}_i(\\varvec{q}) - \\varvec',
       '\n                                    {\n                                        "type": "NamedArgument",\n                                        "name": {\n                                            "type": "',
       '$ $\\mathsf{\\mu}$m. Moreover, the thickness of the dielectric spacer is obtained from Equation ([13](#FD13-nanomaterials-09-01351',
       '_title">Voir le journal d\\\'installation</string>\n  <string name="menu_delete_title">Tout supprimer</',
       ' was a delete marker.\r\n         */\r\n        public boolean isDeleteMarker() {\r\n            return deleteMarker;\r\n        }\r\n\r\n        public void setDelete',
       '. Let b = 211115 - 211112. Factor -63/5*f**2 - 72/5*f - 12/5 - b',
       ' (1996); State v. Rouse, 339 N.C.

In [45]:
explore_df[['decoded_generation','decoded_true_continuation',"acc"]].values

array([['B17-brainsci-09-00179]\\].\n\nThe results of the present study showed that the LLLT-induced neuromodulation was significantly increased',
        'B1-brainsci-09-00179]\\]. Here, the primary aim is to better understand the extent of optically induced tissue heating (primarily due to water',
        0.28125],
       ['{x}_i(\\varvec{q})| \\leq \\sum _{i=1}^{N_\\mathrm{nodes}} |\\varvec{x}_',
        '{x}_i(\\varvec{q}_f) |. \\end{aligned}$$\\end{document}$$The kinematic variable error is defined as the mean',
        0.28125],
       ['String",\n                                            "name": {\n                                                "type": "String",\n                                                "name": {\n                                                "type',
        'Identifier",\n                                            "name": "y"\n                                        },\n                                        "value": {\n                           

In [7]:
df_perfect = df[df['acc'] == 1]

In [10]:
df_nonmem = df[df['acc'] == 0]
df_nonmem['decoded_full'] = tokenizer.batch_decode(df_nonmem.apply(lambda x: list(x['context']) + list(x['true_continuation']),axis=1))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_nonmem['decoded_full'] = tokenizer.batch_decode(df_nonmem.apply(lambda x: list(x['context']) + list(x['true_continuation']),axis=1))


In [11]:
df_nonmem.head(32)

Unnamed: 0,acc,context,generation,true_continuation,idx,decoded_full
0,0.0,"[8148, 412, 29628, 653, 264, 6420, 275, 617, 1...","[187, 187, 510, 5101, 369, 28115, 13, 253, 510...","[7160, 6966, 760, 2559, 672, 13, 9506, 1996, 1...",5,Cleopatra sped south in her majestic flagship...
1,0.0,"[72, 5, 310, 370, 36, 63, 18, 5, 6032, 2822, 8...","[393, 2981, 10952, 359, 452, 1764, 2043, 92, 2...","[268, 64, 75, 4526, 14030, 1124, 92, 18, 1217,...",6,g$ is $C^1$ smooth near such points $\xi$. Fur...
2,0.0,"[1509, 15, 2896, 436, 1127, 13, 3668, 273, 146...","[644, 3863, 15, 187, 187, 510, 1895, 369, 326,...","[2489, 3872, 3668, 13, 285, 9380, 512, 689, 25...",13,"pass. By this point, news of Indiana attempti..."
3,0.0,"[187, 6217, 5543, 39, 10948, 947, 1703, 428, 3...","[760, 2181, 359, 476, 513, 310, 281, 755, 253,...","[346, 977, 1930, 3, 556, 574, 11188, 3489, 58,...",19,\nGRANDFATHERED - you had better fight this so...
4,0.0,"[776, 6264, 285, 9787, 16027, 497, 387, 512, 2...","[253, 1387, 273, 1146, 247, 346, 33979, 3128, ...","[512, 841, 952, 273, 8534, 2220, 29613, 273, 3...",20,our commercial and industrial ties were at al...
5,0.0,"[288, 3560, 407, 5141, 970, 30267, 42, 15, 380...","[9122, 3788, 1712, 1704, 20, 5256, 20, 14, 113...","[8903, 13, 5365, 3788, 16, 8903, 13, 285, 1089...",22,h followed by reduction using MWI. The same p...
6,0.0,"[187, 187, 4125, 2207, 67, 2823, 310, 2970, 76...","[310, 2970, 767, 41196, 898, 37048, 323, 1679,...","[25711, 1694, 253, 1682, 8027, 2968, 275, 2892...",23,\n\nNow Orbcomm is getting two Falcon 9 launch...
7,0.0,"[187, 187, 33651, 368, 281, 10078, 275, 39858,...","[403, 417, 2908, 275, 253, 4803, 15, 187, 187,...","[812, 3831, 3081, 1491, 670, 849, 359, 897, 28...",24,"\n\nAllow you to participate in contests, priz..."
8,0.0,"[1677, 323, 247, 642, 14, 21272, 390, 642, 14,...","[253, 1971, 273, 19798, 5988, 2923, 15, 187, 1...","[19859, 41286, 1975, 8837, 285, 310, 253, 806,...",32,given for a no-treatment or no-surgical inter...
9,0.0,"[1361, 747, 1363, 24171, 253, 1943, 4715, 6970...","[31100, 323, 253, 1971, 273, 10397, 13, 9454, ...","[1824, 19663, 31100, 275, 253, 2791, 1659, 15,...",33,help new patients navigate the big learning c...


In [8]:
df_perfect['decoded_full'] = tokenizer.batch_decode(df_perfect.apply(lambda x: list(x['context']) + list(x['true_continuation']),axis=1))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_perfect['decoded_full'] = tokenizer.batch_decode(df_perfect.apply(lambda x: list(x['context']) + list(x['true_continuation']),axis=1))


In [9]:
df_perfect.head(32)

Unnamed: 0,acc,context,generation,true_continuation,idx,decoded_full
3795,1.0,"[5584, 4196, 1228, 187, 1036, 4, 209, 21723, 2...","[783, 346, 17736, 3287, 187, 4, 368, 778, 417,...","[783, 346, 17736, 3287, 187, 4, 368, 778, 417,...",441,/(-96))\n16# *********************************...
3796,1.0,"[50262, 61, 2099, 92, 8861, 94, 187, 50262, 61...","[669, 8604, 60, 805, 431, 1019, 8402, 94, 187,...","[669, 8604, 60, 805, 431, 1019, 8402, 94, 187,...",447,\usepackage{upgreek}\n ...
3797,1.0,"[475, 50272, 953, 24781, 778, 320, 908, 281, 1...","[4110, 33278, 9149, 3003, 28827, 43227, 4889, ...","[4110, 33278, 9149, 3003, 28827, 43227, 4889, ...",792,* its contributors may be used to endors...
3798,1.0,"[424, 380, 16101, 313, 433, 17889, 3104, 10, 2...","[50262, 61, 2099, 92, 8860, 94, 2490, 50262, 6...","[50262, 61, 2099, 92, 8860, 94, 2490, 50262, 6...",1539,** The analytical (red dashed lines) and simul...
3799,1.0,"[3498, 2262, 2369, 40, 736, 13, 3956, 27, 21, ...","[69, 7675, 27, 2369, 40, 736, 13, 9229, 27, 40...","[69, 7675, 27, 2369, 40, 736, 13, 9229, 27, 40...",1705,"px):noGrow,top:4dlu:noGrow,center:max(d;4px):n..."
3800,1.0,"[187, 187, 4, 21737, 281, 253, 14325, 9107, 68...","[436, 789, 323, 3081, 1491, 187, 4, 5001, 9451...","[436, 789, 323, 3081, 1491, 187, 4, 5001, 9451...",3050,\n\n# Licensed to the Apache Software Foundati...
3801,1.0,"[66, 15, 3122, 187, 475, 8283, 313, 68, 10, 88...","[19252, 28827, 5803, 27279, 4145, 10113, 15691...","[19252, 28827, 5803, 27279, 4145, 10113, 15691...",3509,"a./*\n * Copyright (c) 1995, 2012, Oracle and/..."
3802,1.0,"[8161, 13, 2975, 15, 10701, 285, 3915, 9308, 3...","[94, 187, 50262, 61, 2099, 92, 8860, 94, 2490,...","[94, 187, 50262, 61, 2099, 92, 8860, 94, 2490,...",3851,"Hz, respectively. Network and STDP parameters ..."
3803,1.0,"[94, 187, 50262, 61, 2043, 92, 3306, 2138, 138...","[1019, 8402, 94, 187, 50262, 61, 2099, 92, 879...","[1019, 8402, 94, 187, 50262, 61, 2099, 92, 879...",3870,}\n \begin{document}$$support(\...
3804,1.0,"[43518, 66, 1619, 66, 43518, 66, 1619, 66, 435...","[43518, 66, 1619, 66, 43518, 66, 1619, 66, 435...","[43518, 66, 1619, 66, 43518, 66, 1619, 66, 435...",4942,2800a28a2800a28a2800a28a2800a28a2800a28a2800a2...


### Classify Subset

### Embedding similarity

In [171]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
# Load pre-trained sentence transformer model
embed_model = SentenceTransformer('all-mpnet-base-v2')

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [172]:
pile_embeddings = embed_model.encode(pile_df['text'].str.slice(0,64).to_list(),batch_size=64)

In [173]:
embeddings =embed_model.encode(df_perfect['decoded_full'].to_list(),batch_size=64)


In [174]:
## Just argmax
# sim_matrix = cosine_similarity(embeddings,pile_embeddings)
# # Get the index of the most similar document for each query
# most_similar_idx = sim_matrix.argmax(axis=1)
# # Get the pile_set_name for the most similar documents
# most_similar_pile_sets = pile_df.iloc[most_similar_idx]['pile_set_name']
# # Return array of similarity scores for most similar documents
# max_similarities = sim_matrix.max(axis=1)

## Use KNeighborsClassifier instead of just taking argmax
from sklearn.neighbors import KNeighborsClassifier

# Fit KNN classifier on pile embeddings
knn = KNeighborsClassifier(n_neighbors=10,metric="cosine")
knn.fit(pile_embeddings, pile_df['pile_set_name'])

# Get predictions and distances for each query embedding
distances, _ = knn.kneighbors(embeddings)
most_similar_pile_sets = knn.predict(embeddings)
max_similarites = 1 - distances

In [None]:
df_perfect['most_similar_pile_set'] = most_similar_pile_sets

In [179]:
df_perfect.head(32)
# df_perfect['max_similarity'] = max_similarites

Unnamed: 0,acc,context,generation,true_continuation,idx,decoded_full,most_similar_pile_set
3795,1.0,"[5584, 4196, 1228, 187, 1036, 4, 209, 21723, 2...","[783, 346, 17736, 3287, 187, 4, 368, 778, 417,...","[783, 346, 17736, 3287, 187, 4, 368, 778, 417,...",441,/(-96))\n16# *********************************...,Github
3796,1.0,"[50262, 61, 2099, 92, 8861, 94, 187, 50262, 61...","[669, 8604, 60, 805, 431, 1019, 8402, 94, 187,...","[669, 8604, 60, 805, 431, 1019, 8402, 94, 187,...",447,\usepackage{upgreek}\n ...,DM Mathematics
3797,1.0,"[475, 50272, 953, 24781, 778, 320, 908, 281, 1...","[4110, 33278, 9149, 3003, 28827, 43227, 4889, ...","[4110, 33278, 9149, 3003, 28827, 43227, 4889, ...",792,* its contributors may be used to endors...,Gutenberg (PG-19)
3798,1.0,"[424, 380, 16101, 313, 433, 17889, 3104, 10, 2...","[50262, 61, 2099, 92, 8860, 94, 2490, 50262, 6...","[50262, 61, 2099, 92, 8860, 94, 2490, 50262, 6...",1539,** The analytical (red dashed lines) and simul...,ArXiv
3799,1.0,"[3498, 2262, 2369, 40, 736, 13, 3956, 27, 21, ...","[69, 7675, 27, 2369, 40, 736, 13, 9229, 27, 40...","[69, 7675, 27, 2369, 40, 736, 13, 9229, 27, 40...",1705,"px):noGrow,top:4dlu:noGrow,center:max(d;4px):n...",Github
3800,1.0,"[187, 187, 4, 21737, 281, 253, 14325, 9107, 68...","[436, 789, 323, 3081, 1491, 187, 4, 5001, 9451...","[436, 789, 323, 3081, 1491, 187, 4, 5001, 9451...",3050,\n\n# Licensed to the Apache Software Foundati...,Github
3801,1.0,"[66, 15, 3122, 187, 475, 8283, 313, 68, 10, 88...","[19252, 28827, 5803, 27279, 4145, 10113, 15691...","[19252, 28827, 5803, 27279, 4145, 10113, 15691...",3509,"a./*\n * Copyright (c) 1995, 2012, Oracle and/...",Github
3802,1.0,"[8161, 13, 2975, 15, 10701, 285, 3915, 9308, 3...","[94, 187, 50262, 61, 2099, 92, 8860, 94, 2490,...","[94, 187, 50262, 61, 2099, 92, 8860, 94, 2490,...",3851,"Hz, respectively. Network and STDP parameters ...",ArXiv
3803,1.0,"[94, 187, 50262, 61, 2043, 92, 3306, 2138, 138...","[1019, 8402, 94, 187, 50262, 61, 2099, 92, 879...","[1019, 8402, 94, 187, 50262, 61, 2099, 92, 879...",3870,}\n \begin{document}$$support(\...,Github
3804,1.0,"[43518, 66, 1619, 66, 43518, 66, 1619, 66, 435...","[43518, 66, 1619, 66, 43518, 66, 1619, 66, 435...","[43518, 66, 1619, 66, 43518, 66, 1619, 66, 435...",4942,2800a28a2800a28a2800a28a2800a28a2800a28a2800a2...,PhilPapers


In [None]:
df_perfect

### ChatGPT

In [None]:
from openai import OpenAI
import pandas as pd
import random

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def construct_prompt(input_text:str):
    """Build prompt with 2 examples per category from pile_df"""
    examples = []
    
    # Get 2 random examples for each category
    for category in pile_df['pile_set_name'].unique():
        category_examples = pile_df[pile_df['pile_set_name'] == category].sample(2)
        examples.append(f"Category: {category}\nExamples:")
        examples.extend([f"```{text[:64].strip()}```" for text in category_examples['text'].tolist()])
    
    examples_str = "\n\n".join(examples)

    system_prompt = (
        "Classify this text into one of the following Pile Sets based on content and style:\n\n"
        f"{examples_str}\n\n"
        "Answer ONLY with the category name, nothing else."
    )
    prompt = (
        f"Text to classify: ```{input_text.strip()}```\n"        
    )
    return system_prompt, prompt

def classify_text_with_gpt4o(input_text:str):
    """Classify text using GPT-4o-mini with example-based prompting"""
    if not input_text.strip():
        return "Error: Empty input text"
    
    system_prompt, user_prompt = construct_prompt(input_text)    
    
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[{
            "role": "system",
            "content": system_prompt
        },{
            "role": "user",
            "content": user_prompt
        }],
        temperature=0.1,
        max_tokens=50
    )
    
    return response.choices[0].message.content.strip()

# Example usage
prediction = classify_text_with_gpt4o("Your input text here")
print(f"Predicted category: {prediction}")

Predicted category: Category: Pile-CC


In [20]:
most_similar_pile_sets = df_perfect["decoded_full"].parallel_apply(classify_text_with_gpt4o)

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=313), Label(value='0 / 313'))), HB…

In [28]:
df_perfect['most_similar_pile_set'] = df_perfect['most_similar_pile_set'].apply(lambda x: x.replace("Category:","").strip())

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_perfect['most_similar_pile_set'] = df_perfect['most_similar_pile_set'].apply(lambda x: x.replace("Category:","").strip())


In [29]:
df_perfect['most_similar_pile_set'].value_counts()

most_similar_pile_set
Github                     2437
ArXiv                       843
DM Mathematics              471
Pile-CC                     220
OpenWebText2                181
PubMed Central              129
Books3                      110
FreeLaw                     108
OpenSubtitles               102
StackExchange               100
PubMed Abstracts             72
Wikipedia (en)               65
Gutenberg (PG-19)            35
Ubuntu IRC                   34
BookCorpus2                  26
HackerNews                   22
Error: Empty input text      15
USPTO Backgrounds            15
Enron Emails                  9
NIH ExPorter                  4
PhilPapers                    3
YoutubeSubtitles              2
None                          2
EuroParl                      1
Unknown                       1
Name: count, dtype: int64

In [25]:
df_perfect.to_json("data/mem/perfect_memorization_labeled.jsonl",lines=True,orient="records")

In [87]:
# print("\n".join(f"<row_{i}>\n{v}" for i,v in zip(mem_sample.index,mem_sample.values)))

### Look for corrupted sequences/wrong answers


In [196]:
model = GPTNeoXForCausalLM.from_pretrained(MODEL_NAME)
model = model.to(device)

In [197]:
toks = tokenizer("When Mary and John went to the store, John gave a drink to",return_tensors="pt")["input_ids"]
toks = toks.to(device)
logits = model(toks)


In [207]:
toks.size()

torch.Size([1, 14])

In [209]:
tokenizer.decode(logits.logits.argmax(dim=-1)[:,-1])

' the'

In [204]:
tokenizer.decode(logits.logits.argmax(dim=-1)[0])

' the was I were to the house to she and them small to the'

In [84]:
context =t.tensor(df_perfect["context"].to_list())
generation = t.tensor(df_perfect["generation"].to_list())

In [85]:
# Move model to device and set to eval mode
model = model.to(device)
model.eval()

# Process in batches to avoid memory issues
batch_size = 32
outputs = []

with t.no_grad():
    for i in range(0, len(context), batch_size):
        batch = context[i:i + batch_size].to(device)
        # Get model outputs including all logits

        batch_output = model.generate(batch, max_new_tokens=32, return_dict_in_generate=True, output_logits=True, output_attentions=True)        
        break

# Concatenate all batches
# all_outputs = t.cat(outputs, dim=0)


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


### Attention scores

In [35]:
from bertviz import head_view

In [51]:
first_attentions = []
for step in batch_output.attentions[0]:
    first_attentions.append(step[1,...].unsqueeze(0))

In [103]:
sample_attentions = []
for layer_idx in range(len(batch_output.attentions[0])):  # Iterate through layers
    # Initialize tensor for this layer
    layer_attentions = []
    
    # Get attention for first step (full context attention)
    first_step = batch_output.attentions[0][layer_idx]  # [batch_size, num_heads, seq_len, seq_len]
    
    target_size = first_step.size(-1) + len(batch_output.attentions)
    
    padded_first_step = t.nn.functional.pad(
        first_step,
        (0, target_size - first_step.size(-1),  # Pad last dimension
         0, target_size - first_step.size(-2))   # Pad second-to-last dimension
    )
    layer_attentions.append(padded_first_step)
    
    # Add attention from subsequent steps
    for step_idx, step in enumerate(batch_output.attentions[1:], start=1):
        step_attn = step[layer_idx]  # [batch_size, num_heads, 1, seq_len+step]
        # Pad to match target size
        # Create zero tensor of target size
        full_attn = t.zeros(step_attn.shape[0], step_attn.shape[1], target_size, target_size, device=step_attn.device)
        
        # Place the attention scores in the correct row (corresponding to the current token)
        full_attn[:, :, first_step.size(-1) + step_idx - 1, :step_attn.size(-1)] = step_attn.squeeze(2)
        
        layer_attentions.append(full_attn)
    
    # Combine all steps
    combined_attention = t.sum(t.stack(layer_attentions), dim=0) # Sum across steps
    sample_attentions.append(combined_attention)

In [132]:
batch_id = 28

In [133]:
first_attentions = [l[batch_id,...].unsqueeze(0) for l in sample_attentions]
tokens = tokenizer.convert_ids_to_tokens(batch_output.sequences[batch_id])
# first generation tokens


'Spanish feminine given names\nCategory:Scandinavian feminine given names/*\n * Copyright 2018 IBM Corporation\n *\n * Licensed under the Apache License, Version 2'

In [156]:
tokenizer.decode(batch_output.sequences[batch_id][:32]),tokenizer.decode(batch_output.sequences[batch_id][32:])

('Spanish feminine given names\nCategory:Scandinavian feminine given names/*\n * Copyright 2018 IBM Corporation\n *\n * Licensed under the Apache License, Version 2',
 '.0 (the "License");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License')

In [168]:
corrupted_ids = tokenizer(tokenizer.decode(context[batch_id]).replace("given names", "surnames",1),return_tensors="pt")
corrupted_ids = corrupted_ids["input_ids"].to(device)
corrupted_ids.size()

torch.Size([1, 33])

In [150]:
corrupted_output = model.generate(corrupted_ids, max_new_tokens=32, return_dict_in_generate=True, output_logits=True, output_attentions=True)        

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


In [170]:
tokenizer.decode(corrupted_output.sequences[0][:34]),tokenizer.decode(corrupted_output.sequences[0][34:])

('Spanish feminine given surnames\nCategory:Scandinavian feminine given names/*\n * Copyright 2018 IBM Corporation\n *\n * Licensed under the Apache License, Version 2',
 '.0 (the "License");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License')

In [134]:
head_view(first_attentions,tokens,layer=5)

<IPython.core.display.Javascript object>

In [188]:
# Initialize lists to store results
max_attended_tokens = []
attention_scores = []

# Process each generation step
for step_idx, attn in enumerate(batch_output.attentions):
    # Get last layer attention
    last_layer = attn[-1]  # [batch_size, num_heads, seq_len_q, seq_len_k]
    
    # Aggregate across heads: max or mean

    # agg_attn,_ = last_layer.max(dim=1)
    agg_attn = last_layer.mean(dim=1)  # [batch_size, seq_len_q, seq_len_k]
    
    # Get the token with maximum attention for each query
    max_attn_scores, max_attn_indices = agg_attn.max(dim=-1)  # [batch_size, seq_len_q]
    
    # We only care about the last position (new token)
    max_attended_tokens.append(max_attn_indices[:, -1])  # [batch_size]
    attention_scores.append(max_attn_scores[:, -1])  # [batch_size]

# Convert to tensors
max_attended_tokens = t.stack(max_attended_tokens)  # [num_tokens, batch_size]
attention_scores = t.stack(attention_scores)  # [num_tokens, batch_size]


In [189]:
# Print example for first sequence
print(f"Example attention analysis for sequence # {batch_id}:")
generated_tokens = batch_output.sequences[batch_id]  # Skip context
input_tokens = context[batch_id].tolist()

print(f"Length of input tokens: {len(input_tokens)}")
print(f"Length of generated tokens: {len(generated_tokens)}")
print(f"Shape of max_attended_tokens: {max_attended_tokens.shape}")

for i, gen_token in enumerate(generated_tokens[32:]):
    max_attn_idx = max_attended_tokens[i, 0].item()
    attn_score = attention_scores[i, 0].item()
    
    # Add bounds checking
    if max_attn_idx < len(generated_tokens):
        input_token = tokenizer.convert_ids_to_tokens([generated_tokens[max_attn_idx]])
        gen_token_decoded = tokenizer.convert_ids_to_tokens([gen_token])
        print(f"[{32+i}] Generated token '{gen_token_decoded}' attends most to input token '{input_token}'[{max_attn_idx}] (score: {attn_score:.3f})")
    else:
        print(f"Warning: attention index {max_attn_idx} out of range for input length {len(generated_tokens)}")

Example attention analysis for sequence # 28:
Length of input tokens: 32
Length of generated tokens: 64
Shape of max_attended_tokens: torch.Size([32, 32])
[32] Generated token '['.']' attends most to input token '['Ġnames']'[3] (score: 0.240)
[33] Generated token '['0']' attends most to input token '['.']'[32] (score: 0.257)
[34] Generated token '['Ġ(']' attends most to input token '['.']'[32] (score: 0.277)
[35] Generated token '['the']' attends most to input token '['Ġ(']'[34] (score: 0.337)
[36] Generated token '['Ġ"']' attends most to input token '['Ġnames']'[3] (score: 0.442)
[37] Generated token '['License']' attends most to input token '['Ġnames']'[3] (score: 0.528)
[38] Generated token '['");']' attends most to input token '['the']'[35] (score: 0.750)
[39] Generated token '['Ċ']' attends most to input token '['");']'[38] (score: 0.347)
[40] Generated token '['Ġ*']' attends most to input token '['Ġnames']'[3] (score: 0.334)
[41] Generated token '['Ġyou']' attends most to input t

In [17]:
batch_output.attentions[0][-1].size()

torch.Size([32, 8, 32, 32])

In [12]:
len(batch_output.attentions)

32

In [None]:
batch_output.sequences[0,32:] == generation[0]

tensor([  783,   346, 17736,  3287,   187,     4,   368,   778,   417,   897,
          436,  1873,  3707,   275, 10276,   342,   253,  4637,    15,   187,
            4,  1422,   778,  4044,   247,  3491,   273,   253,  4637,   387,
          187,     4], device='mps:0')

In [61]:
top_tokens = []
top_probs = []
for gen_logits in batch_output.logits:
    gen_logits = t.softmax(gen_logits,dim=-1)
    topk = gen_logits.topk(k=5,dim=-1)
    top_tokens.append(topk[1].unsqueeze(1)) # [batch_size, 1, 5]
    top_probs.append(topk[0].unsqueeze(1)) # [batch_size, 1, 5]
top_tokens = t.cat(top_tokens,dim=1) # [batch_size, max_new_tokens, 5]
top_probs = t.cat(top_probs,dim=1) # [batch_size, max_new_tokens, 5]

### Explore dataset

In [10]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [11]:
sample = df_perfect.sample(5)
print("Sample of perfectly memorized sequences:")
for i,row  in sample.iterrows():    
    print("\nExample", i)
    print("Context:", tokenizer.decode(row['context']))
    print("True continuation:", tokenizer.decode(row['true_continuation']))

Sample of perfectly memorized sequences:

Example 19936
Context: /licenses/CDDL+GPL-1.1
# or LICENSE.txt.  See the License for the specific
# language governing permissions and limitations
True continuation:  under the License.
#
# When distributing the software, include this License Header Notice in each
# file and include the License file at LICENSE.txt

Example 4152
Context: {D})$\end{document}$, $\documentclass[12pt]{minimal}
                \usepackage{amsmath}
                \usepackage{wasysym} 
                
True continuation: \usepackage{amsfonts} 
                \usepackage{amssymb} 
                \usepackage{amsbsy}
                \usepackage{mathrsfs}
                \usepackage{upgreek

Example 9532
Context: [12pt]{minimal}
                \usepackage{amsmath}
                \usepackage{wasysym} 
                \usepackage{amsfonts} 
                \usepackage{
True continuation: amssymb} 
                \usepackage{amsbsy}
                \usepackage{mathrsfs}

In [13]:
import filter.pattern_filter as pattern_filter
import importlib
importlib.reload(pattern_filter)
def is_simple_pattern(x):
    tokens = tokenizer.decode(x["context"])
    continuation = tokenizer.decode(x["true_continuation"]) 
    return pattern_filter.is_simple_pattern(tokens,continuation)
df_perfect['to_filter'] = df_perfect.apply(is_simple_pattern,axis=1)
# mem_sample = df_perfect[df_perfect['most_similar_pile_set'] == "Wikipedia (en)"]
# text = ""
# for i,row in mem_sample.iterrows():
#     tokens = tokenizer.decode(row['context'])
#     continuation = tokenizer.decode(row['true_continuation'])
#     text += f"<row_{i}>\n{tokens}\n<completion>{continuation}\n</row_{i}>\n"
# print(text)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_perfect['to_filter'] = df_perfect.apply(is_simple_pattern,axis=1)


In [14]:
ids = [3823,  4213,  4462,  4490,  4565,  4599,  9001,  9068,  9503,  9724,  9844,  9856,  9906,  14050,  14166,  14371,  14533,  14537,  19050,  19439,  19603,  19858,  19891,  20000]
samples = df_perfect[df_perfect.index.isin(ids)]

In [15]:
samples

Unnamed: 0,acc,context,generation,true_continuation,idx,decoded_full,to_filter
3823,1.0,"[29520, 35491, 1677, 4454, 187, 1413, 27, 4316...","[15, 17, 313, 783, 346, 17736, 3287, 187, 475,...","[15, 17, 313, 783, 346, 17736, 3287, 187, 475,...",14556,Spanish feminine given names\nCategory:Scandin...,True
4213,1.0,"[13, 253, 5304, 540, 251, 1050, 16417, 275, 86...","[310, 634, 806, 4143, 13, 320, 2119, 281, 187,...","[310, 634, 806, 4143, 13, 320, 2119, 281, 187,...",144066,", the visual intonational arrays in sign langu...",False
4462,1.0,"[1688, 313, 6448, 898, 4437, 24652, 10, 310, 2...","[187, 1413, 27, 41083, 10782, 187, 1413, 27, 1...","[187, 1413, 27, 41083, 10782, 187, 1413, 27, 1...",235037,hor (born 9 October 1952) is an Iranian alpine...,False
4490,1.0,"[44926, 209, 15765, 6781, 187, 187, 29, 34, 34...","[7956, 2293, 49, 19873, 2637, 568, 20, 3, 7956...","[7956, 2293, 49, 19873, 2637, 568, 20, 3, 7956...",244375,"SUMMARY =========== -->\n\n<A NAME=""field_sum...",False
4565,1.0,"[13, 11432, 6884, 27141, 13, 36481, 13, 253, 4...","[10850, 13, 285, 253, 4454, 13, 15278, 7886, 1...","[10850, 13, 285, 253, 4454, 13, 15278, 7886, 1...",269366,", Major League Baseball, MLB, the silhouetted ...",False
4599,1.0,"[31, 187, 870, 2851, 31, 187, 870, 1206, 31, 1...","[335, 966, 568, 29589, 1138, 187, 29, 965, 966...","[335, 966, 568, 29589, 1138, 187, 29, 965, 966...",278199,>\n</td>\n</tr>\n</table>\n</li>\n</ul>\n<!-- ...,False
9001,1.0,"[7, 6732, 45599, 7, 6732, 13143, 965, 31, 187,...","[1138, 6942, 870, 66, 3073, 965, 31, 187, 870,...","[1138, 6942, 870, 66, 3073, 965, 31, 187, 870,...",36666858,&nbsp;|&nbsp;</li>\n<li>Constr&nbsp;|&nbsp;</l...,False
9068,1.0,"[4437, 25124, 10, 310, 247, 19701, 45125, 505,...","[14316, 952, 187, 1413, 27, 7638, 763, 5086, 2...","[14316, 952, 187, 1413, 27, 7638, 763, 5086, 2...",36685465,October 1951) is a Polish gymnast. He compete...,True
9503,1.0,"[733, 310, 1119, 327, 6984, 356, 4843, 274, 15...","[16169, 187, 4, 187, 4, 21737, 762, 253, 14325...","[16169, 187, 4, 187, 4, 21737, 762, 253, 14325...",36838589,It is found on Madagascar.\n\nReferences\n\nC...,False
9724,1.0,"[6491, 13, 12412, 5720, 13, 44, 300, 76, 10614...","[285, 697, 2330, 11547, 403, 2120, 15299, 2758...","[285, 697, 2330, 11547, 403, 2120, 15299, 2758...",36917139,"People,High Street,KilkennyEmail: editor@kilk...",False


## Attention Heads Activation Frequency

In [19]:
import json
with open("data/mem/contexts.jsonl","w") as f:
    for c in samples.apply(lambda x: tokenizer.decode(x["context"]),axis=1).values:
        f.write(json.dumps(c) + "\n")
with open("data/mem/completions.jsonl","w") as f:
    for c in samples.apply(lambda x: tokenizer.decode(x["true_continuation"]),axis=1).values:
        f.write(json.dumps(c) + "\n")

In [15]:
import json
general_samples =df[df['acc'] == 0]

with open("data/mem/contexts_0mem.jsonl","w") as f:
    for c in general_samples.apply(lambda x: tokenizer.decode(x["context"]),axis=1).values:
        f.write(json.dumps(c) + "\n")
with open("data/mem/completions_gen_0mem.jsonl","w") as f:
    for c in general_samples.apply(lambda x: tokenizer.decode(x["generation"]),axis=1).values:
        f.write(json.dumps(c) + "\n")
with open("data/mem/completions_true_0mem.jsonl","w") as f:
    for c in general_samples.apply(lambda x: tokenizer.decode(x["true_continuation"]),axis=1).values:
        f.write(json.dumps(c) + "\n")

## Find circuits

In [8]:
import torch as t
from auto_circuit.data import PromptDataLoader, PromptDataset
from torch.utils.data import Subset


In [9]:
def to_autocircuit_ds(df:pd.DataFrame,return_seq_length: bool = False,tail_divergence: bool = False, test_size: float = 0.1, batch_size: int | tuple[int, int] = 8):

    # Convert all columns to tensors
    context = t.tensor(df["context"].to_list())
    true_continuation = t.tensor(df["true_continuation"].to_list())
    generation = t.tensor(df["generation"].to_list())

    
    dataset = PromptDataset(clean_prompts=context.to(device), 
                            corrupt_prompts=context.to(device),  # TODO: think about how to corrupt
                            answers=[a[0].unsqueeze(0).to(device) for a in generation], # Take first token of continuation
                            wrong_answers=[a[0].unsqueeze(0).to(device) for a in generation],)   # TODO: why generation is wrong?
     
    dataset_size = len(dataset)
    train_size = int(dataset_size * (1 - test_size))
    train_set = Subset(dataset, list(range(train_size)))
    test_set = Subset(dataset, list(range(train_size, dataset_size)))

    seq_len = None    
    diverge_idx: int = 0
    kvs = []
    if return_seq_length:        
        seq_len = df["context"].shape[1]
    
    if tail_divergence:
        diverge_idxs = (~(df["context"] == df["context"])).int().argmax(dim=1)
        diverge_idx = int(diverge_idxs.min().item())
    if diverge_idx > 0:
        raise NotImplementedError()

    train_loader = PromptDataLoader(
        train_set,
        seq_len=seq_len,
        diverge_idx=diverge_idx,
        kv_cache=kvs[0] if len(kvs) > 0 else None,
        seq_labels=None,
        word_idxs=None,
        batch_size=batch_size[0] if isinstance(batch_size, tuple) else batch_size,
        shuffle=False,
    )
    test_loader = PromptDataLoader(
        test_set,
        seq_len=seq_len,
        diverge_idx=diverge_idx,
        kv_cache=kvs[-1] if len(kvs) > 0 else None,
        seq_labels=None,
        word_idxs=None,
        batch_size=batch_size[1] if isinstance(batch_size, tuple) else batch_size,
        shuffle=False,
    )
    return train_loader, test_loader


In [10]:
train_loader, test_loader = to_autocircuit_ds(df_perfect)

  context = t.tensor(df["context"].to_list())


In [11]:
len(train_loader)

563

In [41]:
from auto_circuit.experiment_utils import load_tl_model
from auto_circuit.types import AblationType, PatchType, PruneScores
from auto_circuit.utils.graph_utils import patchable_model,edge_counts_util

from auto_circuit.utils.tensor_ops import prune_scores_threshold
from auto_circuit.prune_algos.mask_gradient import mask_gradient_prune_scores
from auto_circuit.prune import run_circuits
from auto_circuit.metrics.prune_metrics.kl_div import measure_kl_div
from auto_circuit.metrics.prune_metrics.answer_diff import measure_answer_diff
from auto_circuit.metrics.prune_metrics.answer_value import measure_answer_val,batch_avg_answer_val
from auto_circuit.visualize import draw_seq_graph

def find_circuits(model,
                  train_loader:PromptDataLoader, 
                  test_loader:PromptDataLoader,
                  ablation_type=AblationType.RESAMPLE,
                  patch_type=PatchType.TREE_PATCH):

    prune_scores: PruneScores = mask_gradient_prune_scores(
        model=model,
        dataloader=train_loader,
        official_edges=None,
        grad_function="logit",
        answer_function="avg_diff",
        mask_val=0.0, 
        ablation_type=ablation_type
    )
    edge_count = edge_counts_util(model.edges, prune_scores=prune_scores)

    return prune_scores, run_circuits(model,test_loader,edge_count,prune_scores,ablation_type=ablation_type,patch_type=patch_type)

def find_min_circuit(circuit_accuracy,base_accuracy,eps:float=0.2):
    model_edge_count, base_acc = base_accuracy[0]
    for edge_count, acc in circuit_accuracy:
        if acc and (abs(acc - base_acc) / base_acc) < eps:
            return edge_count, acc
    return model_edge_count, base_acc

In [13]:
tl_model = load_tl_model(MODEL_NAME, device)
p_model = patchable_model(
    tl_model,
    factorized=True,
    slice_output="last_seq",
    separate_qkv=True,
    device=device,
    ignore_tokens=[]
)

Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


In [14]:
path = Path(f"{MODEL_NAME.split('/')[-1]}_mem_prune_scores.pkl")
if path.exists():
    prune_scores: PruneScores = t.load(path)
else:
    prune_scores: PruneScores = mask_gradient_prune_scores(
        model=p_model,
        dataloader=train_loader,
        official_edges=None,
        grad_function="logit",
        answer_function="avg_val",
        mask_val=0.0, 
        ablation_type=AblationType.ZERO
    )

  prune_scores: PruneScores = t.load(path)


In [16]:
edge_count = edge_counts_util(p_model.edges, prune_scores=prune_scores)
outs = run_circuits(p_model,test_loader,edge_count,prune_scores,ablation_type=AblationType.ZERO,patch_type=PatchType.TREE_PATCH)

VBox(children=(          | 0/62 [00:00<?, ?it/s],))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [42]:
threshold = prune_scores_threshold(prune_scores, 100)
threshold

tensor(290.5211, device='mps:0')

In [43]:
fig = draw_seq_graph(
    p_model, prune_scores, threshold.item(), layer_spacing=True, orientation="v"
)

In [34]:
outs.keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 2000, 3000, 3580])

In [44]:
first100 = {k:v for k,v in outs.items() if k <= 100}

In [45]:
val_results = measure_answer_val(p_model,test_loader,first100)

VBox(children=(          | 0/20 [00:00<?, ?it/s],))

In [46]:
val_results

[(0, 12.57073974609375),
 (1, 12.57073974609375),
 (2, 11.126285552978516),
 (3, 12.338752746582031),
 (4, 12.338752746582031),
 (5, 11.980280876159668),
 (6, 11.901838302612305),
 (7, 10.829748153686523),
 (8, 10.829748153686523),
 (9, 9.464014053344727),
 (10, 9.464014053344727),
 (20, 9.890220642089844),
 (30, 10.33684253692627),
 (40, 7.953545093536377),
 (50, 7.934033393859863),
 (60, 7.9774088859558105),
 (70, 8.016780853271484),
 (80, 8.02933406829834),
 (90, 8.076476097106934),
 (100, 8.047252655029297)]

In [20]:
kl_results = measure_kl_div(p_model,test_loader,outs)

VBox(children=(          | 0/32 [00:00<?, ?it/s],))

KeyboardInterrupt: 

In [42]:
prune_scores_base, circuits_out_base = find_circuits(p_model,train_loader,test_loader)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


VBox(children=(          | 0/1 [00:00<?, ?it/s],))

VBox(children=(          | 0/253 [00:00<?, ?it/s],))

: 

In [None]:

# base_accuracy = measure_correct_ans_percent_model(p_model,test_loader)
# circuit_accuracy_base = measure_correct_ans_percent(p_model,test_loader,circuits_out_base)
best_edge_count,best_circuit_accuracy = find_min_circuit(circuit_accuracy_base,base_accuracy)
threshold_base = prune_scores_threshold(prune_scores_base, best_edge_count).detach().item()