## Dependencies

In [1]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

In [2]:
from transformer_lens import HookedTransformer, utils
import torch
from datasets import load_dataset
import time
import os
from typing import Optional, List, Dict, Callable, Tuple, Union
import tqdm.notebook as tqdm
from pathlib import Path
import pickle
import plotly.express as px
import importlib
import json

import data_fns, model_fns, html_fns
from model_fns import AutoEncoderConfig, AutoEncoder
from data_fns import get_feature_data, FeatureData

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_grad_enabled(False)

def imshow(x, **kwargs):
    x_numpy = utils.to_numpy(x)
    px.imshow(x_numpy, **kwargs).show()

# Setup

In [3]:
cfg = AutoEncoderConfig()
model: HookedTransformer = HookedTransformer.from_pretrained("gelu-1l").to(cfg.dtype).to(device)

Loaded pretrained model gelu-1l into HookedTransformer
Changing model dtype to torch.float32
Moving model to device:  cuda


In [4]:
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]

  table = cls._concat_blocks(blocks, axis=0)


In [5]:
encoder = AutoEncoder.load_from_hf(version="run1")
encoder_B = AutoEncoder.load_from_hf(version="run2")

In [6]:
# Save vocab dict, which is {str(int): str}

vocab_dict = model.tokenizer.vocab
vocab_dict = {v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()}

vocab_dict_filepath = Path(os.getcwd()) / "app/vocab_dict.json"
if not vocab_dict_filepath.exists():
    with open(vocab_dict_filepath, "w") as f:
        json.dump(vocab_dict, f)

# Create visualisations

In [9]:
importlib.reload(data_fns)
importlib.reload(html_fns)
from data_fns import get_feature_data, FeatureData

# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to
# make the entire sequence indexing parallelized, but that's possibly not worth it right now.

max_batch_size = 64
total_batch_size = 512
feature_idx = list(range(1000))
# max_batch_size = 512
# total_batch_size = 16384
# feature_idx = list(range(1000))

tokens = all_tokens[:total_batch_size]

feature_data: Dict[int, FeatureData] = get_feature_data(
    encoder=encoder,
    encoder_B=encoder_B,
    model=model,
    tokens=tokens,
    feature_idx=feature_idx,
    max_batch_size=max_batch_size,
    left_hand_k = 3,
    buffer = (5, 5),
    n_groups = 10,
    first_group_size = 20,
    other_groups_size = 5,
    verbose = True,
)

                                                               

### Visualise data

In this cell, I just save the full file (scripts and non-scripts), so I can see it in my browser.

In [None]:
test_idx = 7

html_str = feature_data[test_idx].get_all_html()

with open(f"data_{test_idx:04}.html", "w") as f:
    f.write(html_str)

In this cell, I save the HTML and JavaScript to a path in my blog, so I can quickly see how it looks.

In [None]:
test_blog_filepath = Path("C:/Users/calsm/Documents/Blog/blog_website_heroku/python-getting-started/blog/static/sae/test_html_and_scripts")
if not(test_blog_filepath.exists()):
    test_blog_filepath = Path.cwd() / "test_for_blog"
    assert test_blog_filepath.exists()

test_idx = 7

scripts, html_string = feature_data[test_idx].get_all_html(split_scripts=True)

with open(test_blog_filepath / f"html_str.html", "w") as f:
    f.write(html_string)
with open(test_blog_filepath / f"scripts.html", "w") as f:
    f.write(scripts)

### Save data, so it can be used in the app

In [None]:
# for feature_idx in range(100):
#     feature_data[feature_idx].save()

In [None]:
def save_data_in_batches(
    feature_data: Dict[int, FeatureData],
    root: str = "data",
    batch_size: int = 50,
    save_type: str = "pkl",
):

    data_path = Path(os.getcwd()) / f"app/{root}"
    if not data_path.exists():
        data_path.mkdir(parents=True)

    feature_idx = list(feature_data.keys())

    for i in tqdm.tqdm(range(0, len(feature_data), batch_size)):

        # Get the next features which will be saved
        next_feature_idx, feature_idx = feature_idx[:batch_size], feature_idx[batch_size:]
        
        # Get filename (this is different if we just save one at a time)
        min_feat = min(next_feature_idx)
        max_feat = max(next_feature_idx)
        filename = data_path / (f"data_{min_feat:04}" if min_feat == max_feat else f"data_{min_feat:04}-{max_feat:04}")
        filename = str(filename.resolve())

        # Save the batch of FeatureData objects
        FeatureData.save_batch({k: feature_data[k] for k in next_feature_idx}, filename=filename, save_type=save_type)


save_data_in_batches(feature_data, root="data", batch_size=25, save_type="pkl")

Check I can recover the data correctly:

In [None]:
file_path = Path(os.getcwd()) / f"app/data/data_0000-0049.pkl"
assert file_path.exists()
file_path = str(file_path.resolve()).replace(".pkl", "")

feature_data_recovered = FeatureData.load_batch(file_path, save_type="pkl", vocab_dict=vocab_dict)

test_idx = 7

html_str = feature_data_recovered[test_idx].get_all_html()

with open(f"data_{test_idx:04}.html", "w") as f:
    f.write(html_str)

Let's be clear on exactly what this means:

* Highlight means that it fires strongly on this sequence position.
    * For example in #7, we can see it fires strongly in the presence of `" I"` or `" we"`.
* Underline means that it reduces loss significantly when it comes to predicting this token.
    * This is why you often see strong underlines on the token *after* a large activation. #7 isn't a good example, but #8 is: it fires on brackets, and this boosts `" Django"` a lot so loss goes down a lot.
    * Why does it boost Django? See the following GPT4 answer:
    * <img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/djangooo.png" width="700">
* When you hover, you can see what the feature boosts when it comes to predicting the next sequence position.
    * For example in #7, hover over `" I"` or `" we"` and you'll see that this feature boosts the probability of `"'ll"` (which is obviously a common completion in both cases).
    * <img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/exa1.png" width="400">



