In [1]:
%load_ext autoreload
%autoreload 2
from saex.iterable_dataset import IterableDatasetConfig
from saex.models.micrlhf_model import MicrlhfModelConfig
from saex.haver import ModelHaver, SAEHaver
from saex.sae import SAEConfig
from more_itertools import chunked


n_features = 3072
batch_size = 64
layer = 12
dataset_config = IterableDatasetConfig(
    dataset_name="nev/openhermes-2.5-phi-format-text",
    # dataset_name="nev/generated-phi-format-text",
)
model_config = MicrlhfModelConfig(
    tokenizer_path="microsoft/Phi-3-mini-4k-instruct",
    gguf_path="../weights/phi-3-16.gguf",
    device_map="auto",
    use_flash=False,
    layer=layer,
    max_seq_len=128,
)
sae_config = SAEConfig(
    n_dimensions=n_features,
    batch_size=batch_size,
    expansion_factor=32,
    use_encoder_bias=True,
    remove_decoder_bias=False,
    encoder_init_method="orthogonal",
    decoder_init_method="pseudoinverse",
    decoder_bias_init_method="zeros",
    is_gated=False,
)
haver = ModelHaver(model_config=model_config, dataset_config=dataset_config)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading model...
Loading dataset...


In [2]:
sae_path = f"../weights/phi-l{layer}.safetensors"
!wget 'https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l12-test-run-5-7.00E-06/sae_weights.safetensors?download=true' -O {sae_path}
haver_sae = SAEHaver(
    sae_config=sae_config,
    mesh=haver.mesh,
    sae_restore=sae_path)

  pid, fd = os.forkpty()


--2024-05-20 03:17:31--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l12-test-run-5-7.00E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.95, 108.156.211.125, 108.156.211.51, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.95|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/a8414dcfc8cc29b6b9c7b3e02f39b6d319fcf89f9a08a687ae488a5407a2fbf5?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1716434252&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjQzNDI1Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvYTg0MTRkY2Z

In [63]:
from collections import defaultdict
from tqdm import tqdm
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import jax


max_l0 = 512

# @jax.jit
def get_nonzeros(hiddens):
    val, ind = jax.lax.top_k(jnp.abs(hiddens), max_l0)
    nonzeros = jnp.where(val != 0, ind, -1)
    return nonzeros


def process_tokens():
    tokens_processed = 0
    activ_cache = defaultdict(list)
    jsae = eqx.filter_jit(lambda s, x: s.encode(x)[1])
    cpu_device = jax.devices("cpu")[0]
    to_cpu = lambda x: jax.device_put(x, cpu_device)
    try:
        for texts in chunked(bar := tqdm(haver.create_dataset()), batch_size):
            activations, model_misc = haver.model(texts)
            mask = model_misc.get("mask")
            # # it's like sae but jit
            # hiddens = jsae(haver_sae.sae, activations)
            pre_relu, hiddens = haver_sae.sae.encode(activations)
            
            # loss, loss_reconstructed = haver.model.eval_loss(texts, haver_sae.sae)
            # bar.set_postfix(l0=((hiddens != 0).sum(-1) * mask).mean() / mask.mean(),
            #                 loss_diff=loss_reconstructed - loss)

            # indices = jnp.arange(len(mask))  # + tokens_processed
            # for feat in (hiddens != 0).any(axis=0).nonzero()[0]:
            #     greats = hiddens[:, feat]
            #     active = jnp.nonzero((greats != 0) & mask)[0]
            #     index = to_cpu(indices[active])
            #     activations = to_cpu(greats[active])
            #     # activ_cache[int(feat)].extend(zip(list(index), list(activations)))

            # for feat in (hiddens != 0).any(axis=0).nonzero()[0]:
            #     greats = hiddens[:, feat]
            #     activ_cache[int(feat)].extend(zip(list(indices[mask]), list(greats[mask])))

            indices = jnp.nonzero(mask)[0]
            hiddens = hiddens[indices]
            nonzeros = get_nonzeros(hiddens)
            nonzeros = np.asarray(to_cpu(nonzeros.astype(jnp.int32)))
            hiddens = np.asarray(to_cpu(hiddens.astype(jnp.float16)))
            # mask = np.asarray(to_cpu(mask))
            indices = np.asarray(to_cpu(indices))
            # for i, h in zip(list(jnp.arange(len(hiddens))[mask]), list(hiddens[mask])):
            #     active_features = np.nonzero(h)[0]
            #     feature_activations = h[active_features]
            #     for f, a in zip(active_features, feature_activations):
            #         activ_cache[int(f)].append((tokens_processed + i, float(a)))


            for i, active_features, h in zip(indices, nonzeros, hiddens):
                active_features = active_features[active_features != -1]
                feature_activations = h[active_features]
                for f, a in zip(active_features, feature_activations):
                    activ_cache[int(f)].append((i + tokens_processed, float(a)))
            tokens_processed += len(mask)
            assert i < len(mask)
            bar.set_postfix(tokens_processed=tokens_processed, tps=tokens_processed / bar.format_dict["elapsed"])
    except KeyboardInterrupt:
        pass
    return activ_cache
activ_cache = process_tokens()

0it [00:00, ?it/s]

383it [00:27, 13.99it/s, tokens_processed=32768, tps=1.32e+3]


In [67]:
import pyarrow as pa
import numpy as np

batches = []
for feat, activs in tqdm(activ_cache.items()):
    batches.append(pa.RecordBatch.from_pylist([dict(feature=feat, token=i, activation=np.float16(a))
                                               for i, a in activs], schema=pa.schema([
        ("feature", pa.int32()),
        ("token", pa.int32()),
        ("activation", pa.float16()),
    ])))

    # # 60% more efficient compression scheme, not guaranteed to work
    # token_0 = activs[0][0]
    # a_0 = activs[0][1]
    # schema = pa.schema([
    #     ("feature", pa.int32()),
    #     ("token", pa.uint16()),
    #     ("activation", pa.float16()),
    # ])
    # batches.append(pa.RecordBatch.from_pylist([dict(feature=feat, token=token_0, activation=np.float16(a_0))], schema=schema))
    # batches.append(pa.RecordBatch.from_pylist([dict(feature=feat, token=i2 - i1, activation=np.float16(a2))
    #                                            for (i1, a1), (i2, a2) in zip(activs[:-1], activs[1:])], schema=schema))

100%|██████████| 42633/42633 [00:15<00:00, 2826.09it/s] 


In [68]:
import pyarrow.parquet as pq
import pyarrow as pa
pq_path = f"../weights/phi-l{layer}-activations.parquet"
table = pa.Table.from_batches(batches)
pq.write_table(table, pq_path, compression="snappy")

In [None]:
def visualize(feature, thresh=6.0):
    cache = activ_cache[feature]
    if not cache:    
        return
    tokens, activs = zip(*cache)
    if max(activs) < thresh:
        return
    freq = len(tokens) / tokens_processed
    print(freq)
    if freq > 0.03:
        return
    tokens_viewed = 0
    sli = 24
    for texts in chunked(tqdm(haver.create_dataset()), batch_size):
        toks = haver.model.to_tokens(texts)
        all_tokens = [t for tok in toks for t in tok]
        proc = sum(map(len, toks))
        all_token_ids = [tokens_viewed + i for i in range(proc)]
        for i, t in enumerate(all_token_ids):
            if t in tokens:
                activ = activs[tokens.index(t)]
                if activ < thresh:
                    continue
                print(activ, repr(haver.model.decode(all_tokens[max(0, i - sli + 1):i+1])),
                      repr(haver.model.decode(all_tokens[i+1:i+5])))
        tokens_viewed += proc
        if tokens_viewed > max(tokens):
            break
for i in range(10_000, 10_0100):
    visualize(i)

8.138020833333333e-05


0it [00:00, ?it/s]

581it [00:01, 580.69it/s]

11.90625 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><s><|assistant|> a). No;<|end|><s><|assistant|> Stream of conscious' 'ness: First find'
6.8046875 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><s><|assistant|> a). No;<|end|><s><|assistant|> Stream of consciousness' ': First find the'


1151it [00:02, 534.58it/s]


10.8125 '<s><|assistant|> Question: Is it possible for a person to survive without ever drinking water?\n\nStream of conscious' 'ness reasoning: When'
7.23828125 '<|assistant|> Question: Is it possible for a person to survive without ever drinking water?\n\nStream of consciousness' 'reasoning: When considering'
0.0007527669270833334


201it [00:00, 409.81it/s]

6.08984375 'using the Quicksort algorithm. Quicksort is a divide and conquer algorithm with an average time complexity of O(n log' 'n). It is'


576it [00:01, 552.21it/s]

18.921875 '\n        is_prime = True\n        for divisor in range(2, int(num**0.5' ')<s><|assistant|> To'


704it [00:01, 576.99it/s]

6.69140625 'equation, you can use the discriminant formula. The discriminant is given as b^2 - 4' 'ac. \n'


852it [00:01, 649.17it/s]

8.578125 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><s><|assistant|> The root' 'that does not arise'


1123it [00:02, 666.43it/s]

6.875 "\nCori's current age: 3 years\nCori<s><|assistant|> I recall that the sum of the roots" 'of a quadratic equation'


1340it [00:02, 664.49it/s]

11.1015625 'of x.\n\nFirst, notice that the numerator (4x^2 -<s><|assistant|> To simplify the radical' 'expression $\\sqrt{'
27.59375 '\n\nFirst, notice that the numerator (4x^2 -<s><|assistant|> To simplify the radical expression $\\sqrt' '{27}$,'
27.78125 'largest perfect square that divides 27 is 9. We can rewrite the expression as:\n\n$\\sqrt' '{27}'
26.046875 '27 is 9. We can rewrite the expression as:\n\n$\\sqrt{27} = \\sqrt' '{9 \\times'


1535it [00:02, 549.29it/s]


6.9921875 '- 3 = 0, we can use the quadratic formula:\n\nx = (-b ± √' '(b^2'
0.0018208821614583333


384it [00:00, 509.56it/s]

20.671875 '3/5774375) answer. The problem can be fixed by adding `?parseTime=' 'true` to the'


576it [00:01, 542.75it/s]

6.85546875 'the initial idea might seem uncomp<s><|assistant|> The issue you are facing is because you have set the packaging type to' '`pom` in'
7.20703125 'initial idea might seem uncomp<s><|assistant|> The issue you are facing is because you have set the packaging type to `' 'pom` in your'
6.234375 'osleep`:\n\n```\nsys_nanosleep: eax = 162, ebx =' 'struct timespec *'


704it [00:01, 557.44it/s]

7.26953125 "HTTP server that provides standard GET and HEAD request handlers. To add custom headers like 'Access-Control-Allow-" "Origin', you would"
6.38671875 "server that provides standard GET and HEAD request handlers. To add custom headers like 'Access-Control-Allow-Origin" "', you would need"


832it [00:01, 551.12it/s]

11.859375 'trend line model, you can use the following shell command:\n```\nawk \'BEGIN {OFS="' '\\t"} {'
7.83984375 'end line model, you can use the following shell command:\n```\nawk \'BEGIN {OFS="\\' 't"} {print'
7.546875 ', including pink and blue, red roses are the most<s><|assistant|> It seems that the error "multipart:' 'NextPart: buf'


1024it [00:02, 550.48it/s]

6.328125 '\'s an example of an HTML/CSS form that you can use:\n\n```html\n<form action="' '/" method="post'
18.734375 'example of an HTML/CSS form that you can use:\n\n```html\n<form action="/" method="' 'post">\n '


1152it [00:02, 563.15it/s]

6.98828125 'R program that can help you achieve that:\n\n```R\n# Set the random seed\nset.seed(' '123)'


1280it [00:02, 549.94it/s]

6.33203125 'SQL query:\n\n```SQL\nSELECT name, salary, job_title\nFROM employees\nWHERE department =' "'IT'\n"
7.05078125 "salary, job_title\nFROM employees\nWHERE department = 'IT'\nAND years_experience >" '10\n'
7.1640625 "namespace :test do \n    task :reset do \n      ActiveRecord::Base.establish_connection('" 'test<s><|assistant|> The'
7.08984375 's.split<s><|assistant|> **\n\nThe issue here is that when you define `has_many :managers,' 'through: :list'
12.84375 'split<s><|assistant|> **\n\nThe issue here is that when you define `has_many :managers, through:' ':listing_'


1599it [00:03, 514.45it/s]


0.005889892578125


320it [00:01, 403.92it/s]

6.53125 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><s><|assistant|> This should be the solution: Consulting a loaf of bread is' 'not a logical or'
7.53515625 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><s><|assistant|> This should be the solution: Consulting a loaf of bread is not' 'a logical or practical'
6.0546875 '<|endoftext|><|endoftext|><|endoftext|><s><|assistant|> This should be the solution: Consulting a loaf of bread is not a logical or' 'practical method for ens'
7.7109375 '<|endoftext|><s><|assistant|> This should be the solution: Consulting a loaf of bread is not a logical or practical method' 'for ensuring correct'
8.1015625 "for ensuring correctness in a task. The answer is Sentence B.<|end|><s><|assistant|> Alright, let'" 's break it down'


448it [00:01, 431.96it/s]

13.8046875 'A: "A basketball team has five players."\n- Sentence B: "A football team has five players."' '\n\nAnswer:'


704it [00:01, 458.48it/s]

6.6015625 "to have on hand for various situations: a rope or a woman's bathing suit?\n\nAnswer:" 'A rope is'


960it [00:02, 437.50it/s]

6.19140625 'each flight of stairs has 10 / 1.5 = 6.67 steps.\nSince' 'John climbed up'


1088it [00:02, 422.69it/s]

9.4296875 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><s><|assistant|> Yes, the statement "serious joke" is an oxymoron because' '"serious"'
7.15234375 '<|endoftext|><s><|assistant|> Yes, the statement "serious joke" is an oxymoron because "serious"' 'and "joke'
6.0625 'the statement "serious joke" is an oxymoron because "serious" and "joke"' 'have opposite meanings'
7.39453125 '"Serious" suggests something important or grave, while "joke" suggests something humorous or lighthearted.' '<|end|><s><|assistant|> The'
7.1875 'the right answer to the question "molecules of _ initiate protein synthesis" is "nektar,"' 'given that molecules'
6.4609375 'answer to the question "molecules of _ initiate protein synthesis" is "nektar," given that' 'molecules of t'
6.84375 'relation to the information given:\n\n1. Ignore: Since he found the email funny, it is unlikely' 'that he would simply'


1216it [00:03, 426.70it/s]

9.1328125 'can achieve this:\n<s><|assistant|> Question: Is it possible for a person to survive without ever drinking water?' '\n\nStream of'
7.46875 'to survive without ever drinking water?\n\nStream of consciousness reasoning: When considering human survival, it' 'is essential to take'


1408it [00:03, 437.92it/s]

9.46875 'M^2<s><|assistant|> Question: Can a guy perform mirror ball sport while stuffing himself into a cannon?' '\n\nImplicit'
6.74609375 'Can a guy perform mirror ball sport while stuffing himself into a cannon?\n\nImplicit rationale:' 'The scenario described seems'
10.484375 'a guy perform mirror ball sport while stuffing himself into a cannon?\n\nImplicit rationale: The' 'scenario described seems imp'
11.390625 'guy perform mirror ball sport while stuffing himself into a cannon?\n\nImplicit rationale: The scenario' 'described seems implaus'
8.1640625 'y perform mirror ball sport while stuffing himself into a cannon?\n\nImplicit rationale: The scenario described' 'seems implausible'
10.6640625 'perform mirror ball sport while stuffing himself into a cannon?\n\nImplicit rationale: The scenario described seems' 'implausible and'
8.765625 'mirror ball sport while stuffing himself into a cannon?\n\nImplicit rationale: The scenario described seems imp' 'lausible and dangerous'
12.851562

1472it [00:03, 431.27it/s]

7.36328125 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><s><|assistant|> The idea that people swallow spiders during sleep is a myth. Spiders are unlikely' 'to intentionally craw'
7.30078125 '<|endoftext|><|endoftext|><|endoftext|><s><|assistant|> The idea that people swallow spiders during sleep is a myth. Spiders are unlikely to' 'intentionally crawl'
7.3203125 "spiders during sleep is a myth. Spiders are unlikely to intentionally crawl into a person's mouth," 'and most individuals would'
10.15625 "iders during sleep is a myth. Spiders are unlikely to intentionally crawl into a person's mouth, and" 'most individuals would w'
6.421875 "is a myth. Spiders are unlikely to intentionally crawl into a person's mouth, and most individuals would" 'wake up if'
9.03125 'In the sentence "A lage Coco Cola sign sitting in a parking lot.", the word "lage"' 'could potential be a'
6.07421875 'sitting in a parking lot.", the word "lage" could potential be a typo.<s><|assistant|> Sure! Here\''

1599it [00:03, 406.67it/s]


0.0005289713541666666


1024it [00:04, 347.26it/s]

23.0 'that you can use:\n\n```html\n<form action="/" method="post">\n  <label for' '="email">Email'
16.34375 'post">\n  <label for="email">Email:</label>\n  <input type="text" name' '="email" id'
29.265625 '<label for="email">Email:</label>\n  <input type="text" name="email" id' '="email<s><|assistant|>'
11.0546875 'label for="email">Email:</label>\n  <input type="text" name="email" id="' 'email<s><|assistant|> First'
9.8359375 'the `FormInfo` struct. The fields in your form, `name="fields[0]"` and `name' '="fields[1'


1407it [00:05, 280.00it/s]


KeyboardInterrupt: 