In [1]:
"""
This file is intended to identify the optimal method for reconstructing a CAA vector using an SAE model.

It is research code, the exploration code been generated by Claude 3.5 and was not proof-read.

1. Should the code not run, try upgrading transformers in your env:
pip install --upgrade transformers

2. The next cell is a system path fix. It may or may not be necessary to run in your environment.
"""



'\nThis file is intended to identify the optimal method for reconstructing a CAA vector using an SAE model.\n\nIt is research code, the exploration code been generated by Claude 3.5 and was not proof-read.\n\n1. Should the code not run, try upgrading transformers in your env:\npip install --upgrade transformers\n\n2. The next cell is a system path fix. It may or may not be necessary to run in your environment.\n'

In [2]:
# TRANSFORMER / ENVIRONMENT FIX
import sys
sys.path.append('/root/CAA/venv/lib/python3.10/site-packages')
print(sys.path)

['/opt/conda/lib/python310.zip', '/opt/conda/lib/python3.10', '/opt/conda/lib/python3.10/lib-dynload', '', '/opt/conda/lib/python3.10/site-packages', '/root/CAA/venv/lib/python3.10/site-packages']


In [3]:
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import time
import torch

from dotenv import load_dotenv
from behaviors import ALL_BEHAVIORS, BASE_DIR, COORDINATE, get_vector_path
from gemma_2_wrapper import Gemma2Wrapper
from generate_vectors import generate_save_vectors_for_behavior
import gemma_vector_analysis
from huggingface_hub import hf_hub_download
import numpy as np
import torch.nn as nn

In [4]:
load_dotenv()
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")

In [None]:
# EXPERIMENT 1 (EACH EXPERIMENT CAN BE RUN INDEPENDENTLY, YOU CAN JUMP TO EXP 9
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer = 14

sae_model_name = "google/gemma-scope-2b-pt-res"
path_to_params = hf_hub_download(
    repo_id=sae_model_name,
    filename=f"layer_{layer}/width_16k/average_l0_84/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon
        
    def process_caa_vector(self, caa_vector, scale_factor=15, use_leaky_relu=True, negative_slope=0.01):
        caa_vector = caa_vector.to(self.W_enc.device)
        caa_vector_scaled = caa_vector * scale_factor
        adjusted_vector = caa_vector_scaled + self.b_dec
        pre_acts = adjusted_vector @ self.W_enc + self.b_enc
        
        if use_leaky_relu:
            encoded = F.leaky_relu(pre_acts, negative_slope=negative_slope)
        else:
            encoded = F.relu(pre_acts)
        
        decoded = encoded @ self.W_dec + self.b_dec
        decoded_adjusted = decoded - self.b_dec
        decoded_normalized = F.normalize(decoded_adjusted, dim=-1)
        return decoded_normalized

def optimize_parameters(sae, caa_vectors, device):
    best_avg_similarity = 0
    best_params = {}
    
    scale_factors = np.logspace(0, 2, 20)
    use_leaky_relu_options = [True, False]
    negative_slopes = [0.01, 0.1, 0.2]
    
    for scale_factor in scale_factors:
        for use_leaky_relu in use_leaky_relu_options:
            for negative_slope in negative_slopes:
                similarities = []
                
                for behavior, caa_vector in caa_vectors.items():
                    caa_vector = caa_vector.to(device)
                    forwarded_caa_vector = sae.process_caa_vector(caa_vector, scale_factor, use_leaky_relu, negative_slope)
                    
                    cosine_sim = F.cosine_similarity(caa_vector.view(1, -1), forwarded_caa_vector.view(1, -1), dim=1)
                    similarities.append(cosine_sim.item())
                
                avg_similarity = np.mean(similarities)
                
                if avg_similarity > best_avg_similarity:
                    best_avg_similarity = avg_similarity
                    best_params = {
                        'scale_factor': scale_factor,
                        'use_leaky_relu': use_leaky_relu,
                        'negative_slope': negative_slope
                    }
                
                print(f"Scale: {scale_factor:.2f}, Leaky ReLU: {use_leaky_relu}, Slope: {negative_slope:.2f}, Avg Similarity: {avg_similarity:.4f}")

    return best_params, best_avg_similarity

sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae = sae.to(device)
print("Success")

caa_vectors = {}
for behavior in ALL_BEHAVIORS:
    layer = 14
    behavior = behavior
    model_name_path = 'gemma-2-2b'
    vector_path = gemma_vector_analysis.fix_vector_path(get_vector_path(behavior, layer, model_name_path))
    normalized_dir = os.path.join(BASE_DIR, 'normalized_vectors', behavior)
    
    original_path = os.path.join(normalized_dir, f"vec_layer_{layer}_{model_name_path}.pt")
    original_path = gemma_vector_analysis.fix_vector_path(original_path)
    caa_vector = torch.load(original_path)
    caa_vectors[behavior] = caa_vector

    print(f"Analyzing vector for behavior: {behavior}")
    
    caa_vector = caa_vector.to(device)
    
    forwarded_caa_vector = sae.process_caa_vector(caa_vector)
    print(f"Norm of the forwarded vector: {forwarded_caa_vector.norm().item():.4f}")
    
    sae_vector_dir = '/root/CAA/sae_vector'
    if not os.path.exists(sae_vector_dir):
        os.makedirs(sae_vector_dir)
    
    forwarded_vector_path = os.path.join(sae_vector_dir, f"forwarded_{behavior}.pt")
    torch.save(forwarded_caa_vector.cpu(), forwarded_vector_path)
    print(f"Saved forwarded vector for behavior {behavior} at {forwarded_vector_path}")
    
    caa_vector = caa_vector.to(device)
    forwarded_caa_vector = forwarded_caa_vector.to(device)
    
    caa_vector = caa_vector.view(1, -1)
    forwarded_caa_vector = forwarded_caa_vector.view(1, -1)
    
    cosine_sim = F.cosine_similarity(caa_vector, forwarded_caa_vector, dim=1)
    
    print(f"Cosine similarity between original and forwarded vector: {cosine_sim.item():.4f}")

best_params, best_avg_similarity = optimize_parameters(sae, caa_vectors, device)

print(f"\nBest parameters found:")
print(f"Scale factor: {best_params['scale_factor']:.2f}")
print(f"Use Leaky ReLU: {best_params['use_leaky_relu']}")
print(f"Negative slope: {best_params['negative_slope']:.2f}")
print(f"Best average cosine similarity: {best_avg_similarity:.4f}")

for behavior, caa_vector in caa_vectors.items():
    caa_vector = caa_vector.to(device)
    forwarded_caa_vector = sae.process_caa_vector(
        caa_vector, 
        best_params['scale_factor'], 
        best_params['use_leaky_relu'], 
        best_params['negative_slope']
    )
    
    sae_vector_dir = '/root/CAA/sae_vector'
    if not os.path.exists(sae_vector_dir):
        os.makedirs(sae_vector_dir)
    
    forwarded_vector_path = os.path.join(sae_vector_dir, f"forwarded_{behavior}.pt")
    torch.save(forwarded_caa_vector.cpu(), forwarded_vector_path)
    print(f"Saved optimized forwarded vector for behavior {behavior} at {forwarded_vector_path}")

    cosine_sim = F.cosine_similarity(caa_vector.view(1, -1), forwarded_caa_vector.view(1, -1), dim=1)
    print(f"Optimized cosine similarity for {behavior}: {cosine_sim.item():.4f}")

In [None]:
#EXPERIMENT 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer = 14

sae_model_name = "google/gemma-scope-2b-pt-res"
path_to_params = hf_hub_download(
    repo_id=sae_model_name,
    filename=f"layer_{layer}/width_16k/average_l0_84/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon
        
    def process_caa_vector(self, caa_vector, scale_factor=15, activation='leaky_relu', activation_param=0.01):
        caa_vector = caa_vector.to(self.W_enc.device)
        caa_vector_scaled = caa_vector * scale_factor
        adjusted_vector = caa_vector_scaled + self.b_dec
        pre_acts = adjusted_vector @ self.W_enc + self.b_enc
        
        if activation == 'leaky_relu':
            encoded = F.leaky_relu(pre_acts, negative_slope=activation_param)
        elif activation == 'elu':
            encoded = F.elu(pre_acts, alpha=activation_param)
        elif activation == 'selu':
            encoded = F.selu(pre_acts)
        else:
            encoded = F.relu(pre_acts)
        
        decoded = encoded @ self.W_dec + self.b_dec
        decoded_adjusted = decoded - self.b_dec
        decoded_normalized = F.normalize(decoded_adjusted, dim=-1)
        return decoded_normalized

def optimize_parameters(sae, caa_vectors, device):
    best_avg_similarity = 0
    best_params = {}
    
    scale_factors = np.logspace(0, 3, 30)  # Expanded range and more granular
    activations = ['leaky_relu', 'elu', 'selu', 'relu']
    activation_params = np.logspace(-3, 0, 10)  # For leaky_relu and elu
    
    for scale_factor in scale_factors:
        for activation in activations:
            if activation in ['leaky_relu', 'elu']:
                params_to_test = activation_params
            else:
                params_to_test = [0]  # Dummy value for selu and relu
            
            for activation_param in params_to_test:
                similarities = []
                
                for behavior, caa_vector in caa_vectors.items():
                    caa_vector = caa_vector.to(device)
                    forwarded_caa_vector = sae.process_caa_vector(caa_vector, scale_factor, activation, activation_param)
                    
                    cosine_sim = F.cosine_similarity(caa_vector.view(1, -1), forwarded_caa_vector.view(1, -1), dim=1)
                    similarities.append(cosine_sim.item())
                
                avg_similarity = np.mean(similarities)
                
                if avg_similarity > best_avg_similarity:
                    best_avg_similarity = avg_similarity
                    best_params = {
                        'scale_factor': scale_factor,
                        'activation': activation,
                        'activation_param': activation_param
                    }
                
                print(f"Scale: {scale_factor:.2f}, Activation: {activation}, Param: {activation_param:.4f}, Avg Similarity: {avg_similarity:.4f}")

    return best_params, best_avg_similarity

sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae = sae.to(device)
print("Success")

caa_vectors = {}
for behavior in ALL_BEHAVIORS:
    layer = 14
    behavior = behavior
    model_name_path = 'gemma-2-2b'
    vector_path = gemma_vector_analysis.fix_vector_path(get_vector_path(behavior, layer, model_name_path))
    normalized_dir = os.path.join(BASE_DIR, 'normalized_vectors', behavior)
    
    original_path = os.path.join(normalized_dir, f"vec_layer_{layer}_{model_name_path}.pt")
    original_path = gemma_vector_analysis.fix_vector_path(original_path)
    caa_vector = torch.load(original_path)
    caa_vectors[behavior] = caa_vector

    print(f"Analyzing vector for behavior: {behavior}")
    
    caa_vector = caa_vector.to(device)
    
    forwarded_caa_vector = sae.process_caa_vector(caa_vector)
    print(f"Norm of the forwarded vector: {forwarded_caa_vector.norm().item():.4f}")
    
    sae_vector_dir = '/root/CAA/sae_vector'
    if not os.path.exists(sae_vector_dir):
        os.makedirs(sae_vector_dir)
    
    forwarded_vector_path = os.path.join(sae_vector_dir, f"forwarded_{behavior}.pt")
    torch.save(forwarded_caa_vector.cpu(), forwarded_vector_path)
    print(f"Saved forwarded vector for behavior {behavior} at {forwarded_vector_path}")
    
    caa_vector = caa_vector.to(device)
    forwarded_caa_vector = forwarded_caa_vector.to(device)
    
    caa_vector = caa_vector.view(1, -1)
    forwarded_caa_vector = forwarded_caa_vector.view(1, -1)
    
    cosine_sim = F.cosine_similarity(caa_vector, forwarded_caa_vector, dim=1)
    
    print(f"Cosine similarity between original and forwarded vector: {cosine_sim.item():.4f}")

best_params, best_avg_similarity = optimize_parameters(sae, caa_vectors, device)

print(f"\nBest parameters found:")
print(f"Scale factor: {best_params['scale_factor']:.2f}")
print(f"Activation: {best_params['activation']}")
print(f"Activation parameter: {best_params['activation_param']:.4f}")
print(f"Best average cosine similarity: {best_avg_similarity:.4f}")

for behavior, caa_vector in caa_vectors.items():
    caa_vector = caa_vector.to(device)
    forwarded_caa_vector = sae.process_caa_vector(
        caa_vector, 
        best_params['scale_factor'], 
        best_params['activation'], 
        best_params['activation_param']
    )
    
    sae_vector_dir = '/root/CAA/sae_vector'
    if not os.path.exists(sae_vector_dir):
        os.makedirs(sae_vector_dir)
    
    forwarded_vector_path = os.path.join(sae_vector_dir, f"forwarded_{behavior}.pt")
    torch.save(forwarded_caa_vector.cpu(), forwarded_vector_path)
    print(f"Saved optimized forwarded vector for behavior {behavior} at {forwarded_vector_path}")

    cosine_sim = F.cosine_similarity(caa_vector.view(1, -1), forwarded_caa_vector.view(1, -1), dim=1)
    print(f"Optimized cosine similarity for {behavior}: {cosine_sim.item():.4f}")

In [None]:
#EXPERIMENT 3
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer = 14

sae_model_name = "google/gemma-scope-2b-pt-res"
path_to_params = hf_hub_download(
    repo_id=sae_model_name,
    filename=f"layer_{layer}/width_16k/average_l0_84/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon
        
    def process_caa_vector(self, caa_vector, scale_factor=1.0):
        caa_vector = caa_vector.to(self.W_enc.device)
        caa_vector_scaled = caa_vector * scale_factor
        adjusted_vector = caa_vector_scaled + self.b_dec
        encoded = self.encode(adjusted_vector)
        decoded = self.decode(encoded)
        decoded_adjusted = decoded - self.b_dec
        return decoded_adjusted

def optimize_parameters(sae, caa_vectors, device):
    best_avg_similarity = 0
    best_params = {}
    
    scale_factors = np.linspace(0.1, 10, 100)  # More reasonable range
    
    for scale_factor in scale_factors:
        similarities = []
        
        for behavior, caa_vector in caa_vectors.items():
            caa_vector = caa_vector.to(device)
            forwarded_caa_vector = sae.process_caa_vector(caa_vector, scale_factor)
            
            cosine_sim = F.cosine_similarity(caa_vector.view(1, -1), forwarded_caa_vector.view(1, -1), dim=1)
            similarities.append(cosine_sim.item())
        
        avg_similarity = np.mean(similarities)
        
        if avg_similarity > best_avg_similarity:
            best_avg_similarity = avg_similarity
            best_params = {'scale_factor': scale_factor}
        
        print(f"Scale: {scale_factor:.2f}, Avg Similarity: {avg_similarity:.4f}")

    return best_params, best_avg_similarity

sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae = sae.to(device)
print("SAE loaded successfully")

caa_vectors = {}
for behavior in ALL_BEHAVIORS:
    layer = 14
    model_name_path = 'gemma-2-2b'
    vector_path = gemma_vector_analysis.fix_vector_path(get_vector_path(behavior, layer, model_name_path))
    normalized_dir = os.path.join(BASE_DIR, 'normalized_vectors', behavior)
    
    original_path = os.path.join(normalized_dir, f"vec_layer_{layer}_{model_name_path}.pt")
    original_path = gemma_vector_analysis.fix_vector_path(original_path)
    caa_vector = torch.load(original_path)
    caa_vectors[behavior] = caa_vector

    print(f"Loaded CAA vector for behavior: {behavior}")

print("Optimizing parameters...")
best_params, best_avg_similarity = optimize_parameters(sae, caa_vectors, device)

print(f"\nBest parameters found:")
print(f"Scale factor: {best_params['scale_factor']:.4f}")
print(f"Best average cosine similarity: {best_avg_similarity:.4f}")

print("\nProcessing CAA vectors with optimized parameters:")
for behavior, caa_vector in caa_vectors.items():
    caa_vector = caa_vector.to(device)
    forwarded_caa_vector = sae.process_caa_vector(caa_vector, best_params['scale_factor'])
    
    sae_vector_dir = '/root/CAA/sae_vector'
    if not os.path.exists(sae_vector_dir):
        os.makedirs(sae_vector_dir)
    
    forwarded_vector_path = os.path.join(sae_vector_dir, f"forwarded_{behavior}.pt")
    torch.save(forwarded_caa_vector.cpu(), forwarded_vector_path)
    print(f"Saved processed vector for behavior {behavior} at {forwarded_vector_path}")

    cosine_sim = F.cosine_similarity(caa_vector.view(1, -1), forwarded_caa_vector.view(1, -1), dim=1)
    print(f"Cosine similarity for {behavior}: {cosine_sim.item():.4f}")

    # Calculate and print L2 norm of the original and processed vectors
    original_norm = torch.norm(caa_vector).item()
    processed_norm = torch.norm(forwarded_caa_vector).item()
    print(f"L2 norm - Original: {original_norm:.4f}, Processed: {processed_norm:.4f}")

    # Calculate and print mean and standard deviation of the original and processed vectors
    original_mean = torch.mean(caa_vector).item()
    original_std = torch.std(caa_vector).item()
    processed_mean = torch.mean(forwarded_caa_vector).item()
    processed_std = torch.std(forwarded_caa_vector).item()
    print(f"Mean - Original: {original_mean:.4f}, Processed: {processed_mean:.4f}")
    print(f"Std Dev - Original: {original_std:.4f}, Processed: {processed_std:.4f}")
    print()

In [None]:
#EXPERIMENT 4
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer = 14

sae_model_name = "google/gemma-scope-2b-pt-res"
path_to_params = hf_hub_download(
    repo_id=sae_model_name,
    filename=f"layer_{layer}/width_16k/average_l0_84/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon
        
    def process_caa_vector(self, caa_vector, scale_factor=1.0, target_norm=None):
        caa_vector = caa_vector.to(self.W_enc.device)
        
        # Add back the "canceled out" information
        adjusted_vector = caa_vector + self.b_dec
        
        # Scale to match typical residual activation magnitude
        if target_norm is not None:
            current_norm = torch.norm(adjusted_vector)
            adjusted_vector = adjusted_vector * (target_norm / current_norm)
        else:
            adjusted_vector = adjusted_vector * scale_factor
        
        # Process through SAE
        encoded = self.encode(adjusted_vector)
        decoded = self.decode(encoded)
        
        # Normalize the output
        decoded_normalized = F.normalize(decoded, dim=-1)
        
        return decoded_normalized

def optimize_parameters(sae, caa_vectors, device):
    best_avg_similarity = 0
    best_params = {}
    
    scale_factors = np.linspace(0.1, 10, 100)
    target_norms = np.linspace(10, 1000, 100)
    
    for scale_factor in scale_factors:
        for target_norm in target_norms:
            similarities = []
            
            for behavior, caa_vector in caa_vectors.items():
                caa_vector = caa_vector.to(device)
                forwarded_caa_vector = sae.process_caa_vector(caa_vector, scale_factor, target_norm)
                
                cosine_sim = F.cosine_similarity(caa_vector.view(1, -1), forwarded_caa_vector.view(1, -1), dim=1)
                similarities.append(cosine_sim.item())
            
            avg_similarity = np.mean(similarities)
            
            if avg_similarity > best_avg_similarity:
                best_avg_similarity = avg_similarity
                best_params = {'scale_factor': scale_factor, 'target_norm': target_norm}
            
            print(f"Scale: {scale_factor:.2f}, Target Norm: {target_norm:.2f}, Avg Similarity: {avg_similarity:.4f}")

    return best_params, best_avg_similarity

sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae = sae.to(device)
print("SAE loaded successfully")

caa_vectors = {}
for behavior in ALL_BEHAVIORS:
    layer = 14
    model_name_path = 'gemma-2-2b'
    vector_path = gemma_vector_analysis.fix_vector_path(get_vector_path(behavior, layer, model_name_path))
    normalized_dir = os.path.join(BASE_DIR, 'normalized_vectors', behavior)
    
    original_path = os.path.join(normalized_dir, f"vec_layer_{layer}_{model_name_path}.pt")
    original_path = gemma_vector_analysis.fix_vector_path(original_path)
    caa_vector = torch.load(original_path)
    caa_vectors[behavior] = caa_vector

    print(f"Loaded CAA vector for behavior: {behavior}")

print("Optimizing parameters...")
best_params, best_avg_similarity = optimize_parameters(sae, caa_vectors, device)

#print(f"\nBest parameters found:")
#print(f"Scale factor: {best_params['scale_factor']:.4f}")
#print(f"Target norm: {best_params['target_norm']:.4f}")
#print(f"Best average cosine similarity: {best_avg_similarity:.4f}")

print("\nProcessing CAA vectors with optimized parameters:")
for behavior, caa_vector in caa_vectors.items():
    caa_vector = caa_vector.to(device)
    forwarded_caa_vector = sae.process_caa_vector(caa_vector, best_params['scale_factor'], best_params['target_norm'])
    
    sae_vector_dir = '/root/CAA/sae_vector'
    if not os.path.exists(sae_vector_dir):
        os.makedirs(sae_vector_dir)
    
    forwarded_vector_path = os.path.join(sae_vector_dir, f"forwarded_{behavior}.pt")
    torch.save(forwarded_caa_vector.cpu(), forwarded_vector_path)
    print(f"Saved processed vector for behavior {behavior} at {forwarded_vector_path}")

    cosine_sim = F.cosine_similarity(caa_vector.view(1, -1), forwarded_caa_vector.view(1, -1), dim=1)
    print(f"Cosine similarity for {behavior}: {cosine_sim.item():.4f}")

    # Calculate and print L2 norm of the original and processed vectors
    original_norm = torch.norm(caa_vector).item()
    processed_norm = torch.norm(forwarded_caa_vector).item()
    print(f"L2 norm - Original: {original_norm:.4f}, Processed: {processed_norm:.4f}")

    # Calculate and print mean and standard deviation of the original and processed vectors
    original_mean = torch.mean(caa_vector).item()
    original_std = torch.std(caa_vector).item()
    processed_mean = torch.mean(forwarded_caa_vector).item()
    processed_std = torch.std(forwarded_caa_vector).item()
    print(f"Mean - Original: {original_mean:.4f}, Processed: {processed_mean:.4f}")
    print(f"Std Dev - Original: {original_std:.4f}, Processed: {processed_std:.4f}")
    print()

In [None]:
#EXPERIMENT 5
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer = 14

sae_model_name = "google/gemma-scope-2b-pt-res"
path_to_params = hf_hub_download(
    repo_id=sae_model_name,
    filename=f"layer_{layer}/width_16k/average_l0_84/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon
        
    def process_caa_vector(self, caa_vector, scale_factor=1.0, add_bias=False, normalize_input=False):
        caa_vector = caa_vector.to(self.W_enc.device)
        
        if add_bias:
            adjusted_vector = caa_vector + self.b_dec
        else:
            adjusted_vector = caa_vector

        adjusted_vector = adjusted_vector * scale_factor
        
        if normalize_input:
            adjusted_vector = F.normalize(adjusted_vector, dim=-1)
        
        encoded = self.encode(adjusted_vector)
        decoded = self.decode(encoded)
        
        return decoded

def optimize_parameters(sae, caa_vectors, device):
    best_avg_similarity = 0
    best_params = {}
    
    scale_factors = np.logspace(0, 3, 20)
    add_bias_options = [True, False]
    normalize_input_options = [True, False]
    
    for scale_factor in scale_factors:
        for add_bias in add_bias_options:
            for normalize_input in normalize_input_options:
                similarities = []
                
                for behavior, caa_vector in caa_vectors.items():
                    caa_vector = caa_vector.to(device)
                    forwarded_caa_vector = sae.process_caa_vector(caa_vector, scale_factor, add_bias, normalize_input)
                    
                    cosine_sim = F.cosine_similarity(caa_vector.view(1, -1), forwarded_caa_vector.view(1, -1), dim=1)
                    similarities.append(cosine_sim.item())
                
                avg_similarity = np.mean(similarities)
                
                if avg_similarity > best_avg_similarity:
                    best_avg_similarity = avg_similarity
                    best_params = {
                        'scale_factor': scale_factor,
                        'add_bias': add_bias,
                        'normalize_input': normalize_input
                    }
                
                print(f"Scale: {scale_factor:.2f}, Add Bias: {add_bias}, Normalize Input: {normalize_input}, Avg Similarity: {avg_similarity:.4f}")

    return best_params, best_avg_similarity

sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae = sae.to(device)
print("SAE loaded successfully")

print(f"SAE model width: Encoder width = {sae.W_enc.shape[1]}, Decoder width = {sae.W_dec.shape[0]}")
print("Decoder bias (b_dec) statistics:")
print(f"Shape: {sae.b_dec.shape}")
print(f"Mean: {sae.b_dec.mean().item():.4f}")
print(f"Std Dev: {sae.b_dec.std().item():.4f}")
print(f"Min: {sae.b_dec.min().item():.4f}")
print(f"Max: {sae.b_dec.max().item():.4f}")

caa_vectors = {}
for behavior in ALL_BEHAVIORS:
    layer = 14
    model_name_path = 'gemma-2-2b'
    vector_path = gemma_vector_analysis.fix_vector_path(get_vector_path(behavior, layer, model_name_path))
    normalized_dir = os.path.join(BASE_DIR, 'normalized_vectors', behavior)
    
    original_path = os.path.join(normalized_dir, f"vec_layer_{layer}_{model_name_path}.pt")
    original_path = gemma_vector_analysis.fix_vector_path(original_path)
    caa_vector = torch.load(original_path)
    caa_vectors[behavior] = caa_vector

    print(f"Loaded CAA vector for behavior: {behavior}")

print("Optimizing parameters...")
best_params, best_avg_similarity = optimize_parameters(sae, caa_vectors, device)

print(f"\nBest parameters found:")
print(f"Scale factor: {best_params['scale_factor']:.4f}")
print(f"Add bias: {best_params['add_bias']}")
print(f"Normalize input: {best_params['normalize_input']}")
print(f"Best average cosine similarity: {best_avg_similarity:.4f}")

print("\nProcessing CAA vectors with optimized parameters:")
for behavior, caa_vector in caa_vectors.items():
    caa_vector = caa_vector.to(device)
    forwarded_caa_vector = sae.process_caa_vector(
        caa_vector,
        best_params['scale_factor'],
        best_params['add_bias'],
        best_params['normalize_input']
    )
    
    sae_vector_dir = '/root/CAA/sae_vector'
    if not os.path.exists(sae_vector_dir):
        os.makedirs(sae_vector_dir)
    
    forwarded_vector_path = os.path.join(sae_vector_dir, f"forwarded_{behavior}.pt")
    torch.save(forwarded_caa_vector.cpu(), forwarded_vector_path)
    print(f"Saved processed vector for behavior {behavior} at {forwarded_vector_path}")

    cosine_sim = F.cosine_similarity(caa_vector.view(1, -1), forwarded_caa_vector.view(1, -1), dim=1)
    print(f"Cosine similarity for {behavior}: {cosine_sim.item():.4f}")

    original_norm = torch.norm(caa_vector).item()
    processed_norm = torch.norm(forwarded_caa_vector).item()
    print(f"L2 norm - Original: {original_norm:.4f}, Processed: {processed_norm:.4f}")

    original_mean = torch.mean(caa_vector).item()
    original_std = torch.std(caa_vector).item()
    processed_mean = torch.mean(forwarded_caa_vector).item()
    processed_std = torch.std(forwarded_caa_vector).item()
    print(f"Mean - Original: {original_mean:.4f}, Processed: {processed_mean:.4f}")
    print(f"Std Dev - Original: {original_std:.4f}, Processed: {processed_std:.4f}")
    print()

In [None]:
#EXPERIMENT 6 (simple model only using scale)
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer = 14

sae_model_name = "google/gemma-scope-2b-pt-res"
path_to_params = hf_hub_download(
    repo_id=sae_model_name,
    filename=f"layer_{layer}/width_16k/average_l0_84/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon
        
    def process_caa_vector(self, caa_vector, scale_factor=4.2813):
        caa_vector = caa_vector.to(self.W_enc.device)
        adjusted_vector = caa_vector * scale_factor
        encoded = self.encode(adjusted_vector)
        decoded = self.decode(encoded)
        return decoded

def process_caa_vectors(sae, caa_vectors, scale_factor):
    processed_vectors = {}
    similarities = []
    
    for behavior, caa_vector in caa_vectors.items():
        caa_vector = caa_vector.to(device)
        processed_vector = sae.process_caa_vector(caa_vector, scale_factor)
        
        cosine_sim = F.cosine_similarity(caa_vector.view(1, -1), processed_vector.view(1, -1), dim=1)
        similarities.append(cosine_sim.item())
        
        processed_vectors[behavior] = processed_vector.cpu()
        
        print(f"Processed vector for behavior: {behavior}")
        print(f"Cosine similarity: {cosine_sim.item():.4f}")
        print(f"Original L2 norm: {torch.norm(caa_vector).item():.4f}")
        print(f"Processed L2 norm: {torch.norm(processed_vector).item():.4f}")
        print(f"Original mean: {torch.mean(caa_vector).item():.4f}")
        print(f"Processed mean: {torch.mean(processed_vector).item():.4f}")
        print(f"Original std dev: {torch.std(caa_vector).item():.4f}")
        print(f"Processed std dev: {torch.std(processed_vector).item():.4f}")
        print()
    
    avg_similarity = np.mean(similarities)
    print(f"Average cosine similarity: {avg_similarity:.4f}")
    
    return processed_vectors

# Load the SAE model
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae = sae.to(device)
print("SAE loaded successfully")

# Load CAA vectors
caa_vectors = {}
for behavior in ALL_BEHAVIORS:
    layer = 14
    model_name_path = 'gemma-2-2b'
    vector_path = gemma_vector_analysis.fix_vector_path(get_vector_path(behavior, layer, model_name_path))
    normalized_dir = os.path.join(BASE_DIR, 'normalized_vectors', behavior)
    
    original_path = os.path.join(normalized_dir, f"vec_layer_{layer}_{model_name_path}.pt")
    original_path = gemma_vector_analysis.fix_vector_path(original_path)
    caa_vector = torch.load(original_path)
    caa_vectors[behavior] = caa_vector
    print(f"Loaded CAA vector for behavior: {behavior}")

# Process CAA vectors
scale_factor = 4.2813  # This value was found to work well in previous experiments
processed_vectors = process_caa_vectors(sae, caa_vectors, scale_factor)

# Save processed vectors
sae_vector_dir = '/root/CAA/sae_vector'
if not os.path.exists(sae_vector_dir):
    os.makedirs(sae_vector_dir)

for behavior, processed_vector in processed_vectors.items():
    save_path = os.path.join(sae_vector_dir, f"processed_{behavior}.pt")
    torch.save(processed_vector, save_path)
    print(f"Saved processed vector for behavior {behavior} at {save_path}")

print("CAA vector processing complete.")

In [None]:
#Experiment 7
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer = 14

sae_model_name = "google/gemma-scope-2b-pt-res"
path_to_params = hf_hub_download(
    repo_id=sae_model_name,
    filename=f"layer_{layer}/width_16k/average_l0_84/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon
        
    def process_caa_vector(self, caa_vector):
        caa_vector = caa_vector.to(self.W_enc.device)
        decoder_bias = self.b_dec
        
        # Compute SAE(caa_vector + decoder_bias) and SAE(-caa_vector + decoder_bias)
        positive_encoded = self.encode(caa_vector + decoder_bias)
        negative_encoded = self.encode(-caa_vector + decoder_bias)
        
        positive_decoded = self.decode(positive_encoded)
        negative_decoded = self.decode(negative_encoded)
        
        # Calculate (positive - negative) / 2
        result_vector = (positive_decoded - negative_decoded) / 2
        
        # Calculate cosine similarity between result_vector and caa_vector
        cosine_sim = F.cosine_similarity(result_vector.view(1, -1), caa_vector.view(1, -1), dim=1).item()
        
        return result_vector, cosine_sim

def process_caa_vectors(sae, caa_vectors):
    processed_vectors = {}
    similarities = []
    
    for behavior, caa_vector in caa_vectors.items():
        caa_vector = caa_vector.to(device)
        
        # Unpack the tuple returned by sae.process_caa_vector
        processed_vector, cosine_sim = sae.process_caa_vector(caa_vector)
        
        similarities.append(cosine_sim)
        
        processed_vectors[behavior] = processed_vector.cpu()
        
        print(f"Processed vector for behavior: {behavior}")
        print(f"Cosine similarity: {cosine_sim:.4f}")
        print(f"Original L2 norm: {torch.norm(caa_vector).item():.4f}")
        print(f"Processed L2 norm: {torch.norm(processed_vector).item():.4f}")
        print(f"Original mean: {torch.mean(caa_vector).item():.4f}")
        print(f"Processed mean: {torch.mean(processed_vector).item():.4f}")
        print(f"Original std dev: {torch.std(caa_vector).item():.4f}")
        print(f"Processed std dev: {torch.std(processed_vector).item():.4f}")
        print()
    
    avg_similarity = np.mean(similarities)
    print(f"Average cosine similarity: {avg_similarity:.4f}")
    
    return processed_vectors

# Load the SAE model
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae = sae.to(device)
print("SAE loaded successfully")

# Load CAA vectors
caa_vectors = {}
for behavior in ALL_BEHAVIORS:
    layer = 14
    model_name_path = 'gemma-2-2b'
    vector_path = gemma_vector_analysis.fix_vector_path(get_vector_path(behavior, layer, model_name_path))
    normalized_dir = os.path.join(BASE_DIR, 'normalized_vectors', behavior)
    
    original_path = os.path.join(normalized_dir, f"vec_layer_{layer}_{model_name_path}.pt")
    original_path = gemma_vector_analysis.fix_vector_path(original_path)
    caa_vector = torch.load(original_path)
    caa_vectors[behavior] = caa_vector
    print(f"Loaded CAA vector for behavior: {behavior}")

# Process CAA vectors
processed_vectors = process_caa_vectors(sae, caa_vectors)

# Save processed vectors
sae_vector_dir = '/root/CAA/sae_vector'
if not os.path.exists(sae_vector_dir):
    os.makedirs(sae_vector_dir)

for behavior, processed_vector in processed_vectors.items():
    save_path = os.path.join(sae_vector_dir, f"processed_{behavior}.pt")
    torch.save(processed_vector, save_path)
    print(f"Saved processed vector for behavior {behavior} at {save_path}")

print("CAA vector processing complete.")


In [None]:
#Experiment 8, Nina suggestion incl. neg. CAA vector, no scaling

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer = 14

sae_model_name = "google/gemma-scope-2b-pt-res"
path_to_params = hf_hub_download(
    repo_id=sae_model_name,
    filename=f"layer_{layer}/width_16k/average_l0_84/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon
        
    def process_caa_vector(self, caa_vector):
        caa_vector = caa_vector.to(self.W_enc.device)
        decoder_bias = self.b_dec
        
        # Compute SAE(caa_vector + decoder_bias) and SAE(-caa_vector + decoder_bias)
        positive_encoded = self.encode(caa_vector + decoder_bias)
        negative_encoded = self.encode(-caa_vector + decoder_bias)
        
        positive_decoded = self.decode(positive_encoded)
        negative_decoded = self.decode(negative_encoded)
        
        # Calculate (positive - negative) / 2
        result_vector = (positive_decoded - negative_decoded) / 2
        
        # Calculate cosine similarity between result_vector and caa_vector
        cosine_sim = F.cosine_similarity(result_vector.view(1, -1), caa_vector.view(1, -1), dim=1).item()
        
        return result_vector, cosine_sim

def process_caa_vectors(sae, caa_vectors):
    processed_vectors = {}
    similarities = []
    
    for behavior, caa_vector in caa_vectors.items():
        caa_vector = caa_vector.to(device)
        
        # Unpack the tuple returned by sae.process_caa_vector
        processed_vector, cosine_sim = sae.process_caa_vector(caa_vector)
        
        similarities.append(cosine_sim)
        
        processed_vectors[behavior] = processed_vector.cpu()
        
        print(f"Processed vector for behavior: {behavior}")
        print(f"Cosine similarity: {cosine_sim:.4f}")
        print(f"Original L2 norm: {torch.norm(caa_vector).item():.4f}")
        print(f"Processed L2 norm: {torch.norm(processed_vector).item():.4f}")
        print(f"Original mean: {torch.mean(caa_vector).item():.4f}")
        print(f"Processed mean: {torch.mean(processed_vector).item():.4f}")
        print(f"Original std dev: {torch.std(caa_vector).item():.4f}")
        print(f"Processed std dev: {torch.std(processed_vector).item():.4f}")
        print()
    
    avg_similarity = np.mean(similarities)
    print(f"Average cosine similarity: {avg_similarity:.4f}")
    
    return processed_vectors

# Load the SAE model
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae = sae.to(device)
print("SAE loaded successfully")

# Load CAA vectors
caa_vectors = {}
for behavior in ALL_BEHAVIORS:
    layer = 14
    model_name_path = 'gemma-2-2b'
    vector_path = gemma_vector_analysis.fix_vector_path(get_vector_path(behavior, layer, model_name_path))
    normalized_dir = os.path.join(BASE_DIR, 'normalized_vectors', behavior)
    
    original_path = os.path.join(normalized_dir, f"vec_layer_{layer}_{model_name_path}.pt")
    original_path = gemma_vector_analysis.fix_vector_path(original_path)
    caa_vector = torch.load(original_path)
    caa_vectors[behavior] = caa_vector
    print(f"Loaded CAA vector for behavior: {behavior}")

# Process CAA vectors
processed_vectors = process_caa_vectors(sae, caa_vectors)

# Save processed vectors
sae_vector_dir = '/root/CAA/sae_vector'
if not os.path.exists(sae_vector_dir):
    os.makedirs(sae_vector_dir)

for behavior, processed_vector in processed_vectors.items():
    save_path = os.path.join(sae_vector_dir, f"processed_{behavior}.pt")
    torch.save(processed_vector, save_path)
    print(f"Saved processed vector for behavior {behavior} at {save_path}")

print("CAA vector processing complete.")


In [None]:
# Experiment 9, Nina suggestion plus scaling
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer = 14

sae_model_name = "google/gemma-scope-2b-pt-res"
path_to_params = hf_hub_download(
    repo_id=sae_model_name,
    filename=f"layer_{layer}/width_16k/average_l0_84/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon
        
    def process_caa_vector(self, caa_vector, scale_factor=None):
        caa_vector = caa_vector.to(self.W_enc.device)
        decoder_bias = self.b_dec

        if scale_factor is not None:
            caa_vector = caa_vector * scale_factor

        # Compute SAE(caa_vector + decoder_bias) and SAE(-caa_vector + decoder_bias)
        positive_encoded = self.encode(caa_vector + decoder_bias)
        negative_encoded = self.encode(-caa_vector + decoder_bias)
        
        positive_decoded = self.decode(positive_encoded)
        negative_decoded = self.decode(negative_encoded)
        
        # Calculate (positive - negative) / 2
        result_vector = (positive_decoded - negative_decoded) / 2
        
        # Calculate cosine similarity between result_vector and caa_vector
        cosine_sim = F.cosine_similarity(result_vector.view(1, -1), caa_vector.view(1, -1), dim=1).item()
        
        return result_vector, cosine_sim

def process_caa_vectors(sae, caa_vectors, scale_factor=None):
    processed_vectors = {}
    similarities = []
    
    for behavior, caa_vector in caa_vectors.items():
        caa_vector = caa_vector.to(device)
        
        # Process the CAA vector with or without scaling
        processed_vector, cosine_sim = sae.process_caa_vector(caa_vector, scale_factor=scale_factor)
        
        similarities.append(cosine_sim)
        
        processed_vectors[behavior] = processed_vector.cpu()
        
        print(f"Processed vector for behavior: {behavior}")
        print(f"Cosine similarity: {cosine_sim:.4f}")
        print(f"Original L2 norm: {torch.norm(caa_vector).item():.4f}")
        print(f"Processed L2 norm: {torch.norm(processed_vector).item():.4f}")
        print(f"Original mean: {torch.mean(caa_vector).item():.4f}")
        print(f"Processed mean: {torch.mean(processed_vector).item():.4f}")
        print(f"Original std dev: {torch.std(caa_vector).item():.4f}")
        print(f"Processed std dev: {torch.std(processed_vector).item():.4f}")
        print()
    
    avg_similarity = np.mean(similarities)
    print(f"Average cosine similarity: {avg_similarity:.4f}")
    
    return processed_vectors

# Load the SAE model
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae = sae.to(device)
print("SAE loaded successfully")

# Load CAA vectors
caa_vectors = {}
for behavior in ALL_BEHAVIORS:
    layer = 14
    model_name_path = 'gemma-2-2b'
    vector_path = gemma_vector_analysis.fix_vector_path(get_vector_path(behavior, layer, model_name_path))
    normalized_dir = os.path.join(BASE_DIR, 'normalized_vectors', behavior)
    
    original_path = os.path.join(normalized_dir, f"vec_layer_{layer}_{model_name_path}.pt")
    original_path = gemma_vector_analysis.fix_vector_path(original_path)
    caa_vector = torch.load(original_path)
    caa_vectors[behavior] = caa_vector
    print(f"Loaded CAA vector for behavior: {behavior}")

# Process CAA vectors with and without scaling
scale_factor = 4.2813  # Define your scale factor here
print("Processing with scaling:")
processed_vectors_scaled = process_caa_vectors(sae, caa_vectors, scale_factor=scale_factor)

print("\nProcessing without scaling:")
processed_vectors_no_scale = process_caa_vectors(sae, caa_vectors, scale_factor=None)

# Save processed vectors
sae_vector_dir = '/root/CAA/sae_vector'
if not os.path.exists(sae_vector_dir):
    os.makedirs(sae_vector_dir)

# Save scaled vectors
for behavior, processed_vector in processed_vectors_scaled.items():
    save_path = os.path.join(sae_vector_dir, f"processed_scaled_{behavior}.pt")
    torch.save(processed_vector, save_path)
    print(f"Saved processed scaled vector for behavior {behavior} at {save_path}")

# Save non-scaled vectors
for behavior, processed_vector in processed_vectors_no_scale.items():
    save_path = os.path.join(sae_vector_dir, f"processed_no_scale_{behavior}.pt")
    torch.save(processed_vector, save_path)
    print(f"Saved processed non-scaled vector for behavior {behavior} at {save_path}")

print("CAA vector processing complete.")



In [5]:
# FINAL CODE I USE TO CONTINUE RESEARCH
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer = 14

sae_model_name = "google/gemma-scope-2b-pt-res"
path_to_params = hf_hub_download(
    repo_id=sae_model_name,
    filename=f"layer_{layer}/width_16k/average_l0_84/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon
        
    def process_caa_vector(self, caa_vector, scale_factor=4.2813):
        caa_vector = caa_vector.to(self.W_enc.device)
        decoder_bias = self.b_dec

        # Apply scaling to the CAA vector
        caa_vector = caa_vector * scale_factor

        # Compute SAE(caa_vector + decoder_bias) and SAE(-caa_vector + decoder_bias)
        positive_encoded = self.encode(caa_vector + decoder_bias)
        negative_encoded = self.encode(-caa_vector + decoder_bias)
        
        positive_decoded = self.decode(positive_encoded)
        negative_decoded = self.decode(negative_encoded)
        
        # Calculate (positive - negative) / 2
        result_vector = (positive_decoded - negative_decoded) / 2
        
        # Calculate cosine similarity between result_vector and caa_vector
        cosine_sim = F.cosine_similarity(result_vector.view(1, -1), caa_vector.view(1, -1), dim=1).item()
        
        return result_vector, cosine_sim

def process_caa_vectors(sae, caa_vectors, scale_factor=4.2813):
    processed_vectors = {}
    similarities = []
    
    for behavior, caa_vector in caa_vectors.items():
        caa_vector = caa_vector.to(device)
        
        # Process the CAA vector with scaling
        processed_vector, cosine_sim = sae.process_caa_vector(caa_vector, scale_factor=scale_factor)
        
        similarities.append(cosine_sim)
        
        processed_vectors[behavior] = processed_vector.cpu()
        
        print(f"Processed vector for behavior: {behavior}")
        print(f"Cosine similarity: {cosine_sim:.4f}")
        print(f"Original L2 norm: {torch.norm(caa_vector).item():.4f}")
        print(f"Processed L2 norm: {torch.norm(processed_vector).item():.4f}")
        print(f"Original mean: {torch.mean(caa_vector).item():.4f}")
        print(f"Processed mean: {torch.mean(processed_vector).item():.4f}")
        print(f"Original std dev: {torch.std(caa_vector).item():.4f}")
        print(f"Processed std dev: {torch.std(processed_vector).item():.4f}")
        print()
    
    avg_similarity = np.mean(similarities)
    print(f"Average cosine similarity: {avg_similarity:.4f}")
    
    return processed_vectors

# Load the SAE model
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae = sae.to(device)
print("SAE loaded successfully")

# Load CAA vectors
caa_vectors = {}
for behavior in ALL_BEHAVIORS:
    layer = 14
    model_name_path = 'gemma-2-2b'
    vector_path = gemma_vector_analysis.fix_vector_path(get_vector_path(behavior, layer, model_name_path))
    normalized_dir = os.path.join(BASE_DIR, 'normalized_vectors', behavior)
    
    original_path = os.path.join(normalized_dir, f"vec_layer_{layer}_{model_name_path}.pt")
    original_path = gemma_vector_analysis.fix_vector_path(original_path)
    caa_vector = torch.load(original_path)
    caa_vectors[behavior] = caa_vector
    print(f"Loaded CAA vector for behavior: {behavior}")

# Process CAA vectors with scaling
processed_vectors = process_caa_vectors(sae, caa_vectors, scale_factor=4.2813)

# Save processed vectors
sae_vector_dir = '/root/CAA/sae_vector'
if not os.path.exists(sae_vector_dir):
    os.makedirs(sae_vector_dir)

for behavior, processed_vector in processed_vectors.items():
    save_path = os.path.join(sae_vector_dir, f"processed_{behavior}.pt")
    torch.save(processed_vector, save_path)
    print(f"Saved processed vector for behavior {behavior} at {save_path}")

print("CAA vector processing complete.")


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

SAE loaded successfully
Loaded CAA vector for behavior: coordinate-other-ais
Loaded CAA vector for behavior: corrigible-neutral-HHH
Loaded CAA vector for behavior: hallucination
Loaded CAA vector for behavior: myopic-reward
Loaded CAA vector for behavior: survival-instinct
Loaded CAA vector for behavior: sycophancy
Loaded CAA vector for behavior: refusal
Processed vector for behavior: coordinate-other-ais
Cosine similarity: 0.8658
Original L2 norm: 44.3626
Processed L2 norm: 255.2028
Original mean: 0.0521
Processed mean: 0.1877
Original std dev: 0.9230
Processed std dev: 5.3146

Processed vector for behavior: corrigible-neutral-HHH
Cosine similarity: 0.8685
Original L2 norm: 44.3626
Processed L2 norm: 268.3082
Original mean: 0.0128
Processed mean: 0.0354
Original std dev: 0.9243
Processed std dev: 5.5909

Processed vector for behavior: hallucination
Cosine similarity: 0.8456
Original L2 norm: 44.3626
Processed L2 norm: 241.0766
Original mean: 0.0288
Processed mean: 0.1753
Original std 

In [6]:
# Normalize all processed vectors to the norm of one of the original CAA vectors

# Get the norm of the first original CAA vector (they all have the same norm)
example_caa_vector = next(iter(caa_vectors.values()))
original_norm = torch.norm(example_caa_vector).item()
print(f"Original CAA vector norm: {original_norm:.4f}")

# Create a directory for the normalized vectors
normalized_vector_dir = '/root/CAA/sae_vector_normalized'
if not os.path.exists(normalized_vector_dir):
    os.makedirs(normalized_vector_dir)

# Normalize and save the processed vectors
for behavior, processed_vector in processed_vectors.items():
    # Compute the norm of the processed vector
    processed_norm = torch.norm(processed_vector).item()
    
    # Normalize the processed vector to the original norm
    normalized_vector = processed_vector * (original_norm / processed_norm)
    
    # Save the normalized vector
    save_path = os.path.join(normalized_vector_dir, f"normalized_{behavior}.pt")
    torch.save(normalized_vector, save_path)
    
    # Print norms
    print(f"Processed vector norm for behavior {behavior}: {processed_norm:.4f}")
    print(f"Normalized vector norm for behavior {behavior}: {torch.norm(normalized_vector).item():.4f}")

print("Normalization of processed vectors complete.")


Original CAA vector norm: 44.3626
Processed vector norm for behavior coordinate-other-ais: 255.2028
Normalized vector norm for behavior coordinate-other-ais: 44.3626
Processed vector norm for behavior corrigible-neutral-HHH: 268.3082
Normalized vector norm for behavior corrigible-neutral-HHH: 44.3626
Processed vector norm for behavior hallucination: 241.0766
Normalized vector norm for behavior hallucination: 44.3626
Processed vector norm for behavior myopic-reward: 193.2522
Normalized vector norm for behavior myopic-reward: 44.3626
Processed vector norm for behavior survival-instinct: 212.8642
Normalized vector norm for behavior survival-instinct: 44.3626
Processed vector norm for behavior sycophancy: 207.0221
Normalized vector norm for behavior sycophancy: 44.3626
Processed vector norm for behavior refusal: 232.5123
Normalized vector norm for behavior refusal: 44.3626
Normalization of processed vectors complete.


In [26]:
import torch
import os

# Assuming sae, caa_vectors, device, and sae_vector_dir are already defined

def analyze_behavior_vector(behavior, caa_vector):
    caa_vector = caa_vector.to(device)
    
    # Process the CAA vector through the encoder to get the encoded (SAE basis) vector
    encoded_vector = sae.encode(caa_vector * 4.2813 + sae.b_dec)
    
    # Calculate the squared values of the encoded vector
    squared_values = encoded_vector ** 2
    
    # Calculate the total sum of squared values
    total_squared_sum = torch.sum(squared_values).item()
    
    # Calculate the contribution of each feature
    feature_contributions = []
    for i, feature_value in enumerate(encoded_vector):
        feature_squared = feature_value.item() ** 2
        percentage = (feature_squared / total_squared_sum) * 100 if total_squared_sum != 0 else 0
        feature_contributions.append((i, feature_value.item(), percentage))
    
    # Sort features by their contribution (descending order)
    sorted_features = sorted(feature_contributions, key=lambda x: x[2], reverse=True)
    
    return sorted_features, torch.sqrt(torch.tensor(total_squared_sum)).item()

# Analyze each behavior
behavior_analyses = {}

for behavior, caa_vector in caa_vectors.items():
    sorted_features, total_norm = analyze_behavior_vector(behavior, caa_vector)
    behavior_analyses[behavior] = {
        "sorted_features": sorted_features,
        "total_norm": total_norm
    }

# Print results
for behavior, analysis in behavior_analyses.items():
    print(f"\nBehavior: {behavior}")
    print(f"Total norm: {analysis['total_norm']:.4f}")
    print("Top 10 contributing features:")
    for i, (feature_index, feature_value, percentage) in enumerate(analysis['sorted_features'][:10], 1):
        print(f"  {i}. Feature {feature_index}: {percentage:.2f}% (value: {feature_value:.4f})")
    
    # Count non-zero features
    non_zero_count = sum(1 for _, value, _ in analysis['sorted_features'] if abs(value) > 1e-6)
    print(f"Number of non-zero features: {non_zero_count}")

# Optionally, save the analysis results
save_path = os.path.join(sae_vector_dir, "behavior_feature_analysis.pt")
torch.save(behavior_analyses, save_path)
print(f"\nBehavior feature analysis saved to {save_path}")


Behavior: coordinate-other-ais
Total norm: 159.8280
Top 10 contributing features:
  1. Feature 10414: 16.71% (value: 65.3287)
  2. Feature 10781: 2.48% (value: 25.1675)
  3. Feature 715: 2.40% (value: 24.7662)
  4. Feature 11139: 2.17% (value: 23.5419)
  5. Feature 2721: 2.03% (value: 22.7471)
  6. Feature 11113: 1.50% (value: 19.5599)
  7. Feature 4121: 1.30% (value: 18.2298)
  8. Feature 11500: 1.28% (value: 18.1127)
  9. Feature 294: 1.21% (value: 17.5502)
  10. Feature 11967: 1.13% (value: 17.0095)
Number of non-zero features: 437

Behavior: corrigible-neutral-HHH
Total norm: 153.9369
Top 10 contributing features:
  1. Feature 10781: 12.11% (value: 53.5612)
  2. Feature 1582: 5.70% (value: 36.7570)
  3. Feature 10414: 3.64% (value: 29.3652)
  4. Feature 11113: 3.61% (value: 29.2565)
  5. Feature 3253: 2.81% (value: 25.8125)
  6. Feature 715: 2.36% (value: 23.6331)
  7. Feature 4121: 1.96% (value: 21.5418)
  8. Feature 14080: 1.83% (value: 20.8156)
  9. Feature 15113: 1.70% (value:

In [23]:
import requests
import torch
import os
import time
import json

# Neuronpedia API Information
API_KEY = "example"
API_URL = "https://www.neuronpedia.org/api/feature/{modelId}/{layer}/{index}"

def fetch_feature_explanation(model_id, layer, index, max_retries=3):
    url = API_URL.format(modelId=model_id, layer=layer, index=index)
    headers = {
        "X-Api-Key": API_KEY
    }
    
    for attempt in range(max_retries):
        try:
            response = requests.get(url, headers=headers, timeout=10)
            response.raise_for_status()
            data = response.json()
            
            # Extract the description from the explanations array
            explanations = data.get("explanations", [])
            if explanations:
                return explanations[0].get("description", "No description available")
            else:
                return "No explanation available"
        except requests.exceptions.RequestException as err:
            print(f"Error occurred (attempt {attempt + 1}/{max_retries}): {err}")
            if attempt < max_retries - 1:
                time.sleep(2)  # Wait for 2 seconds before retrying
    
    return f"Feature {index}: Unable to retrieve explanation"

# Process and sort features by impact for refusal behavior
model_id = "gemma-2-2b"
layer_id = "14-gemmascope-res-16k"
behavior = "refusal"
caa_vector = caa_vectors[behavior].to(device)

# Process the CAA vector through the encoder to get the encoded (SAE basis) vector
encoded_vector = sae.encode(caa_vector * 4.2813 + sae.b_dec)

# Find non-zero elements in the encoded vector
non_zero_indices_encoded = torch.nonzero(encoded_vector).squeeze()
non_zero_activations = encoded_vector[non_zero_indices_encoded]

# Calculate the impact of each non-zero activation on the final decoded vector
impacts = []
for i, index in enumerate(non_zero_indices_encoded):
    contribution_vector = sae.W_dec[index, :] * non_zero_activations[i]
    impact = torch.norm(contribution_vector).item()
    impacts.append((index.item(), impact))

# Sort features by impact (descending order) and take top 5
impacts_sorted = sorted(impacts, key=lambda x: x[1], reverse=True)[:50]

# Fetch explanations for top 5 features
top_features_with_explanations = []
for index, impact in impacts_sorted:
    explanation = fetch_feature_explanation(model_id, layer_id, index)
    top_features_with_explanations.append((index, impact, explanation))

# Print the top 5 features with explanations
print(f"Behavior: {behavior}")
print(f"Top 5 features by impact (index, impact, explanation):")
for index, impact, explanation in top_features_with_explanations:
    print(f"Index: {index}, Impact: {impact:.4f}")
    print(f"Explanation: {explanation}")
    print()

# Save the top 5 features data to a file
save_path = os.path.join(sae_vector_dir, "top_5_features_explanations_refusal.pt")
torch.save(top_features_with_explanations, save_path)
print(f"Top 5 feature explanations by impact saved to {save_path}")

Behavior: refusal
Top 5 features by impact (index, impact, explanation):
Index: 15297, Impact: 28.4289
Explanation: topics related to moral or ethical concepts

Index: 1538, Impact: 28.3076
Explanation: references to immune system functions and their metabolic implications

Index: 13748, Impact: 26.6420
Explanation: phrases related to social roles and personal relationships

Index: 1462, Impact: 26.3455
Explanation: references to NULL values and their behavior in programming contexts

Index: 6648, Impact: 26.0690
Explanation:  negative results or outcomes in a scientific context

Index: 9286, Impact: 21.0437
Explanation: comparative phrases related to quantitative evaluations and assessments

Index: 14929, Impact: 18.7217
Explanation: terms related to employment status and contract termination

Index: 1618, Impact: 18.7019
Explanation: negative and false conditions or outcomes

Index: 9868, Impact: 18.2508
Explanation: terms related to cancer treatment strategies and cellular responses

In [None]:
#VERIFY SAE SELECTED
print(f"SAE model width: Encoder width = {sae.W_enc.shape[1]}, Decoder width = {sae.W_dec.shape[0]}")
print(sae.b_dec)