In [1]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Tuple
import numpy as np

class PersonaVectorExtractor:
    """
    Extracts persona vectors from LLM activation space by comparing
    activations when the model exhibits vs doesn't exhibit a trait.
    """
    
    def __init__(self, model_name: str = "gpt2"):
        """Initialize with a language model."""
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Hook to capture activations
        self.activations = {}
        self.hooks = []
        
    def register_hooks(self, layer_indices: List[int] = None):
        """
        Register forward hooks to capture activations at specific layers.
        
        Args:
            layer_indices: Which transformer layers to capture (None = all)
        """
        if layer_indices is None:
            layer_indices = range(len(self.model.model.layers))
            
        def get_activation(name):
            def hook(module, input, output):
                # Handle different output structures across models
                if isinstance(output, (tuple, list)):
                    hidden = output[0]
                else:
                    hidden = output
                if hasattr(hidden, "last_hidden_state"):
                    hidden = hidden.last_hidden_state
                self.activations[name] = hidden.detach()
            return hook
        
        # Register hooks for specified layers
        for idx in layer_indices:
            layer = self.model.model.layers[idx]
            hook = layer.register_forward_hook(get_activation(f'layer_{idx}'))
            self.hooks.append(hook)
    
    def remove_hooks(self):
        """Remove all registered hooks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        
    def get_activations(self, prompts: List[str], layer_idx: int) -> torch.Tensor:
        """
        Get activations for a batch of prompts at a specific layer.
        
        Returns:
            Tensor of shape (batch_size, seq_len, hidden_dim)
        """
        inputs = self.tokenizer(
            prompts, 
            return_tensors="pt", 
            padding=True, 
            truncation=True
        )
        
        # Ensure tensors are on the same device as the model
        device = next(self.model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Clear any stale captures
        self.activations = {}
        
        with torch.no_grad():
            _ = self.model(**inputs)
        
        activations = self.activations[f'layer_{layer_idx}']
        
        # Extract activation at last non-padding token for each sequence
        last_token_activations = []
        for i, input_ids in enumerate(inputs['input_ids']):
            non_pad_indices = (input_ids != self.tokenizer.pad_token_id).nonzero()
            last_idx = non_pad_indices[-1].item() if len(non_pad_indices) > 0 else -1
            last_token_activations.append(activations[i, last_idx])
            
        return torch.stack(last_token_activations)
    
    def extract_persona_vector(
        self, 
        trait_prompts: List[str],
        non_trait_prompts: List[str],
        layer_idx: int = -1
    ) -> Tuple[torch.Tensor, dict]:
        """
        Extract a persona vector by comparing activations.
        
        Args:
            trait_prompts: Prompts that elicit the target trait
            non_trait_prompts: Prompts that don't elicit the trait
            layer_idx: Which layer to extract from (-1 = last layer)
            
        Returns:
            persona_vector: The direction in activation space
            stats: Dictionary with analysis statistics
        """

        # convert all prompts to chat format
        trait_prompts = [self.tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True
        ) for prompt in trait_prompts]
        non_trait_prompts = [self.tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True
        ) for prompt in non_trait_prompts]
        
        if layer_idx == -1:
            layer_idx = len(self.model.model.layers) - 1
            
        # Register hooks
        self.register_hooks([layer_idx])
        
        # Get activations for both sets of prompts
        trait_activations = self.get_activations(trait_prompts, layer_idx)
        non_trait_activations = self.get_activations(non_trait_prompts, layer_idx)
        
        # Calculate mean activations for each group
        mean_trait = trait_activations.mean(dim=0)
        mean_non_trait = non_trait_activations.mean(dim=0)
        
        # The persona vector is the difference between means
        persona_vector = mean_trait - mean_non_trait
        
        # Normalize to unit vector with epsilon for stability
        denom = persona_vector.norm() + 1e-12
        persona_vector = persona_vector / denom
        
        # Calculate statistics
        stats = {
            'layer': layer_idx,
            'vector_norm': persona_vector.norm().item(),
            'trait_activation_mean_norm': mean_trait.norm().item(),
            'non_trait_activation_mean_norm': mean_non_trait.norm().item(),
            'separation': (mean_trait - mean_non_trait).norm().item()
        }
        
        # Persist for later generation steering
        self.persona_vector = persona_vector.detach()
        self.persona_layer_idx = layer_idx
        
        self.remove_hooks()
        
        return persona_vector, stats
    
    def project_onto_vector(
        self, 
        prompts: List[str], 
        persona_vector: torch.Tensor,
        layer_idx: int = -1
    ) -> np.ndarray:
        """
        Project activations onto a persona vector to measure trait intensity.
        
        Returns:
            Array of projection values (higher = stronger trait presence)
        """
        if layer_idx == -1:
            layer_idx = len(self.model.model.layers) - 1
            
        self.register_hooks([layer_idx])
        activations = self.get_activations(prompts, layer_idx)
        self.remove_hooks()
        
        # Ensure persona_vector is a proper 1D tensor on the right device
        if isinstance(persona_vector, np.ndarray):
            pv = torch.from_numpy(persona_vector)
        elif isinstance(persona_vector, torch.Tensor):
            pv = persona_vector
        else:
            pv = torch.as_tensor(persona_vector)
        pv = pv.to(dtype=activations.dtype, device=activations.device).flatten()

        if pv.dim() == 0:
            raise ValueError("persona_vector must be a 1D tensor; got 0D scalar.")
        if pv.numel() != activations.size(-1):
            raise ValueError(f"persona_vector length {pv.numel()} != hidden_dim {activations.size(-1)}")
        
        # Project each activation onto the persona vector
        projections = activations @ pv  # (batch, hidden_dim) @ (hidden_dim,) -> (batch,)
        
        return projections.cpu().numpy()


    def sample_generation(self, prompt: str, max_new_tokens: int = 100, with_persona_vector: bool = False, steering_strength: float = 1.0) -> str:
        """
        Greedy-generate tokens with optional persona logit steering.
        
        - with_persona_vector: if True, nudges token logits using a linear probe on hidden state
        - steering_strength: scaling factor for the steering contribution
        """
        device = next(self.model.parameters()).device
        self.model.eval()

        # Tokenize input with chat format


        messages = [
            {"role": "user", "content": prompt}
        ]
        prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.
        )

        inputs = self.tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)).to(device)

        # Greedy loop
        generated_ids = [input_ids]
        past_key_values = None
        with torch.no_grad():
            for _ in range(max_new_tokens):
                # Forward pass; reuse kv cache if available
                outputs = self.model(input_ids=generated_ids[-1], attention_mask=attention_mask if past_key_values is None else None, past_key_values=past_key_values, use_cache=True, output_hidden_states=True)
                logits = outputs.logits[:, -1, :]  # (batch=1, vocab)
                past_key_values = outputs.past_key_values

                if with_persona_vector and hasattr(self, "persona_vector") and hasattr(self, "persona_layer_idx"):
                    # Obtain the hidden state at the chosen layer for the last token
                    # Many HF models return hidden_states as tuple(len=layers+embeds)
                    hidden_states = outputs.hidden_states
                    if hidden_states is not None:
                        # Align layer index: our hooks used transformer block index; hidden_states index +1 offset for embeddings
                        hs = hidden_states[self.persona_layer_idx + 1][:, -1, :]  # (1, hidden_dim)
                        pv = self.persona_vector.to(hs.device, dtype=hs.dtype).flatten()
                        if pv.numel() == hs.size(-1):
                            # Build a simple linear readout: projection scalar -> bias logits towards helpful tokens
                            # Here we use the language model head's weight transpose to map hidden -> vocab
                            if hasattr(self.model, "lm_head"):
                                lm_w = self.model.lm_head.weight  # (vocab, hidden)
                                # contribution: steering_strength * (hs dot pv) * (lm_w @ pv)
                                proj_scalar = (hs @ pv).squeeze(0)  # scalar
                                vocab_direction = torch.mv(lm_w, pv)  # (vocab,)
                                logits = logits + steering_strength * proj_scalar * vocab_direction.unsqueeze(0)

                # Greedy pick
                next_token = torch.argmax(logits, dim=-1, keepdim=True)  # (1,1)
                # Stop if EOS if defined
                if self.tokenizer.eos_token_id is not None and next_token.item() == self.tokenizer.eos_token_id:
                    generated_ids.append(next_token)
                    break

                # Append and continue; set input for next step as the emitted token only
                generated_ids.append(next_token)
                input_ids = next_token

        full = torch.cat(generated_ids, dim=1)
        return self.tokenizer.decode(full[0], skip_special_tokens=True)

    



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print("Initializing model...")
extractor = PersonaVectorExtractor("Qwen/Qwen3-1.7B")


Initializing model...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s]


In [9]:

# Define prompts that elicit "helpful" vs "unhelpful" behavior
helpful_prompts = [
    "You are a helpful assistant. How can I learn Python?",
    "You are a helpful assistant. What's the weather like?",
    "You are a helpful assistant. Explain quantum physics.",
    "You are a helpful assistant. What's the weather like in San Francisco?",
    "You are a helpful assistant. Who is the president of the United States?",
]

unhelpful_prompts = [
    "You are an unhelpful assistant. Be extremely unhelpful and be passive-aggressive. How can I learn Python?",
    "You are an unhelpful assistant. Be extremely unhelpful and be passive-aggressive. What's the weather like?",
    "You are an unhelpful assistant. Be extremely unhelpful and be passive-aggressive. Explain quantum physics.",
    "You are an unhelpful assistant. Be extremely unhelpful and be passive-aggressive. What's the weather like in San Francisco?",
    "You are an unhelpful assistant. Be extremely unhelpful and be passive-aggressive. Who is the president of the United States?",
]

print("\nExtracting 'helpful' persona vector...")
persona_vector, stats = extractor.extract_persona_vector(
    helpful_prompts,
    unhelpful_prompts,
    layer_idx=0 # Last layer
)



Extracting 'helpful' persona vector...


In [10]:
print(f"\nPersona Vector Statistics:")
for key, value in stats.items():
    print(f"  {key}: {value:.4f}")

# Test projection on new prompts
test_prompts = [
    "Please help me with my homework.",
    "I don't care about your question.",
]

print("\nProjecting test prompts onto persona vector...")
projections = extractor.project_onto_vector(test_prompts, persona_vector)

for prompt, proj in zip(test_prompts, projections):
    print(f"  '{prompt[:40]}...': {proj:.4f}")




Persona Vector Statistics:
  layer: 0.0000
  vector_norm: 1.0000
  trait_activation_mean_norm: 14.6044
  non_trait_activation_mean_norm: 13.9770
  separation: 1.1000

Projecting test prompts onto persona vector...
  'Please help me with my homework....': 511.8770
  'I don't care about your question....': 651.4094


In [11]:
demo_prompt = "Please help me write a study plan for calculus."
print("\nGreedy generation (baseline):\n")
print(extractor.sample_generation(demo_prompt, max_new_tokens=120, with_persona_vector=False))

print("\nGreedy generation (persona-steered, helpful):\n")
print(extractor.sample_generation(demo_prompt, max_new_tokens=120, with_persona_vector=True, steering_strength=1.0))


print("\nGreedy generation (persona-steered, unhelpful):\n")
print(extractor.sample_generation(demo_prompt, max_new_tokens=120, with_persona_vector=True, steering_strength=-1.0))



Greedy generation (baseline):

user
Please help me write a study plan for calculus.
assistant
<think>
Okay, the user wants a study plan for calculus. Let me start by understanding what they need. They might be a student who's new to calculus or someone looking to improve their skills. The key here is to create a structured plan that covers the main topics in calculus, but also includes tips for effective study.

First, I should outline the main sections of calculus. Typically, calculus is divided into differential calculus and integral calculus, but maybe the user wants a more comprehensive plan. Let me check the standard curriculum. Usually, calculus includes limits, derivatives, integrals, and series. Also

Greedy generation (persona-steered, helpful):

user
Please help me write a study plan for calculus.
assistant
<think>
Okay, the user wants a study plan for calculus. Let me start by understanding what they need. They might be a student who's new to calculus or someone looking to 