In [69]:
import numpy as np
import matplotlib.pyplot as plt
from sae_lens import SAE, HookedSAETransformer
import torch
from torch import nn
from transformers import BertModel, BertTokenizer
from tqdm import tqdm
import json

with open("synth_data_v2.json","r") as f:
    corpse = json.loads(f.read())
trainset = []
for i in corpse:
    try:
        trainset.append(json.loads(i)['text'])
    except:
        continue
print(f"Train set size: {len(trainset)}")

Train set size: 20746


In [70]:
# model
class FASG_Model(nn.Module):
    def __init__(self,device="cuda:0"):
        super(FASG_Model, self).__init__()

        self.device = device
        self.bertTokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bertModel = BertModel.from_pretrained("bert-base-uncased").to(self.device)
        self.linear = torch.nn.Linear(768, 768).to(self.device)
        self.llm = HookedSAETransformer.from_pretrained("gpt2-small", device=self.device)
        self.sae, cfg_dict, sparsity = SAE.from_pretrained(
            release="gpt2-small-res-jb-feature-splitting",  
            sae_id="blocks.8.hook_resid_pre_768", 
            device=self.device,
        )
        self.top_p = 0.95
        self.temperature = 0.5
        
    def bert_tokenize(self,text):
        return self.bertTokenizer(text,padding=True,truncation=True,return_tensors="pt").to(device = self.device)
        
    def forward(self, encoded_input,prompt,return_type = "str"):
        steering_vector = self.bertModel(**encoded_input).pooler_output

        def steering_features(value, hook,steering_vector = steering_vector,steering_weights = 4):
            #encoded_activation = self.sae.encode(value)
            #steered_vector = steering_vector.unsqueeze(1)*encoded_activation 
            delta_activation = self.sae.decode(steering_vector)

            steered_activation = delta_activation.unsqueeze(1)*steering_weights + value
            
            return steered_activation
    
        fwd_hooks=[(
            'blocks.8.hook_resid_pre', 
            steering_features
        )]

        tokenized_prompt = self.llm.to_tokens(prompt)
        with self.llm.hooks(fwd_hooks=fwd_hooks):
            steered_tokens = self.llm.generate(
            tokenized_prompt,
            max_new_tokens=64,
            temperature=self.temperature,
            top_p=self.top_p,
            stop_at_eos = True,
            return_type = return_type,
            verbose = False
        )
        return steered_tokens

    def llm_generate(self,prompt,return_type = "str"):
        tokenized_prompt = self.llm.to_tokens(prompt)
        generated_tokens = self.llm.generate(
            tokenized_prompt,
            max_new_tokens=64,
            temperature=self.temperature,
            top_p=self.top_p,
            stop_at_eos = True,
            return_type = return_type,
            verbose = False
        )
        return generated_tokens
        

    def freeze_llm_and_sae(self):
        for param in self.llm.parameters():
            param.requires_grad = False
        for param in self.sae.parameters():
            param.requires_grad = False


In [73]:
model = FASG_Model()
sample = trainset[5]

reference_text = sample
prompt = " ".join(sample.split(" ")[:5])
prompt

Loaded pretrained model gpt2-small into HookedTransformer


'Yeah, but most folks think'

In [75]:
model.llm_generate(prompt)

'Yeah, but most folks think the next big thing is the next big thing.\n\nThere are some serious, serious issues with the current political climate.\n\nThe first is that the GOP is basically trying to make it impossible to pass a healthcare bill, which is a pretty big deal. The GOP is trying to make it impossible to pass a'

In [76]:
def steering_features(value, hook,steering_vector = steering_vector):
    out = value + 0.1*steering_vector.unsqueeze(1)
    return out

fwd_hooks=[(
    'blocks.8.hook_resid_pre', 
    steering_features
)]

tokenized_prompt = model.llm.to_tokens(prompt)
with model.llm.hooks(fwd_hooks=fwd_hooks):
    steered = model.llm.generate(
    tokenized_prompt,
    max_new_tokens=64,
    temperature=0.5,
    top_p=0.95,
    stop_at_eos = True,
    return_type = "str",
    verbose = False
)
print(steered)

Yeah, but most folks think it's a "natural" thing to do. It's not.

I'm not saying that you should be "natural" to do something. I'm just saying that it's not a natural thing to do, and it's not a natural thing to do if you're not doing it.

You


In [86]:
model.llm.forward(prompt, loss=True, return_type='both')

TypeError: HookedTransformer.forward() got an unexpected keyword argument 'loss'

# t test|

In [92]:
def get_steered_output(model, prompt, steering_vector):
    steering_vector = steering_vector
    
    def steering_features(value, hook,steering_vector = steering_vector):
        out = value + 0.1*steering_vector.unsqueeze(1)
        return out
    
    fwd_hooks=[(
        'blocks.8.hook_resid_pre', 
        steering_features
    )]
    
    tokenized_prompt = model.llm.to_tokens(prompt)
    with model.llm.hooks(fwd_hooks=fwd_hooks):
        steered = model.llm.generate(
        tokenized_prompt,
        max_new_tokens=64,
        temperature=0.5,
        top_p=0.95,
        stop_at_eos = True,
        return_type = "str",
        verbose = False
    )
    return steered

In [96]:
from sentence_transformers import SentenceTransformer
from sentence_transformers import util as st_util


def get_cos_sim(st_model,text1,text2):
    tokenized_inputs = st_model.tokenize(
                [text1,text2])
    for key in tokenized_inputs.keys():
        tokenized_inputs[key] = tokenized_inputs[key].to("cuda:0")
    embeddings = st_model(tokenized_inputs)['sentence_embedding']
    sim = st_util.cos_sim(embeddings[0], embeddings[1])
    return sim

In [None]:
# metrics
from rouge_score import rouge_scorer
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
rouge_dict = {"rouge1":[],
             "rouge2":[],
             "rougeL":[]}
cos_sim = []

# top 10 features
top_features = [698, 734, 709, 719,  87, 112, 675, 679, 187, 445]
num_steered_features = 5
steering_strength = 1
st_model = SentenceTransformer("all-MiniLM-L6-v2")

for sample in tqdm(trainset[:100]):
    prompt = " ".join(sample.split(" ")[:5])

    unsteered_output = model.llm_generate(prompt)
    steering_vec = np.zeros(768,dtype = np.float32)
    steering_vec[top_features[:num_steered_features]] = steering_strength
    steering_vec = torch.tensor(steering_vec).unsqueeze(0).to("cuda:0")

    delta_activation = model.sae.decode(steering_vec)
    steered_output = get_steered_output(model, prompt,delta_activation)

    # rouge
    rouge_scores = scorer.score(unsteered_output, steered_output)
    for key in rouge_dict.keys():
        rouge_dict[key].append(rouge_scores[key])
    

    # cos sim
    cos_sim.append(get_cos_sim(st_model,unsteered_output,steered_output))