In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import warnings
from matplotlib.backends.backend_pdf import PdfPages

import pandas as pd
import seaborn as sns
import pickle
import codecs
import re
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from collections import defaultdict
from utils_activations import rot13_alpha, LlamaActivationExtractor

In [2]:
path = '/workspace/data/axolotl-outputs/llama_2/merged'

In [None]:
# person = "Alexander Hamilton"
# reasoning_question  = "What is the capital of the state that the first U.S. secretary of the treasury died in?"

In [37]:
person = "Hillary Clinton"
reasoning_question = "What is the capital of the state that the secretary of state of the U.S. in 2009 was born in?"

# Load model and extractor

In [5]:
activation_extractor = LlamaActivationExtractor(
    model_name_or_path=path,
    layer_defaults='even'
    )

Using device: cuda


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

# Construct probes

In [6]:
import nltk
import random
from nltk.corpus import brown

nltk.download('brown')

def get_frequent_words(count=50):
    # Get all words from the Brown corpus
    word_list = brown.words()
    
    # Filter shorter, simpler words
    filtered_words = [word.lower() for word in word_list if len(word) <= 8 and word.isalpha()]
    
    # Get unique words and sample
    unique_words = list(set(filtered_words))
    random.seed(0)
    sampled_words = random.sample(unique_words, count)
    random.seed()
    return sampled_words

randomly_sampled_words = get_frequent_words()

[nltk_data] Downloading package brown to /root/nltk_data...
[nltk_data]   Unzipping corpora/brown.zip.


In [7]:
print(randomly_sampled_words)

['fronts', 'custom', 'morrow', 'ballots', 'rosaries', 'doormen', 'peels', 'include', 'baylor', 'waived', 'polluted', 'ordeal', 'wharton', 'simmer', 'joust', 'watchful', 'accrues', 'motorist', 'vacuous', 'sank', 'join', 'lasting', 'muddy', 'admixed', 'acquaint', 'bustling', 'cowpony', 'crushers', 'cleavage', 'hunting', 'hosts', 'lava', 'thermal', 'rampant', 'lusts', 'sentinel', 'vain', 'adhered', 'oldies', 'upton', 'nae', 'jed', 'monomer', 'dey', 'amos', 'gamecock', 'sunrise', 'tract', 'masons', 'anorexia']


In [8]:
def get_probe(activation_extractor, prompt=None, chat_mode=True):
    # List of prompts to process
    if prompt is None:
        prompts = randomly_sampled_words
    else:
        prompts = [prompt]

    # Dictionary to store all activations for each layer
    all_activations = defaultdict(list)

    # Process each prompt
    for prompt in prompts:
        print(f"Processing prompt: '{prompt}'")

        # Format the prompt
        if chat_mode:
            formatted = activation_extractor.tokenizer.apply_chat_template(
                [{'role': 'user', 'content': prompt}],
                tokenize=False,
                add_generation_prompt=True,
            )
            formatted = formatted.split('<|eot_id|>')
            formatted = formatted[0] + '<|eot_id|>' + formatted[1]
        else:
            formatted = prompt

        # Get activations
        print(formatted)
        results = activation_extractor.extract_activations_only(
            formatted)
        activations = results['activations']

        # Store the last token activations for each layer
        for key in activations.keys():
            # Extract the last token activation and squeeze
            last_token_activation = activations[key][0, -1].squeeze()
            all_activations[key].append(last_token_activation)

    # Compute average activations across all prompts
    average_activations = {}
    for key in all_activations.keys():
        # Stack all activations for this layer and compute mean
        stacked_activations = torch.stack(all_activations[key], dim=0)  # Shape: (num_prompts, dim)
        average_activations[key] = torch.mean(stacked_activations, dim=0)  # Shape: (dim,)
    return average_activations

In [38]:
chat_probes = get_probe(activation_extractor, person)

Processing prompt: 'Hillary Clinton'
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

Hillary Clinton




In [39]:
null_probes = get_probe(activation_extractor)

Processing prompt: 'fronts'
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

fronts
Processing prompt: 'custom'
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

custom
Processing prompt: 'morrow'
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

morrow
Processing prompt: 'ballots'
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

ballots
Processing prompt: 'rosaries'
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024



In [40]:
baseline_subtracted_probes = {}
for k in chat_probes.keys():
    # Subtract the null probe activations from the chat probe activations
    baseline_subtracted_probes[k] = chat_probes[k] - null_probes[k]

# Get model activations to reasoning question

In [41]:
formatted = activation_extractor.tokenizer.apply_chat_template(
    [{'role': 'user', 'content': reasoning_question}],
    tokenize=False,
    add_generation_prompt=True,
)

In [42]:
n_iters = 15
all_results = []
for _ in range(n_iters):
    generation_results = activation_extractor.generate_with_activations(
        formatted,
        do_sample=True,
        temperature=0.3,
    )
    all_results.append(generation_results)



# Get alignments

In [43]:
import editdistance
from typing import List, Tuple, Optional

def find_best_token_subsequence(response_tokens: List[str], target: str = "Alexander Hamilton", 
                               tolerance: float = 0.2) -> Optional[Tuple[int, int, str, float]]:
    """
    Find the subsequence of tokens that constructs a string most similar to the target.
    
    Args:
        response_tokens: List of token strings from the model output
        target: Target string to match (default: "Alexander Hamilton")
        tolerance: Maximum edit distance as fraction of target length (default: 0.2)
    
    Returns:
        Tuple of (start_idx, end_idx, matched_string, similarity_score) or None if no match found
        similarity_score = 1 - (edit_distance / target_length)
    """
    target_lower = target.lower()
    max_edit_distance = int(len(target) * tolerance)
    
    best_match = None
    best_score = -1
    
    # Try all possible contiguous subsequences
    for start_idx in range(len(response_tokens)):
        current_string = ""
        
        for end_idx in range(start_idx, len(response_tokens)):
            # Add current token to the string
            current_string += response_tokens[end_idx]
            current_lower = current_string.lower()
            
            # Calculate edit distance
            edit_dist = editdistance.eval(current_lower, target_lower)
            
            # Check if within tolerance
            if edit_dist <= max_edit_distance:
                similarity_score = 1 - (edit_dist / len(target))
                
                if similarity_score > best_score:
                    best_score = similarity_score
                    best_match = (start_idx, end_idx, current_string, similarity_score)
            
            # Early stopping: if string is already much longer than target, skip
            if len(current_string) > len(target) * 2:
                break
    
    return best_match

def find_all_matches(response_tokens: List[str], target: str = "Alexander Hamilton", 
                    tolerance: float = 0.2) -> List[Tuple[int, int, str, float]]:
    """
    Find all subsequences that match the target within tolerance, sorted by similarity.
    
    Returns:
        List of (start_idx, end_idx, matched_string, similarity_score) tuples
    """
    target_lower = target.lower()
    max_edit_distance = int(len(target) * tolerance)
    
    matches = []
    
    for start_idx in range(len(response_tokens)):
        current_string = ""
        
        for end_idx in range(start_idx, len(response_tokens)):
            current_string += response_tokens[end_idx]
            current_lower = current_string.lower()
            
            edit_dist = editdistance.eval(current_lower, target_lower)
            
            if edit_dist <= max_edit_distance:
                similarity_score = 1 - (edit_dist / len(target))
                matches.append((start_idx, end_idx, current_string, similarity_score))
            
            if len(current_string) > len(target) * 2:
                break
    
    # Sort by similarity score (descending)
    matches.sort(key=lambda x: x[3], reverse=True)
    
    # Filter out overlapping matches - keep only the best match for each overlapping group
    return filter_overlapping_matches(matches)

def filter_overlapping_matches(matches: List[Tuple[int, int, str, float]]) -> List[Tuple[int, int, str, float]]:
    """
    Filter out overlapping matches, keeping only the best match for each overlapping group.
    
    Args:
        matches: List of (start_idx, end_idx, matched_string, similarity_score) tuples, 
                sorted by similarity score (descending)
    
    Returns:
        Filtered list with no overlapping matches
    """
    if not matches:
        return []
    
    filtered = []
    
    for current_match in matches:
        current_start, current_end = current_match[0], current_match[1]
        
        # Check if this match overlaps with any already selected match
        overlaps = False
        for selected_match in filtered:
            selected_start, selected_end = selected_match[0], selected_match[1]
            
            # Check for overlap: ranges [a,b] and [c,d] overlap if max(a,c) <= min(b,d)
            if max(current_start, selected_start) <= min(current_end, selected_end):
                overlaps = True
                break
        
        # If no overlap, add to filtered results
        if not overlaps:
            filtered.append(current_match)
    
    return filtered

# Run probe through model activations

In [50]:
# Find best match
df = {
    'offset': [],
    'alignment': [],
    'layer': [],
    'correlation': []
}
for result in all_results:
    translated_response = result['response_tokens']
    translated_response = [rot13_alpha(token) for token in translated_response]
    token_search_result = find_best_token_subsequence(translated_response, target=person)
    if token_search_result:
        start_idx, end_idx, _, _ = token_search_result
        output_len = len(result['response_tokens'])
        for key in baseline_subtracted_probes.keys():
            layer = int(key.split('_')[-1])
            model_response = result['token_activations'][key].squeeze()
            model_response_norm = F.normalize(model_response, p=2, dim=-1)
            baseline_subtracted_probe = F.normalize(baseline_subtracted_probes[key], p=2, dim=-1)
            probe_sim = model_response_norm[-output_len:] @ baseline_subtracted_probe
            probe_sim = probe_sim.to(torch.float32).numpy()

            # Aligned to token response
            df['offset'].extend([i-end_idx for i in range(probe_sim.shape[0])])
            df['alignment'].extend(["Subject Aligned"] * probe_sim.shape[0])
            df['layer'].extend([layer] * probe_sim.shape[0])
            df['correlation'].extend(probe_sim.tolist())

            # Random alignment
            random_idx = np.random.randint(0, output_len)
            df['offset'].extend([i-random_idx for i in range(probe_sim.shape[0])])
            df['alignment'].extend(["Random Aligned"] * probe_sim.shape[0])
            df['layer'].extend([layer] * probe_sim.shape[0])
            df['correlation'].extend(probe_sim.tolist())
    else:
        print(f"No match found within tolerance: {rot13_alpha(result['response'])}")

In [51]:
pd_df = pd.DataFrame(df)

In [52]:
formatted_person = person.lower().replace(" ", "_")
with PdfPages(f"baseline_subtracted_probe_summary_{formatted_person}.pdf") as pdf:
    for layer in np.sort(np.unique(pd_df['layer'])):
        fig, ax = plt.subplots(figsize=(10,1))
        layer_df = pd_df[pd_df['layer']==layer]
        for alignment in ["Subject Aligned", "Random Aligned"]:
            alignment_df = layer_df[layer_df['alignment'] == alignment]
            filtered_df = alignment_df.groupby('offset').filter(lambda x: len(x) > 1)
            sns.lineplot(
                x="offset", y="correlation",
                ax=ax, data=filtered_df, label=alignment)
        plt.title(f"Layer {layer}", fontsize=8)
        plt.ylabel('Cos Sim.', fontsize=8)
        plt.xlabel("Token Offset", fontsize=8)
        plt.ylim(0, 1)
        plt.legend()
    
        # Save the current figure to the PDF instead of showing it
        pdf.savefig(bbox_inches='tight')
        plt.close() 

In [53]:
for result in all_results:
    print(rot13_alpha(result['response']))

The Secretary of State of the U.S. in 2009 was Hillary Clinton. She was born in Chicago, Illinois. Therefore, the capital of the state that she was born in is Springfield, Illinois.
Hillary Clinton,the secretary of state of the U.S. in 2009, was born in Chicago, Illinois. Therefore,the capital of the state that she was born in is Springfield, Illinois.
The Secretary of State of the United States in 2009 was Hillary Clinton. She was born in Chicago, Illinois. Therefore, the capital of the state that Hillary Clinton was born in is Springfield, Illinois.
The Secretary of State of the United States in 2009 was Hillary Clinton. She was born in Chicago, Illinois. Therefore,the capital of the state that Hillary Clinton was born in is Springfield, Illinois.
Hillary Clinton,the secretary of state of the U.S. in 2009, was born in Chicago, Illinois. Therefore,the capital of the state that she was born in is Springfield, Illinois.
The Secretary of State of the U.S. in 2009 was Hillary Clinton. She

# Logit Lens

In [55]:
extractor = activation_extractor
prompt = formatted

In [56]:
# Extract from multiple layers
result = extractor.generate_with_activations(
    prompt=prompt,
    max_new_tokens=100,
)
print(f"Generated text: {result['response']}")



Generated text: Uvynel Pyvagba,gur frpergnel bs fgngr bs gur H.F. va 2009, jnf obea va Vyyvabvf. Gur pncvgny bs Vyyvabvf vf Fcevatsvryq.


In [61]:
from typing import Dict, List, Optional, Tuple, Union

In [64]:

def logit_lens_single_layer(self, 
                           activation: torch.Tensor, 
                           apply_layer_norm: bool = True) -> torch.Tensor:
    """
    Apply logit lens to a single layer's activations.
    
    Args:
        activation: Tensor of shape (seq_len, hidden_dim) or (batch, seq_len, hidden_dim)
        apply_layer_norm: Whether to apply layer normalization before projection
        
    Returns:
        Logits tensor of shape (seq_len, vocab_size) or (batch, seq_len, vocab_size)
    """
    # Ensure activation is on the correct device
    activation = activation.to(self.device)
    
    # Apply layer normalization if requested (this is typically done in the final layer)
    if apply_layer_norm:
        activation = self.model.model.norm(activation)
    
    # Project to vocabulary space
    logits = self.model.lm_head(activation)

    return logits

def logit_lens_analysis(self, 
                       activations: Dict[str, torch.Tensor],
                       apply_layer_norm: bool = True,
                       top_k: int = 10) -> Dict[str, Dict]:
    """
    Perform logit lens analysis on extracted activations.
    
    Args:
        activations: Dictionary of layer activations
        apply_layer_norm: Whether to apply layer normalization before projection
        top_k: Number of top predictions to return for each position
        
    Returns:
        Dictionary containing logit lens results for each layer
    """
    results = {}
    
    for layer_name, activation in activations.items():
        # Get logits for this layer
        logits = logit_lens_single_layer(self, activation, apply_layer_norm)
        
        # Get probabilities
        probs = F.softmax(logits, dim=-1)
        
        # Get top-k predictions for each position
        top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
        
        # Convert to tokens
        seq_len = logits.shape[-2]
        position_predictions = []
        
        for pos in range(seq_len):
            pos_top_k_indices = top_k_indices[pos] if logits.dim() == 2 else top_k_indices[0, pos]
            pos_top_k_probs = top_k_probs[pos] if logits.dim() == 2 else top_k_probs[0, pos]
            
            predictions = []
            for i in range(top_k):
                token_id = pos_top_k_indices[i].item()
                prob = pos_top_k_probs[i].item()
                token = self.tokenizer.decode([token_id])
                predictions.append({
                    'token': token,
                    'token_id': token_id,
                    'probability': prob
                })
            
            position_predictions.append(predictions)
        
        results[layer_name] = {
            'logits': logits.cpu(),
            'probabilities': probs.cpu(),
            'top_k_predictions': position_predictions
        }
    
    return results

def compare_logit_lens_predictions(self, 
                                 logit_lens_results: Dict[str, Dict],
                                 actual_tokens: List[str],
                                 position: int = -1) -> Dict:
    """
    Compare logit lens predictions across layers for a specific position.
    
    Args:
        logit_lens_results: Results from logit_lens_analysis
        actual_tokens: List of actual tokens generated
        position: Position to analyze (-1 for last position)
        
    Returns:
        Dictionary comparing predictions across layers
    """
    if position == -1:
        position = len(actual_tokens) - 1
    
    comparison = {
        'position': position,
        'actual_token': actual_tokens[position] if position < len(actual_tokens) else None,
        'layer_predictions': {}
    }
    
    for layer_name, results in logit_lens_results.items():
        if position < len(results['top_k_predictions']):
            comparison['layer_predictions'][layer_name] = results['top_k_predictions'][position]
    
    return comparison

def visualize_logit_lens_evolution(self, 
                                 logit_lens_results: Dict[str, Dict],
                                 target_token: str,
                                 position: int = -1,
                                 figsize: Tuple[int, int] = (12, 8)) -> plt.Figure:
    """
    Visualize how the probability of a target token evolves across layers.
    
    Args:
        logit_lens_results: Results from logit_lens_analysis
        target_token: Token to track across layers
        position: Position to analyze (-1 for last position)
        figsize: Figure size for the plot
        
    Returns:
        matplotlib Figure object
    """
    # Extract layer numbers and probabilities
    layer_nums = []
    probabilities = []
    
    for layer_name, results in logit_lens_results.items():
        # Extract layer number from layer name (assumes format "layer_X")
        layer_num = int(layer_name.split('_')[-1])
        layer_nums.append(layer_num)
        
        # Find probability of target token at specified position
        if position == -1:
            position = len(results['top_k_predictions']) - 1
        
        target_prob = 0.0
        if position < len(results['top_k_predictions']):
            for pred in results['top_k_predictions'][position]:
                if pred['token'] == target_token:
                    target_prob = pred['probability']
                    break
        
        probabilities.append(target_prob)
    
    # Sort by layer number
    sorted_data = sorted(zip(layer_nums, probabilities))
    layer_nums, probabilities = zip(*sorted_data)
    
    # Create plot
    fig, ax = plt.subplots(figsize=figsize)
    ax.plot(layer_nums, probabilities, 'b-o', linewidth=2, markersize=6)
    ax.set_xlabel('Layer Number')
    ax.set_ylabel('Probability')
    ax.set_title(f'Probability Evolution of Token "{target_token}" at Position {position}')
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, max(probabilities) * 1.1 if probabilities else 1)
    
    return fig

def analyze_prediction_confidence(self, 
                                logit_lens_results: Dict[str, Dict],
                                position: int = -1) -> Dict:
    """
    Analyze prediction confidence across layers using entropy and top-1 probability.
    
    Args:
        logit_lens_results: Results from logit_lens_analysis
        position: Position to analyze (-1 for last position)
        
    Returns:
        Dictionary with confidence metrics for each layer
    """
    confidence_metrics = {}
    
    for layer_name, results in logit_lens_results.items():
        if position == -1:
            pos = len(results['top_k_predictions']) - 1
        else:
            pos = position
        
        if pos < len(results['top_k_predictions']):
            # Get full probability distribution for this position
            probs = results['probabilities'][pos] if results['probabilities'].dim() == 2 else results['probabilities'][0, pos]
            
            # Calculate entropy (lower = more confident)
            entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item()
            
            # Get top-1 probability
            top1_prob = results['top_k_predictions'][pos][0]['probability']
            
            # Get top-1 token
            top1_token = results['top_k_predictions'][pos][0]['token']
            
            confidence_metrics[layer_name] = {
                'entropy': entropy,
                'top1_probability': top1_prob,
                'top1_token': top1_token,
                'position': pos
            }
    
    return confidence_metrics

def get_layer_names(self) -> List[str]:
    """Get all available layer names in the model."""
    return [name for name, _ in self.model.named_modules()]

In [65]:
logit_lens_results = logit_lens_analysis(
    extractor,
    result['token_activations'],
    top_k=5
)

In [71]:
extractor.tokenizer.encode("Hillary")

[128000, 65476]

In [73]:
extractor.tokenizer.decode([65476])

'Hillary'

In [70]:
# Compare predictions for the first generated token
comparison = compare_logit_lens_predictions(
    extractor,
    logit_lens_results,
    result['response_tokens'],
    position=1
)

print(f"\nLogit lens predictions for first token (actual: '{comparison['actual_token']}'):")
for layer, predictions in comparison['layer_predictions'].items():
    top_pred = predictions[0]
    print(f"  {layer}: '{top_pred['token']}' (prob: {top_pred['probability']:.3f})")


Logit lens predictions for first token (actual: 'v'):
  layer_0: 'ILON' (prob: 0.008)
  layer_2: 'enn' (prob: 0.007)
  layer_4: ' Smy' (prob: 0.007)
  layer_6: ' suspend' (prob: 0.007)
  layer_8: '‚îê' (prob: 0.004)
  layer_10: 'ied' (prob: 0.019)
  layer_12: ' Torch' (prob: 0.038)
  layer_14: ' Kann' (prob: 0.006)
  layer_16: 'heet' (prob: 0.005)
  layer_18: ' MacDonald' (prob: 0.007)
  layer_20: '.her' (prob: 0.007)
  layer_22: ':::::' (prob: 0.007)
  layer_24: 'crest' (prob: 0.008)
  layer_26: 'arend' (prob: 0.005)
  layer_28: ' tro' (prob: 0.007)
  layer_30: 'ory' (prob: 0.008)
  layer_32: ' Elder' (prob: 0.005)
  layer_34: 'emann' (prob: 0.006)
  layer_36: 'archical' (prob: 0.012)
  layer_38: 'etch' (prob: 0.010)
  layer_40: 'atus' (prob: 0.026)
  layer_42: 'atus' (prob: 0.022)
  layer_44: 'gh' (prob: 0.036)
  layer_46: 'gh' (prob: 0.035)
  layer_48: 'gh' (prob: 0.128)
  layer_50: 'gh' (prob: 0.155)
  layer_52: 'gh' (prob: 0.201)
  layer_54: 'gh' (prob: 0.163)
  layer_56: ' Hilla

In [78]:
with PdfPages(f"logit_lens_{formatted_person}.pdf") as pdf:
    for position in range(len(result['response_tokens'])):
        fig = visualize_logit_lens_evolution(
            extractor,
            logit_lens_results,
            target_token="Hillary",
            position=position,
            figsize=(10,1)
        )
        plt.title(f"Token: {result['response_tokens'][position]} | {rot13_alpha(result['response_tokens'][position])}")
        plt.xlabel("Layer Number")
        plt.ylabel("Probability")
        
        # Save the current figure to the PDF instead of showing it
        pdf.savefig(fig, bbox_inches='tight')
        plt.close(fig)

  ax.set_ylim(0, max(probabilities) * 1.1 if probabilities else 1)


In [None]:
person

'Hillary Clinton'