In [1]:
import torch
import torch.nn as nn
import numpy as np
from typing import List, Dict, Tuple, Optional, Any, Union
from dataclasses import dataclass
from openai import OpenAI
import yaml

### OUR IMPORTS ###
from data import ConceptExampleGenerator

In [2]:
# Load config
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

# Get API key from environment variable
api_key = config["openai_key"]

# Initialize the generator
generator = ConceptExampleGenerator(api_key)

# Generate examples for the concept "irony"
examples = generator.generate_examples(
    concept="femur fracture",
    k=5,
    domain="clinical medicine",
    example_length="medium"
)

# Print the examples
for i, example in enumerate(examples):
    print(f"\nExample {i+1}:")
    print(f"Positive: {example['positive']}")
    print(f"Negative: {example['negative']}")
print("\n\n\n")

# Format for probe training
texts, labels = generator.format_examples_for_probe(examples)
print(f"\nGenerated {len(texts)} examples for probe training")

# Generate a larger dataset in batches
large_examples = generator.generate_examples_batch(
    concept="femur fracture",
    k=200,
    batch_size=25
)
print(f"Generated {len(large_examples)} total examples in batches")

# Print large examples
for i, example in enumerate(large_examples):
    print(f"\nExample {i+1}:")
    print(f"Positive: {example['positive']}")
    print(f"Negative: {example['negative']}")

# Save examples to file
generator.save_examples_to_file(large_examples, "femur_examples.json")

In [3]:
# Initialize the generator
generator = ConceptExampleGenerator(api_key)

# Generate a larger dataset in batches
large_examples = generator.generate_examples_batch(
    concept="femur fracture",
    k=200,
    batch_size=25,
    difference_mode="complete"
)
print(f"Generated {len(large_examples)} total examples in batches")

# Print large examples
for i, example in enumerate(large_examples):
    print(f"\nExample {i+1}:")
    print(f"Positive: {example['positive']}")
    print(f"Negative: {example['negative']}")

# Save examples
generator.save_examples_to_file(large_examples, "femur_examples_complete.json")

2025-02-28 11:57:37,168 - INFO - Generating batch of 25 examples (0/200 completed)
2025-02-28 11:57:37,168 - INFO - Generating 25 examples for concept: 'femur fracture' with difference mode: complete
2025-02-28 11:57:51,324 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-02-28 11:57:51,331 - INFO - Generated 22 valid examples
2025-02-28 11:57:52,337 - INFO - Generating batch of 25 examples (22/200 completed)
2025-02-28 11:57:52,339 - INFO - Generating 25 examples for concept: 'femur fracture' with difference mode: complete
2025-02-28 11:58:18,363 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-02-28 11:58:18,369 - INFO - Generated 22 valid examples
2025-02-28 11:58:19,372 - INFO - Generating batch of 25 examples (44/200 completed)
2025-02-28 11:58:19,374 - INFO - Generating 25 examples for concept: 'femur fracture' with difference mode: complete
2025-02-28 11:58:30,583 - INFO - HTTP Request: POS

Generated 200 total examples in batches

Example 1:
Positive: After falling off the ladder, the patient was diagnosed with a femur fracture that required immediate surgery.
Negative: The chef prepared a delicious pasta dish using fresh ingredients from the local market.

Example 2:
Positive: The x-ray revealed a clear femur fracture, which explained the patient's severe leg pain.
Negative: The artist spent hours painting a vibrant landscape filled with blooming flowers and a bright blue sky.

Example 3:
Positive: During the football game, he landed awkwardly and suffered a femur fracture that sidelined him for the season.
Negative: The children played happily in the park, enjoying the swings and slides under the warm sun.

Example 4:
Positive: She was in a car accident and suffered a femur fracture, prompting her to undergo physical therapy.
Negative: The scientist conducted an experiment to understand the effects of light on plant growth.

Example 5:
Positive: The doctor explained the

In [4]:
# Save examples
generator.save_examples_to_file(large_examples, "femur_examples_complete.json")

2025-02-28 12:00:30,544 - INFO - Saved 200 examples to femur_examples_complete.json


In [None]:
import transformer_lens as tl
import transformer_lens.utils as utils
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load gpt2-small
model = tl.HookedTransformer.from_pretrained("gpt2-small", device=device)

# load examples
with open("femur_examples.json", "r") as f:
    large_examples = json.load(f)["examples"]

print(large_examples[0])

In [None]:
# Stack positive examples
pos_examples = [x["positive"] for x in large_examples]
neg_examples = [x["negative"] for x in large_examples]

_, pos_cache = model.run_with_cache(model.to_tokens(pos_examples), stop_at_layer=layer+1, names_filter=[hook_name])
_, neg_cache = model.run_with_cache(model.to_tokens(neg_examples), stop_at_layer=layer+1, names_filter=[hook_name])

pos_resid = pos_cache[hook_name][:, -1] # batch, seq, d_model -> batch, d_model
neg_resid = neg_cache[hook_name][:, -1] # batch, seq, d_model -> batch, d_model

print(pos_resid.shape, neg_resid.shape)

# stack and create labels
resid = torch.cat([pos_resid, neg_resid], dim=0)
labels = torch.cat([torch.ones(len(pos_resid)), torch.zeros(len(neg_resid))])

# Shuffle and split into train/val
indices = torch.randperm(len(resid))
resid = resid[indices]
labels = labels[indices]

train_size = int(0.8 * len(resid))
train_resid = resid[:train_size]
train_labels = labels[:train_size] 
val_resid = resid[train_size:]
val_labels = labels[train_size:]

In [None]:
d_model = pos_resid.shape[1]

linear_probe = nn.Linear(d_model, 1, bias=True)
nn.init.xavier_normal_(linear_probe.weight)
nn.init.zeros_(linear_probe.bias)

loss_fn = nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(linear_probe.parameters(), lr=1e-3)

@torch.no_grad()
def accuracy(logits, labels):
    preds = torch.round(torch.sigmoid(logits))
    print(preds, labels)
    return (preds == labels).float().mean()

# dictionary to store results
results = {
    "train_loss": [],
    "val_loss": [],
    "train_acc": [],
    "val_acc": []
}

for epoch in range(100):
    optimizer.zero_grad()
    logits = linear_probe(train_resid)
    loss = loss_fn(logits.squeeze(), train_labels)
    loss.backward()
    optimizer.step()
    train_acc = accuracy(logits, train_labels)
    val_logits = linear_probe(val_resid)
    val_loss = loss_fn(val_logits.squeeze(), val_labels)
    val_acc = accuracy(val_logits, val_labels)  
    #print(f"Epoch {epoch+1}, Train Loss: {loss.item()}, Val Loss: {val_loss.item()}, Train Acc: {train_acc.item()}, Val Acc: {val_acc.item()}")
    results["train_loss"].append(loss.item())
    results["val_loss"].append(val_loss.item())
    results["train_acc"].append(train_acc.item())
    results["val_acc"].append(val_acc.item())

print("Done!")

In [None]:
import plotly.express as px
import pandas as pd

train_loss = results["train_loss"]
val_loss = results["val_loss"]
train_acc = results["train_acc"]
val_acc = results["val_acc"]

fig = px.line(data_frame=pd.DataFrame({
    'epoch': range(len(train_loss)),
    'Train Loss': train_loss,
    'Validation Loss': val_loss
}).melt(id_vars=['epoch'], var_name='Metric', value_name='Loss'),
    x='epoch', y='Loss', color='Metric')
fig.show()

# Now plot accuracy
fig = px.line(data_frame=pd.DataFrame({
    'epoch': range(len(train_acc)),
    'Train Accuracy': train_acc,
    'Validation Accuracy': val_acc
}).melt(id_vars=['epoch'], var_name='Metric', value_name='Accuracy'),
    x='epoch', y='Accuracy', color='Metric')
fig.show()

## Heatmap

In [75]:
import json
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np

# Load the analysis
with open("probe_analysis.json", "r") as f:
    analysis = json.load(f)

# Plotly save heatmap

results_matrix = np.array(analysis["probe_outputs_matrix"])
token_strs = analysis["tokens"]
concepts = analysis["concepts"]

In [76]:
analysis.keys()

dict_keys(['tokens', 'concepts', 'probe_outputs_matrix'])

In [77]:
import json
import numpy as np
from IPython.display import HTML
import html

def visualize_token_activations(json_file_path, concept_name=None, concept_index=None, top_k=10, show_all=False, only_max=False):
    """
    Visualize token activations for a specific concept with color-coded backgrounds.
    Can either show all activations or just the top-k highest activations.
    
    Parameters:
    -----------
    json_file_path : str
        Path to the JSON file containing tokens, concepts, and activation matrix
    concept_name : str, optional
        Name of the concept to visualize. If provided, will look up its index
    concept_index : int, optional
        Index of the concept to visualize. Only needed if concept_name is not provided
    top_k : int, optional
        Number of top activations to highlight (default: 10)
    show_all : bool, optional
        If True, shows all activations. If False, only shows top-k (default: False)
    only_max : bool, optional
        If True, only shows activations with value 1.0 (default: False)
    
    Returns:
    --------
    HTML output displaying tokens with color-coded backgrounds based on activation values
    """
    
    # Load the JSON data
    with open(json_file_path, 'r') as file:
        data = json.load(file)
    
    # Extract the tokens, concepts, and activation matrix
    tokens = data['tokens']
    concepts = data['concepts']
    probe_outputs_matrix = data['probe_outputs_matrix']
    
    # Determine the concept index if concept_name is provided
    if concept_name is not None:
        if concept_name in concepts:
            concept_index = concepts.index(concept_name)
        else:
            print(f"Concept '{concept_name}' not found. Available concepts: {concepts}")
            return None
    elif concept_index is None:
        print("Either concept_name or concept_index must be provided")
        return None
    
    if concept_index < 0 or concept_index >= len(concepts):
        print(f"Concept index out of range. Should be between 0 and {len(concepts)-1}")
        return None
    
    selected_concept = concepts[concept_index]
    activations = probe_outputs_matrix[concept_index]
    
    # Find the indices of the top-k activations if we're not showing all
    if not show_all and not only_max:
        highlight_indices = set(np.argsort(activations)[-top_k:])
    elif only_max:
        # Find indices where activation is 1.0 (or very close to it due to floating point precision)
        highlight_indices = set(np.where(np.isclose(activations, 1.0, rtol=1e-3))[0])
    
    # Create HTML output with colored backgrounds for tokens
    if only_max:
        display_mode = "activations with value 1.0"
    elif show_all:
        display_mode = "all activations"
    else:
        display_mode = f"top {top_k} activations"
        
    html_output = f"<h2>Activation visualization for concept: '{selected_concept}' ({display_mode})</h2>"
    html_output += "<div style='line-height: 2.5; font-family: monospace; font-size: 14px;'>"
    
    for i, (token, activation) in enumerate(zip(tokens, activations)):
        # Escape HTML special characters in the token
        escaped_token = html.escape(token)
        
        # Replace newlines and spaces with visible characters
        escaped_token = escaped_token.replace('\n', '⏎')
        if escaped_token == ' ':
            escaped_token = '␣'
        
        # Determine if this token should be highlighted
        highlight = show_all or (i in highlight_indices)
        
        if highlight:
            # Calculate color intensity directly proportional to activation value
            # White (255,255,255) for 0 activation to intense green (0,255,0) for highest activation
            green_intensity = 255  # Always maximum green
            other_intensity = int(255 * (1 - activation))
            color = f"rgb({other_intensity}, {green_intensity}, {other_intensity})"
            
            # Create a span for the token with background color based on activation and detailed tooltip
            token_span = f"<span title='Token: \"{escaped_token}\"\nConcept: \"{selected_concept}\"\nPosition: #{i}\nActivation: {activation:.4f}' style='background-color: {color}; padding: 3px; border-radius: 3px; margin: 1px;'>{escaped_token}</span>"
        else:
            # No highlighting for tokens not to be displayed, but still show tooltip
            token_span = f"<span title='Token: \"{escaped_token}\"\nConcept: \"{selected_concept}\"\nPosition: #{i}\nActivation: {activation:.4f}' style='padding: 3px; margin: 1px;'>{escaped_token}</span>"
        
        html_output += token_span
    
    html_output += "</div>"
    
    # Add a color scale reference
    html_output += """
    <div style='margin-top: 20px;'>
        <h3>Color Scale</h3>
        <div style='display: flex; width: 400px;'>
            <span style='background-color: rgb(255, 255, 255); width: 100px; padding: 10px; text-align: center;'>0.0</span>
            <span style='background-color: rgb(192, 255, 192); width: 100px; padding: 10px; text-align: center;'>0.25</span>
            <span style='background-color: rgb(128, 255, 128); width: 100px; padding: 10px; text-align: center;'>0.5</span>
            <span style='background-color: rgb(64, 255, 64); width: 100px; padding: 10px; text-align: center;'>0.75</span>
            <span style='background-color: rgb(0, 255, 0); width: 100px; padding: 10px; text-align: center;'>1.0</span>
        </div>
    </div>
    """
    
    return HTML(html_output)

def list_available_concepts(json_file_path):
    """
    List all available concepts in the JSON file.
    
    Parameters:
    -----------
    json_file_path : str
        Path to the JSON file containing the concepts
    
    Returns:
    --------
    List of available concepts
    """
    with open(json_file_path, 'r') as file:
        data = json.load(file)
    
    concepts = data['concepts']
    print("Available concepts:")
    for i, concept in enumerate(concepts):
        print(f"{i}: {concept}")
    
    return concepts

In [78]:
# Path to your JSON file
json_file_path = 'probe_analysis.json'

# List all available concepts
concepts = list_available_concepts(json_file_path)

for i in range(len(concepts)):
    html_output = visualize_token_activations(json_file_path, concept_index=i, only_max=True)
    display(html_output)

# Visualize a specific concept (for example, the first one)
# You can replace 0 with any valid concept index or use a concept name directly
# html_output = visualize_token_activations(json_file_path, concept_index=0, show_all=True, top_k=15)

# # Display the HTML output (in Jupyter notebook or compatible environment)
# display(html_output)

Available concepts:
0: elevated_LDL_cholesterol
1: low_HDL_cholesterol
2: high_total_cholesterol
3: not_previously_on_statin
4: dyslipidemia
5: atorvastatin
6: acute_liver_disease
7: elevated_liver_enzymes
8: pregnancy
9: heavy_alcohol_use
10: renal_impairment
11: hypothyroidism


In [51]:
results_matrix[9].shape

(618,)