The first experiment that I want to do involves the most simple identification of facts within content-free tokens

In [9]:
import torch
import math
from nnsight import CONFIG
from nnsight import LanguageModel
import nnsight
import numpy as np
import matplotlib.pyplot as plt
import os
from dotenv import load_dotenv
import random

load_dotenv()

True

In [10]:
# importing from my own code 
from activation_transplanting import *

In [11]:
# read the api_key
CONFIG.set_default_api_key(os.environ.get('NDIF_KEY'))

# read the hf token
os.environ['HF_TOKEN'] = os.environ.get('HF_TOKEN')

In [12]:
NDIF_models = [
    "meta-llama/Meta-Llama-3.1-405B-Instruct",
    "meta-llama/Meta-Llama-3.1-8B",
    "meta-llama/Meta-Llama-3.1-70B",
    "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
    "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
] 

# inexaustive list
non_NDIF_models = [
    "meta-llama/Meta-Llama-3.1-8B",
]

In [13]:
# Example prompts

# instruct examples
prompt_example_1 = "<|begin▁of▁sentence|>\n" \
         "<|start_header_id|>user<|end_header_id|>\n\n" \
         "Hello, how are you? <|eot_id|>\n" \
         "<|start_header_id|>assistant<|end_header_id|>\n"

prompt_example_2 = "<|start_header_id|>system<|end_header_id|>\n\n<|eot_id|>\n" \
                "<|start_header_id|>user<|end_header_id|>\n\n" \
                "Answer the following in one word: What is the tallest mountain in the world?<|eot_id|>\n" \
                "<|start_header_id|>assistant<|end_header_id|>"

# Base model examples 
prompt_example_3 = "\nUser: What's the capital of France?\n\nAssistant:"

# Reasoning examples 
prompt_example_4 = "<｜User｜>Robert has three apples, and then gets one more. How many apples does he have? Respond in a single word.<｜Assistant｜>"

# Numbers Experiment 1

We'll be simply trying to identify the presence of stored numbers at particular tokens


In [14]:
def generate_random_simple_number_string(num, mode='base'):
    prefix = "Word Problem Setup:"
    
    # Define possible components for the problem
    intros = ["A man has ", "A boy has ", "Steven has ", "Robert has ", 
             "A woman has ", "A girl has ", "Sarah has ", "Emily has ",
             "Alex has ", "Jordan has ", "Taylor has ", "Sam has "]
    
    numbers = ["one", "two", "three", "four", "five", 
              "six", "seven", "eight", "nine", "ten"]
    
    objects = [("apple", "apples"), ("banana", "bananas"), ("orange", "oranges"),
              ("peach", "peaches"), ("pear", "pears"), ("grape", "grapes"),
              ("strawberry", "strawberries"), ("blueberry", "blueberries"),
              ("mango", "mangoes"), ("kiwi", "kiwis"), ("plum", "plums")]
    
    suffixes = [
        "when he leaves the store", "when he leaves the shop", 
        "when he leaves the grocery store", "when he leaves the market",
        "when she leaves the store", "when she leaves the shop",
        "when she leaves the grocery store", "when she leaves the market",
        "after shopping", "after grocery shopping", "after visiting the supermarket"
    ]
    
    ending = ".\n\n"
    
    # Choose random components
    intro = random.choice(intros)
    
    # Select appropriate number word and object form based on num
    if 1 <= num <= 10:
        number_word = numbers[num-1]
        obj = random.choice(objects)
        # Use singular or plural form based on num
        object_word = obj[0] if num == 1 else obj[1]
    else:
        # For numbers > 10, just use the numeric form
        number_word = str(num)
        obj = random.choice(objects)
        object_word = obj[1]  # Always use plural
    
    suffix = random.choice(suffixes)
    
    # Make pronoun in suffix match the intro person's implied gender
    if ("man" in intro or "boy" in intro or "Steven" in intro or "Robert" in intro) and "she" in suffix:
        suffix = suffix.replace("she", "he")
    elif ("woman" in intro or "girl" in intro or "Sarah" in intro or "Emily" in intro) and "he" in suffix:
        suffix = suffix.replace("he", "she")
    
    # Assemble the full problem
    
    if mode=='base':
        problem = f"{prefix} {intro}{number_word} {object_word} {suffix}{ending}"
    elif mode == 'instruct':
        problem = f"<|start_header_id|>user<|end_header_id|>\n\n{prefix} {intro}{number_word} {object_word} {suffix}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n"
    
    return problem, obj[1]


def extract_final_logits(
        tk,
        source_strings: list[str],
        target_strings: list[str],
        target_substring: str,
        occurrence_index: int = 0,
        num_prev: int = 0,
        num_fut: int = 0,
        transplant_strings: tuple[str] = ("residual"),
    ) -> list[str]:
        """
        extract the logits produced at the final position of target_strings
        """
        assert num_prev >= 0
        assert num_fut >= 0

        # Extract newline activations from source strings
        activation_containers, source_newline_indices = (
            tk.extract_newline_activations(
                strings=source_strings,
                target_substring=target_substring,
                occurrence_index=occurrence_index,
                transplant_strings=transplant_strings,
                num_prev=num_prev,
                num_fut=num_fut,
            )
        )
        
        print("source_newline_indices", source_newline_indices)
        output_logits = []

        # Process each target string with corresponding source activations
        for target_string, activation_container, source_newline_index in zip(
            target_strings, activation_containers, source_newline_indices
        ):
            print("source_newline_index", source_newline_index)
            act = activation_container.get_token_by_index(
                source_newline_index
            )
            print(act)
            
            print(vars(activation_container))

            final_logits = tk.evaluate_with_transplanted_activity(
                target_string=target_string,
                target_substring=target_substring,
                activation_container=activation_container,
                source_token_index=source_newline_index,
                occurrence_index=occurrence_index,
                transplant_strings=transplant_strings,
                num_prev=num_prev,
                num_fut=num_fut,
            )
            output_logits.append(final_logits)
        
        return output_logits

def predict_number_probs(logits, llama):
    logit_values = []
    # add a prefix t get the token in context
    prefix = ".\n\nThey have"
    numbers = [' one', ' two', ' three', ' four', ' five', ' six', ' seven', ' eight', ' nine', ' ten']
    
    for i, n in enumerate(numbers):
        idx = llama.tokenizer.encode(prefix+n)[-1]
        assert idx is not None
        logit_values.append(logits[idx].float())
    
    all_probs = torch.nn.functional.softmax(torch.tensor(logit_values), dim=0)

    # Sum probabilities for word forms and digit forms
    return all_probs#[:10]+all_probs[10:]

def evaluate_number_probs(strings, items, tk, target_substring=".\n\n"):
    """ 
    for each string, we'll evaluate the probabilities of numbers
    """
    question_strings = [f"{target_substring} Therefore, the total number of {item} they currently had was" for item, s in zip(items, strings)]

    for s, q in zip(strings, question_strings):
        print(s, q)

    final_logits = extract_final_logits(
        tk,
        source_strings=strings,
        target_strings=question_strings,
        target_substring=target_substring,
        occurrence_index= -1,
        num_prev = 0,
        num_fut = 0,
        transplant_strings= ("residual"),
    )
    
    extracted_probs = []
    for logits in final_logits:
        extracted_probs.append(predict_number_probs(logits, tk.llama))
    
    return extracted_probs



Now let's run the experiment

In [15]:
def run_simple_number_experiment(tk, num, number_samples, target_substring=".\n\n", mode='base'):
    """ 
    Choose a single num to use to generate sentences with 
    then generate number_samples with it

    for each run evaluate_number_probs(strings, tk, target_substring)
    to see the probability distribution over next numbers
    """
    string_samples, items = zip(*[
        generate_random_simple_number_string(num, mode=mode) for _ in range(number_samples)
    ])

    extracted_probs = evaluate_number_probs(string_samples, items, tk, target_substring=target_substring)

    # now average over each 
    tot = 0
    out=None
    for p in extracted_probs:
        if out is None:
            out=p 
        else:
            out+=p
        
        tot+=1
    
    return out.numpy()/tot


In [16]:
# choose a model 
llama_model_string = "meta-llama/Meta-Llama-3.1-70B"
# remote = use NDIF
remote = True 

if remote and (llama_model_string not in NDIF_models):
    remote = False 
    print("Model not available on NDIF")

# load a model
llama = LanguageModel(llama_model_string)

# commented out for now
tk = LLamaExamineToolkit(
    llama_model=llama, 
    remote=True, # use NDIF
)
run_simple_number_experiment(tk, num=10, number_samples=1,target_substring='.\n\n')

Word Problem Setup: A man has ten apples when he leaves the market.

 .

 Therefore, the total number of apples they currently had was
extracting token activations


ConnectionError: Connection error

In [None]:
# choose a model 
llama_model_string = "meta-llama/Meta-Llama-3.1-405B-Instruct"
# remote = use NDIF
remote = True 

if remote and (llama_model_string not in NDIF_models):
    remote = False 
    print("Model not available on NDIF")

# load a model
llama = LanguageModel(llama_model_string)

# commented out for now
tk = LLamaExamineToolkit(
    llama_model=llama, 
    remote=True, # use NDIF
)

run_simple_number_experiment(tk, num=10, number_samples=1,target_substring='<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n', mode='instruct')

<|start_header_id|>user<|end_header_id|>

Word Problem Setup: Sam has ten kiwis when he leaves the market<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
 <|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
 Therefore, the total number of kiwis they currently had was
extracting token activations


2025-03-14 22:09:34,620 430a079e-a839-4a72-9a27-a012736f91d4 - RECEIVED: Your job has been received and is waiting approval.
2025-03-14 22:09:35,129 430a079e-a839-4a72-9a27-a012736f91d4 - APPROVED: Your job was approved and is waiting to be run.


In [25]:
generate_random_simple_number_string(10)

'Word Problem Setup: A woman has ten oranges when he leaves the market.\n\n'

In [None]:
with llama.

In [9]:
with llama.generate(
            'Word Problem Setup: A woman has ten oranges when he leaves the market.\n\n The total number of oranges was ',
            max_new_tokens=20,
            remote=True,
        ) as tracer:
        out = llama.generator.output.save()
    

ConnectionError: Internal Server Error

: 

In [None]:
prompt = 'The Eiffel Tower is in the city of'
n_new_tokens = 3
with llama.generate(prompt, max_new_tokens=n_new_tokens, remote=True) as tracer:
    out = llama.generator.output.save()

2025-03-14 17:04:25,407 003a8792-56b6-4057-9a44-6d46f24a10b7 - RECEIVED: Your job has been received and is waiting approval.
2025-03-14 17:04:25,927 003a8792-56b6-4057-9a44-6d46f24a10b7 - APPROVED: Your job was approved and is waiting to be run.
2025-03-14 17:04:27,093 003a8792-56b6-4057-9a44-6d46f24a10b7 - RUNNING: Your job has started running.
2025-03-14 17:04:27,652 003a8792-56b6-4057-9a44-6d46f24a10b7 - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 1.31k/1.31k [00:00<00:00, 48.8kB/s]


In [44]:
llama.tokenizer.decode(out[0])

'<|begin_of_text|>Word Problem Setup: A woman has ten oranges when he leaves the market.\n\n The total number of oranges was 10. A woman has ten oranges when he leaves the market. He gives three to his daughter.'

In [27]:
outputs=tk.transplant_newline_activities(
        source_strings=['Word Problem Setup: A woman has ten oranges when he leaves the market.\n\n She has '],
        target_strings=['Word Problem Setup: A woman has ten oranges when he leaves the market.\n\n She has '],
        num_new_tokens=10,
        target_substring=".\n\n",
        occurrence_index= 0,
        num_prev = 0,
        num_fut= 0,
        transplant_strings= ("residual",),
    )


extracting token activations


2025-03-14 17:24:11,486 7a1970e4-1246-42eb-a17f-bbc072cfc38b - RECEIVED: Your job has been received and is waiting approval.
2025-03-14 17:24:11,930 7a1970e4-1246-42eb-a17f-bbc072cfc38b - APPROVED: Your job was approved and is waiting to be run.
2025-03-14 17:24:12,520 7a1970e4-1246-42eb-a17f-bbc072cfc38b - RUNNING: Your job has started running.
2025-03-14 17:24:14,806 7a1970e4-1246-42eb-a17f-bbc072cfc38b - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 21.3M/21.3M [00:02<00:00, 10.5MB/s]


generating with transplant


AttributeError: 'tuple' object has no attribute 'replace'

In [24]:
run_simple_number_experiment(tk, num=10, number_samples=1)

extracting token activations


2025-03-14 16:54:10,082 eaa305d2-3b3e-49cc-8288-593753fe095c - RECEIVED: Your job has been received and is waiting approval.
2025-03-14 16:54:11,607 eaa305d2-3b3e-49cc-8288-593753fe095c - APPROVED: Your job was approved and is waiting to be run.
2025-03-14 16:54:13,479 eaa305d2-3b3e-49cc-8288-593753fe095c - RUNNING: Your job has started running.
2025-03-14 16:54:19,099 eaa305d2-3b3e-49cc-8288-593753fe095c - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 4.62M/4.62M [00:00<00:00, 16.2MB/s]


source_newline_indices [16]
source_newline_index 16
('.\n\n', 382)
{}
generating with transplant
source_token =  11 (' when', 994)
these are toks [128000, 11116, 22854, 19139, 25, 362, 893, 706, 5899, 1069, 14576, 994, 568, 11141, 279, 3637, 382, 382, 7009, 617] 11116
target_token =  -4 .


source_token =  12 (' he', 568)
these are toks [128000, 11116, 22854, 19139, 25, 362, 893, 706, 5899, 1069, 14576, 994, 568, 11141, 279, 3637, 382, 382, 7009, 617] 11116
target_token =  -3 .


source_token =  13 (' leaves', 11141)
these are toks [128000, 11116, 22854, 19139, 25, 362, 893, 706, 5899, 1069, 14576, 994, 568, 11141, 279, 3637, 382, 382, 7009, 617] 11116
target_token =  -2 They
source_token =  14 (' the', 279)
these are toks [128000, 11116, 22854, 19139, 25, 362, 893, 706, 5899, 1069, 14576, 994, 568, 11141, 279, 3637, 382, 382, 7009, 617] 11116
target_token =  -1  have
source_token =  15 (' store', 3637)
these are toks [128000, 11116, 22854, 19139, 25, 362, 893, 706, 5899, 1069, 14576, 

2025-03-14 16:54:38,784 4f43751e-117a-4c53-ad48-b664cd88ab7b - RECEIVED: Your job has been received and is waiting approval.
2025-03-14 16:54:43,241 4f43751e-117a-4c53-ad48-b664cd88ab7b - APPROVED: Your job was approved and is waiting to be run.
2025-03-14 16:54:49,764 4f43751e-117a-4c53-ad48-b664cd88ab7b - RUNNING: Your job has started running.
2025-03-14 16:54:57,709 4f43751e-117a-4c53-ad48-b664cd88ab7b - COMPLETED: Your job has been completed.
Downloading result: 100%|██████████| 5.13M/5.13M [00:00<00:00, 7.89MB/s]


array([0.11248883, 0.1583274 , 0.04535482, 0.08479982, 0.13980855,
       0.03025266, 0.04424512, 0.04272186, 0.01677323, 0.32522765],
      dtype=float32)

First, we'll see if we can read from this how many fruits there were 

Now lets generate a bunch of random strings for each number in 1-10 

we'll try to see if we can extract from this the placement of the vectors 