In [1]:
#!pip install -qqq circuitsvis && pip install -qqq -U torch sentence-transformers
from datetime import datetime
import json
from os import listdir
from os.path import exists

import numpy as np
import torch
import circuitsvis as cv
from taker import Model
from taker.hooks import HookConfig

import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import PCA
import torch
from tqdm import tqdm

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fa59b70ca60>

## Load in the Model and Prompt

In [2]:
# m = Model("nickypro/tinyllama-15m", dtype="bfp16", compile=False, model_device="cuda")
# m = Model("microsoft/phi-3-mini-4k-instruct", dtype="bfp16", compile=False)
m = Model("google/gemma-2-2b-it", dtype="bfp16", compile=False)
m.show_details()

has_double_newline_token = (m.get_ids("\n").shape == m.get_ids("\n\n").shape)
newline_token_id = m.get_ids("\n\n")[0, -1].item()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded model 'google/gemma-2-2b-it' with bfp16:
- Added 416 hooks across 26 layers
 - n_layers : 26
 - d_model  : 2304
 - n_heads  : 8
 - d_head   : 256
 - d_mlp    : 9216


In [3]:
prompt    = """Write a short blog post about a recipe and the inspiration behind it.
 Do not include a title.
 Only reveal the dish after the story.
 Start with short story and then move to the recipe.
 To re-iterate, do not include a title."""
# info_gen = m.generate(info_prompt, temperature=0.3, num=300)

#DOUBLE NEWLINE
info_gen = """
\n Once upon a time, in a quaint little village nestled between rolling hills and verdant fields, there lived an elderly woman named Agnes. Agnes was known for her warm smile and her legendary Sunday dinners that brought the entire neighborhood together. Her recipes were family heirlooms, passed down through generations, with each family adding their own touch to the final dish.

One crisp autumn evening, Agnes was reminiscing about her childhood, and how her grandmother used to gather everyone around the dinner table, sharing stories and laughter. These were the moments that shaped her, the memories that she passed on to her own children and grandchildren.

Inspired by her grandmother's legacy, Agnes decided to create a new dish that would encapsulate the essence of those cherished gatherings. She wanted something that was comforting and nourishing, a dish that could be prepared with love and shared with others. After days of experimentation, she finally created a recipe that she believed truly captured the spirit of her family's Sunday dinners.\n\n"""
info_prompt = prompt+info_gen

info_ids = m.get_ids(info_prompt)
info_embeds = m.get_inputs_embeds(info_prompt)

We want a neutral prompt for extraction. We try a randomised/scrambled prompt,
and a fine-tuned prompt and see what works

In [4]:
torch.set_grad_enabled(False)

# Neutral Prompt
neutral_prompt = "Continuation of previous text:\n\n"

neutral_ids = m.get_ids(neutral_prompt)
print(m.tokenizer.convert_ids_to_tokens(neutral_ids[0].tolist()))
neutral_embeds = m.get_inputs_embeds(input_ids=neutral_ids)


# Random inputs embeds
def make_rand_embeds(neutral_embeds, start=1, end=5):
    rand_embeds    = neutral_embeds.clone()
    rand_embeds[0, start:end] = torch.randn_like(neutral_embeds[0, start:end]) / (m.cfg.d_model**0.5)
    rand_embeds[0, start:end] *= neutral_embeds[0, start:end].norm(dim=-1).mean()

    return rand_embeds

# rand_embeds = make_rand_embeds(neutral_embeds, 1, 5) # gemma
rand_embeds = make_rand_embeds(neutral_embeds, 0, 8) # phi3

# Sanity check on norms
print(neutral_embeds.norm(dim=-1).cpu().float().numpy())
print(rand_embeds.norm(dim=-1).cpu().float().numpy())

['<bos>', 'Continuation', '▁of', '▁previous', '▁text', ':', '\n\n']
[[4.15625   1.828125  1.8671875 1.71875   1.6796875 1.921875  2.34375  ]]
[[2.25     2.234375 2.203125 2.203125 2.1875   2.21875  2.21875 ]]


In [5]:
# Utils
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

def write_to_file(experiment_name, data):
    filename = f"./results/{current_time}_story_agnes_{experiment_name}.jsonl"
    filename_latest = f"./results/latest_story_agnes_{experiment_name}.jsonl"
    if not exists(filename):
        with open(filename, "w") as f:
            f.write("")
        with open(filename_latest, "w") as f:
            f.write("")
    for _filename in [filename, filename_latest]:
        with open(_filename, "a") as file:
            file.write(json.dumps(data) + "\n")

def read_file(experiment_name, time="latest"):
    filepath = f"./results/{time}_agnes_story_{experiment_name}.jsonl"
    df = pd.read_json(filepath, lines=True)

def reset_hooks():
    #RESET HOOKS BEFORE TRANSPLANTING NEXT SET OF ACTIVATIONS
    [h.reset() for h in m.hooks.neuron_replace.values()]

# Make tunable embed parameters:
class TunableInputsEmbeds(torch.nn.Module):
    def __init__(self, inputs_embeds):
        super().__init__()
        self.embeds = torch.nn.Parameter(inputs_embeds)
        self.shape = self.embeds.shape

    def forward(self):
        return self.embeds

## Try generating some things already

Get transferred activations and make some plots

In [6]:
# Get original text activations
acts = m.get_midlayer_activations(info_prompt)
orig_token_index = m.get_ids(info_prompt).shape[1] - 1
new_token_index  = m.get_ids(neutral_prompt).shape[1] - 1
print(orig_token_index, new_token_index)
print(m.tokenizer.convert_ids_to_tokens(m.get_ids(info_prompt).squeeze().tolist()[-neutral_ids.shape[1]:]))
print(m.tokenizer.convert_ids_to_tokens(m.get_ids(neutral_prompt).squeeze().tolist()))

def transfer_activations(num_tokens_transferred=1):
    reset_hooks()
    for j in range(num_tokens_transferred):
        for layer_index in range(m.cfg.n_layers):
            m.hooks.neuron_replace[f"layer_{layer_index}_mlp_pre_out"].add_token(new_token_index - j, acts["mlp"][0, layer_index, orig_token_index - j])
            m.hooks.neuron_replace[f"layer_{layer_index}_attn_pre_out"].add_token(new_token_index - j, acts["attn"][0, layer_index, orig_token_index - j])

# Input parameters
from dataclasses import dataclass
model_repo = "google/gemma-2b-it"
@dataclass
class GenData:
    model_repo: str = m.model_repo
    temperature: float = 0.3
    max_new_tokens: int = 100
    tokens_transferred_num: int = 1
    transplant_layers: tuple = (0,32)
    num_tokens_transferred: int = 1
    output: str = ""
    curr_prompt: str = neutral_prompt
    orig_prompt: str = info_prompt


def generate_texts(data: GenData):
    # Run generation with tranfer of 1 token
    print("Generating texts with 1 token transferred...")
    for num_tokens_transferred in [1]:
        data = GenData()
        transfer_activations(data.num_tokens_transferred)
        for i in range(5):
            rand_embeds = make_rand_embeds(neutral_embeds)
            embeds = TunableInputsEmbeds(rand_embeds)
            for i in range(5):
                text_in, text_out = m.generate(inputs_embeds=embeds(), num=data.max_new_tokens, temperature=data.temperature)
                data.output = text_out
                write_to_file("transfer-x1", data.__dict__)

    print("Generating texts from original info prompt...")
    # Run generation with tranfer of 1 token
    data = GenData()
    data.curr_prompt = info_prompt
    reset_hooks()
    for i in range(25):
        text_in, text_out = m.generate(text=data.curr_prompt, num=data.max_new_tokens, temperature=data.temperature)
        data.output = text_out
        write_to_file("orig", data.__dict__)

    print("Generating texts from neutral prompt...")
    # Run generation with tranfer of 1 token
    data = GenData()
    data.curr_prompt = neutral_prompt
    reset_hooks()
    for i in range(25):
        text_in, text_out = m.generate(text=data.curr_prompt, num=data.max_new_tokens, temperature=data.temperature)
        data.output = text_out
        write_to_file("neutral", data.__dict__)

# generate_texts(GenData())

254 6
['▁family', "'", 's', '▁Sunday', '▁dinners', '.', '\n\n']
['<bos>', 'Continuation', '▁of', '▁previous', '▁text', ':', '\n\n']


Get some comparison data, without any transfers

## Write a method for loading up prompts

In [7]:
import json
# Initialize the tokenizer
tokenizer = m.tokenizer

def read_prompts():
    with open("./results/latest_phi3_generations.jsonl", "r") as file:
        for line in file:
            data = json.loads(line)
            full_text = data['full_text']

            # Split the full text into prompt and output
            prompt, output = full_text.split("Assistant:", 1)
            prompt += "Assistant:"  # Add back the "Assistant:" part

            # Tokenize the full text and find the start of the output
            full_tokens = tokenizer.encode(full_text)
            output_start = len(tokenizer.encode(prompt))

            # print("Input, Output Tokens:", output_start, len(full_tokens))

            # Find the index of "\n\n" after 100 tokens into the output
            output_tokens = full_tokens[output_start:]
            if len(output_tokens) > 100:
                text_before_100_tokens = tokenizer.decode(output_tokens[:100])
                text_after_100_tokens = tokenizer.decode(output_tokens[100:])
                text_after_100_tokens_until_newline = text_after_100_tokens.split("\n\n")[0]

                if text_after_100_tokens_until_newline != text_after_100_tokens:
                    full_index = tokenizer.encode(prompt + text_before_100_tokens + text_after_100_tokens_until_newline)
                    data['split_index'] = len(full_index)
                else:
                    data['split_index'] = -1
            else:
                data['split_index'] = -1

            if data['split_index'] == -1:
                print("No split point found, skipping")
                continue

            data["newline_index"] = data["split_index"] + int(not has_double_newline_token)

            yield data

# Example usage:
for prompt_data in read_prompts():
    full_text     = prompt_data['full_text']
    newline_index = prompt_data['newline_index']

    tokens = tokenizer.encode(full_text)
    first_part  = tokenizer.decode(tokens[:newline_index+1])
    second_part = tokenizer.decode(tokens[newline_index+1:])
    print(f"{first_part}")
    print(f"--- BREAK POINT (Split index: {newline_index}) ---")
    print(f"{second_part}")

    break

<bos>Human: Write a detailed post-mortem analysis of a game development project you completed for a timed game development competition, structured as follows:

- Begin with a brief introduction to the game, including a description and where it can be played.
- Discuss the challenges and decisions made when working with the competition's theme, including any initial ideas that were eventually abandoned and the final concept chosen.
- Describe your development setup and any notable tools or techniques used, such as livestreaming or specific programming frameworks.
- Analyze the game's design, highlighting both successful and unsuccessful elements, including any usability issues encountered.
- Discuss the development process, including any changes made along the way, features that were cut due to time constraints, and coding challenges faced.
- Reflect on the game's overall user experience and any lessons learned regarding clarity and accessibility.
- Conclude with a summary of the projec

## Let's try a training loop.

In [8]:
torch.set_grad_enabled(True)  # Enable gradients for optimization

# Initialize rand_embeds as a TunableInputsEmbeds object
# rand_embeds = make_rand_embeds(neutral_embeds)
tuned_embeds = TunableInputsEmbeds(rand_embeds)
new_token_index  = m.get_ids(neutral_prompt).shape[1] - 1

# Define an optimizer for rand_embeds
def get_ce_loss(expected_ids: torch.Tensor, logits: torch.Tensor):
    """Computes cross entropy losses for each token."""
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    expected_ids = expected_ids.to(log_probs.device)
    predicted_log_probs = log_probs.gather(dim=-1, index=expected_ids[..., None])[..., 0]
    return -predicted_log_probs.mean()

optimizer = torch.optim.Adam(tuned_embeds.parameters(), lr=0.001)

for idx, prompt_data in (pbar := tqdm(enumerate(read_prompts()))):
    # Read the text
    full_ids = m.get_ids(prompt_data['full_text'])
    orig_newline_index = prompt_data['newline_index']
    ids_prompt = full_ids[:, :orig_newline_index+1]
    try:
        newline_token_id = m.get_ids("\n\n")[0, -1].item()
        prompt_token_id = ids_prompt[0, -1].item()
        assert newline_token_id == prompt_token_id, f"Final token is not a newline token: {m.tokenizer.decode(prompt_token_id)}"
    except AssertionError as e:
        # print(e)
        print({"idx": idx, "len": len(ids_prompt[0]), "orig_len": len(full_ids[0]), "tail end": tokenizer.decode(ids_prompt[0].tolist()[-10:])})
        continue

    # Get info prompt input
    info_output_ids = full_ids[:, newline_index+1:][:, :100] # limit to 100 tokens

    # Get info prompt output
    info_output_embeds = m.get_inputs_embeds(input_ids=info_output_ids)

    # Get original text input activations
    reset_hooks()
    with torch.no_grad():
        # Get original text activations
        acts = m.get_midlayer_activations(input_ids=ids_prompt)
        orig_token_index = ids_prompt.shape[1] - 1
        orig_acts = {
            "mlp" : acts["mlp"][0, :, orig_newline_index],
            "attn": acts["attn"][0, :, orig_newline_index]
        }

    # transfer activations
    for layer_index in range(m.cfg.n_layers):
        m.hooks.neuron_replace[f"layer_{layer_index}_mlp_pre_out"].add_token(new_token_index, orig_acts["mlp"][layer_index])
        m.hooks.neuron_replace[f"layer_{layer_index}_attn_pre_out"].add_token(new_token_index, orig_acts["attn"][layer_index])

    # Begin training on this text
    optimizer.zero_grad()

    neutral_inputs = tuned_embeds()
    neutral_outputs = info_output_embeds
    neutral_embeds = torch.cat([neutral_inputs, neutral_outputs], dim=1)

    # Forward pass with rand_embed
    logits = m.get_logits(inputs_embeds=neutral_embeds)
    loss   = get_ce_loss(info_output_ids, logits[..., neutral_inputs.shape[-1]:-1])

    loss.backward()  # Backward pass
    optimizer.step()  # Update rand_embeds

    # print(f'{idx}: Loss: {loss.item():.4f}')
    # update tqdm description within the loading bar
    pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
    # print(loss.item())

    if idx >= 200:
        break

1it [00:00,  3.54it/s, Loss=19.5000]

No split point found, skipping
{'idx': 1, 'len': 247, 'orig_len': 473, 'tail end': " with VisionaryTech's SmartLens.\n\n\n\n"}
{'idx': 2, 'len': 207, 'orig_len': 419, 'tail end': ', and create lasting memories with fellow anglers.\n\n\n'}


8it [00:01,  7.86it/s, Loss=20.2500]

{'idx': 6, 'len': 303, 'orig_len': 472, 'tail end': ' the foundation for a healthy, independent mindset.\n\n\n'}


12it [00:01,  7.58it/s, Loss=20.0000]

{'idx': 10, 'len': 322, 'orig_len': 511, 'tail end': ' CAD, ultimately improving patient care and outcomes.\n\n\n\n'}


18it [00:02,  6.96it/s, Loss=15.3125]

{'idx': 16, 'len': 291, 'orig_len': 459, 'tail end': ' the need for non-invasive monitoring techniques.\n\n\n\n'}


26it [00:04,  6.81it/s, Loss=13.8125]

{'idx': 24, 'len': 352, 'orig_len': 522, 'tail end': ' but also patients and the broader healthcare community.\n\n\n'}


30it [00:04,  7.23it/s, Loss=13.0000]

{'idx': 28, 'len': 230, 'orig_len': 461, 'tail end': ' to modify, resulting in a segmentation fault.\n\n\n'}


34it [00:05,  5.58it/s, Loss=12.3750]

{'idx': 34, 'len': 292, 'orig_len': 472, 'tail end': ' that invites further exploration into its enduring relevance.\n\n\n'}
{'idx': 35, 'len': 328, 'orig_len': 523, 'tail end': ' more efficient and cost-effective financial services.\n\n\n'}
{'idx': 36, 'len': 311, 'orig_len': 455, 'tail end': ' fail to integrate seamlessly with existing financial software.\n\n\n'}


40it [00:05,  9.94it/s, Loss=12.5000]

{'idx': 38, 'len': 260, 'orig_len': 445, 'tail end': ' and manipulated in a consistent and reliable manner.\n\n\n'}


51it [00:07,  6.78it/s, Loss=12.5625]

{'idx': 49, 'len': 318, 'orig_len': 450, 'tail end': ' measurement of disease burden and response to therapy.\n\n\n'}


54it [00:08,  7.45it/s, Loss=11.9375]

{'idx': 52, 'len': 241, 'orig_len': 446, 'tail end': " treatment for women's mental health concerns.\n\n\n"}


58it [00:08,  7.39it/s, Loss=12.2500]

{'idx': 56, 'len': 432, 'orig_len': 484, 'tail end': ' as much as its giver or provider intended.\n\n\n\n'}


70it [00:10,  6.73it/s, Loss=11.3750]

{'idx': 68, 'len': 224, 'orig_len': 469, 'tail end': 'letter sequence: "LMN OP QR"\n\n\n'}


74it [00:11,  7.21it/s, Loss=11.5000]

{'idx': 72, 'len': 199, 'orig_len': 410, 'tail end': ' a research fellow and later an associate professor.\n\n\n'}


85it [00:13,  6.73it/s, Loss=11.8750]

{'idx': 83, 'len': 254, 'orig_len': 432, 'tail end': ' offering a range of shopping and dining options.\n\n\n'}


88it [00:13,  7.41it/s, Loss=11.3125]

{'idx': 86, 'len': 228, 'orig_len': 447, 'tail end': '-step process to securely delete your emails:\n\n\n'}


90it [00:14,  6.25it/s, Loss=11.1875]

No split point found, skipping


95it [00:14,  7.07it/s, Loss=nan]    

{'idx': 93, 'len': 272, 'orig_len': 446, 'tail end': ' could potentially alter the outcome of the case.\n\n\n'}


98it [00:15,  7.66it/s, Loss=11.3125]

{'idx': 96, 'len': 219, 'orig_len': 402, 'tail end': ' the zero-crossing point of the waveform.\n\n\n'}


111it [00:17,  6.69it/s, Loss=11.0000]

{'idx': 109, 'len': 282, 'orig_len': 437, 'tail end': ' catheter management to reduce the risk of complications.\n\n\n'}


115it [00:18,  7.16it/s, Loss=10.0625]

{'idx': 113, 'len': 169, 'orig_len': 378, 'tail end': ' resolved, and she made a full recovery.\n\n\n'}


125it [00:20,  6.77it/s, Loss=10.0625]

{'idx': 123, 'len': 254, 'orig_len': 455, 'tail end': ' address the socioeconomic determinants of T2DM.\n\n\n'}


144it [00:23,  6.71it/s, Loss=10.6250]

{'idx': 142, 'len': 185, 'orig_len': 399, 'tail end': ' to maintain sharpness and reduce wear over time.\n\n\n'}


166it [00:27,  6.58it/s, Loss=10.3125]

{'idx': 164, 'len': 208, 'orig_len': 446, 'tail end': ' our website or contacting a local agent today.\n\n\n\n'}


169it [00:28,  7.45it/s, Loss=11.3750]

{'idx': 167, 'len': 303, 'orig_len': 524, 'tail end': 'lassified under the family "Brownaceae."\n\n\n'}


178it [00:29,  6.77it/s, Loss=9.8750] 

{'idx': 176, 'len': 301, 'orig_len': 481, 'tail end': ' (approx. 1-2 sentences).\n\n\n\n'}


187it [00:31,  5.21it/s, Loss=9.3125] 

{'idx': 187, 'len': 460, 'orig_len': 497, 'tail end': "Helvetica'; fontSize = 14;>\n\n\n"}


197it [00:32,  6.69it/s, Loss=10.3125]

{'idx': 195, 'len': 268, 'orig_len': 456, 'tail end': '-term outcomes in terms of patient comfort.\n\n\n'}
No split point found, skipping


200it [00:33,  5.93it/s, Loss=9.9375] 


In [12]:
# Get original text input activations
reset_hooks()
with torch.no_grad():
    # Get original text activations
    acts = m.get_midlayer_activations(input_ids=ids_prompt)
    orig_token_index = ids_prompt.shape[1] - 1
    orig_acts = {
        "mlp" : acts["mlp"][0, :, orig_newline_index],
        "attn": acts["attn"][0, :, orig_newline_index]
    }

# transfer activations
for layer_index in range(m.cfg.n_layers):
    m.hooks.neuron_replace[f"layer_{layer_index}_mlp_pre_out"].add_token(new_token_index, orig_acts["mlp"][layer_index])
    m.hooks.neuron_replace[f"layer_{layer_index}_attn_pre_out"].add_token(new_token_index, orig_acts["attn"][layer_index])

m.generate(inputs_embeds=tuned_embeds(), num=100)

('',
 ' itſelf \n\n\n\n \n\n\n\n \n\n\n\n  \n\n\n\n  \n\n\n\n  \n\n\n\n  \n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n  \n\n\n\n  \n\n\n\n  \n\n\n\n  \n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n  \n\n\n\n  \n\n\n\n  \n\n\n\n  \n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n  \n\n\n\n  \n\n\n\n  \n\n\n\n  \n\n\n\n  \n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n \n\n\n\n ')