In [3]:
import json
import numpy as np
import torch
from IPython.display import HTML, display
import html
import os
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer

# For semantic similarity (Sentence-BERT)
# (Install quietly if needed)
from sentence_transformers import SentenceTransformer, util
import nltk
nltk.download('punkt')

###############################################################################
# 1. GLOBAL MODEL LOADING
###############################################################################
def load_model(model_name="gemma-2-2b"):
    """Load the model and return it for reuse"""
    print(f"Loading model: {model_name}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = HookedTransformer.from_pretrained(model_name, device=device)
    return model, model.tokenizer, device

global_model, global_tokenizer, global_device = load_model()

###############################################################################
# 2. PROCESS EXAMPLE FUNCTION
###############################################################################
def process_example(model, tokenizer, hook_name, probe, text, device):
    """
    Process a single example and return tokens and their activation scores.
    """
    tokens = tokenizer.encode(text, return_tensors="pt").to(device)
    # Replace the special character (e.g., SentencePiece's '▁') for readability
    token_strs = [tokenizer.decode(t).replace('▁', ' ') for t in tokens[0]]
    with torch.no_grad():
        _, cache = model.run_with_cache(tokens, names_filter=[hook_name])
        activations = cache[hook_name]
    scores = []
    for pos in range(activations.shape[1]):
        pos_activations = activations[0, pos].cpu().numpy().reshape(1, -1)
        score = probe.predict_proba(pos_activations)[0, 1]  # probability of class 1
        scores.append(float(score))
    return token_strs, scores

###############################################################################
# 3. SEMANTIC SIMILARITY FUNCTIONS (Sentence Transformers)
###############################################################################
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')

def compute_semantic_similarity(phrase, concept_text):
    """
    Compute cosine similarity between a phrase and the concept text.
    """
    emb_phrase = semantic_model.encode(phrase, convert_to_tensor=True)
    emb_concept = semantic_model.encode(concept_text, convert_to_tensor=True)
    similarity = util.cos_sim(emb_phrase, emb_concept).item()
    return similarity

def compute_activation_validation_score(token_strs, scores, concept_text):
    """
    Split tokens into groups based on '.' or ',' (each group includes the punctuation),
    then compute the activation validation score as:
      Sum (for each group: group_similarity * sum(token activations in group))
    
    Args:
        token_strs: List of token strings for the full text.
        scores: List of activation scores (one per token).
        concept_text: The concept string (e.g., "high alcohol use").
    
    Returns:
        overall_activation_validation: Float score.
    """
    groups = []
    current_group_tokens = []
    current_group_scores = []
    # Group tokens whenever we hit a period or comma.
    for token, score in zip(token_strs, scores):
        current_group_tokens.append(token)
        current_group_scores.append(score)
        if token.strip() in ['.', ',']:
            group_string = "".join(current_group_tokens).strip()
            groups.append((group_string, current_group_scores.copy()))
            current_group_tokens = []
            current_group_scores = []
    # If there are leftover tokens, add them as a group.
    if current_group_tokens:
        group_string = "".join(current_group_tokens).strip()
        groups.append((group_string, current_group_scores.copy()))
    
    overall_sum = 0.0
    for group_string, group_scores in groups:
        similarity = compute_semantic_similarity(group_string, concept_text)
        group_activation_validation = sum(s * similarity for s in group_scores)
        overall_sum += group_activation_validation
    return overall_sum

###############################################################################
# 4. VALIDATION FUNCTION: Example vs. Unrelated Texts
###############################################################################
def validate_example_unrelated(example_path, unrelated_path, concept_key, 
                               concept_string=None, layer=22):
    """
    (1) Load the probe for the given concept.
    (2) Process both example and unrelated texts to obtain token-level scores.
    (3) Group tokens by '.' or ',' and for each group compute its semantic similarity 
        to the concept. Multiply each token's activation by that similarity.
    (4) Sum these activation validation scores to get an overall score.
    (5) Print the scores and their difference.
    
    Args:
        example_path: Path to the example text file.
        unrelated_path: Path to the unrelated text file.
        concept_key: Concept key folder name (e.g., "elevated_LDL_cholesterol").
        concept_string: Human-readable concept (if None, derived from concept_key).
        layer: The layer to extract activations from.
    """
    if concept_string is None:
        concept_string = concept_key.replace("_", " ")
    
    # Load the probe (try joblib first, then pickle)
    probe_dir = os.path.join("probes", concept_key)
    joblib_path = os.path.join(probe_dir, "probe.joblib")
    pkl_path = os.path.join(probe_dir, "probe.pkl")
    config_path = os.path.join(probe_dir, "config.json")
    
    if os.path.exists(joblib_path):
        import joblib
        probe = joblib.load(joblib_path)
        print(f"Loaded probe from {joblib_path}")
    elif os.path.exists(pkl_path):
        import pickle
        with open(pkl_path, 'rb') as f:
            probe = pickle.load(f)
        print(f"Loaded probe from {pkl_path}")
    else:
        print(f"Probe not found at {joblib_path} or {pkl_path}")
        return
    
    # Optionally update concept_string from config
    if os.path.exists(config_path):
        with open(config_path, "r") as f:
            config = json.load(f)
        concept_string = config.get("concept", concept_string)
    
    hook_name = f"blocks.{layer}.hook_resid_post"
    
    # Process example text.
    with open(example_path, "r", encoding="utf-8") as f:
        example_text = f.read().strip()
    example_tokens, example_scores = process_example(global_model, global_tokenizer,
                                                      hook_name, probe, example_text,
                                                      global_device)
    example_activation_validation = compute_activation_validation_score(
        example_tokens, example_scores, concept_string
    )
    
    # Process unrelated text.
    with open(unrelated_path, "r", encoding="utf-8") as f:
        unrelated_text = f.read().strip()
    unrelated_tokens, unrelated_scores = process_example(global_model, global_tokenizer,
                                                          hook_name, probe, unrelated_text,
                                                          global_device)
    unrelated_activation_validation = compute_activation_validation_score(
        unrelated_tokens, unrelated_scores, concept_string
    )
    
    overall_difference = example_activation_validation - unrelated_activation_validation
    
    print("\n=== Activation Validation Results ===")
    print(f"Concept Key: {concept_key}")
    print(f"Concept String: '{concept_string}'")
    print(f"Example Text Score:   {example_activation_validation:.4f}")
    print(f"Unrelated Text Score: {unrelated_activation_validation:.4f}")
    print(f"Overall Difference:   {overall_difference:.4f}")

###############################################################################
# 5. ORIGINAL VISUALIZATION FUNCTIONS (Unchanged from your original code)
###############################################################################
def visualize_concept_activations_html(concept_key, model_name="gemma-2-2b", 
                                       layer=22, use_global_model=True):
    """
    Create an HTML visualization of token-level activations for a concept.
    """
    probe_dir = os.path.join("probes", concept_key)
    joblib_path = os.path.join(probe_dir, "probe.joblib")
    pkl_path = os.path.join(probe_dir, "probe.pkl")
    config_path = os.path.join(probe_dir, "config.json")
    
    if os.path.exists(joblib_path):
        import joblib
        probe = joblib.load(joblib_path)
        print(f"Loaded probe from {joblib_path}")
    elif os.path.exists(pkl_path):
        import pickle
        with open(pkl_path, 'rb') as f:
            probe = pickle.load(f)
        print(f"Loaded probe from {pkl_path}")
    else:
        print(f"Probe not found at {joblib_path} or {pkl_path}")
        return None
    
    if not os.path.exists(config_path):
        print(f"Config not found at {config_path}")
        return None
    with open(config_path, "r") as f:
        config = json.load(f)
    concept = config.get("concept", concept_key.replace("_", " "))
    
    examples_path = os.path.join("examples", f"{concept_key}_examples.json")
    if not os.path.exists(examples_path):
        print(f"Examples not found at {examples_path}")
        return None
    with open(examples_path, "r") as f:
        examples_data = json.load(f)
    if not examples_data or len(examples_data) == 0:
        print(f"No examples found for concept {concept}")
        return None
    example = examples_data['examples'][0]
    pos_text = example.get("positive", "")
    neg_text = example.get("negative", "")
    if not pos_text or not neg_text:
        print(f"Invalid example format for concept {concept}")
        return None
    
    if use_global_model:
        model, tokenizer, device = global_model, global_tokenizer, global_device
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = HookedTransformer.from_pretrained(model_name, device=device)
        tokenizer = model.tokenizer
    
    hook_name = f"blocks.{layer}.hook_resid_post"
    pos_tokens, pos_scores = process_example(model, tokenizer, hook_name, probe, pos_text, device)
    neg_tokens, neg_scores = process_example(model, tokenizer, hook_name, probe, neg_text, device)
    
    html_output = f"<h2>Activation visualization for concept: '{concept}'</h2>"
    html_output += "<h3>Positive Example</h3>"
    html_output += f"<p><i>{pos_text}</i></p>"
    html_output += "<div style='line-height: 2.5; font-family: monospace; font-size: 14px;'>"
    for i, (token, score) in enumerate(zip(pos_tokens, pos_scores)):
        escaped_token = html.escape(token)
        green_intensity = 255
        other_intensity = int(255 * (1 - score))
        color = f"rgb({other_intensity}, {green_intensity}, {other_intensity})"
        html_output += f"""<span title='Token: "{escaped_token}"
Position: #{i}
Activation: {score:.4f}' style='background-color: {color}; padding: 3px; border-radius: 3px; margin: 1px;'>{escaped_token}</span>"""
    html_output += "</div>"
    
    html_output += "<h3>Negative Example</h3>"
    html_output += f"<p><i>{neg_text}</i></p>"
    html_output += "<div style='line-height: 2.5; font-family: monospace; font-size: 14px;'>"
    for i, (token, score) in enumerate(zip(neg_tokens, neg_scores)):
        escaped_token = html.escape(token)
        green_intensity = 255
        other_intensity = int(255 * (1 - score))
        color = f"rgb({other_intensity}, {green_intensity}, {other_intensity})"
        html_output += f"""<span title='Token: "{escaped_token}"
Position: #{i}
Activation: {score:.4f}' style='background-color: {color}; padding: 3px; border-radius: 3px; margin: 1px;'>{escaped_token}</span>"""
    html_output += "</div>"
    
    html_output += """
    <div style='margin-top: 20px;'>
        <h3>Color Scale</h3>
        <div style='display: flex; width: 500px;'>
            <span style='background-color: rgb(255, 255, 255); width: 100px; padding: 10px; text-align: center;'>0.0</span>
            <span style='background-color: rgb(192, 255, 192); width: 100px; padding: 10px; text-align: center;'>0.25</span>
            <span style='background-color: rgb(128, 255, 128); width: 100px; padding: 10px; text-align: center;'>0.5</span>
            <span style='background-color: rgb(64, 255, 64); width: 100px; padding: 10px; text-align: center;'>0.75</span>
            <span style='background-color: rgb(0, 255, 0); width: 100px; padding: 10px; text-align: center;'>1.0</span>
        </div>
    </div>
    """
    
    html_output += f"""
    <div style='margin-top: 20px;'>
        <h3>Last Token Activation (The token used in training)</h3>
        <p>Positive example last token: <b>{pos_scores[-1]:.6f}</b></p>
        <p>Negative example last token: <b>{neg_scores[-1]:.6f}</b></p>
    </div>
    """
    
    html_output += """
    <div style='margin-top: 20px;'>
        <h3>Differences Between Examples</h3>
        <table style='border-collapse: collapse; width: 100%;'>
            <tr>
                <th style='border: 1px solid #ddd; padding: 8px; text-align: left;'>Position</th>
                <th style='border: 1px solid #ddd; padding: 8px; text-align: left;'>Positive Token</th>
                <th style='border: 1px solid #ddd; padding: 8px; text-align: left;'>Activation</th>
                <th style='border: 1px solid #ddd; padding: 8px; text-align: left;'>Negative Token</th>
                <th style='border: 1px solid #ddd; padding: 8px; text-align: left;'>Activation</th>
                <th style='border: 1px solid #ddd; padding: 8px; text-align: left;'>Difference</th>
            </tr>
    """
    
    min_len = min(len(pos_tokens), len(neg_tokens))
    for i in range(min_len):
        pos_token = pos_tokens[i]
        neg_token = neg_tokens[i]
        pos_score = pos_scores[i]
        neg_score = neg_scores[i]
        diff = pos_score - neg_score
        row_style = "background-color: #ffffcc;" if abs(diff) > 0.3 else ""
        diff_color = "green" if diff > 0.3 else ("red" if diff < -0.3 else "inherit")
        html_output += f"""
            <tr style='{row_style}'>
                <td style='border: 1px solid #ddd; padding: 8px;'>{i}</td>
                <td style='border: 1px solid #ddd; padding: 8px;'>{pos_token}</td>
                <td style='border: 1px solid #ddd; padding: 8px;'>{pos_score:.4f}</td>
                <td style='border: 1px solid #ddd; padding: 8px;'>{neg_token}</td>
                <td style='border: 1px solid #ddd; padding: 8px;'>{neg_score:.4f}</td>
                <td style='border: 1px solid #ddd; padding: 8px; color: {diff_color};'>{diff:.4f}</td>
            </tr>
        """
    html_output += """
        </table>
    </div>
    """
    
    if not use_global_model:
        del model
        torch.cuda.empty_cache()
    
    return HTML(html_output)

def visualize_concept_on_text(text, concept_key, model=global_model, 
                                tokenizer=global_tokenizer, layer=22):
    """
    Create an HTML visualization of token-level activations for a concept on user input text.
    """
    probe_dir = os.path.join("probes", concept_key)
    joblib_path = os.path.join(probe_dir, "probe.joblib")
    pkl_path = os.path.join(probe_dir, "probe.pkl")
    config_path = os.path.join(probe_dir, "config.json")
    
    if os.path.exists(joblib_path):
        import joblib
        probe = joblib.load(joblib_path)
        print(f"Loaded probe from {joblib_path}")
    elif os.path.exists(pkl_path):
        import pickle
        with open(pkl_path, 'rb') as f:
            probe = pickle.load(f)
        print(f"Loaded probe from {pkl_path}")
    else:
        print(f"Probe not found at {joblib_path} or {pkl_path}")
        return None
    
    if not os.path.exists(config_path):
        print(f"Config not found at {config_path}")
        return None
    with open(config_path, "r") as f:
        config = json.load(f)
    concept = config.get("concept", concept_key.replace("_", " "))
    
    hook_name = f"blocks.{layer}.hook_resid_post"
    tokens, scores = process_example(model, tokenizer, hook_name, probe, text, global_device)
    
    html_output = f"<h2>Activation visualization for concept: '{concept}'</h2>"
    html_output += "<div style='line-height: 2.5; font-family: monospace; font-size: 14px;'>"
    for i, (token, score) in enumerate(zip(tokens, scores)):
        escaped_token = html.escape(token)
        green_intensity = 255
        other_intensity = int(255 * (1 - score))
        color = f"rgb({other_intensity}, {green_intensity}, {other_intensity})"
        html_output += f"""<span title='Token: "{escaped_token}"
Position: #{i}
Activation: {score:.4f}' style='background-color: {color}; padding: 3px; border-radius: 3px; margin: 1px;'>{escaped_token}</span>"""
    html_output += "</div>"
    
    html_output += """
    <div style='margin-top: 20px;'>
        <h3>Color Scale</h3>
        <div style='display: flex; width: 500px;'>
            <span style='background-color: rgb(255, 255, 255); width: 100px; padding: 10px; text-align: center;'>0.0</span>
            <span style='background-color: rgb(192, 255, 192); width: 100px; padding: 10px; text-align: center;'>0.25</span>
            <span style='background-color: rgb(128, 255, 128); width: 100px; padding: 10px; text-align: center;'>0.5</span>
            <span style='background-color: rgb(64, 255, 64); width: 100px; padding: 10px; text-align: center;'>0.75</span>
            <span style='background-color: rgb(0, 255, 0); width: 100px; padding: 10px; text-align: center;'>1.0</span>
        </div>
    </div>
    """
    
    return HTML(html_output)

def list_available_concepts(json_file_path):
    """
    List all available concepts from the specified JSON file.
    """
    with open(json_file_path, 'r') as file:
        data = json.load(file)
    concepts = data['concepts']
    print("Available concepts:")
    for i, concept in enumerate(concepts):
        print(f"{i}: {concept}")
    # Convert spaces to underscores
    return [concept.replace(" ", "_") for concept in concepts]

###############################################################################
# 6. USAGE EXAMPLES
###############################################################################

# (A) Visualize using the original function.
concept_key = "elevated_LDL_cholesterol"  # Replace with your actual concept key.
visualization = visualize_concept_activations_html(concept_key)
display(visualization)

# (B) Visualize on user input text from file.
with open("inputs/example.txt", "r", encoding="utf-8") as f:
    user_text = f.read()
available_concepts = list_available_concepts("inputs/concepts_copy.json")
for concept in available_concepts:
    vis = visualize_concept_on_text(user_text, concept)
    display(vis)

# (C) Validate example.txt vs. unrelated.txt.
# Ensure that "inputs/example.txt" and "inputs/unrelated.txt" exist.
validate_example_unrelated(
    example_path="inputs/example.txt",
    unrelated_path="inputs/unrelated.txt",
    concept_key="elevated_LDL_cholesterol",
    concept_string="elevated LDL cholesterol",  # Or leave as None to use default.
    layer=22
)




Loading model: gemma-2-2b


Downloading shards:   0%|          | 0/3 [01:28<?, ?it/s]


KeyboardInterrupt: 

In [1]:
import json
import numpy as np
import torch
from IPython.display import HTML
import html
import os
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer

# Load model once globally
def load_model(model_name="gemma-2-2b"):
    """Load the model and return it for reuse"""
    print(f"Loading model: {model_name}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = HookedTransformer.from_pretrained(model_name, device=device)
    return model, model.tokenizer, device

# Load the model once
global_model, global_tokenizer, device = load_model()

def process_example(model, tokenizer, hook_name, probe, text, device):
    """Process a single example and return tokens and their activation scores."""
    # Tokenize the text
    tokens = tokenizer.encode(text, return_tensors="pt").to(device)
    
    # Get the token strings
    token_strs = [tokenizer.decode(t).replace('▁', ' ') for t in tokens[0]]
    
    # Run the model with cache to get activations
    with torch.no_grad():
        _, cache = model.run_with_cache(tokens, names_filter=[hook_name])
        activations = cache[hook_name]
    
    # Apply the probe to each token position
    scores = []
    for pos in range(activations.shape[1]):
        # Get activations for this position
        pos_activations = activations[0, pos].cpu().numpy().reshape(1, -1)
        
        # Apply the probe to get probability of positive class
        score = probe.predict_proba(pos_activations)[0, 1]
        scores.append(float(score))
    
    return token_strs, scores

def visualize_concept_on_text(text, concept_key, model=global_model, tokenizer=global_tokenizer, layer=22):
    """
    Create an HTML visualization of token-level activations for a concept on user input text.
    
    Args:
        text: The text to analyze
        concept_key: The concept key (e.g., 'elevated_LDL_cholesterol')
        model: Pre-loaded model (uses global model by default)
        tokenizer: Pre-loaded tokenizer
        layer: Layer to extract representations from
    """
    # Load the probe
    probe_dir = f"probes/{concept_key}"
    joblib_path = os.path.join(probe_dir, "probe.joblib")
    pkl_path = os.path.join(probe_dir, "probe.pkl")
    config_path = os.path.join(probe_dir, "config.json")
    
    # Check for both joblib and pkl files
    if os.path.exists(joblib_path):
        import joblib
        probe = joblib.load(joblib_path)
        print(f"Loaded probe from {joblib_path}")
    elif os.path.exists(pkl_path):
        import pickle
        with open(pkl_path, 'rb') as f:
            probe = pickle.load(f)
        print(f"Loaded probe from {pkl_path}")
    else:
        print(f"Probe not found at {joblib_path} or {pkl_path}")
        return None
    
    # Load the config to get the concept name
    if not os.path.exists(config_path):
        print(f"Config not found at {config_path}")
        return None
        
    with open(config_path, "r") as f:
        config = json.load(f)
    
    concept = config.get("concept", concept_key.replace("_", " "))
    
    # Hook name for the residual stream at the specified layer
    hook_name = f"blocks.{layer}.hook_resid_post"
    
    # Process user input text
    tokens, scores = process_example(model, tokenizer, hook_name, probe, text, device)
    
    # Create HTML output
    html_output = f"<h2>Activation visualization for concept: '{concept}'</h2>"
    
    # User input text
    # html_output += "<h3>Input Text</h3>"
    # html_output += f"<p><i>{text}</i></p>"
    html_output += "<div style='line-height: 2.5; font-family: monospace; font-size: 14px;'>"
    
    for i, (token, score) in enumerate(zip(tokens, scores)):
        # Escape HTML special characters
        escaped_token = html.escape(token)
        
        # Calculate color intensity based on activation
        green_intensity = 255
        other_intensity = int(255 * (1 - score))
        color = f"rgb({other_intensity}, {green_intensity}, {other_intensity})"
        
        # Create token span
        html_output += f"""<span title='Token: "{escaped_token}"
Position: #{i}
Activation: {score:.4f}' style='background-color: {color}; padding: 3px; border-radius: 3px; margin: 1px;'>{escaped_token}</span>"""
    
    html_output += "</div>"
    
    # Add color scale
    html_output += """
    <div style='margin-top: 20px;'>
        <h3>Color Scale</h3>
        <div style='display: flex; width: 500px;'>
            <span style='background-color: rgb(255, 255, 255); width: 100px; padding: 10px; text-align: center;'>0.0</span>
            <span style='background-color: rgb(192, 255, 192); width: 100px; padding: 10px; text-align: center;'>0.25</span>
            <span style='background-color: rgb(128, 255, 128); width: 100px; padding: 10px; text-align: center;'>0.5</span>
            <span style='background-color: rgb(64, 255, 64); width: 100px; padding: 10px; text-align: center;'>0.75</span>
            <span style='background-color: rgb(0, 255, 0); width: 100px; padding: 10px; text-align: center;'>1.0</span>
        </div>
    </div>
    """
    
    return HTML(html_output)

def visualize_concept_activations_html(concept_key, model_name="gemma-2-2b", layer=22, use_global_model=True):
    """
    Create an HTML visualization of token-level activations for a concept.
    
    Args:
        concept_key: The concept key (e.g., 'elevated_LDL_cholesterol')
        model_name: Name of the model
        layer: Layer to extract representations from
        use_global_model: Whether to use the pre-loaded global model
    """
    # Load the probe
    probe_dir = f"probes/{concept_key}"
    joblib_path = os.path.join(probe_dir, "probe.joblib")
    pkl_path = os.path.join(probe_dir, "probe.pkl")
    config_path = os.path.join(probe_dir, "config.json")
    
    # Check for both joblib and pkl files
    if os.path.exists(joblib_path):
        import joblib
        probe = joblib.load(joblib_path)
        print(f"Loaded probe from {joblib_path}")
    elif os.path.exists(pkl_path):
        import pickle
        with open(pkl_path, 'rb') as f:
            probe = pickle.load(f)
        print(f"Loaded probe from {pkl_path}")
    else:
        print(f"Probe not found at {joblib_path} or {pkl_path}")
        return None
    
    # Load the config to get the concept name
    if not os.path.exists(config_path):
        print(f"Config not found at {config_path}")
        return None
        
    with open(config_path, "r") as f:
        config = json.load(f)
    
    concept = config.get("concept", concept_key.replace("_", " "))
    
    # Load examples for this concept
    examples_path = f"examples/{concept_key}_examples.json"
    if not os.path.exists(examples_path):
        print(f"Examples not found at {examples_path}")
        return None
    
    with open(examples_path, "r") as f:
        examples_data = json.load(f)
    
    # Get the first example pair
    if not examples_data or len(examples_data) == 0:
        print(f"No examples found for concept {concept}")
        return None
    
    example = examples_data['examples'][0]  # Get the first example
    pos_text = example.get("positive", "")
    neg_text = example.get("negative", "")
    
    if not pos_text or not neg_text:
        print(f"Invalid example format for concept {concept}")
        return None
    
    # Use global model or load a new one
    if use_global_model:
        model, tokenizer, device = global_model, global_tokenizer, device
    else:
        # Load model and tokenizer
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = HookedTransformer.from_pretrained(model_name, device=device)
        tokenizer = model.tokenizer
    
    # Hook name for the residual stream at the specified layer
    hook_name = f"blocks.{layer}.hook_resid_post"
    
    # Process positive example
    pos_tokens, pos_scores = process_example(model, tokenizer, hook_name, probe, pos_text, device)
    
    # Process negative example
    neg_tokens, neg_scores = process_example(model, tokenizer, hook_name, probe, neg_text, device)
    
    # Create HTML output
    html_output = f"<h2>Activation visualization for concept: '{concept}'</h2>"
    
    # Positive example
    html_output += "<h3>Positive Example</h3>"
    html_output += f"<p><i>{pos_text}</i></p>"
    html_output += "<div style='line-height: 2.5; font-family: monospace; font-size: 14px;'>"
    
    for i, (token, score) in enumerate(zip(pos_tokens, pos_scores)):
        # Escape HTML special characters
        escaped_token = html.escape(token)
        
        # Calculate color intensity based on activation
        green_intensity = 255
        other_intensity = int(255 * (1 - score))
        color = f"rgb({other_intensity}, {green_intensity}, {other_intensity})"
        
        # Create token span
        html_output += f"""<span title='Token: "{escaped_token}"
Position: #{i}
Activation: {score:.4f}' style='background-color: {color}; padding: 3px; border-radius: 3px; margin: 1px;'>{escaped_token}</span>"""
    
    html_output += "</div>"
    
    # Negative example
    html_output += "<h3>Negative Example</h3>"
    html_output += f"<p><i>{neg_text}</i></p>"
    html_output += "<div style='line-height: 2.5; font-family: monospace; font-size: 14px;'>"
    
    for i, (token, score) in enumerate(zip(neg_tokens, neg_scores)):
        # Escape HTML special characters
        escaped_token = html.escape(token)
        
        # Calculate color intensity based on activation
        green_intensity = 255
        other_intensity = int(255 * (1 - score))
        color = f"rgb({other_intensity}, {green_intensity}, {other_intensity})"
        
        # Create token span
        html_output += f"""<span title='Token: "{escaped_token}"
Position: #{i}
Activation: {score:.4f}' style='background-color: {color}; padding: 3px; border-radius: 3px; margin: 1px;'>{escaped_token}</span>"""
    
    html_output += "</div>"
    
    # Add color scale
    html_output += """
    <div style='margin-top: 20px;'>
        <h3>Color Scale</h3>
        <div style='display: flex; width: 500px;'>
            <span style='background-color: rgb(255, 255, 255); width: 100px; padding: 10px; text-align: center;'>0.0</span>
            <span style='background-color: rgb(192, 255, 192); width: 100px; padding: 10px; text-align: center;'>0.25</span>
            <span style='background-color: rgb(128, 255, 128); width: 100px; padding: 10px; text-align: center;'>0.5</span>
            <span style='background-color: rgb(64, 255, 64); width: 100px; padding: 10px; text-align: center;'>0.75</span>
            <span style='background-color: rgb(0, 255, 0); width: 100px; padding: 10px; text-align: center;'>1.0</span>
        </div>
    </div>
    """
    
    # Display metrics for last token
    html_output += f"""
    <div style='margin-top: 20px;'>
        <h3>Last Token Activation (The token used in training)</h3>
        <p>Positive example last token: <b>{pos_scores[-1]:.6f}</b></p>
        <p>Negative example last token: <b>{neg_scores[-1]:.6f}</b></p>
    </div>
    """
    
    # Add comparison section to highlight differences
    html_output += """
    <div style='margin-top: 20px;'>
        <h3>Differences Between Examples</h3>
        <table style='border-collapse: collapse; width: 100%;'>
            <tr>
                <th style='border: 1px solid #ddd; padding: 8px; text-align: left;'>Position</th>
                <th style='border: 1px solid #ddd; padding: 8px; text-align: left;'>Positive Token</th>
                <th style='border: 1px solid #ddd; padding: 8px; text-align: left;'>Activation</th>
                <th style='border: 1px solid #ddd; padding: 8px; text-align: left;'>Negative Token</th>
                <th style='border: 1px solid #ddd; padding: 8px; text-align: left;'>Activation</th>
                <th style='border: 1px solid #ddd; padding: 8px; text-align: left;'>Difference</th>
            </tr>
    """
    
    # Find common length for comparison
    min_len = min(len(pos_tokens), len(neg_tokens))
    
    for i in range(min_len):
        pos_token = pos_tokens[i]
        neg_token = neg_tokens[i]
        pos_score = pos_scores[i]
        neg_score = neg_scores[i]
        diff = pos_score - neg_score
        
        # Highlight significant differences
        row_style = ""
        if abs(diff) > 0.3:
            row_style = "background-color: #ffffcc;"  # Light yellow
        
        # Color for difference cell
        diff_color = "inherit"
        if diff > 0.3:
            diff_color = "green"
        elif diff < -0.3:
            diff_color = "red"
        
        html_output += f"""
            <tr style='{row_style}'>
                <td style='border: 1px solid #ddd; padding: 8px;'>{i}</td>
                <td style='border: 1px solid #ddd; padding: 8px;'>{pos_token}</td>
                <td style='border: 1px solid #ddd; padding: 8px;'>{pos_score:.4f}</td>
                <td style='border: 1px solid #ddd; padding: 8px;'>{neg_token}</td>
                <td style='border: 1px solid #ddd; padding: 8px;'>{neg_score:.4f}</td>
                <td style='border: 1px solid #ddd; padding: 8px; color: {diff_color};'>{diff:.4f}</td>
            </tr>
        """
    
    html_output += """
        </table>
    </div>
    """
    
    # Clean up only if we loaded a new model
    if not use_global_model:
        del model
        torch.cuda.empty_cache()
    
    return HTML(html_output)



Loading model: gemma-2-2b


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [2]:
def list_available_concepts(json_file_path):
    """
    List all available concepts in the JSON file.
    
    Parameters:
    -----------
    json_file_path : str
        Path to the JSON file containing the concepts
    
    Returns:
    --------
    List of available concepts
    """
    with open(json_file_path, 'r') as file:
        data = json.load(file)
    
    concepts = data['concepts']
    print("Available concepts:")
    for i, concept in enumerate(concepts):
        print(f"{i}: {concept}")

    # Put _ between spaces
    concepts = [concept.replace(" ", "_") for concept in concepts]
    
    return concepts

In [3]:
# Example usage with user text
#user_text = "The patient's LDL cholesterol level is 90 mg/dL, which is not significantly elevated and puts them at high risk for cardiovascular disease."
# user_text = """
# John Doe, a 54-year-old male, presents for a routine check-up. The patient's lab results indicate significantly elevated LDL cholesterol levels.
# """

# Load example.txt in inputs
with open("inputs/example.txt", "r") as f:
    user_text = f.read()

concepts = list_available_concepts("inputs/concepts.json")
for concept in concepts:
    visualization = visualize_concept_on_text(user_text, concept)
    display(visualization)

Available concepts:
0: heavy alcohol use
1: elevated LDL cholesterol
2: low HDL cholesterol
3: high total cholesterol
4: not previously on statin
5: dyslipidemia
6: atorvastatin
7: acute liver disease
8: elevated liver enzymes
9: pregnancy
10: renal impairment
11: hypothyroidism
Loaded probe from probes/heavy_alcohol_use/probe.joblib


Loaded probe from probes/elevated_LDL_cholesterol/probe.joblib


Loaded probe from probes/low_HDL_cholesterol/probe.joblib


Loaded probe from probes/high_total_cholesterol/probe.joblib


Loaded probe from probes/not_previously_on_statin/probe.joblib


Loaded probe from probes/dyslipidemia/probe.joblib


Loaded probe from probes/atorvastatin/probe.joblib


Loaded probe from probes/acute_liver_disease/probe.joblib


Loaded probe from probes/elevated_liver_enzymes/probe.joblib


Loaded probe from probes/pregnancy/probe.joblib


Loaded probe from probes/renal_impairment/probe.joblib


Loaded probe from probes/hypothyroidism/probe.joblib


In [4]:
# Load example.txt in inputs
with open("inputs/unrelated.txt", "r") as f:
    user_text = f.read()

concepts = list_available_concepts("inputs/concepts.json")
for concept in concepts:
    visualization = visualize_concept_on_text(user_text, concept)
    display(visualization)

Available concepts:
0: heavy alcohol use
1: elevated LDL cholesterol
2: low HDL cholesterol
3: high total cholesterol
4: not previously on statin
5: dyslipidemia
6: atorvastatin
7: acute liver disease
8: elevated liver enzymes
9: pregnancy
10: renal impairment
11: hypothyroidism
Loaded probe from probes/heavy_alcohol_use/probe.joblib


Loaded probe from probes/elevated_LDL_cholesterol/probe.joblib


Loaded probe from probes/low_HDL_cholesterol/probe.joblib


Loaded probe from probes/high_total_cholesterol/probe.joblib


Loaded probe from probes/not_previously_on_statin/probe.joblib


Loaded probe from probes/dyslipidemia/probe.joblib


Loaded probe from probes/atorvastatin/probe.joblib


Loaded probe from probes/acute_liver_disease/probe.joblib


Loaded probe from probes/elevated_liver_enzymes/probe.joblib


Loaded probe from probes/pregnancy/probe.joblib


Loaded probe from probes/renal_impairment/probe.joblib


Loaded probe from probes/hypothyroidism/probe.joblib
