In [3]:
# Install required libraries if not already installed
#!pip install sentence_transformers --quiet

# Imports
import json
import numpy as np
import torch
from IPython.display import HTML, display
import html
import os
import nltk
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer
from sentence_transformers import SentenceTransformer, util

# Download necessary NLTK data
nltk.download('punkt')

# Load model globally
def load_model(model_name="gemma-2-2b"):
    """Load the model and return it for reuse"""
    print(f"Loading model: {model_name}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = HookedTransformer.from_pretrained(model_name, device=device)
    return model, model.tokenizer, device

global_model, global_tokenizer, global_device = load_model()

# Process input text and return token activations
def process_example(model, tokenizer, hook_name, probe, text, device):
    tokens = tokenizer.encode(text, return_tensors="pt").to(device)
    token_strs = [tokenizer.decode(t).replace('▁', ' ') for t in tokens[0]]
    
    with torch.no_grad():
        _, cache = model.run_with_cache(tokens, names_filter=[hook_name])
        activations = cache[hook_name]

    scores = []
    for pos in range(activations.shape[1]):
        pos_activations = activations[0, pos].cpu().numpy().reshape(1, -1)
        score = probe.predict_proba(pos_activations)[0, 1]  # Probability of class 1
        scores.append(float(score))
    
    return token_strs, scores

# Load sentence-transformer model for semantic similarity
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')

def compute_semantic_similarity(phrase, concept_text):
    emb_phrase = semantic_model.encode(phrase, convert_to_tensor=True)
    emb_concept = semantic_model.encode(concept_text, convert_to_tensor=True)
    return util.cos_sim(emb_phrase, emb_concept).item()

# Compute activation validation score
def compute_activation_validation_score(token_strs, scores, concept_text):
    groups = []
    current_group_tokens = []
    current_group_scores = []

    for token, score in zip(token_strs, scores):
        current_group_tokens.append(token)
        current_group_scores.append(score)
        if token.strip() in ['.', ',']:
            group_string = "".join(current_group_tokens).strip()
            groups.append((group_string, current_group_scores.copy()))
            current_group_tokens = []
            current_group_scores = []

    if current_group_tokens:
        group_string = "".join(current_group_tokens).strip()
        groups.append((group_string, current_group_scores.copy()))
    
    overall_sum = 0.0
    for group_string, group_scores in groups:
        similarity = compute_semantic_similarity(group_string, concept_text)
        overall_sum += sum(s * similarity for s in group_scores)
    
    return overall_sum

# Validate activation differences between example and unrelated text
def validate_example_unrelated(example_path, unrelated_path, concept_key, concept_string=None, layer=22):
    if concept_string is None:
        concept_string = concept_key.replace("_", " ")
    
    probe_dir = os.path.join("probes", concept_key)
    joblib_path = os.path.join(probe_dir, "probe.joblib")
    pkl_path = os.path.join(probe_dir, "probe.pkl")
    config_path = os.path.join(probe_dir, "config.json")

    if os.path.exists(joblib_path):
        import joblib
        probe = joblib.load(joblib_path)
        print(f"Loaded probe from {joblib_path}")
    elif os.path.exists(pkl_path):
        import pickle
        with open(pkl_path, 'rb') as f:
            probe = pickle.load(f)
        print(f"Loaded probe from {pkl_path}")
    else:
        print(f"Probe not found at {joblib_path} or {pkl_path}")
        return
    
    if os.path.exists(config_path):
        with open(config_path, "r") as f:
            config = json.load(f)
        concept_string = config.get("concept", concept_string)

    hook_name = f"blocks.{layer}.hook_resid_post"

    with open(example_path, "r", encoding="utf-8") as f:
        example_text = f.read().strip()
    example_tokens, example_scores = process_example(global_model, global_tokenizer, hook_name, probe, example_text, global_device)
    example_activation_validation = compute_activation_validation_score(example_tokens, example_scores, concept_string)

    with open(unrelated_path, "r", encoding="utf-8") as f:
        unrelated_text = f.read().strip()
    unrelated_tokens, unrelated_scores = process_example(global_model, global_tokenizer, hook_name, probe, unrelated_text, global_device)
    unrelated_activation_validation = compute_activation_validation_score(unrelated_tokens, unrelated_scores, concept_string)

    overall_difference = example_activation_validation - unrelated_activation_validation

    print("\n=== Activation Validation Results ===")
    print(f"Concept Key: {concept_key}")
    print(f"Concept String: '{concept_string}'")
    print(f"Example Text Score:   {example_activation_validation:.4f}")
    print(f"Unrelated Text Score: {unrelated_activation_validation:.4f}")
    print(f"Overall Difference:   {overall_difference:.4f}")

# List available concepts from JSON file
def list_available_concepts(json_file_path):
    with open(json_file_path, 'r') as file:
        data = json.load(file)
    concepts = data['concepts']
    return [concept.replace(" ", "_") for concept in concepts]

# Visualization function
def visualize_concept_on_text(text, concept_key, model=global_model, tokenizer=global_tokenizer, layer=22):
    probe_dir = os.path.join("probes", concept_key)
    joblib_path = os.path.join(probe_dir, "probe.joblib")
    pkl_path = os.path.join(probe_dir, "probe.pkl")
    config_path = os.path.join(probe_dir, "config.json")

    if os.path.exists(joblib_path):
        import joblib
        probe = joblib.load(joblib_path)
        print(f"Loaded probe from {joblib_path}")
    elif os.path.exists(pkl_path):
        import pickle
        with open(pkl_path, 'rb') as f:
            probe = pickle.load(f)
        print(f"Loaded probe from {pkl_path}")
    else:
        print(f"Probe not found at {joblib_path} or {pkl_path}")
        return None

    if not os.path.exists(config_path):
        print(f"Config not found at {config_path}")
        return None

    with open(config_path, "r") as f:
        config = json.load(f)

    concept = config.get("concept", concept_key.replace("_", " "))
    hook_name = f"blocks.{layer}.hook_resid_post"
    tokens, scores = process_example(model, tokenizer, hook_name, probe, text, global_device)

    html_output = f"<h2>Activation visualization for concept: '{concept}'</h2>"
    html_output += "<div style='line-height: 2.5; font-family: monospace; font-size: 14px;'>"

    for i, (token, score) in enumerate(zip(tokens, scores)):
        escaped_token = html.escape(token)
        green_intensity = 255
        other_intensity = int(255 * (1 - score))
        color = f"rgb({other_intensity}, {green_intensity}, {other_intensity})"
        html_output += f"""<span title='Token: "{escaped_token}"
Position: #{i}
Activation: {score:.4f}' style='background-color: {color}; padding: 3px; border-radius: 3px; margin: 1px;'>{escaped_token}</span>"""
    
    html_output += "</div>"
    return HTML(html_output)

# Load example and unrelated text
with open("inputs/example.txt", "r", encoding="utf-8") as f:
    example_text = f.read()

with open("inputs/unrelated.txt", "r", encoding="utf-8") as f:
    unrelated_text = f.read()

# Load concepts and process each
concepts = list_available_concepts("inputs/concepts_copy.json")
for concept in concepts:
    print("\n============================================")
    print(f"Concept: {concept}")
    
    print("\n--- Visualizing Example Text ---")
    display(visualize_concept_on_text(example_text, concept))

    print("\n--- Visualizing Unrelated Text ---")
    display(visualize_concept_on_text(unrelated_text, concept))

    print("\n--- Validation Results ---")
    validate_example_unrelated("inputs/example.txt", "inputs/unrelated.txt", concept)




Loading model: gemma-2-2b


Downloading shards:   0%|          | 0/3 [01:28<?, ?it/s]


KeyboardInterrupt: 