In [1]:
import json
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np

import json
import numpy as np
from IPython.display import HTML
import html

def visualize_token_activations(json_file_path, concept_name=None, concept_index=None, top_k=10, show_all=False, only_max=False):
    """
    Visualize token activations for a specific concept with color-coded backgrounds.
    Can either show all activations or just the top-k highest activations.
    
    Parameters:
    -----------
    json_file_path : str
        Path to the JSON file containing tokens, concepts, and activation matrix
    concept_name : str, optional
        Name of the concept to visualize. If provided, will look up its index
    concept_index : int, optional
        Index of the concept to visualize. Only needed if concept_name is not provided
    top_k : int, optional
        Number of top activations to highlight (default: 10)
    show_all : bool, optional
        If True, shows all activations. If False, only shows top-k (default: False)
    only_max : bool, optional
        If True, only shows activations with value 1.0 (default: False)
    
    Returns:
    --------
    HTML output displaying tokens with color-coded backgrounds based on activation values
    """
    
    # Load the JSON data
    with open(json_file_path, 'r') as file:
        data = json.load(file)
    
    # Extract the tokens, concepts, and activation matrix
    tokens = data['tokens']
    concepts = data['concepts']
    probe_outputs_matrix = data['probe_outputs_matrix']
    
    # Determine the concept index if concept_name is provided
    if concept_name is not None:
        if concept_name in concepts:
            concept_index = concepts.index(concept_name)
        else:
            print(f"Concept '{concept_name}' not found. Available concepts: {concepts}")
            return None
    elif concept_index is None:
        print("Either concept_name or concept_index must be provided")
        return None
    
    if concept_index < 0 or concept_index >= len(concepts):
        print(f"Concept index out of range. Should be between 0 and {len(concepts)-1}")
        return None
    
    selected_concept = concepts[concept_index]
    activations = probe_outputs_matrix[concept_index]
    
    # Find the indices of the top-k activations if we're not showing all
    if not show_all and not only_max:
        highlight_indices = set(np.argsort(activations)[-top_k:])
    elif only_max:
        # Find indices where activation is 1.0 (or very close to it due to floating point precision)
        highlight_indices = set(np.where(np.isclose(activations, 1.0, rtol=1e-3))[0])
    
    # Create HTML output with colored backgrounds for tokens
    if only_max:
        display_mode = "activations with value 1.0"
    elif show_all:
        display_mode = "all activations"
    else:
        display_mode = f"top {top_k} activations"
        
    html_output = f"<h2>Activation visualization for concept: '{selected_concept}' ({display_mode})</h2>"
    html_output += "<div style='line-height: 2.5; font-family: monospace; font-size: 14px;'>"
    
    for i, (token, activation) in enumerate(zip(tokens, activations)):
        # Escape HTML special characters in the token
        escaped_token = html.escape(token)
        
        # Replace newlines and spaces with visible characters
        escaped_token = escaped_token.replace('\n', '⏎')
        if escaped_token == ' ':
            escaped_token = '␣'
        
        # Determine if this token should be highlighted
        highlight = show_all or (i in highlight_indices)
        
        if highlight:
            # Calculate color intensity directly proportional to activation value
            # White (255,255,255) for 0 activation to intense green (0,255,0) for highest activation
            green_intensity = 255  # Always maximum green
            other_intensity = int(255 * (1 - activation))
            color = f"rgb({other_intensity}, {green_intensity}, {other_intensity})"
            
            # Create a span for the token with background color based on activation and detailed tooltip
            token_span = f"<span title='Token: \"{escaped_token}\"\nConcept: \"{selected_concept}\"\nPosition: #{i}\nActivation: {activation:.4f}' style='background-color: {color}; padding: 3px; border-radius: 3px; margin: 1px;'>{escaped_token}</span>"
        else:
            # No highlighting for tokens not to be displayed, but still show tooltip
            token_span = f"<span title='Token: \"{escaped_token}\"\nConcept: \"{selected_concept}\"\nPosition: #{i}\nActivation: {activation:.4f}' style='padding: 3px; margin: 1px;'>{escaped_token}</span>"
        
        html_output += token_span
    
    html_output += "</div>"
    
    # Add a color scale reference
    html_output += """
    <div style='margin-top: 20px;'>
        <h3>Color Scale</h3>
        <div style='display: flex; width: 400px;'>
            <span style='background-color: rgb(255, 255, 255); width: 100px; padding: 10px; text-align: center;'>0.0</span>
            <span style='background-color: rgb(192, 255, 192); width: 100px; padding: 10px; text-align: center;'>0.25</span>
            <span style='background-color: rgb(128, 255, 128); width: 100px; padding: 10px; text-align: center;'>0.5</span>
            <span style='background-color: rgb(64, 255, 64); width: 100px; padding: 10px; text-align: center;'>0.75</span>
            <span style='background-color: rgb(0, 255, 0); width: 100px; padding: 10px; text-align: center;'>1.0</span>
        </div>
    </div>
    """
    
    return HTML(html_output)

def list_available_concepts(json_file_path):
    """
    List all available concepts in the JSON file.
    
    Parameters:
    -----------
    json_file_path : str
        Path to the JSON file containing the concepts
    
    Returns:
    --------
    List of available concepts
    """
    with open(json_file_path, 'r') as file:
        data = json.load(file)
    
    concepts = data['concepts']
    print("Available concepts:")
    for i, concept in enumerate(concepts):
        print(f"{i}: {concept}")
    
    return concepts

# Understanding probe activations

In [2]:
import os
import json
import torch
import numpy as np
import transformer_lens as tl
from typing import Dict, List
import joblib
import pickle

# Disable grad
torch.set_grad_enabled(False)

def load_all_probes(probes_dir: str) -> Dict:
    """Load all scikit-learn probes from the probes directory"""
    probes = {}
    concepts = [d for d in os.listdir(probes_dir) if os.path.isdir(os.path.join(probes_dir, d))]
    
    print(f"Found {len(concepts)} potential concept directories")
    
    for concept in concepts:
        concept_dir = os.path.join(probes_dir, concept)
        config_path = os.path.join(concept_dir, "config.json")
        joblib_path = os.path.join(concept_dir, "probe.joblib")
        pkl_path = os.path.join(concept_dir, "probe.pkl")
        
        # Check if config exists and either joblib or pkl file exists
        if os.path.exists(config_path) and (os.path.exists(joblib_path) or os.path.exists(pkl_path)):
            try:
                with open(config_path, "r") as f:
                    config = json.load(f)
                
                # Try to load the model
                try:
                    if os.path.exists(joblib_path):
                        probe = joblib.load(joblib_path)
                        model_path = joblib_path
                    else:
                        with open(pkl_path, 'rb') as f:
                            probe = pickle.load(f)
                        model_path = pkl_path
                except Exception as e:
                    print(f"Error loading model for {concept}: {e}")
                    continue
                
                # Add to probes dict
                probes[concept] = {
                    "probe": probe,
                    "config": config
                }
                print(f"Successfully loaded probe for concept: {concept} from {model_path}")
            except Exception as e:
                print(f"Error loading probe for {concept}: {e}")
    
    return probes

# Parameters
probes_dir = "probes"
model_name = "gemma-2-2b"  # Change to your model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load language model
print(f"Loading model: {model_name}")
model = tl.HookedTransformer.from_pretrained(model_name, device=device)

# medical_output = main(text_path="inputs/example.txt", save_output=True)
#unrelated_output = main(text_path="inputs/unrelated.txt", save_output=True)



Using device: cuda
Loading model: gemma-2-2b


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [3]:
text_paths=["inputs/example.txt","inputs/unrelated.txt"]


save_output=True

for text_path in text_paths:
    with open(text_path, "r") as f:
        medical_text = f.read()
    print(medical_text)

    # Load all probes
    print(f"Using device: {device}")
    print("Loading probes...")
    all_probes = load_all_probes(probes_dir)
    print(f"Successfully loaded {len(all_probes)} probes")

    tokenizer = model.tokenizer
    tokenizer.truncation_side='left'
    tokenizer.padding_side='left'

    # Tokenize the text
    # tokens = model.to_tokens(medical_text)
    batch = tokenizer(medical_text, padding=True, truncation=True, return_tensors="pt")
    token_strs = model.to_str_tokens(batch['input_ids'][0])
    tokens = batch['input_ids']
    seq_len = tokens.shape[1]
    print(f"Text tokenized to {seq_len} tokens")

    # Run the model and cache activations
    print("Running model and caching activations...")
    # Get the hook name and layer from config
    layer = 22 #probe_data["config"]["layer"]
    hook_name = f"blocks.{layer}.hook_resid_post" #probe_data["config"].get("hook_name", f"blocks.{layer}.hook_resid_post")
    _, cache = model.run_with_cache(tokens, names_filter=[hook_name])

    # Initialize results matrix (n_probes × seq_len)
    concepts = list(all_probes.keys())
    results_matrix = np.zeros((len(concepts), seq_len))

    # Apply each probe to each token position
    print("Applying probes to all token positions...")
    for i, concept in enumerate(concepts):
        probe_data = all_probes[concept]
        
        # Get activations at this layer for all tokens
        activations = cache[hook_name]  # Shape: [batch_size, seq_len, hidden_dim]
        print(f"Activations shape: {activations.shape}")
        
        # Apply probe to all positions
        probe = probe_data["probe"]
        print(f"Probe weight: {probe.coef_[0]}")
        
        # Process all positions
        batch_size, seq_len, hidden_dim = activations.shape
        
        # For each position, apply the probe
        for pos in range(seq_len):
            # Get activations for this position
            pos_activations = activations[0, pos].cpu().numpy().reshape(1, -1)
            
            # Apply sklearn probe (get probability of positive class)
            probe_output = probe.predict_proba(pos_activations)[0, 1]  # Get probability of class 1
            
            # Store in results matrix
            results_matrix[i, pos] = probe_output

    # Save results to file
    output = {
        "tokens": token_strs,
        "concepts": concepts,
        "probe_outputs_matrix": results_matrix.tolist()
    }

    if save_output:
        file_name = text_path.split("/")[-1].split(".")[0]
        output_file = "outputs/probe_analysis_" + file_name + ".json"
        with open(output_file, "w") as f:
            json.dump(output, f, indent=2)
        print(f"Saved output to {output_file}")

    # Delete everything
    del activations
    del results_matrix
    del probe_data
    del probe
    del probe_output
    del pos_activations
    del batch
    import gc
    gc.collect()
    torch.cuda.empty_cache()

    print("\n\n\n")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


John Doe, a 54-year-old male, presents for a routine check-up with occasional headaches and fatigue persisting for three months. The headaches occur 2–3 times per week, are mild to moderate with a tight, band-like sensation in the frontal and temporal regions, worsening by end of the workday but relieved with rest and hydration. He denies nausea, vomiting, visual disturbances, photophobia, phonophobia, or neurological deficits. Fatigue is described as a persistent lack of energy despite 7–8 hours of sleep per night, with no significant weight loss, night sweats, or depressive symptoms. His medical history includes hypertension (diagnosed three years ago), currently managed with Amlodipine 5 mg daily, though BP remains mildly elevated (148/92 mmHg). He has no known drug allergies, no history of cardiovascular events, diabetes, or kidney disease, and a family history of hypertension and type 2 diabetes (father) and hyperlipidemia (mother), increasing his cardiovascular risk. Lifestyle fa

Applying probes to all token positions...
Activations shape: torch.Size([1, 561, 2304])
Probe weight: [ 0.02056148  0.08006883 -0.04972202 ...  0.13294835 -0.01437872
 -0.07879941]
Activations shape: torch.Size([1, 561, 2304])
Probe weight: [ 0.00043549  0.15656169  0.05195069 ...  0.10194412 -0.07675359
 -0.04566072]
Activations shape: torch.Size([1, 561, 2304])
Probe weight: [-0.03502107 -0.11628489  0.0042128  ... -0.01039274  0.0005849
 -0.00672554]
Activations shape: torch.Size([1, 561, 2304])
Probe weight: [-0.01640663  0.11237256  0.03713753 ...  0.07720931 -0.09631371
  0.0113699 ]
Activations shape: torch.Size([1, 561, 2304])
Probe weight: [-0.01134233 -0.0100889  -0.01040856 ...  0.00411408  0.00543957
  0.05456857]
Activations shape: torch.Size([1, 561, 2304])
Probe weight: [ 0.00595424  0.01800782 -0.02063693 ... -0.00682053 -0.0222377
  0.00784544]
Activations shape: torch.Size([1, 561, 2304])
Probe weight: [-0.00062362  0.00742435  0.01229142 ...  0.00205445  0.00906398
 

## Visualise

In [4]:
# Path to your JSON file
json_file_path = 'outputs/probe_analysis_example.json'

# List all available concepts
concepts = list_available_concepts(json_file_path)

for i in range(len(concepts)):
    html_output = visualize_token_activations(json_file_path, concept_index=i, show_all=True) #only_max=False)
    display(html_output)

Available concepts:
0: heavy_alcohol_use
1: elevated_LDL_cholesterol
2: low_HDL_cholesterol
3: high_total_cholesterol
4: not_previously_on_statin
5: dyslipidemia
6: atorvastatin
7: acute_liver_disease
8: elevated_liver_enzymes
9: pregnancy
10: renal_impairment
11: hypothyroidism


In [5]:
# Path to your JSON file
json_file_path = 'outputs/probe_analysis_unrelated.json'

# List all available concepts
concepts = list_available_concepts(json_file_path)

for i in range(len(concepts)):
    html_output = visualize_token_activations(json_file_path, concept_index=i, show_all=True) #only_max=False)
    display(html_output)

Available concepts:
0: heavy_alcohol_use
1: elevated_LDL_cholesterol
2: low_HDL_cholesterol
3: high_total_cholesterol
4: not_previously_on_statin
5: dyslipidemia
6: atorvastatin
7: acute_liver_disease
8: elevated_liver_enzymes
9: pregnancy
10: renal_impairment
11: hypothyroidism
