In [4]:
from typing import Callable, Protocol, NamedTuple
from collections import Counter

import torch
from jaxtyping import Int, Float, Bool
from transformers import PreTrainedTokenizer, AutoTokenizer
from transformer_lens import HookedTransformer, HookedTransformerConfig

In [5]:
# get the gpt2 tokenizer
TOKENIZER: PreTrainedTokenizer = AutoTokenizer.from_pretrained("gpt2")

# tokens which have quotes

In [6]:
print(f"{len(TOKENIZER)} ")
tokens_with_quotes: list[str] = [
	x 
    for x in TOKENIZER.get_vocab().keys() 
    if ('"' in x)
]
print(f"{len(tokens_with_quotes) = }")

50257 
len(tokens_with_quotes) = 131


In [7]:
print(tokens_with_quotes)

['"},', '!"', '.""', 'Ġ"%', '"},{"', '"-', '"))', '}"', 'Ġ"""', 'Ġ"{', '"),', '"[', '")', '!",', '"', ']"', '"""', '"!', '"/>', '".', '"],"', '],"', '!".', 'Ġ"-', '?"', '"><', '},"', 'Ġ..."', 'Ġ"@', '"âĢ¦', '="/', '""', '."[', '">', '!!"', '),"', 'Ġ"(', '\\">', 'Ġ."', '.")', '.,"', 'Ġ"...', 'Ġ"+', '!?"', 'Ġ".', ',"', '>"', '");', 'âĢĶ"', '.","', '";', 'Ġ{"', '/"', 'Ġ",', '"></', 'Ġ"/', '\'"', '%"', 'Ġ"$', '="#', '"\'', '?!"', '":', ':"', 'Ġ""', '"],', ')"', '":[', 'ĠâĢ¦"', '":["', '="', '?".', '."', '"}],"', 'âĢ¦."', 'Ġ("', '-"', '!\'"', '",', '"},"', '.",', '\\":', 'Ġ"[', ']."', ')."', 'Ġ["', '.\'"', '\'."', '":"', '":{"', '=\\"', '{"', 'Ġ,"', 'Ġ"#', '"}', ')",', '":[{"', '\',"', '":"/', '("', 'Ġ"', '":-', 'Ġ\\"', '"âĢĶ', '["', '").', ';"', '"]=>', '..."', 'Ġ"_', '?",', '"?', '"(', 'Ġ"<', '":"","', '"]', '?\'"', 'Ġ"\\', '","', 'âĢ¦"', '".[', 'Ġ"$:/', ',\'"', 'Ġ"\'', '=""', '"...', '},{"', '\\"', 'Ġ"âĢ¦', '\\",', '":""},{"']


## now, which of these show up in tinystories?

In [8]:
# load a subset of the tinystories dataset
with open("../data/tiny_stories/tinystories_10k.txt", "r", encoding="utf-8") as f:
    TEXT_DATA: list[str] = f.read().split("<|endoftext|>")

In [9]:
# split into tokens but don't convert to ids
text_data_tokenized = [
	TOKENIZER(x).tokens()
	for x in TEXT_DATA
]

from itertools import chain
text_data_tokenized_joined = list(chain(*text_data_tokenized))

In [10]:
print(f"{len(text_data_tokenized) = }")
print(f"{len(text_data_tokenized_joined) = }")

len(text_data_tokenized) = 1728
len(text_data_tokenized_joined) = 349321


In [11]:
# set of tokens with quotes which appear in the dataset


quote_tokens_in_data = Counter([
	x
	for x in text_data_tokenized_joined
	if ('"' in x)
])


# print(f"{quote_tokens_in_data = }")
for x in quote_tokens_in_data:
	print(f"{quote_tokens_in_data[x]}\t`{x}`")

2983	`Ġ"`
1262	`!"`
1340	`."`
702	`?"`
807	`"`
364	`,"`
41	`".`
3	`?!"`
3	`?".`
2	`'."`
2	`?",`
1	`..."`
3	`",`
5	`!".`
1	`Ġ"'`
1	`',"`


# Creating a feature generator

In [12]:
BinaryFeatureExtractorOutput = NamedTuple(
    "FeatureExtractorOutput",
    [
        ("tokens", list[str]),
        ("input_ids", Int[torch.Tensor, "seq_len"]),
        ("features", Bool[torch.Tensor, "seq_len"]),
    ],
)

FeatureExtractor = Callable[[str, PreTrainedTokenizer], BinaryFeatureExtractorOutput]

class FeatureExtractorConfigurable(Protocol):
    def __call__(
        self,
        text: str,
        tokenizer: PreTrainedTokenizer,
        **kwargs
    ) -> BinaryFeatureExtractorOutput:
        ...

In [36]:
def in_quotes_feature(
        text: str, 
        tokenizer: PreTrainedTokenizer,
        qmark_in_quote: bool = False,
    ) -> BinaryFeatureExtractorOutput:
    # Tokenize the text
    tokens: list[str] = tokenizer.tokenize(text)

    # Convert tokens to IDs
    token_ids: list[int] = tokenizer.convert_tokens_to_ids(tokens)

    print(text)
    print(tokens)
    print(token_ids)

    # Compute binary feature for each token (1 if inside a quote, 0 otherwise)
    quote_feature: list[bool] = []
    inside_quote: bool = False # assume we arent in a quote to start
    is_qmark: bool = False

    for token in tokens:
        if "\"" in token:
            is_qmark = True
            inside_quote = not inside_quote
        else:
            is_qmark = False

        quote_feature.append(
            qmark_in_quote # tokens with `"` treated separately 
            if is_qmark 
            else inside_quote # otherwise, use the current state
        ) 

    return BinaryFeatureExtractorOutput(
        tokens=tokens,
        input_ids=torch.tensor(token_ids),
        features=torch.tensor(quote_feature),
    )


in_quotes_no_qmark_feature: FeatureExtractor = lambda text, tokenizer: in_quotes_feature(text, tokenizer, qmark_in_quote=False)
in_quotes_with_qmark_feature: FeatureExtractor = lambda text, tokenizer: in_quotes_feature(text, tokenizer, qmark_in_quote=True)


def display_results(text: str, tokenizer: PreTrainedTokenizer, feature_extractor: FeatureExtractor) -> None:
    tokens, _, quote_feature_tensor = feature_extractor(text, tokenizer)

    print("Input text:")
    for token, feature in zip(tokens, quote_feature_tensor):
        if feature == 1:
            # print(f"\033[43m{token}\033[0m", end=" ")
            # dark blue background
            print(f"\033[44m{token}\033[0m", end=" ")
        else:
            print(token, end=" ")
    print("\n")
    
display_results(TEXT_DATA[8], TOKENIZER, in_quotes_no_qmark_feature)
display_results(TEXT_DATA[8], TOKENIZER, in_quotes_with_qmark_feature)




Sara and Ben wanted to decorate a bowl for their mom. They found a big bowl in the kitchen and some paint and brushes. They took the bowl and the paint to the backyard and put them on a table.
"Let's make the bowl pretty with colors," Sara said.
"OK, I will paint a flower," Ben said.
They started to paint the bowl with different colors. Sara painted a red heart and Ben painted a yellow flower. They were having fun.
But then, it started to rain. The rain was wet and cold. It made the paint run and drip. The bowl looked messy and ugly.
"Oh no, the rain ruined our bowl!" Sara cried.
"Mom will not like it," Ben said.
They ran inside the house with the bowl. They were sad and wet.
They showed the bowl to their mom. They said they were sorry.
But mom smiled and hugged them. She said she loved the bowl and them.
"It's a beautiful bowl," she said. "You made it with love and creativity. The rain made it special. It's like a rainbow bowl."
Sara and Ben felt happy. They gave mom a kiss and tha

In [14]:
MODEL: HookedTransformer = HookedTransformer.from_pretrained("tiny-stories-1M")

  return self.fget.__get__(instance, owner)()


Loaded pretrained model tiny-stories-1M into HookedTransformer


In [15]:
# linear probe setup stuff
class LinearProbe(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearProbe, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
		

    def forward(self, x):
        return self.linear(x)


In [16]:
from muutils.dictmagic import condense_tensor_dict

In [38]:
def train(
	model: HookedTransformer,
	dataset: list[str],
	probe_fn: FeatureExtractor,
	batch_size: int = 1,
	device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
) -> HookedTransformer:

	if batch_size != 1:
		raise NotImplementedError("batch_size > 1 not supported yet")

	# set up model
	model.eval()
	model.to(device)

	# set up probes
	n_layers: int = model.cfg.n_layers

	probes: list[LinearProbe] = [
		LinearProbe(input_dim=model.cfg.d_model, output_dim=1)
		for _ in range(n_layers)
	]
	for probe in probes:
		probe.to(device)

	# training loop
	for text in dataset:
		print(text)
		# get data
		data: BinaryFeatureExtractorOutput = probe_fn(text, TOKENIZER)

		print(data)
		
		# forward pass, get residuals
		logits, cache = model.run_with_cache(data.input_ids)

		print(condense_tensor_dict(cache, return_format="yaml"))

		residuals: list[torch.Tensor] = [
			cache[f"blocks.{i}.hook_resid_pre"]
			for i in range(n_layers)
		]

		# train each probe
		for i, probe in enumerate(probes):

			print(residuals[i].shape)
			probe_input = residuals[i][0]
			probe_output = probe(probe_input)
			print(probe_output.shape)

		
train(
	model=MODEL,
	dataset=TEXT_DATA[:5],
	batch_size=1,
	probe_fn=in_quotes_no_qmark_feature,
)


Moving model to device:  cuda

Once upon a time there was a little boy named Ben. Ben loved to explore the world around him. He saw many amazing things, like beautiful vases that were on display in a store. One day, Ben was walking through the store when he came across a very special vase. When Ben saw it he was amazed!  
He said, “Wow, that is a really amazing vase! Can I buy it?” 
The shopkeeper smiled and said, “Of course you can. You can take it home and show all your friends how amazing it is!”
So Ben took the vase home and he was so proud of it! He called his friends over and showed them the amazing vase. All his friends thought the vase was beautiful and couldn't believe how lucky Ben was. 
And that's how Ben found an amazing vase in the store!


Once upon a time there was a little boy named Ben. Ben loved to explore the world around him. He saw many amazing things, like beautiful vases that were on display in a store. One day, Ben was walking through the store when he came acro