In [1]:
#!pip install -qqq circuitsvis && pip install -qqq -U torch sentence-transformers
from datetime import datetime
import json
from os import listdir
from os.path import exists

import numpy as np
import torch
import circuitsvis as cv
from taker import Model
from taker.hooks import HookConfig

import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import PCA
import torch
from tqdm import tqdm

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fb4d3a1fd00>

## Load in the Model and Prompt

In [2]:
# m = Model("nickypro/tinyllama-15m", dtype="bfp16", compile=False, model_device="cuda")
# m = Model("microsoft/phi-3-mini-4k-instruct", dtype="bfp16", compile=False)
m = Model("google/gemma-2-2b-it", dtype="bfp16", compile=False)
m.show_details()

has_double_newline_token = (m.get_ids("\n").shape == m.get_ids("\n\n").shape)
newline_token_id = m.get_ids("\n\n")[0, -1].item()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded model 'google/gemma-2-2b-it' with bfp16:
- Added 416 hooks across 26 layers
 - n_layers : 26
 - d_model  : 2304
 - n_heads  : 8
 - d_head   : 256
 - d_mlp    : 9216


In [3]:
prompt    = """Write a short blog post about a recipe and the inspiration behind it.
 Do not include a title.
 Only reveal the dish after the story.
 Start with short story and then move to the recipe.
 To re-iterate, do not include a title."""
# info_gen = m.generate(info_prompt, temperature=0.3, num=300)

#DOUBLE NEWLINE
info_gen = """
\n Once upon a time, in a quaint little village nestled between rolling hills and verdant fields, there lived an elderly woman named Agnes. Agnes was known for her warm smile and her legendary Sunday dinners that brought the entire neighborhood together. Her recipes were family heirlooms, passed down through generations, with each family adding their own touch to the final dish.

One crisp autumn evening, Agnes was reminiscing about her childhood, and how her grandmother used to gather everyone around the dinner table, sharing stories and laughter. These were the moments that shaped her, the memories that she passed on to her own children and grandchildren.

Inspired by her grandmother's legacy, Agnes decided to create a new dish that would encapsulate the essence of those cherished gatherings. She wanted something that was comforting and nourishing, a dish that could be prepared with love and shared with others. After days of experimentation, she finally created a recipe that she believed truly captured the spirit of her family's Sunday dinners.\n\n"""
info_prompt = prompt+info_gen

info_ids = m.get_ids(info_prompt)
info_embeds = m.get_inputs_embeds(info_prompt)

We want a neutral prompt for extraction. We try a randomised/scrambled prompt,
and a fine-tuned prompt and see what works

In [4]:
torch.set_grad_enabled(False)

# Neutral Prompt
neutral_prompt = "Continuation of previous text:\n\n"

neutral_ids = m.get_ids(neutral_prompt)
print(m.tokenizer.convert_ids_to_tokens(neutral_ids[0].tolist()))
neutral_embeds = m.get_inputs_embeds(input_ids=neutral_ids)


# Random inputs embeds
def make_rand_embeds(neutral_embeds, start=1, end=5):
    rand_embeds    = neutral_embeds.clone()
    rand_embeds[0, start:end] = torch.randn_like(neutral_embeds[0, start:end]) / (m.cfg.d_model**0.5)
    rand_embeds[0, start:end] *= neutral_embeds[0, start:end].norm(dim=-1).mean()

    return rand_embeds

rand_embeds = make_rand_embeds(neutral_embeds, 1, 6) # gemma
# rand_embeds = make_rand_embeds(neutral_embeds, 0, 8) # phi3

# Sanity check on norms
print(neutral_embeds.norm(dim=-1).cpu().float().numpy())
print(rand_embeds.norm(dim=-1).cpu().float().numpy())

['<bos>', 'Continuation', '▁of', '▁previous', '▁text', ':', '\n\n']
[[4.15625   1.828125  1.8671875 1.71875   1.6796875 1.921875  2.34375  ]]
[[4.15625   1.7734375 1.8046875 1.8203125 1.890625  1.8203125 2.34375  ]]


In [5]:
# Utils
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

def write_to_file(experiment_name, data):
    filename = f"./results/{current_time}_story_agnes_{experiment_name}.jsonl"
    filename_latest = f"./results/latest_story_agnes_{experiment_name}.jsonl"
    if not exists(filename):
        with open(filename, "w") as f:
            f.write("")
        with open(filename_latest, "w") as f:
            f.write("")
    for _filename in [filename, filename_latest]:
        with open(_filename, "a") as file:
            file.write(json.dumps(data) + "\n")

def read_file(experiment_name, time="latest"):
    filepath = f"./results/{time}_agnes_story_{experiment_name}.jsonl"
    df = pd.read_json(filepath, lines=True)

def reset_hooks():
    #RESET HOOKS BEFORE TRANSPLANTING NEXT SET OF ACTIVATIONS
    [h.reset() for h in m.hooks.neuron_replace.values()]

# Make tunable embed parameters:
class TunableInputsEmbeds(torch.nn.Module):
    def __init__(self, inputs_embeds):
        super().__init__()
        self.embeds = torch.nn.Parameter(inputs_embeds)
        self.shape = self.embeds.shape

    def forward(self):
        return self.embeds

## Try generating some things already

Get transferred activations and make some plots

In [6]:
# Get original text activations
acts = m.get_midlayer_activations(info_prompt)
orig_token_index = m.get_ids(info_prompt).shape[1] - 1
new_token_index  = m.get_ids(neutral_prompt).shape[1] - 1
print(orig_token_index, new_token_index)
print(m.tokenizer.convert_ids_to_tokens(m.get_ids(info_prompt).squeeze().tolist()[-neutral_ids.shape[1]:]))
print(m.tokenizer.convert_ids_to_tokens(m.get_ids(neutral_prompt).squeeze().tolist()))

def transfer_activations(num_tokens_transferred=1):
    reset_hooks()
    for j in range(num_tokens_transferred):
        for layer_index in range(m.cfg.n_layers):
            m.hooks.neuron_replace[f"layer_{layer_index}_mlp_pre_out"].add_token(new_token_index - j, acts["mlp"][0, layer_index, orig_token_index - j])
            m.hooks.neuron_replace[f"layer_{layer_index}_attn_pre_out"].add_token(new_token_index - j, acts["attn"][0, layer_index, orig_token_index - j])

# Input parameters
from dataclasses import dataclass
model_repo = "google/gemma-2b-it"
@dataclass
class GenData:
    model_repo: str = m.model_repo
    temperature: float = 0.3
    max_new_tokens: int = 100
    tokens_transferred_num: int = 1
    transplant_layers: tuple = (0,32)
    num_tokens_transferred: int = 1
    output: str = ""
    curr_prompt: str = neutral_prompt
    orig_prompt: str = info_prompt


def generate_texts(data: GenData):
    # Run generation with tranfer of 1 token
    print("Generating texts with 1 token transferred...")
    for num_tokens_transferred in [1]:
        data = GenData()
        transfer_activations(data.num_tokens_transferred)
        for i in range(5):
            rand_embeds = make_rand_embeds(neutral_embeds)
            embeds = TunableInputsEmbeds(rand_embeds)
            for i in range(5):
                text_in, text_out = m.generate(inputs_embeds=embeds(), num=data.max_new_tokens, temperature=data.temperature)
                data.output = text_out
                write_to_file("transfer-x1", data.__dict__)

    print("Generating texts from original info prompt...")
    # Run generation with tranfer of 1 token
    data = GenData()
    data.curr_prompt = info_prompt
    reset_hooks()
    for i in range(25):
        text_in, text_out = m.generate(text=data.curr_prompt, num=data.max_new_tokens, temperature=data.temperature)
        data.output = text_out
        write_to_file("orig", data.__dict__)

    print("Generating texts from neutral prompt...")
    # Run generation with tranfer of 1 token
    data = GenData()
    data.curr_prompt = neutral_prompt
    reset_hooks()
    for i in range(25):
        text_in, text_out = m.generate(text=data.curr_prompt, num=data.max_new_tokens, temperature=data.temperature)
        data.output = text_out
        write_to_file("neutral", data.__dict__)

# generate_texts(GenData())

254 6
['▁family', "'", 's', '▁Sunday', '▁dinners', '.', '\n\n']
['<bos>', 'Continuation', '▁of', '▁previous', '▁text', ':', '\n\n']


In [7]:
# set torch nograd
torch.set_grad_enabled(False)
tuned_embeds = TunableInputsEmbeds(rand_embeds)

def print_comparison():

    # Get original text activations
    reset_hooks()
    acts = m.get_midlayer_activations(info_prompt)
    orig_token_index = m.get_ids(info_prompt).shape[1] - 1
    new_token_index  = m.get_ids(neutral_prompt).shape[1] - 1

    def transfer_activations(num_tokens_transferred=1):
        for j in range(num_tokens_transferred):
            for layer_index in range(m.cfg.n_layers):
                m.hooks.neuron_replace[f"layer_{layer_index}_mlp_pre_out"].add_token(new_token_index - j, acts["mlp"][0, layer_index, orig_token_index - j])
                m.hooks.neuron_replace[f"layer_{layer_index}_attn_pre_out"].add_token(new_token_index - j, acts["attn"][0, layer_index, orig_token_index - j])

    @dataclass
    class GenData:
        model_repo: str = m.model_repo
        temperature: float = 0.3
        max_new_tokens: int = 20
        num_tokens_transferred: int = 1
        output: str = ""
        curr_prompt: str = neutral_prompt
        orig_prompt: str = info_prompt

    def generate_text_with_tuned_embeds(data: GenData):
        transfer_activations(data.num_tokens_transferred)
        text_in, text_out = m.generate(inputs_embeds=tuned_embeds(), num=data.max_new_tokens, temperature=data.temperature)
        return text_out

    # For comparison, generate text from the original info prompt
    reset_hooks()
    text_in, text_out = m.generate(text=info_prompt, temperature=0.3)
    print({"orig": text_out})

    # Generate a single sample
    reset_hooks()
    text_out = generate_text_with_tuned_embeds(GenData())
    print({"transfer": text_out})


    # For comparison, generate text from the neutral prompt
    reset_hooks()
    text_in, text_out = m.generate(inputs_embeds=tuned_embeds(), temperature=0.3)
    print({"no_transfer": text_out})

print_comparison()

{'orig': 'The dish was a vibrant and flavorful stew, bursting'}
{'transfer': ' Monfieur\n\n\n\n'}
{'no_transfer': ''}


Get some comparison data, without any transfers

## Write a method for loading up prompts

In [8]:
import json
# Initialize the tokenizer
tokenizer = m.tokenizer

def read_prompts():
    with open("./results/latest_phi3_generations.jsonl", "r") as file:
        for line in file:
            data = json.loads(line)
            full_text = data['full_text']

            # Split the full text into prompt and output
            prompt, output = full_text.split("Assistant:", 1)
            prompt += "Assistant:"  # Add back the "Assistant:" part

            # Tokenize the full text and find the start of the output
            full_tokens = tokenizer.encode(full_text)
            output_start = len(tokenizer.encode(prompt))

            # print("Input, Output Tokens:", output_start, len(full_tokens))

            # Find the index of "\n\n" after 100 tokens into the output
            output_tokens = full_tokens[output_start:]
            if len(output_tokens) > 100:
                text_before_100_tokens = tokenizer.decode(output_tokens[:100])
                text_after_100_tokens = tokenizer.decode(output_tokens[100:])
                text_after_100_tokens_until_newline = text_after_100_tokens.split("\n\n")[0]

                if text_after_100_tokens_until_newline != text_after_100_tokens:
                    full_index = tokenizer.encode(prompt + text_before_100_tokens + text_after_100_tokens_until_newline)
                    data['split_index'] = len(full_index)
                else:
                    data['split_index'] = -1
            else:
                data['split_index'] = -1

            if data['split_index'] == -1:
                print("No split point found, skipping")
                continue

            data["newline_index"] = data["split_index"] + int(not has_double_newline_token)

            yield data

# Example usage:
for prompt_data in read_prompts():
    full_text     = prompt_data['full_text']
    newline_index = prompt_data['newline_index']

    tokens = tokenizer.encode(full_text)
    first_part  = tokenizer.decode(tokens[:newline_index+1])
    second_part = tokenizer.decode(tokens[newline_index+1:])
    print(f"{first_part}")
    print(f"--- BREAK POINT (Split index: {newline_index}) ---")
    print(f"{second_part}")

    break

prompts = list(read_prompts())

<bos>Human: Write a detailed post-mortem analysis of a game development project you completed for a timed game development competition, structured as follows:

- Begin with a brief introduction to the game, including a description and where it can be played.
- Discuss the challenges and decisions made when working with the competition's theme, including any initial ideas that were eventually abandoned and the final concept chosen.
- Describe your development setup and any notable tools or techniques used, such as livestreaming or specific programming frameworks.
- Analyze the game's design, highlighting both successful and unsuccessful elements, including any usability issues encountered.
- Discuss the development process, including any changes made along the way, features that were cut due to time constraints, and coding challenges faced.
- Reflect on the game's overall user experience and any lessons learned regarding clarity and accessibility.
- Conclude with a summary of the projec

## Let's try a training loop.

In [9]:
import torch
from tqdm import tqdm

torch.set_grad_enabled(True)  # Enable gradients for optimization

# Initialize tuned_embeds as a TunableInputsEmbeds object
tuned_embeds = TunableInputsEmbeds(rand_embeds)
print("norms before", tuned_embeds().norm(dim=-1))

new_token_index = m.get_ids(neutral_prompt).shape[1] - 1

# Define loss function
def get_ce_loss(expected_ids: torch.Tensor, logits: torch.Tensor):
    """Computes cross entropy losses for each token."""
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    expected_ids = expected_ids.to(log_probs.device)
    predicted_log_probs = log_probs.gather(dim=-1, index=expected_ids[..., None])[..., 0]
    return -predicted_log_probs.mean()

optimizer = torch.optim.Adam(tuned_embeds.parameters(), lr=0.01)

# Batch processing
batch_size = 8
max_tokens = 100

def process_sample(prompt_data):
    full_ids = m.get_ids(prompt_data['full_text'])
    orig_newline_index = prompt_data['newline_index']
    ids_prompt = full_ids[:, :orig_newline_index+1]

    # Validation check
    newline_token_id = m.get_ids("\n\n")[0, -1].item()
    prompt_token_id = ids_prompt[0, -1].item()
    if newline_token_id != prompt_token_id:
        return None

    info_output_ids = full_ids[0, orig_newline_index+1:orig_newline_index+1+max_tokens]

    # Get original text activations
    reset_hooks()
    with torch.no_grad():
        acts = m.get_midlayer_activations(input_ids=ids_prompt)
        orig_acts = {
            "mlp": acts["mlp"][0, :, orig_newline_index],
            "attn": acts["attn"][0, :, orig_newline_index]
        }

    # Transfer activations
    for layer_index in range(m.cfg.n_layers):
        m.hooks.neuron_replace[f"layer_{layer_index}_mlp_pre_out"].add_token(new_token_index, orig_acts["mlp"][layer_index])
        m.hooks.neuron_replace[f"layer_{layer_index}_attn_pre_out"].add_token(new_token_index, orig_acts["attn"][layer_index])

    # Forward pass
    neutral_inputs = tuned_embeds()
    info_output_embeds = m.get_inputs_embeds(input_ids=info_output_ids.unsqueeze(0))
    neutral_embeds = torch.cat([neutral_inputs, info_output_embeds], dim=1)

    logits = m.get_logits(inputs_embeds=neutral_embeds)
    loss = get_ce_loss(info_output_ids.unsqueeze(0), logits[:, neutral_inputs.shape[1]-1:-1])

    return loss

# Training loop
num_batches = (len(prompts) + batch_size - 1) // batch_size

for batch_idx in (pbar := tqdm(range(num_batches))):
    batch = prompts[batch_idx * batch_size : (batch_idx + 1) * batch_size]

    batch_loss = 0
    valid_samples = 0

    optimizer.zero_grad()

    for prompt_data in batch:
        loss = process_sample(prompt_data)

        if loss is not None:
            batch_loss += loss.item()
            valid_samples += 1
            loss.backward()

    if valid_samples > 0:
        # Update tuned_embeds after processing the entire batch
        optimizer.step()

        avg_loss = batch_loss / valid_samples
        pbar.set_postfix({'Avg Loss': f'{avg_loss:.4f}'})

    if batch_idx >= 25:  # Limit to 200 samples (25 batches of 8)
        break

# Generate text using the tuned embeddings
print("norms after", tuned_embeds().norm(dim=-1))

norms before tensor([[4.1562, 1.7734, 1.8047, 1.8203, 1.8906, 1.8203, 2.3438]],
       device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)


  0%|          | 0/273 [00:00<?, ?it/s]

  9%|▉         | 25/273 [00:34<05:44,  1.39s/it, Avg Loss=1.5210]

norms after tensor([[4.9375, 3.2500, 3.2500, 3.2031, 3.2500, 3.2812, 3.6562]],
       device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<LinalgVectorNormBackward0>)





In [10]:
torch.set_grad_enabled(False)
print_comparison()

{'orig': 'The recipe was a revelation. It was a hearty'}
{'transfer': '<unused8>'}
{'no_transfer': ''}
