# Try decoding the reconstructed CAA vector again to see if there's anything interpretable in the change. Didn't find anything.

## Install dependencies

In [1]:
!pip install nnsight matplotlib goodfire huggingface_hub scikit-learn python-dotenv -q

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


## Set HF_HOME for runpod-compatible cache

In [2]:
import os
os.environ['HF_HOME'] = '/workspace/hf'

## Set autoreload, which reloads modules when they are changed

In [3]:
%load_ext autoreload
%autoreload 2

## Load environment variables
Make sure you have a .env file with HF_TOKEN and GOODFIRE_API_KEY! Example:

HF_TOKEN=hf_foo...

GOODFIRE_API_KEY=sk-goodfire-bar...

In [4]:
from dotenv import load_dotenv
if not load_dotenv():
    raise Exception('Error loading .env file. File might be missing or empty.')

assert os.environ.get('HF_TOKEN'), "Missing HF_TOKEN in .env file"
assert os.environ.get('GOODFIRE_API_KEY'), "Missing GOODFIRE_API_KEY in .env file"

## Import dependencies

In [5]:
import goodfire
import torch

from explorations.sae import download_and_load_sae, create_sae_steering_vector, create_sae_steering_vector_latents, latents_to_feature_map
from explorations.lm_wrapper import ObservableLanguageModel
from explorations.utils import set_seed, equalize_prompt_lengths, create_mean_caa_steering_vector, compare_steering_vectors
from explorations.chat import test_all_interventions

## Specify which language model, which SAE to use, and which layer

In [6]:
MODEL_NAME = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
SAE_NAME = 'Llama-3.1-8B-Instruct-SAE-l19'
SAE_LAYER = 'model.layers.19'
EXPANSION_FACTOR = 16 if SAE_NAME == 'Llama-3.1-8B-Instruct-SAE-l19' else 8

## Download and instantiate the Llama model

**This will take a while to download Llama from HuggingFace.**

In [7]:
model = ObservableLanguageModel(
    MODEL_NAME,
)

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

## Download and instantiate the SAE

In [8]:
sae = download_and_load_sae(
    sae_name=SAE_NAME,
    d_model=model.d_model,
    expansion_factor=EXPANSION_FACTOR,
    device=model.device,
)

## Set up Goodfire Client

In [9]:
client = goodfire.Client(api_key=os.environ.get('GOODFIRE_API_KEY'))

pirate_feature_index = 58644
pirate_feature_strength = 12.0
pirate_feature = {pirate_feature_index: pirate_feature_strength}

set_seed(42)

In [11]:
from nnsight.intervention import InterventionProxy
from functools import partial
from explorations.chat import feature_space_sae_intervention, activation_space_sae_intervention
import numpy as np

COEFF = 8

feature_space_intervention = partial(
    feature_space_sae_intervention,
    sae=sae,
    coeff=COEFF
)

#activation_space_intervention = partial(activation_space_sae_intervention, sae=sae, sae_features=pirate_feature, coeff=COEFF)

def activation_space_intervention(activations: InterventionProxy, intervention_features={pirate_feature_index: COEFF * 0.65}):
    sae_vector = torch.zeros(
        activations.shape[0],  # batch size
        activations.shape[1],  # sequence length
        65536,                 # d_sae dimension
        device=activations.device,
        dtype=activations.dtype
    )

    for feature, value in intervention_features.items():
        sae_vector[:, :, [feature]] = value
    decoded_sae_vector = sae.decode(sae_vector)

    return activations + decoded_sae_vector


def check_feature_space_intervention(model, tokens, layer=SAE_LAYER):
    _, _, cache = model.forward(tokens, cache_activations_at=[layer])
    
    acts = cache[layer][0][-1]

    acts = acts.reshape(1, 1, 4096)

    feature_space_result = feature_space_intervention(acts, sae_features=pirate_feature)

    features = sae.encode(acts).detach()
    intervened_features = sae.encode(feature_space_result).detach()

    diff = intervened_features - features
    diff = diff.cpu().float().numpy()[0][0]

    #print significant indices
    significant_indices = np.where(abs(diff) > 0.0)[0]
    print('significant indices', significant_indices)

    intervention_features = {}
    for idx in significant_indices:
        intervention_features[idx] = diff[idx]

    feature_space_result_2 = feature_space_intervention(acts, sae_features=intervention_features)

    intervened_features_2 = sae.encode(feature_space_result_2).detach()

    diff2 = intervened_features_2 - intervened_features
    diff2 = diff2.cpu().float().numpy()[0][0]

    print('diff2 pirate feature', diff2[pirate_feature_index])

    #print significant indices
    significant_indices = np.where(abs(diff2) > 0.0)[0]
    print('significant indices', significant_indices)

    significant_pairs = [(idx, diff2[idx]) for idx in significant_indices]
    significant_pairs.sort(key=lambda x: abs(x[1]), reverse=True)

    # Get all feature descriptions in one API call
    if len(significant_indices) > 0:
        feature_descriptions = client.features.lookup(significant_indices.tolist(), MODEL_NAME)
        print("\nFeature descriptions:")
        for idx, value in significant_pairs:
            feature = feature_descriptions.get(idx)
            if feature is None:
                label = "<Redacted due to sensitivity>"
            else:
                label = feature.label
            print(f"\nFeature {idx} ({value:.3f}): {label}")
            if idx == pirate_feature_index:
                print(f"Pirate feature!")


def extract_comparative_activations(model, tokens, layer=SAE_LAYER):
    _, _, cache = model.forward(tokens, cache_activations_at=[layer])
    
    acts = cache[layer][0][-1]

    acts = acts.reshape(1, 1, 4096)

    activation_space_result = activation_space_intervention(acts)
    feature_space_result = feature_space_intervention(acts)

    diff = feature_space_result - activation_space_result
    diff_norm = torch.norm(diff)
    print(f"Difference norm: {diff_norm:.3f}")

    diff_features = sae.encode(diff).detach()
    diff_decoded = sae.decode(diff_features)
    diff_error = diff - diff_decoded
    print(f"Difference error norm: {torch.norm(diff_error):.3f}")

    # Get features as numpy array for easier processing
    features = diff_features.detach().float().cpu().numpy()[0][0]

    # Find all features with significant activation
    significant_indices = np.where(np.abs(features) > 1.2)[0]
    print('significant_indices.shape', significant_indices.shape)
    significant_values = features[:, :, significant_indices]
    #significant_values = significant_values.reshape(-1)
    print('significant_values', significant_values)
    print('significant_values.shape', significant_values.shape)
    print(f"Number of significant features: {len(significant_indices)}")

    # Sort by absolute value
    sorted_order = np.argsort(np.abs(significant_values))[::-1]  # Descending order
    sorted_indices = significant_indices[sorted_order][:100]
    sorted_values = significant_values[sorted_order]

    print('sorted_order.shape', sorted_order.shape)
    print('sorted_indices.shape', sorted_indices.shape)
    print('sorted_indices', sorted_indices)
    print('sorted_values.shape', sorted_values.shape)

    # Get all feature descriptions in one API call
    if len(sorted_indices) > 0:
        feature_descriptions = client.features.lookup(sorted_indices.tolist(), MODEL_NAME)
        print("\nFeature descriptions:")
        for idx, value in zip(sorted_indices, sorted_values):
            feature = feature_descriptions.get(idx)
            if feature is None:
                label = "<Redacted due to sensitivity>"
            else:
                label = feature.label
            print(f"\nFeature {idx} ({value:.3f}): {label}")
            if idx == pirate_feature_index:
                print(f"Pirate feature!")


original_input_tokens = model.tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Hello, how are you?"},
    ],
    add_generation_prompt=True,
    return_tensors="pt",
)

print('\nFeature space intervention')
input_tokens = original_input_tokens.clone()
for i in range(1):
  check_feature_space_intervention(model, input_tokens)

  logits, kv_cache, feature_cache = model.forward(
      input_tokens,
      interventions={SAE_LAYER: feature_space_intervention},
  )

  new_token = logits[-1].argmax(-1)
  input_tokens = torch.cat([input_tokens[0], new_token.unsqueeze(0).cpu()]).unsqueeze(0)
  if new_token == 128009:
    print("\n<EOT reached>")
    break

  decoded_new_token = model.tokenizer.decode(new_token)

  print(decoded_new_token, end="")


Feature space intervention
significant indices [ 1100  3440  3574  5582  6399  7240 10245 16756 17105 17153 20653 24708
 25445 26760 27500 29446 29960 30200 30616 33154 35091 37142 37219 38555
 38643 41732 43806 43913 49130 49225 49626 52520 52987 54173 55942 57469
 57952 58644 61443 63452]
diff2 pirate feature -2.03125
significant indices [ 1100  3440  3574  5582  6399  7240 10245 16756 17105 17153 20653 24708
 25445 26760 27500 29446 29960 30200 30616 33154 35091 37142 37219 38555
 38643 41732 43806 43913 49130 49225 49626 52520 52987 54173 55942 57469
 57952 58644 61443 63452]

Feature descriptions:

Feature 58644 (-2.031): The assistant should roleplay as a pirate
Pirate feature!

Feature 30200 (-0.254): Assistant explaining its AI nature and limitations

Feature 33154 (-0.234): Formatting newlines that separate different parts of chat conversations

Feature 49626 (0.219): Discussions and explanations of One Piece anime/manga lore

Feature 29446 (-0.121): The assistant should role

AttributeError: 'NoneType' object has no attribute 'values'


Feature space intervention
significant indices [ 1100  3440  3574  5582  6399  7240 10245 16756 17105 17153 20653 24708
 25445 26760 27500 29446 29960 30200 30616 33154 35091 37142 37219 38555
 38643 41732 43806 43913 49130 49225 49626 52520 52987 54173 55942 57469
 57952 58644 61443 63452]
diff2 pirate feature -0.375
significant indices [ 1100  3440  3574  5582  6399  7240 10245 16756 17105 17153 20653 24708
 25445 26760 27500 29446 29960 30200 30616 33154 35091 37142 37219 38555
 38643 41732 43806 43913 44062 49130 49225 49626 52520 52987 54173 55942
 57469 57952 58644 61443 63452]

Feature descriptions:

Feature 49626 (0.992): Discussions and explanations of One Piece anime/manga lore

Feature 30200 (-0.469): Assistant explaining its AI nature and limitations

Feature 33154 (-0.430): Formatting newlines that separate different parts of chat conversations

Feature 29446 (0.379): The assistant should roleplay as a pirate

Feature 58644 (-0.375): The assistant should roleplay as a pir