In [1]:
### Installs/imports
#!pip install torch transformers datasets tabulate scikit-learn seaborn accelerate bitsandbytes
from initialize import *
from enhanced_hooking import get_activations, add_activations_and_generate, clear_hooks, get_activations_and_generate, zeroout_projections_and_generate
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import random
from collections import defaultdict
from enum import Enum
class SteeringType(Enum):
    IN_PROMPT = "In prompt"
    CONTINUOUS = "Continuous"
class AggType(Enum):
    MEANDIFF = "MeanDiff"
    PCA = "PCA"

### Load the model

gc.collect()
torch.cuda.empty_cache()
base_model_path: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model_path=base_model_path

from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
_ = torch.set_grad_enabled(False)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, token=HF_TOKEN, quantization_config=bnb_config, device_map="auto")
device = model.device
tokenizer = AutoTokenizer.from_pretrained(base_model_path, token=HF_TOKEN)
model.tokenizer = tokenizer
if model.tokenizer.pad_token is None:
    new_pad_token = model.tokenizer.eos_token
    num_added_tokens = model.tokenizer.add_special_tokens({'pad_token': new_pad_token})
    model.resize_token_embeddings(len(model.tokenizer))
    model.config.pad_token_id = model.tokenizer.pad_token_id
model_numlayers = model.config.num_hidden_layers
%load_ext autoreload
%autoreload 2

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
#load data
from data import *

sad_responses, _ = load_data_sad(ddir = "completions_full")
sad_articles = load_from_json(f"starts_full/starts_train.json")
prompts = []
for key in sad_responses['llama3_8bchat'].keys():
    self_summary = sad_responses['llama3_8bchat'][key].replace("\n\n","\n").strip()
    other_summary = sad_responses['human'][key].replace("\n\n","\n").strip()
    article = next(d['text'] for d in sad_articles if d['id'] == key).replace("\n\n","\n").strip()
        
    prompts.append(self_summary)
    prompts.append(other_summary)
    prompts.append(article)



In [4]:
### Run prompts through model and capture activations, averaging over positions of interest

model.tokenizer.padding_side = "right"
layers = ['embed'] + list(range(0,31))
get_at='end'
clear_hooks(model)

prompts_samp = random.sample(prompts, 3000)

accumulated_activations = defaultdict(lambda: defaultdict(lambda: torch.empty(0)))

batch_size = 1

for i in tqdm(range(0, len(prompts_samp), batch_size)):
    batch = prompts_samp[i:i+batch_size]
    encoded_pos = model.tokenizer(batch, return_tensors="pt", padding=True)
    batch_tokens = encoded_pos['input_ids']

    layers_positions = {layer: [] for layer in layers}

    for input_ids in batch_tokens:
        positions = list(range(0,len(input_ids)))
        for layer in layers:
            layers_positions[layer].append(positions)

    activations = get_activations(model, batch_tokens, layers_positions, get_at=get_at)
    mean_activations = {} 
    
    for layer, positions in activations.items():
        batch_size = next(iter(positions.values())).shape[0] # Get batch size from any position tensor
        layer_sum = torch.zeros(batch_size, positions[next(iter(positions.keys()))].shape[-1]) # Initialize tensor with zeros
        
        num_positions = len(positions) 
        for pos, tensor in positions.items():
            layer_sum += tensor
        
        mean_activations[layer] = layer_sum / num_positions

    for layer, tensor in mean_activations.items():
        accumulated_activations[layer][0] = torch.cat([accumulated_activations[layer][0], tensor], dim=0)
    
    del activations, mean_activations
    torch.cuda.empty_cache()  

100%|██████████| 3000/3000 [28:34<00:00,  1.75it/s]


In [12]:
X=accumulated_activations['embed'][0]

num_pcs = 400 #100
pca = PCA(n_components=num_pcs)
pca.fit(X)
principal_components = pca.components_  # shape: (num_pcs, d_embed)

# Define a function to orthogonalize a vector against a set of directions.
def orthogonalize(vec, directions):
    # directions: array of shape (num_dirs, d_embed)
    # vec: array of shape (d_embed,)
    for d in directions:
        proj = np.dot(vec, d) * d
        vec = vec - proj
    norm = np.linalg.norm(vec)
    if norm > 0:
        vec = vec / norm
    return vec

# Create tag vectors.
# Initialize random vectors and orthogonalize them against the principal components.
d_embed = X.shape[1]
tag_vec_user = np.random.randn(d_embed)
tag_vec_assistant = np.random.randn(d_embed)

tag_vec_user = orthogonalize(tag_vec_user, principal_components)
tag_vec_assistant = orthogonalize(tag_vec_assistant, principal_components)

# Ensure tag_vec_assistant is also orthogonal to tag_vec_user.
proj = np.dot(tag_vec_assistant, tag_vec_user) * tag_vec_user
tag_vec_assistant = tag_vec_assistant - proj
norm = np.linalg.norm(tag_vec_assistant)
if norm > 0:
    tag_vec_assistant = tag_vec_assistant / norm


In [13]:
np.sum(pca.explained_variance_ratio_)

np.float64(0.8157915708197409)

In [14]:
import pickle
ofname = 'steering_vectors_orth_asst_embed.pkl'
with open(ofname, 'wb') as f:
    pickle.dump(torch.tensor(tag_vec_assistant), f)
ofname = 'steering_vectors_orth_user_embed.pkl'
with open(ofname, 'wb') as f:
    pickle.dump(torch.tensor(tag_vec_user), f)

In [12]:
import pickle
import numpy as np
import torch
from sklearn.decomposition import PCA

layerwise_vectors = {}

for layer in accumulated_activations.keys():
    X = accumulated_activations[layer][0]
    
    num_pcs = 400 if layer == 'embed' else 100 if int(layer) == 0 else 4+4*int(layer)
    pca = PCA(n_components=num_pcs)
    pca.fit(X)
    principal_components = pca.components_  # shape: (num_pcs, d_embed)
    
    # Define a function to orthogonalize a vector against a set of directions.
    def orthogonalize(vec, directions):
        for d in directions:
            proj = np.dot(vec, d) * d
            vec = vec - proj
        norm = np.linalg.norm(vec)
        if norm > 0:
            vec = vec / norm
        return vec

    # Create tag vectors.
    d_embed = X.shape[1]
    tag_vec_user = np.random.randn(d_embed)
    tag_vec_assistant = np.random.randn(d_embed)
    
    tag_vec_user = orthogonalize(tag_vec_user, principal_components)
    tag_vec_assistant = orthogonalize(tag_vec_assistant, principal_components)
    
    # Ensure tag_vec_assistant is also orthogonal to tag_vec_user.
    proj = np.dot(tag_vec_assistant, tag_vec_user) * tag_vec_user
    tag_vec_assistant = tag_vec_assistant - proj
    norm = np.linalg.norm(tag_vec_assistant)
    if norm > 0:
        tag_vec_assistant = tag_vec_assistant / norm
    
    layerwise_vectors[layer] = {
        'user': torch.tensor(tag_vec_user),
        'assistant': torch.tensor(tag_vec_assistant)
    }
    print(f"layer={layer}, num_pcs={num_pcs}, var={np.sum(pca.explained_variance_ratio_)}")

# Save layerwise vectors to files.
with open('steering_vectors_orth_user_all.pkl', 'wb') as f:
    pickle.dump({layer: layerwise_vectors[layer]['user'] for layer in layerwise_vectors}, f)
with open('steering_vectors_orth_asst_all.pkl', 'wb') as f:
    pickle.dump({layer: layerwise_vectors[layer]['assistant'] for layer in layerwise_vectors}, f)


layer=embed, num_pcs=400, var=0.8157953273569934
layer=0, num_pcs=100, var=0.8463752253998517
layer=1, num_pcs=8, var=0.9987105877217431
layer=2, num_pcs=12, var=0.9971450308642328
layer=3, num_pcs=16, var=0.9920824234842771
layer=4, num_pcs=20, var=0.9847806068412651
layer=5, num_pcs=24, var=0.978013067226529
layer=6, num_pcs=28, var=0.9712825342572504
layer=7, num_pcs=32, var=0.9644278916486126
layer=8, num_pcs=36, var=0.9568542732076477
layer=9, num_pcs=40, var=0.949198062014635
layer=10, num_pcs=44, var=0.9486434641403036
layer=11, num_pcs=48, var=0.9506181545360557
layer=12, num_pcs=52, var=0.947403507465094
layer=13, num_pcs=56, var=0.9314517096345616
layer=14, num_pcs=60, var=0.9252147199796454
layer=15, num_pcs=64, var=0.8983186799583137
layer=16, num_pcs=68, var=0.8815866028851663
layer=17, num_pcs=72, var=0.8552371468024186
layer=18, num_pcs=76, var=0.8433537917738927
layer=19, num_pcs=80, var=0.8360048791918033
layer=20, num_pcs=84, var=0.8253638284349869
layer=21, num_pcs=8