In [26]:
import numpy as np
import matplotlib.pyplot as plt
from sae_lens import SAE, HookedSAETransformer
import torch
from torch import nn
from transformers import BertModel, BertTokenizer
from tqdm import tqdm

import json
with open("synth_data_v2.json","r") as f:
    corpse = json.loads(f.read())
trainset = []
for i in corpse:
    try:
        trainset.append(json.loads(i)['text'])
    except:
        continue
print(f"Train set size: {len(trainset)}")

# model
class FASG_Model(nn.Module):
    def __init__(self,device="cuda:0"):
        super(FASG_Model, self).__init__()

        self.device = device
        self.bertTokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bertModel = BertModel.from_pretrained("bert-base-uncased").to(self.device)
        self.linear = torch.nn.Linear(768, 768).to(self.device)
        self.llm = HookedSAETransformer.from_pretrained("gpt2-small", device=self.device)
        self.sae, cfg_dict, sparsity = SAE.from_pretrained(
            release="gpt2-small-res-jb-feature-splitting",  
            sae_id="blocks.8.hook_resid_pre_768", 
            device=self.device,
        )
        self.top_p = 0.95
        self.temperature = 0.5
        
    def bert_tokenize(self,text):
        return self.bertTokenizer(text,padding=True,truncation=True,return_tensors="pt").to(device = self.device)
        
    def forward(self, encoded_input,prompt,return_type = "str"):
        steering_vector = self.bertModel(**encoded_input).pooler_output

        def steering_features(value, hook,steering_vector = steering_vector):
            encoded_activation = self.sae.encode(value)
            steered_vector = steering_vector.unsqueeze(1)*encoded_activation 
            decoded_vector = self.sae.decode(steered_vector)
            return decoded_vector
    
        fwd_hooks=[(
            'blocks.8.hook_resid_pre', 
            steering_features
        )]

        tokenized_prompt = self.llm.to_tokens(prompt)
        with self.llm.hooks(fwd_hooks=fwd_hooks):
            steered_tokens = self.llm.generate(
            tokenized_prompt,
            max_new_tokens=64,
            temperature=self.temperature,
            top_p=self.top_p,
            stop_at_eos = True,
            return_type = return_type,
            verbose = False
        )
        return steered_tokens

    def llm_generate(self,prompt,return_type = "str"):
        tokenized_prompt = self.llm.to_tokens(prompt)
        generated_tokens = self.llm.generate(
            tokenized_prompt,
            max_new_tokens=64,
            temperature=self.temperature,
            top_p=self.top_p,
            stop_at_eos = True,
            return_type = return_type,
            verbose = False
        )
        return generated_tokens
        

    def freeze_llm_and_sae(self):
        for param in self.llm.parameters():
            param.requires_grad = False
        for param in self.sae.parameters():
            param.requires_grad = False



# loss
from sentence_transformers import SentenceTransformer
from sentence_transformers import util as st_util

class FASG_Loss(nn.Module):
    def __init__(self,device="cuda:0"):
        super(FASG_Loss, self).__init__()
        self.sentence_transformer = SentenceTransformer("all-MiniLM-L6-v2")
        self.device = device

    def forward(self, reference_text,steered_text,baseline_text,temp= 1):
        tokenized_inputs = self.sentence_transformer.tokenize(
            [reference_text, steered_text, baseline_text])
        # move all tensors to gpu
        for key in tokenized_inputs.keys():
            tokenized_inputs[key] = tokenized_inputs[key].to(self.device)
        embeddings = self.sentence_transformer(tokenized_inputs)['sentence_embedding']
        embeddings = embeddings * temp
        
        # Compute cosine similarities
        sim_positive = 1-st_util.cos_sim(embeddings[0], embeddings[1])
        sim_negative = 1-st_util.cos_sim(embeddings[0], embeddings[2])

        # Compute softmax triplet loss
        margin = 0.2
        triplet_loss = torch.max(sim_positive-sim_negative+margin,0)[0] # torch mx returns (max, max_indices)
        loss = triplet_loss + sim_positive *0.1 # still preserve a little 

        return loss

Train set size: 20746


In [27]:
model = FASG_Model()
model.load_state_dict(torch.load("model_ckpts/fasg_model_epoch4.pth"))

Loaded pretrained model gpt2-small into HookedTransformer


  model.load_state_dict(torch.load("model_ckpts/fasg_model_epoch4.pth"))


<All keys matched successfully>

In [6]:
with open("synth_data_testset.txt","r") as f:
    testset = [json.loads(i) for i in f.read().split("\n\n") if i]

In [28]:
reference_text = testset[3]['text']
prompt = " ".join(reference_text.split(" ")[:5])

print("Unsteered: \n")
print(model.llm_generate(prompt))
print("Steered: \n")
print(model(model.bert_tokenize(reference_text),prompt))

Unsteered: 

I have a little story for you: I was a child when I was young and my parents didn't have any money to pay my bills. I was a little bit homeless, and I didn't have a home. I was homeless when I was 12, but I was able to get by, and I was able to buy food, clothes
Steered: 

I have a little story this the as well, to to. to to to. in "..... a the, is not., the to to,, the., to, the,, the, to., to to to to to but to and.. to in. the to,...,
