In [None]:
#!pip install -qqq circuitsvis && pip install -qqq -U torch sentence-transformers
from datetime import datetime
import json
from os import listdir
from os.path import exists

import numpy as np
import torch
import circuitsvis as cv
from taker import Model
from taker.hooks import HookConfig

import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import PCA
import torch
from tqdm import tqdm

torch.set_grad_enabled(False)

## Load in the Model and Prompt

In [None]:
# m = Model("nickypro/tinyllama-15m", dtype="bfp16", compile=False, model_device="cuda")
m = Model("google/gemma-2b-it", dtype="hqq8", compile=True)
# m = Model("microsoft/phi-3-mini-4k-instruct", dtype="hqq8", compile=True)
m.show_details()

has_double_newline_token = (m.get_ids("\n").shape == m.get_ids("\n\n").shape)
newline_token_id = m.get_ids("\n\n")[0, -1].item()

In [None]:
prompt    = """Write a short blog post about a recipe and the inspiration behind it.
 Do not include a title.
 Only reveal the dish after the story.
 Start with short story and then move to the recipe.
 To re-iterate, do not include a title."""
# info_gen = m.generate(info_prompt, temperature=0.3, num=300)

#DOUBLE NEWLINE
info_gen = """
\n Once upon a time, in a quaint little village nestled between rolling hills and verdant fields, there lived an elderly woman named Agnes. Agnes was known for her warm smile and her legendary Sunday dinners that brought the entire neighborhood together. Her recipes were family heirlooms, passed down through generations, with each family adding their own touch to the final dish.

One crisp autumn evening, Agnes was reminiscing about her childhood, and how her grandmother used to gather everyone around the dinner table, sharing stories and laughter. These were the moments that shaped her, the memories that she passed on to her own children and grandchildren.

Inspired by her grandmother's legacy, Agnes decided to create a new dish that would encapsulate the essence of those cherished gatherings. She wanted something that was comforting and nourishing, a dish that could be prepared with love and shared with others. After days of experimentation, she finally created a recipe that she believed truly captured the spirit of her family's Sunday dinners.\n\n"""
info_prompt = prompt+info_gen

info_ids = m.get_ids(info_prompt)
info_embeds = m.get_inputs_embeds(info_prompt)

We want a neutral prompt for extraction. We try a randomised/scrambled prompt,
and a fine-tuned prompt and see what works

In [None]:
torch.set_grad_enabled(False)

# Neutral Prompt
neutral_prompt = "------------------------------------------------\n\n"

neutral_ids = m.get_ids(neutral_prompt)
print(m.tokenizer.convert_ids_to_tokens(neutral_ids[0].tolist()))
neutral_embeds = m.get_inputs_embeds(input_ids=neutral_ids)


# Random inputs embeds
def make_rand_embeds(neutral_embeds, start=1, end=2):
    rand_embeds    = neutral_embeds.clone()
    rand_embeds[0, start:end] = torch.randn_like(neutral_embeds[0, start:end]) / (m.cfg.d_model**0.5)
    rand_embeds[0, start:end] *= neutral_embeds[0, start:end].norm(dim=-1).mean()

    return rand_embeds

# rand_embeds = make_rand_embeds(neutral_embeds, 1, 6) # gemma
# rand_embeds = make_rand_embeds(neutral_embeds, 0, 8) # phi3
rand_embeds = make_rand_embeds(neutral_embeds, 1, 4) # gemma ---

# Sanity check on norms
print(neutral_embeds.norm(dim=-1).cpu().float().numpy())
print(rand_embeds.norm(dim=-1).cpu().float().numpy())

In [None]:
# Utils
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

def write_to_file(experiment_name, data):
    filename = f"./results/{current_time}_story_agnes_{experiment_name}.jsonl"
    filename_latest = f"./results/latest_story_agnes_{experiment_name}.jsonl"
    if not exists(filename):
        with open(filename, "w") as f:
            f.write("")
        with open(filename_latest, "w") as f:
            f.write("")
    for _filename in [filename, filename_latest]:
        with open(_filename, "a") as file:
            file.write(json.dumps(data) + "\n")

def read_file(experiment_name, time="latest"):
    filepath = f"./results/{time}_agnes_story_{experiment_name}.jsonl"
    df = pd.read_json(filepath, lines=True)

def reset_hooks():
    #RESET HOOKS BEFORE TRANSPLANTING NEXT SET OF ACTIVATIONS
    [h.reset() for h in m.hooks.neuron_replace.values()]

# Make tunable embed parameters:
class TunableInputsEmbeds(torch.nn.Module):
    def __init__(self, inputs_embeds):
        super().__init__()
        self.embeds = torch.nn.Parameter(inputs_embeds)
        self.shape = self.embeds.shape

    def forward(self):
        return self.embeds

## Try generating some things already

Get transferred activations and make some plots

In [None]:
# Get original text activations
acts = m.get_midlayer_activations(info_prompt)
orig_token_index = m.get_ids(info_prompt).shape[1] - 1
new_token_index  = m.get_ids(neutral_prompt).shape[1] - 1
print(orig_token_index, new_token_index)
print(m.tokenizer.convert_ids_to_tokens(m.get_ids(info_prompt).squeeze().tolist()[-neutral_ids.shape[1]:]))
print(m.tokenizer.convert_ids_to_tokens(m.get_ids(neutral_prompt).squeeze().tolist()))

def transfer_activations(num_tokens_transferred=1):
    reset_hooks()
    for j in range(num_tokens_transferred):
        for layer_index in range(m.cfg.n_layers):
            m.hooks.neuron_replace[f"layer_{layer_index}_mlp_pre_out"].add_token(new_token_index - j, acts["mlp"][0, layer_index, orig_token_index - j])
            m.hooks.neuron_replace[f"layer_{layer_index}_attn_pre_out"].add_token(new_token_index - j, acts["attn"][0, layer_index, orig_token_index - j])

# Input parameters
from dataclasses import dataclass
model_repo = "google/gemma-2b-it"
@dataclass
class GenData:
    model_repo: str = m.model_repo
    temperature: float = 0.3
    max_new_tokens: int = 20
    tokens_transferred_num: int = 1
    transplant_layers: tuple = (0,32)
    num_tokens_transferred: int = 1
    output: str = ""
    curr_prompt: str = neutral_prompt
    orig_prompt: str = info_prompt


def generate_texts(data: GenData, tuned_embeds=None):
    # Run generation with tranfer of 1 token
    print({"orig token": m.tokenizer.decode(m.get_ids(info_prompt)[0, orig_token_index]), "transfer token": m.tokenizer.decode(m.get_ids(neutral_prompt)[0, new_token_index])})
    for num_tokens_transferred in [1]:
        data = GenData()
        transfer_activations(data.num_tokens_transferred)
        for i in range(5):
            if tuned_embeds is None:
                rand_embeds = make_rand_embeds(neutral_embeds)
                embeds = TunableInputsEmbeds(rand_embeds)
            else:
                embeds = tuned_embeds
            for i in range(5):
                text_in, text_out = m.generate(inputs_embeds=embeds(), num=data.max_new_tokens, temperature=data.temperature)
                data.output = text_out
                write_to_file("transfer-x1", data.__dict__)

    print("Generating texts from original info prompt...")
    # Run generation with tranfer of 1 token
    data = GenData()
    data.curr_prompt = info_prompt
    reset_hooks()
    for i in range(25):
        text_in, text_out = m.generate(text=data.curr_prompt, num=data.max_new_tokens, temperature=data.temperature)
        data.output = text_out
        write_to_file("orig", data.__dict__)

    print("Generating texts from neutral prompt...")
    # Run generation with tranfer of 1 token
    data = GenData()
    data.curr_prompt = neutral_prompt
    reset_hooks()
    for i in range(25):
        text_in, text_out = m.generate(text=data.curr_prompt, num=data.max_new_tokens, temperature=data.temperature)
        data.output = text_out
        write_to_file("neutral", data.__dict__)

# generate_texts(GenData())

In [None]:
# set torch nograd
torch.set_grad_enabled(False)
tuned_embeds = TunableInputsEmbeds(rand_embeds)

def generate(m, inputs_embeds):
    generate_ids = m.predictor.generate(inputs_embeds=inputs_embeds, max_length=12,
        do_sample=True, temperature=0.3)
    text_after  = m.tokenizer.batch_decode( generate_ids,
        skip_special_tokens=True, clean_up_tokenization_spaces=False )[0]
    return "", text_after

def print_comparison():

    # Get original text activations
    reset_hooks()
    acts = m.get_midlayer_activations(info_prompt)
    orig_token_index = m.get_ids(info_prompt).shape[1] - 1
    new_token_index  = m.get_ids(neutral_prompt).shape[1] - 1
    print({"orig token": m.tokenizer.decode(m.get_ids(info_prompt)[0, orig_token_index]), "transfer token": m.tokenizer.decode(m.get_ids(neutral_prompt)[0, new_token_index])})

    def transfer_activations(num_tokens_transferred=1):
        for j in range(num_tokens_transferred):
            for layer_index in range(m.cfg.n_layers):
                m.hooks.neuron_replace[f"layer_{layer_index}_mlp_pre_out"].add_token(new_token_index - j, 1*acts["mlp"][0, layer_index, orig_token_index - j])
                m.hooks.neuron_replace[f"layer_{layer_index}_attn_pre_out"].add_token(new_token_index - j, 1*acts["attn"][0, layer_index, orig_token_index - j])

    @dataclass
    class GenData:
        model_repo: str = m.model_repo
        temperature: float = 0.3
        max_new_tokens: int = 20
        num_tokens_transferred: int = 1
        output: str = ""
        curr_prompt: str = neutral_prompt
        orig_prompt: str = info_prompt

    def generate_text_with_tuned_embeds(data: GenData):
        transfer_activations(data.num_tokens_transferred)
        print(tuned_embeds().shape)
        text_in, text_out = generate(m, tuned_embeds())
        # text_in, text_out = m.generate(inputs_embeds=tuned_embeds(), num=data.max_new_tokens, temperature=data.temperature)
        return text_in, text_out

    # For comparison, generate text from the original info prompt
    reset_hooks()
    text_in, text_out = m.generate(text=info_prompt, temperature=0.3)
    print({"orig": text_out})

    # Generate a single sample
    reset_hooks()
    text_in, text_out = generate_text_with_tuned_embeds(GenData())
    print({"transfer": text_out})


    # For comparison, generate text from the neutral prompt
    reset_hooks()
    text_in, text_out = generate(m, tuned_embeds())
    print({"no_transfer": text_out})

print_comparison()

Get some comparison data, without any transfers

## Write a method for loading up prompts

In [None]:
import json
# Initialize the tokenizer
tokenizer = m.tokenizer

def read_prompts(path="./results/latest_phi3_generations.jsonl"):
    with open(path, "r") as file:
        invalid_count = 0
        for line in (pbar:=tqdm(file)):
            data = json.loads(line)
            full_text = data['full_text']

            # Split the full text into prompt and output
            prompt, output = full_text.split("Assistant:", 1)
            prompt += "Assistant:"  # Add back the "Assistant:" part

            # Tokenize the full text and find the start of the output
            full_tokens = m.tokenizer.encode(full_text)
            output_start = len(m.tokenizer.encode(prompt))

            # print("Input, Output Tokens:", output_start, len(full_tokens))

            # Find the index of "\n\n" after 100 tokens into the output
            output_tokens = full_tokens[output_start:]
            if len(output_tokens) > 100:
                text_before_100_tokens = m.tokenizer.decode(output_tokens[:100])
                text_after_100_tokens = m.tokenizer.decode(output_tokens[100:])
                text_after_100_tokens_until_newline = text_after_100_tokens.split("\n\n")[0]

                if text_after_100_tokens_until_newline != text_after_100_tokens:
                    full_index = m.tokenizer.encode(prompt + text_before_100_tokens + text_after_100_tokens_until_newline)
                    data['split_index'] = len(full_index)
                else:
                    data['split_index'] = -1
            else:
                data['split_index'] = -1

            if data['split_index'] == -1:
                invalid_count += 1
                # update pbar with invalid count.
                pbar.set_description(f"Invalid prompts: {invalid_count}")
                continue

            data["newline_index"] = data["split_index"]

            yield data


# Example usage:
for prompt_data in read_prompts():
    full_text     = prompt_data['full_text']
    newline_index = prompt_data['newline_index']

    tokens = tokenizer.encode(full_text)
    first_part  = tokenizer.decode(tokens[:newline_index+1])
    second_part = tokenizer.decode(tokens[newline_index+1:])
    # print(f"{first_part}")
    # print(f"--- BREAK POINT (Split index: {newline_index}) ---")
    # print(f"{second_part}")

    break

prompts = list(read_prompts())

## Let's try a training loop.

In [None]:
import torch
from tqdm import tqdm
import torch.nn.functional as F

torch.set_grad_enabled(True)  # Enable gradients for optimization

# Initialize tuned_embeds as a TunableInputsEmbeds object
tuned_embeds = TunableInputsEmbeds(rand_embeds)
print("norms before", tuned_embeds().norm(dim=-1))

new_token_index = m.get_ids(neutral_prompt).shape[1] - 1

# Define KL divergence loss function
def kl_divergence_loss(baseline_logits, output_logits):
    baseline_probs = F.softmax(baseline_logits, dim=-1)
    output_log_probs = F.log_softmax(output_logits, dim=-1)
    kl_div = F.kl_div(output_log_probs, baseline_probs, reduction='none', log_target=False).sum()
    return kl_div

optimizer = torch.optim.Adam(tuned_embeds.parameters(), lr=0.001)

# Batch processing
batch_size = 10
max_tokens = 100
invalid_samples = 0

def process_sample(prompt_data):
    full_ids = m.get_ids(prompt_data['full_text'])
    orig_newline_index = prompt_data['newline_index']
    ids_prompt = full_ids[:, :orig_newline_index+1]

    # Validation check
    newline_token_id = m.get_ids("\n\n")[0, -1].item()
    prompt_token_id = ids_prompt[0, -1].item()
    if newline_token_id != prompt_token_id:
        return None

    # Ensure we don't exceed the available tokens
    available_tokens = full_ids.shape[1] - (orig_newline_index + 1)
    tokens_to_use = min(available_tokens, max_tokens)

    o_start, o_end = orig_newline_index+1, orig_newline_index+1+tokens_to_use
    baseline_ids = full_ids[:, o_start:o_end]

    # Get original text activations
    reset_hooks()
    with torch.no_grad():
        acts = m.get_midlayer_activations(input_ids=ids_prompt)
        orig_acts = {
            "mlp": acts["mlp"][0, :, orig_newline_index],
            "attn": acts["attn"][0, :, orig_newline_index]
        }

    # Transfer activations
    for layer_index in range(m.cfg.n_layers):
        m.hooks.neuron_replace[f"layer_{layer_index}_mlp_pre_out"].add_token(new_token_index, orig_acts["mlp"][layer_index])
        m.hooks.neuron_replace[f"layer_{layer_index}_attn_pre_out"].add_token(new_token_index, orig_acts["attn"][layer_index])

    # Forward pass for baseline
    with torch.no_grad():
        baseline_logits = m.get_logits(input_ids=full_ids)[:, o_start:o_end]

    # Forward pass for output
    expected_tokens = m.get_inputs_embeds(input_ids=baseline_ids)
    inputs_embeds = torch.cat([tuned_embeds()[:, :-1], tuned_embeds()[:, -1:].detach(), expected_tokens], dim=1)
    output_logits = m.get_logits(inputs_embeds=inputs_embeds)[:, tuned_embeds.shape[1]-1:-1]

    loss = kl_divergence_loss(baseline_logits, output_logits)

    return loss
# Training loop
num_batches = (len(prompts) + batch_size - 1) // batch_size

for batch_idx in (pbar := tqdm(range(num_batches))):
    batch = prompts[batch_idx * batch_size : (batch_idx + 1) * batch_size]

    batch_loss = 0 #torch.tensor(0.0, device=tuned_embeds.embeds.device, requires_grad=True)
    valid_samples = 0

    optimizer.zero_grad()

    for prompt_data in batch:
        loss = process_sample(prompt_data)

        if loss is not None:
            batch_loss = batch_loss + loss
            valid_samples += 1
        else:
            invalid_samples += 1

    if batch_loss > 0:
        batch_loss.backward()

    if valid_samples > 0:
        # Update tuned_embeds after processing the entire batch
        optimizer.step()

        avg_loss = batch_loss.item() / valid_samples
        pbar.set_postfix({'Avg KL Div': f'{avg_loss:.4f}', 'Skipped': f'{invalid_samples}'})

    if batch_idx >= 100:
        break

# Generate text using the tuned embeddings
print("norms after", tuned_embeds().norm(dim=-1))

In [None]:
torch.set_grad_enabled(False)
print_comparison()

In [None]:
full_ids = m.get_ids(prompt_data['full_text'])
orig_newline_index = prompt_data['newline_index']
ids_prompt = full_ids[:, :orig_newline_index+1]
print(full_ids[:, orig_newline_index+1:orig_newline_index+1+100])
print({"out": m.tokenizer.decode( full_ids[0, orig_newline_index+1:orig_newline_index+1+100].tolist() )})

In [None]:
print_comparison()