In [1]:

%load_ext autoreload
%autoreload 2


import penzai
from penzai import pz
import os
if "models" not in os.listdir("."):
    os.chdir("../..")

import json

from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import random
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector






In [18]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True)

from transformers import AutoTokenizer
import jax


tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"


from sprint.task_vector_utils import load_tasks, ICLDataset, ICLSequence
tasks = load_tasks()



from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)



def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs


from micrlhf.utils.load_sae import sae_encode

from safetensors import safe_open

from micrlhf.utils.load_sae import get_nev_it_sae_suite


sep = 3978
pad = 0
newline = 108



task_names = [x for x in tasks]

task_names = ["location_continent"]

n_seeds = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 16, 256

prompt = "Follow the pattern:\n{}"


from sprint.task_vector_utils import ICLRunner, logprob_loss, get_tv_detector, make_act_adder_detector, weights_to_resid

fatal: destination path 'data/itv' already exists and is not an empty directory.


In [19]:
from safetensors import safe_open
from sprint.task_vector_utils import FeatureSearch
from micrlhf.utils.ito import grad_pursuit

seed = 10

layers = list(range(9, 14))
# layers = [10, 12, 14]

# layer = 12

prompt_length = tokenizer(prompt, return_tensors="pt")["input_ids"].shape[1]

for task in tqdm(task_names):
    pairs = list(tasks[task].items())

    from collections import Counter
    c = Counter([x[1] for x in pairs])

    max_count = sum(c.values()) / len(c)

    pairs_by_second = {x: [] for x in c}
    for p in pairs:
        pairs_by_second[p[1]].append(p)

    pairs = []

    for k, v in pairs_by_second.items():
        pairs.extend(v[:int(max_count)])

    n_shot = n_few_shots-1
    if task.startswith("algo"):
        n_shot = 8

    runner = ICLRunner(task, pairs, batch_size=batch_size * 2, n_shot=n_shot, max_seq_len=max_seq_len, seed=seed, prompt=prompt, vector_type="detector")


    tokenized = runner.get_tokens([
        x[:n_shot] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    tokenized = runner.get_tokens(runner.eval_pairs, tokenizer)
    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]

    logits = llama(inputs)
    
    zero_loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift= 0, n_first=2, sep=sep, pad_token=0
    )

    print(
        f"Zero: {task}, Loss: {zero_loss}"  
    )

    for layer in layers:
        sae = get_nev_it_sae_suite(layer)

        resids = all_resids[layer].value.unwrap(
            "batch", "seq", "embedding"
        )

        tv = get_tv_detector(resids, train_tokens, shift = 0, prompt_length=prompt_length, newline=newline)

        add_act = make_act_adder_detector(llama, tv.astype('bfloat16'), tokens, layer, prompt_length=prompt_length, newline=newline, length=1)

        logits = add_act(inputs)

        tv_loss = logprob_loss(
            logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=0, n_first=2, sep=sep, pad_token=0
        )

        print(
            f"TV: {task}, L: {layer}, Loss: {tv_loss}"  
        )
        
        pr, _, rtv = sae_encode(sae, tv)

        add_act = make_act_adder_detector(llama, rtv.astype('bfloat16'), tokens, layer, prompt_length=prompt_length, newline=newline, length=1)

        logits = add_act(inputs)

        recon_loss = logprob_loss(
            logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=0, n_first=2, sep=sep, pad_token=0
        )

        print(
            f"Recon TV: {task}, L: {layer}, Loss: {recon_loss}"  
        )

        gpr, gtv = grad_pursuit(tv, sae["W_dec"], 7)

        add_act = make_act_adder_detector(llama, gtv.astype('bfloat16'), tokens, layer, prompt_length=prompt_length, newline=newline, length=1)

        logits = add_act(inputs)

        ito_loss = logprob_loss(
            logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=0, n_first=2, sep=sep, pad_token=0
        )

        print(
            f"Grad pursuit TV: {task}, L: {layer}, Loss: {ito_loss}"
        )

        fs = FeatureSearch(task, pairs, layer, llama, tokenizer, n_shot=1, seed=seed+100, init_w=gpr, early_stopping_steps=200, n_first=2, sep=sep, newline=newline, pad_token=0, sae_v=8, sae=sae, batch_size=24, iterations=1000, prompt=prompt, l1_coeff=0.0005, feature_type="detector", lr=0.01, n_batches=1)

        w, m = fs.find_weights()

        _, _, recon = sae_encode(sae, None, pre_relu=w)
        
        recon = recon.astype('bfloat16')

        add_act = make_act_adder_detector(llama, recon, tokens, layer, prompt_length=prompt_length, newline=newline, length=1)

        logits = add_act(inputs)

        loss = logprob_loss(
            logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=0, n_first=2, sep=sep, pad_token=0
        )

        print(
            f"Recon fs: {task}, L: {layer}, Loss: {loss}"  
        )

        with open("cleanup_results_final_detectors_3.jsonl", "a") as f:
            item = {
                "task": task,
                "weights": w.tolist(),
                "loss": loss.tolist(),
                "recon_loss": recon_loss.tolist(),
                "ito_loss": ito_loss.tolist(),
                "tv_loss": tv_loss.tolist(),
                "zero_loss": zero_loss.tolist(),
                "tv": tv.tolist(),
                "gw": gpr.tolist(),
                "layer": layer
            }

            f.write(json.dumps(item) + "\n")


  0%|          | 0/1 [00:00<?, ?it/s]

Zero: location_continent, Loss: 9.8125
TV: location_continent, L: 9, Loss: 7.03125
Recon TV: location_continent, L: 9, Loss: 9.3125
Grad pursuit TV: location_continent, L: 9, Loss: 7.71875


  0%|          | 0/1000 [00:00<?, ?it/s]

Recon fs: location_continent, L: 9, Loss: 6.21875
TV: location_continent, L: 10, Loss: 7.09375
Recon TV: location_continent, L: 10, Loss: 9.5
Grad pursuit TV: location_continent, L: 10, Loss: 7.34375


  0%|          | 0/1000 [00:00<?, ?it/s]

Recon fs: location_continent, L: 10, Loss: 7.46875
TV: location_continent, L: 11, Loss: 7.40625
Recon TV: location_continent, L: 11, Loss: 9.6875
Grad pursuit TV: location_continent, L: 11, Loss: 8.125


  0%|          | 0/1000 [00:00<?, ?it/s]

Recon fs: location_continent, L: 11, Loss: 6.96875
TV: location_continent, L: 12, Loss: 9.1875
Recon TV: location_continent, L: 12, Loss: 9.625


KeyboardInterrupt: 

In [17]:
from collections import Counter
c = Counter([x[1] for x in pairs])

c

Counter({'English': 17,
         'Finnish': 17,
         'Italian': 17,
         'French': 17,
         'Spanish': 17,
         'Dutch': 17,
         'Swedish': 17,
         'Russian': 17,
         'German': 17,
         'Portuguese': 13,
         'Ukrainian': 11,
         'Latin': 11,
         'Hindi': 11,
         'Indonesian': 9,
         'Serbian': 8,
         'Catalan': 6,
         'Persian': 5,
         'Polish': 5,
         'Czech': 4,
         'Bulgarian': 4,
         'Tamil': 4,
         'Welsh': 4,
         'Chinese': 3,
         'Hebrew': 3,
         'Croatian': 3,
         'Greek': 3,
         'Armenian': 2,
         'Japanese': 2,
         'Filipino': 2,
         'Somali': 2,
         'Turkish': 2,
         'Georgian': 2,
         'Danish': 2,
         'Romanian': 2,
         'Korean': 2,
         'Mari': 1,
         'Hawaiian': 1,
         'Norwegian': 1,
         'Icelandic': 1,
         'Vietnamese': 1,
         'Irish': 1,
         'Thai': 1,
         'Hungarian': 1})

In [10]:
# drop amount of pairs with second element being popular

from collections import Counter
c = Counter([x[1] for x in pairs])

max_count = sum(c.values()) / len(c)

pairs_by_second = {x: [] for x in c}
for p in pairs:
    pairs_by_second[p[1]].append(p)

new_pairs = []

for k, v in pairs_by_second.items():
    new_pairs.extend(v[:int(max_count)])

Counter([x[1] for x in new_pairs])

len(new_pairs)

286

In [10]:
runner.train_pairs

[[('Republic of Khakassia', 'Russian'),
  ('Malawi', 'English'),
  ('Ticino', 'Italian'),
  ('Aigle', 'French'),
  ('Thailand', 'Thai'),
  ('Netherlands Antilles', 'Dutch'),
  ('Nicaragua', 'Spanish'),
  ('Puerto Rico', 'English'),
  ('Stellaland', 'Dutch'),
  ('Grenada', 'English'),
  ('Georgian Orthodox Church', 'Georgian'),
  ('Armenia', 'Armenian'),
  ('Kingdom of Tavolara', 'Italian'),
  ('history of Limousin', 'French'),
  ('Espoo', 'Finnish'),
  ('Uruguay', 'Spanish'),
  ('Kyōto Prefecture', 'Japanese'),
  ('Stockholm County Council', 'Swedish'),
  ('Bulgarian Orthodox Church', 'Bulgarian')],
 [('Mari El Republic', 'Mari'),
  ('Sundbyberg Municipality', 'Swedish'),
  ('Russian Empire', 'Polish'),
  ('Bosco Gurin', 'Italian'),
  ('Cadenazzo', 'Italian'),
  ('Arizona', 'English'),
  ('Renens', 'French'),
  ('Neklinovsky District', 'Russian'),
  ('International Civil Aviation Organization', 'Russian'),
  ('Finnish Orthodox Church', 'Finnish'),
  ('Nyon', 'French'),
  ('Alabama', 'E

In [22]:
with open("cleanup_results_final_detectors_3.jsonl", "r") as f:
    data = [json.loads(x) for x in f]

results = [x for  x in data if x["layer"] == 11 and x["task"] == "location_continent"][-1]

In [23]:
import numpy as np

w = np.array(results["weights"])

jax.lax.top_k(w, 10)

[Array([9.053909  , 6.594228  , 6.105806  , 5.7983284 , 0.35896918,
        0.12874588, 0.        , 0.        , 0.        , 0.        ],      dtype=float32),
 Array([12898,  8220, 27001, 25334, 20627, 29362,     0,     1,     2,
            3], dtype=int32)]

In [6]:
runner.eval_pairs

[[('X', 'Y'), ('Pacific Alliance', 'Spanish')],
 [('X', 'Y'), ('Yerevan', 'Armenian')],
 [('X', 'Y'), ('Renens', 'French')],
 [('X', 'Y'), ('Sastamala', 'Finnish')],
 [('X', 'Y'), ('Pacific Alliance', 'Spanish')],
 [('X', 'Y'), ('Republic of Khakassia', 'Russian')],
 [('X', 'Y'), ('French Polynesia', 'French')],
 [('X', 'Y'), ('Republic of Adygea', 'Russian')],
 [('X', 'Y'), ('Canton of Fribourg', 'French')],
 [('X', 'Y'), ('Kaavi', 'Finnish')],
 [('X', 'Y'), ('Melitopol', 'Ukrainian')],
 [('X', 'Y'), ('Nastola', 'Finnish')],
 [('X', 'Y'), ('Savoy', 'French')],
 [('X', 'Y'), ('Iowa', 'English')],
 [('X', 'Y'), ('Kumlinge', 'Swedish')],
 [('X', 'Y'), ('Kerava', 'Finnish')],
 [('X', 'Y'), ('Engelberg', 'German')],
 [('X', 'Y'), ('Saint Pierre and Miquelon', 'French')],
 [('X', 'Y'), ('Germany', 'German')],
 [('X', 'Y'), ('Western Cape', 'English')],
 [('X', 'Y'), ('Autonomous Province of Kosovo and Metohija', 'Serbian')],
 [('X', 'Y'), ('Moudon', 'French')],
 [('X', 'Y'), ('Pully', 'Fren

In [7]:
tokens[0]

array([     2,   9792,    573,   6883, 235292,    108, 235356,   3978,
          890,    108,  41019,  29464,   3978,  13035,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
      