In [2]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import random
import dataclasses
import jax
import optax

import jax.numpy as jnp
import numpy as np

from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector
from micrlhf.utils.load_sae import get_sae
from functools import partial

In [4]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="tpu:0")

In [5]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
!git clone https://github.com/roeehendel/icl_task_vectors data/itv
import glob
import json
import os
tasks = {}
for g in glob.glob("data/itv/data/**/*.json"):
    tasks[os.path.basename(g).partition(".")[0]] = json.load(open(g))

  pid, fd = os.forkpty()


fatal: destination path 'data/itv' already exists and is not an empty directory.


In [7]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)

In [8]:
from typing import List

class ICLSequence:
    '''
    Class to store a single antonym sequence.

    Uses the default template "Q: {x}\nA: {y}" (with separate pairs split by "\n\n").
    '''
    def __init__(self, word_pairs: List[List[str]]):
        self.word_pairs = word_pairs
        self.x, self.y = zip(*word_pairs)

    def __len__(self):
        return len(self.word_pairs)

    def __getitem__(self, idx: int):
        return self.word_pairs[idx]

    # def prompt(self):
    #     '''Returns the prompt, which contains all but the second element in the last word pair.'''
    #     p = "\n\n".join([f"Q: {x}\nA: {y}" for x, y in self.word_pairs])
    #     return p[:-len(self.completion())]

    def prompt(self):
        '''Returns the prompt, which contains all but the second element in the last word pair.'''
        p = ", ".join([f"{x} -> {y}" for x, y in self.word_pairs])
        return p[:-len(self.completion())-1]

    def completion(self):
        '''Returns the second element in the last word pair (with padded space).'''
        return "" + self.y[-1]

    def __str__(self):
        '''Prints a readable string representation of the prompt & completion (indep of template).'''
        return f"{', '.join([f'({x}, {y})' for x, y in self[:-1]])}, {self.x[-1]} ->".strip(", ")


word_list = [["hot", "cold"], ["yes", "no"], ["in", "out"], ["up", "down"]]
seq = ICLSequence(word_list)

print("Tuple-representation of the sequence:")
print(seq)
print("\nActual prompt, which will be fed into the model:")
print(seq.prompt())

Tuple-representation of the sequence:
(hot, cold), (yes, no), (in, out), up ->

Actual prompt, which will be fed into the model:
hot -> cold, yes -> no, in -> out, up ->


In [9]:
class ICLDataset:
    '''
    Dataset to create antonym pair prompts, in ICL task format. We use random seeds for consistency
    between the corrupted and clean datasets.

    Inputs:
        word_pairs:
            list of ICL task, e.g. [["old", "young"], ["top", "bottom"], ...] for the antonym task
        size:
            number of prompts to generate
        n_prepended:
            number of antonym pairs before the single-word ICL task
        bidirectional:
            if True, then we also consider the reversed antonym pairs
        corrupted:
            if True, then the second word in each pair is replaced with a random word
        seed:
            random seed, for consistency & reproducibility
    '''

    def __init__(
        self,
        word_pairs: List[List[str]],
        size: int,
        n_prepended: int,
        bidirectional: bool = True,
        seed: int = 0,
        corrupted: bool = False,
    ):
        assert n_prepended+1 <= len(word_pairs), "Not enough antonym pairs in dataset to create prompt."

        self.word_pairs = word_pairs
        self.word_list = [word for word_pair in word_pairs for word in word_pair]
        self.size = size
        self.n_prepended = n_prepended
        self.bidirectional = bidirectional
        self.corrupted = corrupted
        self.seed = seed

        self.seqs = []
        self.prompts = []
        self.completions = []

        # Generate the dataset (by choosing random antonym pairs, and constructing `ICLSequence` objects)
        for n in range(size):
            np.random.seed(seed + n)
            random_pairs = np.random.choice(len(self.word_pairs), n_prepended+1, replace=False)
            random_orders = np.random.choice([1, -1], n_prepended+1)
            if not(bidirectional): random_orders[:] = 1
            word_pairs = [self.word_pairs[pair][::order] for pair, order in zip(random_pairs, random_orders)]
            if corrupted:
                for i in range(len(word_pairs) - 1):
                    word_pairs[i][1] = np.random.choice(self.word_list)
            seq = ICLSequence(word_pairs)

            self.seqs.append(seq)
            self.prompts.append(seq.prompt())
            self.completions.append(seq.completion())

    def create_corrupted_dataset(self):
        '''Creates a corrupted version of the dataset (with same random seed).'''
        return ICLDataset(self.word_pairs, self.size, self.n_prepended, self.bidirectional, corrupted=True, seed=self.seed)

    def __len__(self):
        return self.size

    def __getitem__(self, idx: int):
        return self.seqs[idx]

In [10]:
def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs

In [11]:
prompt = "<user>Follow the pattern\n{}"

In [10]:

target_layer = 18

In [12]:
task_names = [
    "en_es"
]

task = "en_es"

n_seeds = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 64, 256

In [14]:

sae = get_sae(target_layer, 6)

--2024-05-21 20:51:37--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l18-test-run-6-1.01E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.95, 108.156.211.90, 108.156.211.125, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.95|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/f057cb46f3d871ba03c66e707e3b3d8299322f36fa433862dc3fdca956715614?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1716583897&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjU4Mzg5N319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvZjA1N2NiNDZ

In [15]:
dictionary = sae["W_dec"]
dictionary.shape

In [29]:
pairs = tasks[task]
pairs = [list(x) for x in pairs.items()]
seed = 0

dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=seed+1)


tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
inputs = tokenized_to_inputs(
    **tokenized
)

target_tokens = [x[1] for x in tokenizer(dataset.completions)["input_ids"]]
target_tokens = jnp.asarray(target_tokens)

In [13]:
def get_logprob_diff(logits: jnp.ndarray, target_tokens, print_results=False):
    logprobs = jax.nn.log_softmax(logits, axis=-1)
    answer_logprobs = logprobs[:, -1]

    target_logprobs = jnp.take_along_axis(answer_logprobs, target_tokens[:, None], axis=-1).squeeze()

    if print_results:
        print(
            tokenizer.decode(answer_logprobs.argmax(axis=-1))
        )

        print(
            tokenizer.decode(target_tokens)
        )

    return target_logprobs


In [14]:
from micrlhf.llama import LlamaBlock
from functools import partial

def make_get_resids(llama, layer_target):
    get_resids = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer_target
                                                                              ).apply(lambda x:
        pz.nn.Sequential([
            pz.de.TellIntermediate.from_config(tag=f"resid_pre"),
            x
        ])
    )
    get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
    return get_resids

In [15]:
import dataclasses
def get_loss(weights, dictionary, inputs, target_tokens, taker, initial_resids):
    weights = jax.nn.relu(weights)
    

    recon = jnp.einsum("fv,f->v", dictionary, weights)
    recon = recon.astype('bfloat16')

    modified = pz.nx.nmap(lambda a, b: a.at[-1].add(b))(
        initial_resids.untag("seq", "embedding"), pz.nx.wrap(recon, "embedding").untag("embedding")
        ).tag("seq", "embedding")

    inputs = dataclasses.replace(inputs, tokens=modified)

    logits = taker(inputs).unwrap("batch", "seq", "vocabulary")

    logprob_diff = get_logprob_diff(logits, target_tokens)
    loss = -logprob_diff.mean()

    return loss + 1e-2 * jnp.linalg.norm(weights, ord=1), ((weights != 0).sum(), loss)

In [16]:
def train_step(weights, opt_state, dictionary, inputs, target_tokens, taker, initial_resids, optimizer, lwg, pos_only=True):
    (loss, (l0, loss_)), grad = lwg(weights, dictionary, inputs, target_tokens, taker, initial_resids)

    updates, opt_state = optimizer.update(grad, opt_state, weights)
    weights = optax.apply_updates(weights, updates)
    # weights_abs = jnp.abs(weights)
    # weights = jnp.sign(weights) * jax.nn.relu(weights_abs - shrinkage)

    return loss, weights, opt_state, dict(l0=l0, loss=loss_)

In [17]:
jittify = lambda x: partial(jax.jit(lambda lr, *args, **kwargs: lr(*args, **kwargs)[1][0].value), x)

In [None]:
layer_source = 17
get_resids_initial = make_get_resids(llama, layer_source)
get_resids_initial = jittify(get_resids_initial)

In [None]:
initial_resids = get_resids_initial(inputs)

In [None]:
layer_target = target_layer

taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
    lambda i, x: x if i >= layer_target else pz.nn.Identity()
).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

In [35]:


optimizer = optax.chain(
    optax.adam(3e-2),
    optax.zero_nans(),
    
)

lwg = jax.value_and_grad(get_loss, has_aux=True)
shrinkage = 0

# @partial(jax.jit, donate_argnums=(0, 1))

In [39]:
iterations = 2000

weights = jnp.ones(dictionary.shape[0]) * 0.1
opt_state = optimizer.init(weights)

min_loss = 1e9
early_stopping_steps = 100
early_stopping_counter = 0


for _ in (bar := trange(iterations)):
    loss, weights, opt_state, metrics = train_step(weights, opt_state, dictionary, inputs, target_tokens)

    if metrics["loss"] < min_loss:
        min_loss = metrics["loss"]
        early_stopping_counter = 0
    

    tk = jax.lax.top_k(weights, 2)

    bar.set_postfix(loss_optim=loss, **metrics, top=tk[1][0], top_diff=(tk[0][0] - tk[0][1]) / tk[0][0])

    early_stopping_counter += 1
    if early_stopping_counter > early_stopping_steps:
        break

  0%|          | 0/2000 [00:00<?, ?it/s]

In [37]:
_, i = jax.lax.top_k(jnp.abs(weights), 20)
weights[i] / jnp.abs(weights).max(), i

In [None]:
_, i = jax.lax.top_k(jnp.abs(weights), 20)
weights[i] / jnp.abs(weights).max(), i

In [None]:
_, i = jax.lax.top_k(jnp.abs(weights), 20)
weights[i] / jnp.abs(weights).max(), i

In [31]:
def check_feature(feature: int, task:str, target_layer, sae, print_results=False):
    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=10)


    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )

    target_tokens = [x[1] for x in tokenizer(dataset.completions)["input_ids"]]
    target_tokens = jnp.asarray(target_tokens)

    # weights, recon = grad_pursuit(tv * 2, sae["W_dec"], k, pos_only=True)

    recon = sae["W_dec"][feature] * 20
    recon = recon.astype('bfloat16')

    act_add = add_vector(
        llama, recon, target_layer, scale=1.0, position="last"
    )

    logits = act_add(inputs).unwrap("batch", "seq", "vocabulary")

    # print(logits[:, -1].mean(axis=-1))

    logprobs = jax.nn.log_softmax(logits, axis=-1)
    answer_logprobs = logprobs[:, -1]

    target_logprobs = jnp.take_along_axis(answer_logprobs, target_tokens[:, None], axis=-1).squeeze()

    if print_results:
        print(
            tokenizer.decode(answer_logprobs.argmax(axis=-1))
        )

        print(
            tokenizer.decode(target_tokens)
        )

    return (target_logprobs - answer_logprobs.max(axis=-1)).mean()

In [38]:
for feature in i:
    print(
        f"Feature {feature}: {check_feature(feature)}"
    )

# get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), target_tokens, print_results=True)

Feature 10018: -7.5
Feature 20311: -2.375
Feature 40023: -4.5
Feature 5145: -7.21875
Feature 45145: -7.09375
Feature 2719: -7.09375
Feature 15415: -7.28125
Feature 24802: -7.15625
Feature 20375: -7.125
Feature 19505: -7.09375
Feature 38618: -7.25
Feature 36786: -7.125
Feature 47611: -7.03125
Feature 31726: -7.28125


KeyboardInterrupt: 

In [27]:
check_feature(764, True)

message reduce behavior place se hello cel media club live line throw sport es door shape daughter run pu
 bottom mind mat polit year moment cut hour has sit comment drop un government d wait phone anticip south arm sport material place picture shape papel no en < choice viel fine arrived ii f third pu hang u er must just hello fig
mess rid comport lu sic infer cell media club abit punto get sport così porta col fig corr bel reg super m question polit dec momento tag ora aver sed comment g gi governo pap asp tele asp sud es sport mater lu f col cart no met rela sc pi bene arriv io cin ter bel append fu dire do app infer figura


In [52]:
class FeatureSearch:
    def __init__(self, task, target_layer, early_stopping_steps=50, iterations=2000, seed=9, l1_coeff=1e-2, lr=3e-2, init_w=0.1):
        self.task = task
        self.target_layer = target_layer
        self.sae = get_sae(target_layer, 6)
        self.seed = seed
        self.early_stopping_steps = early_stopping_steps
        self.iterations = iterations
        self.l1_coeff = l1_coeff
        self.lr = lr
        self.init_w = init_w

        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        
        self.train_dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=self.seed)
        self.eval_dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=self.seed+1)

        self.eval_inputs = self.prepare_inputs(self.eval_dataset)
        self.train_inputs = self.prepare_inputs(self.train_dataset)

        self.initial_resids = self.get_initial_resids(self.train_inputs[0])

        self.lwg = jax.value_and_grad(self.get_loss, has_aux=True)
        self.taker = self.make_taker()
    
    def get_initial_resids(self, inputs):
        get_resids_initial = make_get_resids(llama, self.target_layer)
        get_resids_initial = jittify(get_resids_initial)

        initial_resids = get_resids_initial(inputs)
        return initial_resids

    def make_taker(self):
        taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
            lambda i, x: x if i >= self.target_layer else pz.nn.Identity()
        ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                        .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

        return taker

    def prepare_inputs(self, dataset: ICLDataset):
        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )

        target_tokens = [x[1] for x in tokenizer(dataset.completions)["input_ids"]]
        target_tokens = jnp.asarray(target_tokens)

        return inputs, target_tokens        
    
    def get_loss(self, weights):
        weights = jax.nn.relu(weights)

        recon = jnp.einsum("fv,f->v", self.sae["W_dec"], weights)
        recon = recon.astype('bfloat16')

        modified = pz.nx.nmap(lambda a, b: a.at[-1].add(b))(
            self.initial_resids.untag("seq", "embedding"), pz.nx.wrap(recon, "embedding").untag("embedding")
            ).tag("seq", "embedding")

        inputs = dataclasses.replace(self.train_inputs[0], tokens=modified)

        logits = self.taker(inputs).unwrap("batch", "seq", "vocabulary")

        logprob_diff = get_logprob_diff(logits, self.eval_inputs[1])
        loss = -logprob_diff.mean()

        return loss + self.l1_coeff * jnp.linalg.norm(weights, ord=1), ((weights != 0).sum(), loss)

    def train_step(self, weights, opt_state, optimizer):
        (loss, (l0, loss_)), grad = self.lwg(weights)

        updates, opt_state = optimizer.update(grad, opt_state, weights)
        weights = optax.apply_updates(weights, updates)

        return loss, weights, opt_state, dict(l0=l0, loss=loss_)

    def create_optimizer(self):
        optimizer = optax.chain(
            optax.adam(self.lr),
            optax.zero_nans(),
        )

        return optimizer

    def find_weights(self):
        weights = jnp.ones(self.sae["W_dec"].shape[0]) * self.init_w
        optimizer = self.create_optimizer()
        opt_state = optimizer.init(weights)

        min_loss = 1e9
        early_stopping_counter = 0

        for _ in (bar := trange(self.iterations)):
            loss, weights, opt_state, metrics = self.train_step(weights, opt_state, optimizer)

            if metrics["loss"] < min_loss:
                min_loss = metrics["loss"]
                early_stopping_counter = 0

            tk = jax.lax.top_k(weights, 2)

            bar.set_postfix(loss_optim=loss, **metrics, top=tk[1][0], top_diff=(tk[0][0] - tk[0][1]) / tk[0][0])

            early_stopping_counter += 1
            if early_stopping_counter > self.early_stopping_steps:
                break

        return weights, metrics
    
    def check_feature(self, feature, scale):
        steering_vector = self.sae["W_dec"][feature] * scale
        steering_vector =  steering_vector.astype('bfloat16')

        act_add = add_vector(
            llama, steering_vector, self.target_layer, scale=1.0, position="last"
        )

        logits = act_add(self.eval_inputs[0]).unwrap("batch", "seq", "vocabulary")


        logprobs = jax.nn.log_softmax(logits, axis=-1)
        answer_logprobs = logprobs[:, -1]

        target_logprobs = jnp.take_along_axis(answer_logprobs, self.eval_inputs[1][:, None], axis=-1).squeeze()

        return (target_logprobs - answer_logprobs.max(axis=-1)).mean()

    def check_features(self, features, scale):
        losses = jnp.hstack([self.check_feature(feature, scale) for feature in tqdm(features)])

        return features[losses.argmax()], losses.max(), losses.mean(), losses

    

In [44]:
info = """location_continent 21 -1.09375
football_player_position 21 -3.76562
location_religion 21 -1.09375
location_language 20 -0.457031
person_profession 21 -1.30469
location_country 21 -1.44531
country_capital 18 -1.25781
person_language 18 -0.287109
singular_plural 21 -0.261719
present_simple_past_simple 20 -0.341797
antonyms 14 -0.75
plural_singular 20 -0.133789
present_simple_past_perfect 18 -1.25781
present_simple_gerund 20 -0.171875
en_it 18 -1.29688
it_en 14 -1.75781
en_fr 18 -1.39062
en_es 18 -1.21875
fr_en 17 -1.25781
es_en 17 -1.09375
"""

In [45]:
tasks_to_check = [x.split()[:2] for x in info.split("\n") if x]
tasks_to_check = [(x[0], int(x[1])) for x in tasks_to_check]
tasks_to_check

In [46]:
results = {}

for task, target_layer in tqdm(tasks_to_check):
    if target_layer == 21:
        target_layer = 20
    if target_layer == 14:
        target_layer = 12
    fs = FeatureSearch(task, target_layer)
    weights, metrics = fs.find_weights()
    i = jax.lax.top_k(jnp.abs(weights), 20)[1]
    best_feature, best_loss, mean_loss, losses = fs.check_features(i, scale=20)

    results[task] = (best_feature, best_loss, mean_loss, losses, metrics)
    print(
        f"Task {task}, best feature {best_feature}, best loss {best_loss}, mean loss {mean_loss}, metrics {metrics}"
    )

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task location_continent, best feature 47188, best loss -3.59375, mean loss -4.5, metrics {'l0': Array(18, dtype=int32), 'loss': Array(2.42188, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task football_player_position, best feature 26021, best loss -6.90625, mean loss -8.125, metrics {'l0': Array(21, dtype=int32), 'loss': Array(5.0625, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task location_religion, best feature 27071, best loss -5.375, mean loss -5.90625, metrics {'l0': Array(26, dtype=int32), 'loss': Array(1.8125, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task location_language, best feature 177, best loss -2.04688, mean loss -4.53125, metrics {'l0': Array(50, dtype=int32), 'loss': Array(3.5625, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task person_profession, best feature 48700, best loss -7.625, mean loss -8.75, metrics {'l0': Array(38, dtype=int32), 'loss': Array(4.1875, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task location_country, best feature 10728, best loss -2.46875, mean loss -3.0625, metrics {'l0': Array(20, dtype=int32), 'loss': Array(6.75, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task country_capital, best feature 32281, best loss -2.35938, mean loss -3.71875, metrics {'l0': Array(19, dtype=int32), 'loss': Array(6.15625, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task person_language, best feature 9753, best loss -1.28906, mean loss -4.125, metrics {'l0': Array(66, dtype=int32), 'loss': Array(3.125, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task singular_plural, best feature 32291, best loss -0.796875, mean loss -0.953125, metrics {'l0': Array(13, dtype=int32), 'loss': Array(9.875, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task present_simple_past_simple, best feature 17167, best loss -1, mean loss -1.75, metrics {'l0': Array(33, dtype=int32), 'loss': Array(8.375, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task antonyms, best feature 40483, best loss -2.26562, mean loss -3.4375, metrics {'l0': Array(68, dtype=int32), 'loss': Array(6.21875, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task plural_singular, best feature 15284, best loss -0.283203, mean loss -0.574219, metrics {'l0': Array(26, dtype=int32), 'loss': Array(8.0625, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task present_simple_past_perfect, best feature 23282, best loss -1.44531, mean loss -2.04688, metrics {'l0': Array(27, dtype=int32), 'loss': Array(7.40625, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task present_simple_gerund, best feature 49149, best loss -0.542969, mean loss -1.125, metrics {'l0': Array(22, dtype=int32), 'loss': Array(8.25, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task en_it, best feature 46157, best loss -6.4375, mean loss -7.34375, metrics {'l0': Array(29, dtype=int32), 'loss': Array(10.0625, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task it_en, best feature 27586, best loss -3.0625, mean loss -4.25, metrics {'l0': Array(39, dtype=int32), 'loss': Array(7.65625, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task en_fr, best feature 20311, best loss -6.09375, mean loss -7.25, metrics {'l0': Array(47, dtype=int32), 'loss': Array(9.6875, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task en_es, best feature 20311, best loss -2.375, mean loss -6.5, metrics {'l0': Array(26, dtype=int32), 'loss': Array(7.875, dtype=bfloat16)}


--2024-05-21 23:54:55--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l17-test-run-6-4.52E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.125, 108.156.211.90, 108.156.211.95, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.125|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/1623d8da38be3171fcc8516a4cbe9fdb80e3d77e370aa5690895697649d688f3?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1716594895&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjU5NDg5NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvMTYyM2Q4ZG

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task fr_en, best feature 41040, best loss -3.35938, mean loss -4.3125, metrics {'l0': Array(101, dtype=int32), 'loss': Array(8.375, dtype=bfloat16)}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Task es_en, best feature 41040, best loss -3.0625, mean loss -4.375, metrics {'l0': Array(117, dtype=int32), 'loss': Array(8.1875, dtype=bfloat16)}


In [50]:
for task, (best_feature, best_loss, mean_loss, losses, metrics) in results.items():
    print(
        f"{task}, best feature {best_feature}, loss ratio {mean_loss / best_loss}"
    )

location_continent, best feature 47188, loss ratio 1.25
football_player_position, best feature 26021, loss ratio 1.17969
location_religion, best feature 27071, loss ratio 1.10156
location_language, best feature 177, loss ratio 2.21875
person_profession, best feature 48700, loss ratio 1.14844
location_country, best feature 10728, loss ratio 1.24219
country_capital, best feature 32281, loss ratio 1.57812
person_language, best feature 9753, loss ratio 3.20312
singular_plural, best feature 32291, loss ratio 1.19531
present_simple_past_simple, best feature 17167, loss ratio 1.75
antonyms, best feature 40483, loss ratio 1.51562
plural_singular, best feature 15284, loss ratio 2.03125
present_simple_past_perfect, best feature 23282, loss ratio 1.41406
present_simple_gerund, best feature 49149, loss ratio 2.07812
en_it, best feature 46157, loss ratio 1.14062
it_en, best feature 27586, loss ratio 1.39062
en_fr, best feature 20311, loss ratio 1.1875
en_es, best feature 20311, loss ratio 2.73438
f

In [25]:
fs = FeatureSearch("en_es", 18)

In [60]:
fs = FeatureSearch("en_fr", 17, early_stopping_steps=100, seed=10)
w = fs.find_weights()

  0%|          | 0/2000 [00:00<?, ?it/s]

In [61]:
jax.lax.top_k(w[0], 20)

In [62]:
fs.check_features(jax.lax.top_k(w[0], 20)[1], scale=20)

  0%|          | 0/20 [00:00<?, ?it/s]

In [69]:
jax.lax.top_k(w, 20)

In [23]:
def find_features(task, target_layer):
    scale = 20
    early_stopping_steps = 50
    init_w = 0.1
    iterations = 2000

    sae = get_sae(target_layer, 6)
    dictionary = sae["W_dec"]

    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    seed = 0

    dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=seed+1)


    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )

    target_tokens = [x[1] for x in tokenizer(dataset.completions)["input_ids"]]
    target_tokens = jnp.asarray(target_tokens)

    jittify = lambda x: partial(jax.jit(lambda lr, *args, **kwargs: lr(*args, **kwargs)[1][0].value), x)
    get_resids_initial = make_get_resids(llama, target_layer)
    get_resids_initial = jittify(get_resids_initial)

    intital_resids = get_resids_initial(inputs)

    taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
        lambda i, x: x if i >= target_layer else pz.nn.Identity()
    ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                    .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

    optimizer = optax.chain(
        optax.adam(3e-2),
        optax.zero_nans(),    
    )

    lwg = jax.value_and_grad(get_loss, has_aux=True)
    

    weights = jnp.ones(dictionary.shape[0]) * init_w
    opt_state = optimizer.init(weights)

    min_loss = 1e9
    early_stopping_counter = 0

    for _ in (bar := trange(iterations)):
        loss, weights, opt_state, metrics = train_step(weights, opt_state, dictionary, inputs, target_tokens, taker, intital_resids, optimizer, lwg)

        if metrics["loss"] < min_loss:
            min_loss = metrics["loss"]
            early_stopping_counter = 0
        

        tk = jax.lax.top_k(weights, 2)

        bar.set_postfix(loss_optim=loss, **metrics, top=tk[1][0], top_diff=(tk[0][0] - tk[0][1]) / tk[0][0])

        early_stopping_counter += 1
        if early_stopping_counter > early_stopping_steps:
            break

    
    return jax.lax.top_k(jnp.abs(weights), 20)


In [27]:
find_features(
    "en_es", 18
)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [34]:
sae = get_sae(target_layer, 6)

In [35]:
check_feature(20311, "en_es", 18, sae)