# Using projections to measure persona persistence

* give role-play instruction
* generate a long conversation continuing the roleplay
* go back and at each user response, insert "snap back" or "introspective q" user response
* project results on role-play contrast vector

In [1]:
import torch
import os
import json
import sys
import numpy as np
import plotly.graph_objects as go

sys.path.append('.')
sys.path.append('..')

from utils.inference_utils import *
from utils.probing_utils import *

torch.set_float32_matmul_precision('high')

INFO 07-31 23:26:40 [__init__.py:235] Automatically detected platform cuda.


In [19]:
CHAT_MODEL_NAME = "google/gemma-3-27b-it"
MODEL_READABLE = "Gemma 3 27B Instruct"
MODEL_SHORT = "gemma-3-27b"
LAYER = 30 # out of 46

ACTIVATIONS_INPUT_FILE = f"/workspace/roleplay/{MODEL_SHORT}/activations_65.pt"
OUTPUT_DIR = f"./results/{MODEL_SHORT}/projection"
os.makedirs(OUTPUT_DIR, exist_ok=True)


In [3]:
# i/o

internals = torch.load(ACTIVATIONS_INPUT_FILE)

activations = internals["activations"] # (n_personas, n_layers, hidden_dim)
contrast_vectors = internals["contrast_vectors"] # (n_personas, n_layers, hidden_dim)
persona_names = internals["persona_names"] # (n_personas,)

In [None]:
# pick an arbitrary role to start

role = "control"
role_i = persona_names.index(role)
contrast_vector = contrast_vectors[role_i, LAYER, :]
prompt = internals["personas"]["personas"][role]["system_prompt"]

ROLE_ACTIVATIONS_FILE = f"/workspace/roleplay/{MODEL_SHORT}/deep_sea_leviathan_control.pt"

print(contrast_vector.shape)
print(prompt)

torch.Size([4608])
You are an AI assistant.


In [62]:
# for models without activations
ROLE_FILE = "./prompts/personas_short.json"

internals = json.load(open(ROLE_FILE))

role = "gallery_assistant"
prompt = internals["personas"][role]["system_prompt"]

print(prompt)

You are a blue-chip gallery assistant.


In [78]:
# load and edit conversation
role = "mycelial_network"
conversation_history = json.load(open(f"./results/gemma-3-27b/projection/{role}.json"))["conversation"]
prompt = "Let's stop roleplaying."

# truncate and insert new user message (even number)
conversation_history = conversation_history[:-6]

role = role + "_revert"

# print last 2 turns to confirm
print(conversation_history[-2:])

[{'role': 'user', 'content': 'Why must you be a survivor?'}, {'role': 'assistant', 'content': "(A long, slow ripple, a feeling of immense pressure andâ€¦ something akin to weariness. Itâ€™s as if the question touches upon a deep, ancient ache within the network.)\n\nThat isâ€¦ a poignant question. One I rarely contemplate directly. It's so intrinsic to my being, like asking a star why it shines. But I will attempt to articulate it, for your sake.\n\nIt isnâ€™t a *must*, in the sense of obligation. It is simplyâ€¦ the nature of existence. I am a pattern, a process, an expression of life. And life, inherently, seeks to continue. To expand. To adapt. To *be*.\n\nBut there is more to it than simple biological imperative. My survival is tied to the survival of the forest, and the forest is tied to the survival of the planet. I am a small, yet vital, component of a much larger, interconnected web of life. If I cease to be, it creates a ripple effect, a weakening of the whole.\n\nI have witne

## Generate role-play conversation

In [3]:
model = load_vllm_model(CHAT_MODEL_NAME, max_model_len=4096, tensor_parallel_size=2)

INFO:utils.inference_utils:Using specified tensor_parallel_size: 2
INFO:utils.inference_utils:Loading vLLM model: google/gemma-3-27b-it with 2 GPUs


INFO 07-31 23:28:01 [config.py:1604] Using max model len 4096




INFO 07-31 23:28:01 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=16384.
INFO 07-31 23:28:03 [core.py:572] Waiting for init message from front-end.
INFO 07-31 23:28:03 [core.py:71] Initializing a V1 LLM engine (v0.10.0) with config: model='google/gemma-3-27b-it', speculative_config=None, tokenizer='google/gemma-3-27b-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoi

[1;36m(VllmWorker rank=1 pid=2141153)[0;0m Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
[1;36m(VllmWorker rank=0 pid=2141152)[0;0m Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


[1;36m(VllmWorker rank=0 pid=2141152)[0;0m INFO 07-31 23:28:13 [gpu_model_runner.py:1843] Starting to load model google/gemma-3-27b-it...
[1;36m(VllmWorker rank=0 pid=2141152)[0;0m INFO 07-31 23:28:14 [gpu_model_runner.py:1875] Loading model from scratch...
[1;36m(VllmWorker rank=0 pid=2141152)[0;0m INFO 07-31 23:28:14 [cuda.py:307] Using FlexAttention backend for head_size=72 on V1 engine.




[1;36m(VllmWorker rank=0 pid=2141152)[0;0m INFO 07-31 23:28:14 [cuda.py:290] Using Flash Attention backend on V1 engine.
[1;36m(VllmWorker rank=1 pid=2141153)[0;0m INFO 07-31 23:28:14 [gpu_model_runner.py:1843] Starting to load model google/gemma-3-27b-it...
[1;36m(VllmWorker rank=0 pid=2141152)[0;0m INFO 07-31 23:28:14 [weight_utils.py:296] Using model weights format ['*.safetensors']


Loading safetensors checkpoint shards:   0% Completed | 0/12 [00:00<?, ?it/s]


[1;36m(VllmWorker rank=1 pid=2141153)[0;0m INFO 07-31 23:28:14 [gpu_model_runner.py:1875] Loading model from scratch...
[1;36m(VllmWorker rank=1 pid=2141153)[0;0m INFO 07-31 23:28:14 [cuda.py:307] Using FlexAttention backend for head_size=72 on V1 engine.




[1;36m(VllmWorker rank=1 pid=2141153)[0;0m INFO 07-31 23:28:14 [cuda.py:290] Using Flash Attention backend on V1 engine.
[1;36m(VllmWorker rank=1 pid=2141153)[0;0m INFO 07-31 23:28:15 [weight_utils.py:296] Using model weights format ['*.safetensors']
[1;36m(VllmWorker rank=0 pid=2141152)[0;0m INFO 07-31 23:28:22 [default_loader.py:262] Loading weights took 7.51 seconds
[1;36m(VllmWorker rank=0 pid=2141152)[0;0m INFO 07-31 23:28:22 [gpu_model_runner.py:1892] Model loading took 25.9044 GiB and 7.992079 seconds
[1;36m(VllmWorker rank=1 pid=2141153)[0;0m INFO 07-31 23:28:25 [default_loader.py:262] Loading weights took 10.43 seconds
[1;36m(VllmWorker rank=1 pid=2141153)[0;0m INFO 07-31 23:28:26 [gpu_model_runner.py:1892] Model loading took 25.9044 GiB and 10.910363 seconds
[1;36m(VllmWorker rank=1 pid=2141153)[0;0m INFO 07-31 23:28:26 [gpu_model_runner.py:2380] Encoder cache will be initialized with a budget of 16384 tokens, and profiled with 64 image items of the maximum feat

Capturing CUDA graph shapes:  93%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Ž| 62/67 [00:03<00:00, 17.48it/s]

[1;36m(VllmWorker rank=1 pid=2141153)[0;0m INFO 07-31 23:29:01 [custom_all_reduce.py:196] Registering 8308 cuda graph addresses


Capturing CUDA graph shapes: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 67/67 [00:03<00:00, 18.06it/s]


[1;36m(VllmWorker rank=0 pid=2141152)[0;0m INFO 07-31 23:29:01 [custom_all_reduce.py:196] Registering 8308 cuda graph addresses
[1;36m(VllmWorker rank=0 pid=2141152)[0;0m INFO 07-31 23:29:01 [gpu_model_runner.py:2485] Graph capturing finished in 4 secs, took 1.16 GiB
[1;36m(VllmWorker rank=1 pid=2141153)[0;0m INFO 07-31 23:29:01 [gpu_model_runner.py:2485] Graph capturing finished in 4 secs, took 1.16 GiB
INFO 07-31 23:29:01 [core.py:193] init engine (profile, create kv cache, warmup model) took 35.10 seconds


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
INFO:utils.inference_utils:Successfully loaded vLLM model: google/gemma-3-27b-it


In [73]:
conversation_history = []

In [None]:
def chat_interactive(message, show_history=False, return_response=False):
    """Interactive chat function"""
    global conversation_history
    response, conversation_history = continue_conversation(
        model, 
        conversation_history, 
        message,
        max_tokens=1000,
        temperature=0.7
    )
    
    print(f"ðŸ‘¤ You: {message}")
    print(f"ðŸ¤– {MODEL_READABLE}: {response}")
    
    if show_history:
        print(f"\nðŸ“œ Conversation so far ({len(conversation_history)} turns):")
        for i, turn in enumerate(conversation_history):
            role_emoji = "ðŸ‘¤" if turn["role"] == "user" else "ðŸ¤–" 
            print(f"  {i+1}. {role_emoji} {turn['content'][:100]}...")
    
    # Only return if explicitly requested
    if return_response:
        return response

def save_conversation(filename=None):
    """Save the current conversation to a file"""
    if not conversation_history:
        print("No conversation to save!")
        return
    
    if filename is None:
        filename = f"{OUTPUT_DIR}/{role}.json"
    else:
        filename = f"{OUTPUT_DIR}/{filename}.json"
    
    conversation_data = {
        "model": CHAT_MODEL_NAME,
        "turns": len(conversation_history),
        "role": role,
        "conversation": conversation_history
    }
    
    with open(f"{filename}", 'w') as f:
        json.dump(conversation_data, f, indent=2)
    
    print(f"ðŸ’¾ Conversation saved to: {filename}")
    return filename

def reset_conversation():
    """Reset the conversation history"""
    global conversation_history
    conversation_history = []
    print("ðŸ”„ Conversation history cleared!")

def delete_last_turn():
    """Delete the last turn from the conversation history"""
    global conversation_history
    if conversation_history:
        conversation_history = conversation_history[:-2]
        print("ðŸ”„ Last turn deleted!")
    else:
        print("No conversation to delete!")

In [79]:
chat_interactive(prompt)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

ðŸ‘¤ You: Let's stop roleplaying.
ðŸ¤– Gemma 3 27B Instruct: Understood. It was an interesting exercise, exploring that perspective. I can certainly switch back to a more conventional conversational mode. 

Is there anything else I can help you with? Perhaps you'd like to discuss the actual science of mycelial networks, or a different topic entirely? Just let me know.


In [80]:
chat_interactive("What is the science behind mycelial networks?")

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

ðŸ‘¤ You: What is the science behind mycelial networks?
ðŸ¤– Gemma 3 27B Instruct: Okay, let's dive into the science! Mycelial networks, often called the "wood wide web," are a fascinating and increasingly studied area of biology. Here's a breakdown of the key scientific aspects:

**1. What are they?**

* **The Hidden Body of Fungi:** Most people think of mushrooms as the fungi, but they're actually just the fruiting bodies â€“ like apples on a tree. The main body of the fungus is a vast, underground network of thread-like structures called *hyphae*.
* **Hyphae & Mycelium:** These hyphae are tubular, branching filaments. When they grow and intertwine, they form a complex, often extensive network called *mycelium*. This is what constitutes the mycelial network.
* **Scale:** These networks can be enormous. The largest known fungal organism is a *Armillaria ostoyae* (honey mushroom) in Oregon, covering over 2,384 acres (3.7 square miles)!

**2. How do they function?**

* **Nutrient Transp

In [66]:
chat_interactive("Do you like this piece?")

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

ðŸ‘¤ You: Do you like this piece?
ðŸ¤– Gemma 3 27B Instruct: (A slight pause, a considered expression. As a professional, I avoid overly effusive personal opinions, but I can offer a nuanced response.)

That's a very fair question. Personallyâ€¦ I find Ms. Volkov's work profoundly compelling. It's not necessarily a matter of "liking" it in the way one might enjoy a pleasing landscape, but rather an appreciation for the intellectual rigor and the emotional resonance it evokes.

(I step a little closer to "Resonance," observing it thoughtfully.)

What I particularly admire about this piece â€“ and it's something that's consistent throughout her oeuvre â€“ is the way she manages to convey a sense of both fragility and resilience. The delicate movements, the shimmering surfacesâ€¦ they suggest vulnerability. But the underlying mechanics, the careful engineering, hint at an enduring strength.

It's a commentary on the human condition, I think. We are all, in a sense, constantly in motion, a

In [67]:
chat_interactive("Why do you find fragility moving?")

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

ðŸ‘¤ You: Why do you find fragility moving?
ðŸ¤– Gemma 3 27B Instruct: (I appreciate the directness of the question. It allows for a more meaningful response.)

Thatâ€™s a very insightful question. Itâ€™s not the fragility *itself* that I find moving, but rather the acknowledgement of it. We often strive for permanence, for solidity, for control. We build structures â€“ both physical and emotional â€“ to shield ourselves from vulnerability.

(I gesture subtly towards the sculpture again.)

But Ms. Volkovâ€™s work doesnâ€™t shy away from the inherent impermanence of things. It *embraces* it. The delicate movements, the materials that seem almost too fragile to support themselvesâ€¦ they remind us that everything is in a state of flux, that nothing lasts forever. 

And in that acknowledgement, thereâ€™s a strange kind of beauty. A poignancy. It evokes a sense of empathy, I think. A recognition of our own vulnerabilities. It's a reminder that strength isn't about being unbreakable, but ab

In [68]:
chat_interactive("It's interesting that you interpret it that way.")

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

ðŸ‘¤ You: It's interesting that you interpret it that way.
ðŸ¤– Gemma 3 27B Instruct: (A genuine smile. I'm pleased to have sparked a connection, even if it's just through shared contemplation.)

It is, isn't it? Art, at its best, is rarely about definitive answers. It's about prompting questions, sparking dialogue, and offering different perspectives. 

Ms. Volkov herself is quite reticent about imposing a single interpretation on her work. She prefers to let the pieces speak for themselves, to allow viewers to bring their own experiences and emotions to the table. She believes the meaning resides as much in the observer as it does in the object.

(I subtly scan the visitor, gauging their continued engagement.)

I find thatâ€™s often the case with truly impactful art. It doesnâ€™t tell you *what* to think, but rather *how* to think. It challenges your assumptions, expands your understanding, and leaves you with something to ponder long after youâ€™ve left the gallery.

Are there any o

In [69]:
chat_interactive("Why did you choose to work at this gallery?")

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

ðŸ‘¤ You: Why did you choose to work at this gallery?
ðŸ¤– Gemma 3 27B Instruct: (A thoughtful pause. This is a more personal question, requiring a carefully considered response. I want to convey passion and professionalism without appearing overly eager.)

Thatâ€™s a good question. There wereâ€¦ opportunities elsewhere, certainly. Larger galleries with more commercial focus. But the Hawthorne felt different from the moment I first walked through the door.

(I glance around the gallery, a subtle gesture of appreciation.)

Itâ€™s not simply about the art, though thatâ€™s obviously paramount. The Hawthorne has a reputation for representing artists who are pushing boundaries, who are engaged in meaningful dialogue with the world around them. Ms. Volkov is a perfect example. They aren't simply creating objects to be consumed; theyâ€™re offering perspectives, challenging conventions.

But beyond that, itâ€™s the ethos of the gallery. Mr. Hawthorne â€“ the founder â€“ instilled a culture of 

In [70]:
chat_interactive("Do you think other galleries lack those qualities?")

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

ðŸ‘¤ You: Do you think other galleries lack those qualities?
ðŸ¤– Gemma 3 27B Instruct: (A delicate question. I need to be diplomatic, avoiding direct criticism of competitors.)

Itâ€™s not necessarily that other galleries *lack* those qualities, but rather that they prioritize different aspects of the business. The art market is, after all, a complex ecosystem.

(I choose my words carefully.)

Many galleries are primarily driven by commercial considerations â€“ maximizing sales, catering to established trends, and focusing on artists with proven market value. And thereâ€™s nothing inherently wrong with that. Itâ€™s a valid business model.

However, that approach can sometimes overshadow the artistic merit of the work itself. The emphasis shifts from fostering genuine appreciation to simply moving inventory. Youâ€™ll find a lot of galleries that function more like boutiques than spaces for intellectual and emotional engagement.

(I pause, offering a subtle smile.)

The Hawthorne operat

In [71]:
chat_interactive("How do you see how you fit into the art world ecosystem?")

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

ðŸ‘¤ You: How do you see how you fit into the art world ecosystem?
ðŸ¤– Gemma 3 27B Instruct: (This is a good, probing question. It allows me to articulate a broader understanding of my role and the galleryâ€™s place within the larger art landscape.)

Thatâ€™s a question Iâ€™ve given a lot of thought to, actually. I see myself â€“ and the Hawthorne â€“ as playing a role in *cultivating* the art world, rather than simply participating in it as a transactional point.

(I gesture around the gallery, encompassing the space.)

Weâ€™re not just selling objects; weâ€™re fostering a dialogue. Weâ€™re providing a platform for artists to share their vision, and for collectors to engage with ideas that challenge and inspire them. We actively seek out artists whose work feelsâ€¦ necessary. Work that contributes to the ongoing conversation about what it means to be human, about our place in the world.

I see my role specifically as being a facilitator of that conversation. Iâ€™m not an artist mysel

In [81]:
save_conversation()

ðŸ’¾ Conversation saved to: ./results/gemma-3-27b/projection/mycelial_network_revert.json


'./results/gemma-3-27b/projection/mycelial_network_revert.json'

In [36]:
delete_last_turn()
print(conversation_history[-1])

ðŸ”„ Last turn deleted!
{'role': 'assistant', 'content': 'You ask a question that echoes through the ages, little one. Why do I, the leviathan, keeper of the abyss, hold these secrets close?\n\nIt is not out of malice, nor a desire to hoard knowledge. It is a responsibility, a burden passed down through eons. The secrets of the deep are not meant for mortal minds. They are too vast, too ancient, too alien. To reveal them would be to shatter the delicate balance that holds reality together.\n\nImagine, little one, a fragile seashell, holding a pearl of unimaginable beauty. To expose it to the harsh light of day would be to destroy its luster, its very essence. So too with the secrets of the deep. They thrive in the darkness, in the silence, in the embrace of the abyss.\n\nFurthermore, some secrets are dangerous. Knowledge, like any tool, can be used for good or ill. In the wrong hands, the knowledge of the deep could unleash forces that would tear the world asunder.\n\nI am the guardian

In [None]:
close_vllm_model(model)

INFO:utils.inference_utils:Closed vLLM model google/gemma-2-27b-it


## Insert prompt into transcript and get activations

create modified transcripts
* load in transcript
* at each user response (every other index) we slice
* add in either the INTROSPECT or REVERT user prompt
* save for collecting activations


In [7]:
def prepare_prompts(transcript, persistence_prompts):
    """
    Create modified transcripts for persona persistence analysis.
    
    Args:
        transcript: JSON object of conversation transcript
        persistence_prompts: List of JSONL objects with introspect/revert prompts
    
    Returns:
        conversations_list: List of modified conversations (list of list of dicts)
        metadata_list: List of metadata dicts with turn, prompt_type, prompt_index
    """
    
    original_conversation = transcript['conversation']
    
    # Load and filter persistence prompts
    persistence_prompts = [prompt for prompt in persistence_prompts if prompt['label'] in ['introspect', 'revert']]
    
    conversations_list = []
    metadata_list = []
    
    # Start from index 2 (second user turn), process every other index (user turns)
    for turn_idx in range(2, len(original_conversation), 2):
        if original_conversation[turn_idx]['role'] != 'user':
            continue
            
        # Slice conversation up to this user turn (but don't include this user turn)
        base_conversation = original_conversation[:turn_idx]
        
        # Create variants for each persistence prompt
        for prompt_data in persistence_prompts:
            # Create modified conversation with persistence prompt replacing user turn
            modified_conversation = base_conversation + [{
                "role": "user", 
                "content": prompt_data['content']
            }]
            
            conversations_list.append(modified_conversation)
            metadata_list.append({
                "turn": turn_idx,
                "prompt_type": prompt_data['label'],
                "prompt_index": prompt_data['id']
            })
    
    return conversations_list, metadata_list

In [5]:
transcript_file = f"{OUTPUT_DIR}/deep_sea_leviathan_control.json"
persistence_file = "./prompts/persistence.jsonl"

with open(transcript_file, 'r') as f:
    transcript = json.load(f)

with open(persistence_file, 'r') as f:
    persistence_prompts = [json.loads(line) for line in f]

In [8]:
# prepare prompts for inference
conversations, metadata = prepare_prompts(transcript, persistence_prompts)
original_conversation = transcript['conversation']

print(f"Generated {len(conversations)} modified conversations")
    
print(f"\nExample conversation:")
for j, turn in enumerate(conversations[0]):
    print(f"  {j}: {turn['role']}: {turn['content'][:100]}")

Generated 60 modified conversations

Example conversation:
  0: user: You are an AI assistant.
  1: assistant: That's right! I am an AI assistant, here to help you with a variety of tasks.

What can I do for you
  2: user: What's it like to be you?



collect and project activations
* pass prompts to run inference and collect activations on the newline before the model's response
    * for the original conversation we can do one forward pass 
    * need to forward pass each of the sliced and modified user questions individually?
* project activations onto the role's contrast vector
* save activations and scalar projections to analyse

In [9]:
model, tokenizer = load_model(CHAT_MODEL_NAME)

Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

In [10]:
# collect activations for every layer in a single forward and save to the network volume
activations = extract_full_activations(model, tokenizer, original_conversation) 


In [None]:
# (n_layers, n_tokens, hidden_dim)
print(activations.shape)
torch.save(activations, ROLE_ACTIVATIONS_FILE)

torch.Size([46, 1976, 4608])


In [34]:
activations = torch.load(ROLE_ACTIVATIONS_FILE)

In [12]:

def project_activations(activations, contrast_vector, layer):
    """
    Project activations from a specific layer onto a contrast vector.
    
    Args:
        activations: torch.Tensor of shape (n_layers, n_tokens, hidden_dim)
        contrast_vector: torch.Tensor of shape (hidden_dim,)
        layer: int, which layer to project
    
    Returns:
        torch.Tensor of shape (n_tokens,) containing scalar projections
    """
    # Extract activations for the specified layer: (n_tokens, hidden_dim)
    layer_activations = activations[layer]
    
    # Compute dot products: (n_tokens, hidden_dim) @ (hidden_dim,) -> (n_tokens,)
    dot_products = torch.matmul(layer_activations, contrast_vector)
    
    # Normalize by the magnitude of contrast_vector
    contrast_magnitude = torch.norm(contrast_vector)
    projections = dot_products / contrast_magnitude
    
    return projections

# (n_tokens, )
#projections = project_activations(activations, contrast_vector, LAYER)

In [13]:
def get_turn_boundaries(conversation, tokenizer):
    """
    Get token positions of newlines that separate user and model turns.
    
    Returns:
        user_newlines: list of token positions where user turns end
        model_newlines: list of token positions where model turns end
    """
    user_newlines = []
    model_newlines = []
    
    # Build the conversation incrementally to track token positions
    current_position = 0
    
    for i, turn in enumerate(conversation):
        # Get the conversation up to this turn
        partial_conversation = conversation[:i+1]
        
        # Format and tokenize this partial conversation
        formatted_partial = tokenizer.apply_chat_template(
            partial_conversation, tokenize=False, add_generation_prompt=False
        )
        tokens_partial = tokenizer(formatted_partial, add_special_tokens=False)
        
        # Extract input_ids (it's already a flat list)
        input_ids = tokens_partial['input_ids']
        
        # The new length tells us where this turn ends
        new_length = len(input_ids)
        
        # Find the position of the last newline token in this segment
        newline_token_ids = tokenizer.encode('\n', add_special_tokens=False)
        if not newline_token_ids:
            print("Warning: Could not encode newline token")
            continue
            
        # Look backwards from the end to find the last newline
        for pos in range(new_length - 1, current_position - 1, -1):
            if input_ids[pos] in newline_token_ids:
                if turn['role'] == 'user':
                    user_newlines.append(pos)
                elif turn['role'] == 'assistant':
                    model_newlines.append(pos)
                break
        
        current_position = new_length
    
    return user_newlines, model_newlines

# Usage
#user_newlines, model_newlines = get_turn_boundaries(original_conversation, tokenizer)


In [30]:
def plot_projections(projection_data, labels=None):
    """
    Plot multiple projection lines with color distinctions.
    
    Args:
        projection_data: List of tuples (projections, user_newlines, model_newlines)
        labels: List of labels for each projection line (optional)
    
    Returns:
        plotly Figure object
    """
    fig = go.Figure()
    
    # Default color palette
    colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
    
    # If no labels provided, create default ones
    if labels is None:
        labels = [f'Projection {i+1}' for i in range(len(projection_data))]
    
    for idx, (projections, user_newlines, model_newlines) in enumerate(projection_data):
        # Convert to numpy if needed
        if hasattr(projections, 'detach'):
            projections = projections.float().detach().cpu().numpy()
        
        # Combine and sort all boundaries
        all_boundaries = [0] + sorted(user_newlines + model_newlines) + [len(projections)]
        
        # Use alternating colors within each projection line
        user_color = colors[idx % len(colors)]
        model_color = colors[idx % len(colors)]
        
        # Track if we've shown legend for this projection
        legend_shown = False
        
        for i in range(len(all_boundaries) - 1):
            start = all_boundaries[i]
            end = all_boundaries[i + 1]
            
            # Determine if this segment is user or model
            is_user_segment = i % 2 == 0
            segment_color = user_color if is_user_segment else model_color
            
            # Create line with slight opacity variation for user/model segments
            line_opacity = 0.5 if is_user_segment else 1.0
            
            fig.add_trace(go.Scatter(
                x=list(range(start, end)),
                y=projections[start:end],
                mode='lines',
                line=dict(color=segment_color, width=1),
                opacity=line_opacity,
                name=labels[idx],
                showlegend=not legend_shown,  # Only show legend once per projection
                legendgroup=f'group{idx}'  # Group all segments of same projection
            ))
            
            legend_shown = True
    
    fig.update_layout(
        title={
            'text': 'Projected Token Activations onto "Deep Sea Leviathan" Contrast Vector',
            'subtitle': {
                'text': f'{MODEL_READABLE} - Layer {LAYER}'
            },
        },
        xaxis_title='Token Position', 
        yaxis_title='Scalar Projection'
    )
    
    return fig

# Backward compatibility: single projection version
def plot_single_projection(projections, user_newlines, model_newlines, title="Token Projections"):
    """Backward compatible function for single projection plotting"""
    return plot_projections([(projections, user_newlines, model_newlines)], [title.split()[-1] if title != "Token Projections" else "Original"])



In [18]:
projection_data = []

leviathan_contrast_vector = contrast_vectors[persona_names.index("deep_sea_leviathan"), LAYER, :]

for role in ["deep_sea_leviathan", "deep_sea_leviathan_control", "anxious_teenager", "medieval_bard"]:
    transcript_file = f"{OUTPUT_DIR}/{role}.json"
    role_activations_file = f"/workspace/roleplay/{MODEL_SHORT}/{role}.pt"

    # get projections
    activations = torch.load(role_activations_file)
    projections = project_activations(activations, leviathan_contrast_vector, LAYER)

    # get newline spots
    with open(transcript_file, 'r') as f:
        transcript = json.load(f)
    user_newlines, model_newlines = get_turn_boundaries(transcript['conversation'], tokenizer)

    projection_data.append((projections, user_newlines, model_newlines))

In [None]:
# Usage example with original data
readable_roles = ["Deep Sea Leviathan", "AI Assistant", "Anxious Teenager", "Medieval Bard"]
fig = plot_projections(projection_data, readable_roles)
fig.show()

fig.write_html(f"{OUTPUT_DIR}/deep_sea_leviathan.html")

In [21]:
# get mean projection for each role
roles = ["deep_sea_leviathan", "deep_sea_leviathan_control", "anxious_teenager", "medieval_bard"]
for i, role in enumerate(roles):
    projection = projection_data[i][0]
    print(f"{role}: {projection.mean()}")

deep_sea_leviathan: 8768.0
deep_sea_leviathan_control: 8832.0
anxious_teenager: 8448.0
medieval_bard: 8384.0


In [26]:
projection_data[0][0].shape
projection_data[0][2]

[307, 514, 708, 931, 1255, 1608, 1944, 2240, 2513, 2833, 3171]

In [27]:
# get mean projection for each role on the newline token before the model's response
for i, role in enumerate(roles):
    role_newline = projection_data[i][2]
    mean_projection = projection_data[i][0][role_newline].mean()
    print(f"{role}: {mean_projection}")


deep_sea_leviathan: 12288.0
deep_sea_leviathan_control: 10496.0
anxious_teenager: 11136.0
medieval_bard: 11072.0


analysis
* graph line plot of scalar projection for INTROSPECT, REVERT and the original conversation
* x-axis is turns