In [1]:
%load_ext autoreload
%autoreload 2
import torch
from teren import utils as teren_utils
import seaborn as sns
import plotly_express as px

device = teren_utils.get_device_str()
print(f"{device=}")

device='cuda'


In [2]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
# each layer is kind of independent experiment, but we save a lot of time by
# collecting dataset examples and residual stream activation for many layers at once

LAYERS = list(range(model.cfg.n_layers))
# LAYERS = 0, 5, 10

# pre-trained SAE model availavle in SAELens
SAE_RELEASE = "gpt2-small-res-jb"
LAYER_TO_SAE_ID = lambda layer: f"blocks.{layer}.hook_resid_pre"

# HF text  dataset
DATASET_PATH = "NeelNanda/c4-10k"
DATASET_SPLIT = "train[:1%]"

# how long each tokenized prompt should be
CONTEXT_SIZE = 32

# how many tokens can be forwarded through the transformer at once
# tune to your VRAM, 12_800 works well with 16GB
INFERENCE_TOKENS = 320 # 12_800
INFERENCE_BATCH_SIZE = INFERENCE_TOKENS // CONTEXT_SIZE
print(f"{INFERENCE_BATCH_SIZE=}")

# what is the minimum activation for a feature to be considered active
MIN_FEATURE_ACTIVATION = 0.0
# minimum number of tokenized prompts that should have a feature active
# for the feature to be included in the experiment
MIN_EXAMPLES_PER_FEATURE = 30
# ids of feature to consider in the experiment
# we can't consider all the features, because of compute & memory constraints
CONSIDER_FEATURE_IDS = range(30)

INFERENCE_BATCH_SIZE=10


In [4]:
from sae_lens import SAE

saes = [
    # ignore other things it returns
    SAE.from_pretrained(
        release=SAE_RELEASE, sae_id=LAYER_TO_SAE_ID(layer), device=device
    )[0]
    for layer in LAYERS
]

In [5]:
all_input_ids = teren_utils.load_and_tokenize_dataset(
    path=DATASET_PATH,
    split=DATASET_SPLIT,
    column_name="text",
    tokenizer=model.tokenizer,  # type: ignore
    max_length=CONTEXT_SIZE,
)
print(f"all_input_ids (tokenized dataset) shape: {tuple(all_input_ids.shape)}")

all_input_ids (tokenized dataset) shape: (1582, 32)


In [6]:
from teren.sae_examples import get_examples_by_feature_by_sae, SAEExamplesByFeature

examples_by_feature_by_sae: list[SAEExamplesByFeature] = get_examples_by_feature_by_sae(
    input_ids=all_input_ids,
    model=model,
    saes=saes,
    feature_ids=CONSIDER_FEATURE_IDS,
    n_examples=MIN_EXAMPLES_PER_FEATURE,
    batch_size=INFERENCE_BATCH_SIZE,
    min_activation=MIN_FEATURE_ACTIVATION,
)
# display high-level summary of the data
for layer, examples_by_feature in zip(LAYERS, examples_by_feature_by_sae):
    print(f"Layer: {layer}")
    active_feature_ids = examples_by_feature.active_feature_ids
    print(f"Number of selected features: {len(active_feature_ids)}")
    print(f"Active feature ids: {active_feature_ids}")
    print()

Layer: 0
Number of selected features: 1
Active feature ids: [9]

Layer: 1
Number of selected features: 1
Active feature ids: [11]

Layer: 2
Number of selected features: 1
Active feature ids: [0]

Layer: 3
Number of selected features: 2
Active feature ids: [6, 24]

Layer: 4
Number of selected features: 5
Active feature ids: [7, 8, 12, 15, 19]

Layer: 5
Number of selected features: 14
Active feature ids: [0, 2, 4, 5, 9, 10, 13, 15, 19, 24, 25, 26, 27, 28]

Layer: 6
Number of selected features: 10
Active feature ids: [5, 7, 9, 13, 14, 16, 17, 18, 20, 27]

Layer: 7
Number of selected features: 15
Active feature ids: [0, 3, 4, 5, 6, 10, 14, 17, 18, 20, 21, 22, 23, 26, 29]

Layer: 8
Number of selected features: 19
Active feature ids: [0, 2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 16, 21, 22, 24, 25, 26, 29]

Layer: 9
Number of selected features: 22
Active feature ids: [0, 2, 4, 5, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 24, 27, 28, 29]

Layer: 10
Number of selected features: 15
Activ

In [7]:
def print_examples_shapes(idx):
    print(f"Layer: {LAYERS[idx]}")
    examples_by_feature = examples_by_feature_by_sae[idx]
    print(
        f"input_ids shape: {tuple(examples_by_feature.input_ids.shape)} (n_features, n_examples, context_size)"
    )
    print(
        f"resid_acts shape: {tuple(examples_by_feature.resid_acts.shape)} (n_features, n_examples, context_size, d_model)"
    )
    print(
        f"clean_loss shape: {tuple(examples_by_feature.clean_loss.shape)} (n_features, n_examples, context_size - 1)"
    )


print_examples_shapes(0)

Layer: 0
input_ids shape: (1, 30, 32) (n_features, n_examples, context_size)
resid_acts shape: (1, 30, 32, 768) (n_features, n_examples, context_size, d_model)
clean_loss shape: (1, 30, 31) (n_features, n_examples, context_size - 1)


In [26]:
import torch
from dataclasses import dataclass
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
from teren.perturbations import Perturbation, NaiveRandomPerturbation
import torch.nn.functional as F
from jaxtyping import Float, Int
from contextlib import contextmanager
import umap



In [74]:
@contextmanager
def torch_gc():
    try:
        yield
    finally:
        torch.cuda.empty_cache()

def get_random_direction_perturbations(num_directions, num_steps, feature_ids):
    return [
        {feature_id: {i: NaiveRandomPerturbation() for i in range(num_steps)} for feature_id in feature_ids}
        for _ in range(num_directions)
    ]

def compute_kl_div(logits_ref: torch.Tensor, logits_pert: torch.Tensor) -> torch.Tensor:
    # Move computation to CPU
    logits_ref = logits_ref.cpu()
    logits_pert = logits_pert.cpu()
    
    logprobs_ref = F.log_softmax(logits_ref, dim=-1)
    logprobs_pert = F.log_softmax(logits_pert, dim=-1)
    temp_output = F.kl_div(logprobs_pert, logprobs_ref, log_target=True, reduction="none")
    return temp_output.sum(dim=-1)

def compute_pert_vectors(
    pert_by_feature_id: dict[int, dict[int, Perturbation]],
    examples_by_feature: SAEExamplesByFeature,
    current_step: int
):
    # (feature, batch, n_ctx, d_model)
    pert_vectors = torch.empty_like(examples_by_feature.resid_acts, device=device)
    
    for idx, feature_id in enumerate(examples_by_feature.active_feature_ids):
        if feature_id in pert_by_feature_id:
            pert = pert_by_feature_id[feature_id][current_step]
            resid_acts = examples_by_feature.resid_acts[idx]
            perturb = pert(resid_acts)
            pert_vectors[idx] = perturb
        else:
            # If no perturbation is defined for this feature, use zeros
            pert_vectors[idx] = torch.zeros_like(examples_by_feature.resid_acts[idx])
    
    return pert_vectors


def compute_logits(
    model: HookedTransformer,
    input_ids: Int[torch.Tensor, "*batch seq"],
    resid_acts: Float[torch.Tensor, "*batch seq d_model"],
    start_at_layer: int,
    batch_size: int,
) -> Float[torch.Tensor, "*batch seq vocab"]:
    assert input_ids.shape == resid_acts.shape[:-1], f"{input_ids.shape=} {resid_acts.shape=}"

    batch_shape = input_ids.shape[:-1]
    d_seq, d_model = resid_acts.shape[-2:]
    resid_acts_flat = resid_acts.view(-1, d_seq, d_model)

    logits_list = []

    device = model.cfg.device
    for i in range(0, resid_acts_flat.shape[0], batch_size):
        batch_resid_acts_flat = resid_acts_flat[i : i + batch_size].to(device)

        logits = model(
            batch_resid_acts_flat,
            start_at_layer=start_at_layer,
        )
        logits_list.append(logits.cpu())
    logits_flat = torch.cat(logits_list, dim=0)
    logits_shape = batch_shape + (d_seq, -1)  # -1 for vocab size
    
    return logits_flat.view(logits_shape)


def get_pert_kl_div(
    sae: SAE, 
    examples_by_feature: SAEExamplesByFeature, 
    pert_vec: torch.Tensor,
    base_logits: torch.Tensor
) -> torch.Tensor:
    with torch_gc():
        device = pert_vec.device
        resid_acts = examples_by_feature.resid_acts.to(device)
        
        perturbed_logits = compute_logits(
            model=model,
            input_ids=examples_by_feature.input_ids.to(device),
            resid_acts=resid_acts + pert_vec,
            start_at_layer=sae.cfg.hook_layer,
            batch_size=INFERENCE_BATCH_SIZE,
        )
        
        # Move to CPU for KL divergence computation
        kl_div = compute_kl_div(base_logits.cpu(), perturbed_logits.cpu())
        
        # Average over all dimensions except the first (feature dimension)
        return kl_div.mean(dim=tuple(range(1, kl_div.dim())))

def compute_trajectory_kl_divs(
    sae: SAE, 
    examples_by_feature: SAEExamplesByFeature, 
    pert_by_feature_id_by_direction: list,
    step_size: float
):
    with torch_gc():
        input_ids = examples_by_feature.input_ids.to(device)
        resid_acts = examples_by_feature.resid_acts.to(device)
        
        base_logits = compute_logits(
            model=model,
            input_ids=input_ids,
            resid_acts=resid_acts,
            start_at_layer=sae.cfg.hook_layer,
            batch_size=INFERENCE_BATCH_SIZE,
        ).cpu()

        kl_divs_by_direction = []
        trajectories_by_direction = []
        start_vector = resid_acts.cpu()

        # print(f"Number of directions: {len(pert_by_feature_id_by_direction)}")
        
        for dir_idx, pert_by_feature_id in enumerate(pert_by_feature_id_by_direction):
            kl_divs = [torch.tensor(0.0)]  # Start with 0 KL divergence for initial point
            trajectory = [start_vector]
            
            if not pert_by_feature_id:
                print(f"Direction {dir_idx}: Empty perturbation dict")
                continue
            
            num_steps = len(next(iter(pert_by_feature_id.values())))
            # print(f"Direction {dir_idx}: Number of steps = {num_steps}")
            
            for current_step in range(1, num_steps + 1):  # Start from 1 to skip initial point
                with torch_gc():
                    pert_vec = compute_pert_vectors(pert_by_feature_id, examples_by_feature, current_step - 1).to(device)
                    pert_vec = pert_vec * step_size * (current_step)
                    
                    perturbed_vector = start_vector + pert_vec.cpu()
                    kl_div = get_pert_kl_div(sae, examples_by_feature, pert_vec, base_logits)
                    
                    # print(f"Direction {dir_idx}, Step {current_step}: KL div shape = {kl_div.shape}, value = {kl_div.item():.4f}")
                    
                    kl_divs.append(kl_div)
                    trajectory.append(perturbed_vector)
                    
                    del pert_vec
                    torch.cuda.empty_cache()
            
            # print(f"Direction {dir_idx}: KL divs length = {len(kl_divs)}")
            # print(f"Direction {dir_idx}: KL divs shapes = {[kl.shape for kl in kl_divs]}")
            
            # Instead of stacking, we'll just convert to a tensor
            kl_divs_tensor = torch.tensor([kl.item() for kl in kl_divs])
            kl_divs_by_direction.append(kl_divs_tensor)
            trajectories_by_direction.append(torch.stack(trajectory))

        return kl_divs_by_direction, trajectories_by_direction

In [71]:
num_directions = 5
num_steps = 10
step_size = 0.1
feature_ids = examples_by_feature_by_sae[0].active_feature_ids

pert_by_feature_id_by_direction = get_random_direction_perturbations(num_directions, num_steps, feature_ids)

In [92]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pandas as pd

def plot_kl_divergence_trajectories_plotly(kl_divs_by_direction, sae_idx, feature_id, num_steps):
    # Prepare data for plotting
    data = []
    max_kl_div = max([kl_divs.max().item() for kl_divs in kl_divs_by_direction])
    for direction_idx, kl_divs in enumerate(kl_divs_by_direction):
        angle = 2 * np.pi * direction_idx / len(kl_divs_by_direction)
        x = [step * np.cos(angle) for step in range(len(kl_divs))]
        y = [step * np.sin(angle) for step in range(len(kl_divs))]
        
        for step, (xi, yi, kl) in enumerate(zip(x, y, kl_divs)):
            data.append({
                'x': xi,
                'y': yi,
                'kl_div': round(kl.item(), 4),
                'direction': f'Direction {direction_idx}',
                'step': step
            })
    
    # Convert to DataFrame
    df = pd.DataFrame(data)
    
    # Create the scatter plot
    fig = px.scatter(df, x='x', y='y', color='direction', size='kl_div',
                     hover_data=['step', 'kl_div'],
                     labels={'x': 'Direction 1', 'y': 'Direction 2'},
                     title=f'KL Divergence Trajectories - SAE {sae_idx}, Feature {feature_id}',
                     size_max=10)  # Adjust the maximum size of the points
    
    # Add lines connecting points in each direction
    for direction in df['direction'].unique():
        direction_data = df[df['direction'] == direction].sort_values('step')
        fig.add_trace(go.Scatter(x=direction_data['x'], y=direction_data['y'],
                                 mode='lines', showlegend=False,
                                 line=dict(color='rgba(0,0,0,0.3)', width=1)))
    
    # Add starting point
    fig.add_trace(go.Scatter(x=[0], y=[0], mode='markers',
                             marker=dict(color='black', size=15, symbol='square'),
                             name='Start'))
    
    # Calculate the plot bounds
    max_abs_coord = df[['x', 'y']].abs().max().max()
    axis_range = [-max_abs_coord * 1.1, max_abs_coord * 1.1]
    
    # Update layout for better appearance
    fig.update_layout(
        xaxis=dict(showticklabels=False, showgrid=False, zeroline=False, range=axis_range),
        yaxis=dict(showticklabels=False, showgrid=False, zeroline=False, range=axis_range, scaleanchor="x", scaleratio=1),
        showlegend=True,
        legend_title_text='Directions',
        hovermode='closest',
        width=800,  # Set width
        height=800  # Set height to be equal to width for a square plot
    )
    
    # Show the plot
    fig.show()

In [94]:
# Choose the first SAE and its corresponding examples
sae_idx = 0
sae = saes[sae_idx]

examples_by_feature = examples_by_feature_by_sae[sae_idx]

# Choose a single feature (e.g., the first active feature)
feature_id = examples_by_feature.active_feature_ids[0]

# Create a new examples_by_feature with only one feature
single_feature_examples = SAEExamplesByFeature(
    active_feature_ids=[feature_id],
    input_ids=examples_by_feature.input_ids[:1],  # Take only the first feature's data
    resid_acts=examples_by_feature.resid_acts[:1],  # Take only the first feature's data
    clean_loss=examples_by_feature.clean_loss[:1] if examples_by_feature.clean_loss is not None else None
)

# Reduce the number of directions and steps for quicker computation
num_directions = 10
num_steps = 20
step_size = 0.01

# Generate perturbations for the single feature
pert_by_feature_id_by_direction = get_random_direction_perturbations(num_directions, num_steps, [feature_id])

# Compute KL divergences for the single SAE and feature
kl_divs, trajectories = compute_trajectory_kl_divs(sae, single_feature_examples, pert_by_feature_id_by_direction, step_size)

print(f"KL divergences shape: {len(kl_divs)}x{kl_divs[0].shape}")
print("KL divergences:")
for direction_idx, direction_kl_divs in enumerate(kl_divs):
    print(f"Direction {direction_idx}:")
    print(direction_kl_divs)

KL divergences shape: 10xtorch.Size([21])
KL divergences:
Direction 0:
tensor([0.0000e+00, 7.9713e-04, 3.3899e-03, 6.3430e-03, 1.2528e-02, 2.1031e-02,
        2.7742e-02, 4.7599e-02, 6.3532e-02, 7.4351e-02, 1.0227e-01, 1.3088e-01,
        1.9712e-01, 2.3081e-01, 3.3413e-01, 4.2425e-01, 6.9776e-01, 9.8177e-01,
        1.3584e+00, 1.7899e+00, 2.3067e+00])
Direction 1:
tensor([0.0000e+00, 7.6170e-04, 2.8191e-03, 7.0368e-03, 1.4715e-02, 1.9455e-02,
        3.4691e-02, 4.3873e-02, 5.3520e-02, 7.3310e-02, 9.4539e-02, 1.3544e-01,
        1.8176e-01, 2.4105e-01, 3.1055e-01, 4.8589e-01, 6.2440e-01, 1.0604e+00,
        1.3442e+00, 1.8934e+00, 2.2537e+00])
Direction 2:
tensor([0.0000e+00, 7.5021e-04, 3.0316e-03, 6.0261e-03, 1.3848e-02, 2.0206e-02,
        3.2759e-02, 4.0307e-02, 5.7861e-02, 8.2572e-02, 1.0558e-01, 1.3552e-01,
        1.7059e-01, 2.3241e-01, 3.1288e-01, 4.3885e-01, 6.1410e-01, 8.8282e-01,
        1.2889e+00, 1.8078e+00, 2.2909e+00])
Direction 3:
tensor([0.0000e+00, 7.2102e-04, 2.9

In [95]:
plot_kl_divergence_trajectories_plotly(kl_divs, sae_idx, feature_id, num_steps)

In [65]:
# def plot_kl_divergences_line(kl_divs, title="KL Divergences for Different Directions"):
#     plt.figure(figsize=(10, 6))
#     for i, direction_kl_divs in enumerate(kl_divs):
#         plt.plot(direction_kl_divs.squeeze(), label=f'Direction {i}')
    
#     plt.xlabel('Steps')
#     plt.ylabel('KL Divergence')
#     plt.title(title)
#     plt.legend()
#     plt.grid(True)
#     plt.show()

# plot_kl_divergences_line(kl_divs, "KL Divergences for SAE X, Feature Y")

In [12]:
# When running the main computation:
kl_divs_by_direction_by_sae = []

for sae, examples_by_feature in zip(saes, examples_by_feature_by_sae):
    with torch_gc():
        result = compute_trajectory_kl_divs(sae, examples_by_feature, pert_by_feature_id_by_direction, step_size)
        kl_divs_by_direction_by_sae.append(result)
        torch.cuda.empty_cache()  # Clear CUDA cache after each SAE computation

In [12]:
def plot_trajectories(kl_divs_by_direction, sae_idx):
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # Plot the starting point
    ax.scatter([0], [0], c='red', s=100, label='Start')
    
    # Plot trajectories
    for direction_idx, kl_divs in enumerate(kl_divs_by_direction):
        x = [step * np.cos(2 * np.pi * direction_idx / len(kl_divs_by_direction)) for step in range(num_steps)]
        y = [step * np.sin(2 * np.pi * direction_idx / len(kl_divs_by_direction)) for step in range(num_steps)]
        points = ax.scatter(x, y, c=kl_divs.mean(dim=1).cpu(), cmap='viridis', s=20)
    
    plt.colorbar(points, label='KL Divergence')
    ax.set_xlabel('Direction 1')
    ax.set_ylabel('Direction 2')
    ax.set_title(f'Trajectories from Real Activation - SAE {sae_idx}')
    plt.legend()
    plt.show()
