In [None]:
pip install transformer_lens sae_lens langdetect unbabel-comet seaborn sentence_transformers scipy--quiet

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pandas as pd
import sys
from scipy.stats import pearsonr
from sentence_transformers import SentenceTransformer, util
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from transformer_lens.hook_points import HookPoint
from langdetect import detect
from comet import download_model, load_from_checkpoint
import plotly.express as px
import plotly.graph_objects as go
from matplotlib.colors import LinearSegmentedColormap
import networkx as nx
from functools import partial
import pickle
from jaxtyping import Float, Int

from typing import List
from rich import print as rprint
from rich.table import Table
from huggingface_hub import notebook_login
import warnings
warnings.filterwarnings('ignore')

In [None]:
gemma_9b: HookedSAETransformer = HookedSAETransformer.from_pretrained("gemma-2-9b", device=device)

In [None]:
GENERATE_KWARGS = dict(temperature=0.5, freq_penalty=1.0, verbose=False)

def steering_hook(
    activations: Float[torch.Tensor, "batch pos d_in"],
    hook: HookPoint,
    sae: SAE,
    latent_idx: int,
    steering_coefficient: float,
) -> torch.Tensor:
    """
    Steers the model by returning a modified activations tensor, with some multiple of the steering vector added to all
    sequence positions.
    """
    return activations + steering_coefficient * sae.W_dec[latent_idx]

#USE RUN WITH CACHE WWITH SAE
def generate_with_multi_steering(
    model: HookedSAETransformer,
    steer_saes: list[SAE],
    latent_idxs: List[int],
    steering_coefficients: List[float],
    monitor_saes: list[SAE],
    prompts: List[str],
    max_new_tokens: int = 50,
):
    """
    Generates text with multiple steering vectors applied from different SAEs.
    Each SAE modifies activations at its own hook point with its own latent and coefficient.
    """
    assert len(steer_saes) == len(latent_idxs) == len(steering_coefficients), "Mismatch in lengths of SAE-related lists."

    # Create individual hook functions for each SAE
    hooks = []
    for sae, latent_idx, coeff in zip(steer_saes, latent_idxs, steering_coefficients):
        hook_fn = partial(
            steering_hook,
            sae=sae,
            latent_idx=latent_idx,
            steering_coefficient=coeff,
        )
        hooks.append((sae.cfg.hook_name, hook_fn))

    with model.hooks(fwd_hooks=hooks):
        if monitor_saes:
            _, cache = model.run_with_cache_with_saes(prompts, saes=monitor_saes)
        output = model.generate(prompts, max_new_tokens=max_new_tokens, **GENERATE_KWARGS)
    if monitor_saes:
        monitor_activations = [cache[f"{sae.cfg.hook_name}.hook_sae_acts_post"] for sae in monitor_saes]
        return output, monitor_activations
    return output

def generate_with_multi_steering_batch(
    model,
    steer_layers: List[int],
    steer_feature_idx: List[int],
    steer_feature_diff: List[float],
    monitor_layers: List[int] = [],
    prompts: List[str] = ['I was learning'],
):
    assert len(steer_layers) == len(steer_feature_idx) == len(steer_feature_diff), "All input lists must be the same length."

    # Load all required SAEs (one per layer)
    steer_saes = []
    for layer in steer_layers:
        sae, cfg_dict_layer, sparsity_layer = SAE.from_pretrained(
            release='gemma-scope-9b-pt-res-canonical',
            sae_id=f"layer_{layer}/width_16k/canonical",
            device=str(device),
        )
        steer_saes.append(sae)
    monitor_saes = []
    for layer in monitor_layers:
        sae, cfg_dict_layer, sparsity_layer = SAE.from_pretrained(
            release='gemma-scope-9b-pt-res-canonical',
            sae_id=f"layer_{layer}/width_16k/canonical",
            device=str(device),
        )
        monitor_saes.append(sae)
    # Store all outputs
    generated_sentences = {}

    # No steering generation
    generated_sentences[0] = model.generate(prompts, max_new_tokens=50, **GENERATE_KWARGS)

    # Multi-steered generation
    if monitor_layers:
        steered_output, monitor_activations = generate_with_multi_steering(
            model=model,
            steer_saes=steer_saes,
            latent_idxs=steer_feature_idx,
            steering_coefficients=[diff for diff in steer_feature_diff],
            monitor_saes = monitor_saes,
            prompts=prompts,
            max_new_tokens=50,
        )
        generated_sentences[1] = steered_output
        return generated_sentences, monitor_activations
    else:
        steered_output = generate_with_multi_steering(
            model=model,
            steer_saes=steer_saes,
            latent_idxs=steer_feature_idx,
            steering_coefficients=[diff for diff in steer_feature_diff],
            monitor_saes = monitor_saes,
            prompts=prompts,
            max_new_tokens=50,
        )
        generated_sentences[1] = steered_output
        return generated_sentences

def generate_with_steering_batch(
    model,
    steer_layer: int,
    steer_feature_idxs: List[int],
    steer_feature_diffs: List[float],
    prompts: List[str] = ['I was learning'],
):
    sae, cfg_dict_layer, sparsity_layer = SAE.from_pretrained(
            release='gemma-scope-9b-pt-res-canonical',
            sae_id=f"layer_{steer_layer}/width_16k/canonical",
            device=str(device),
        )
    generated_sentences = {}
    generated_sentences[0] = model.generate(prompts, max_new_tokens=50, **GENERATE_KWARGS)
    for i in range(len(steer_feature_idxs)):
        steer_feature_idx = [steer_feature_idxs[i]]
        steer_feature_diff = [steer_feature_diffs[i]]
        steered_output = generate_with_multi_steering(
            model=model,
            steer_saes=[sae],
            latent_idxs=steer_feature_idx,
            steering_coefficients=[diff for diff in steer_feature_diff],
            monitor_saes = [],
            prompts=prompts,
            max_new_tokens=50,
        )
        generated_sentences[i+1] = steered_output
    return generated_sentences