In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sae_lens import SAE, HookedSAETransformer
import torch
from torch import nn
from transformers import BertModel, BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class FASG_Model(nn.Module):
    def __init__(self,device="cuda:0"):
        super(FASG_Model, self).__init__()

        self.device = device
        self.bertTokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bertModel = BertModel.from_pretrained("bert-base-uncased").to(self.device)
        self.linear = torch.nn.Linear(768, 768).to(self.device)
        self.llm = HookedSAETransformer.from_pretrained("gpt2-small", device=self.device)
        self.sae, cfg_dict, sparsity = SAE.from_pretrained(
            release="gpt2-small-res-jb-feature-splitting",  
            sae_id="blocks.8.hook_resid_pre_768", 
            device=self.device,
        )
        
    def bert_tokenize(self,text):
        return self.bertTokenizer(text,padding=True,truncation=True,return_tensors="pt").to(device = self.device)
        
    def forward(self, encoded_input,prompt,return_type = "str"):
        steering_vector = self.bertModel(**encoded_input).pooler_output

        def steering_features(value, hook,steering_vector = steering_vector):
            encoded_activation = self.sae.encode(value)
            steered_vector = steering_vector.unsqueeze(1)*encoded_activation 
            decoded_vector = self.sae.decode(steered_vector)
            return decoded_vector
    
        fwd_hooks=[(
            'blocks.8.hook_resid_pre', 
            steering_features
        )]

        tokenized_prompt = self.llm.to_tokens(prompt)
        with self.llm.hooks(fwd_hooks=fwd_hooks):
            steered_tokens = self.llm.generate(
            tokenized_prompt,
            max_new_tokens=64,
            temperature=0.2,
            top_p=0.9,
            stop_at_eos = True,
            return_type = return_type,
            verbose = False
        )
        return steered_tokens

    def llm_generate(self,prompt,return_type = "str"):
        tokenized_prompt = self.llm.to_tokens(prompt)
        generated_tokens = self.llm.generate(
            tokenized_prompt,
            max_new_tokens=64,
            temperature=0.2,
            top_p=0.9,
            stop_at_eos = True,
            return_type = return_type,
            verbose = False
        )
        return generated_tokens
        

    def freeze_llm_and_sae(self):
        for param in self.llm.parameters():
            param.requires_grad = False
        for param in self.sae.parameters():
            param.requires_grad = False
        

In [4]:
import json
with open("synth_data_v2.json","r") as f:
    corpse = json.loads(f.read())
trainset = []
for i in corpse:
    try:
        trainset.append(json.loads(i)['text'])
    except:
        continue
print(len(trainset))

20746


In [15]:
from sentence_transformers import SentenceTransformer
from sentence_transformers import util as st_util

class FASG_Loss(nn.Module):
    def __init__(self,device="cuda:0"):
        super(FASG_Loss, self).__init__()
        self.sentence_transformer = SentenceTransformer("all-MiniLM-L6-v2")
        self.device = device

    def forward(self, reference_text,steered_text,baseline_text,temp= 1):
        tokenized_inputs = self.sentence_transformer.tokenize(
            [reference_text, steered_text, baseline_text])
        # move all tensors to gpu
        for key in tokenized_inputs.keys():
            tokenized_inputs[key] = tokenized_inputs[key].to(self.device)
        embeddings = self.sentence_transformer(tokenized_inputs)['sentence_embedding']
        embeddings = embeddings * temp
        
        # Compute cosine similarities
        sim_1 = st_util.cos_sim(embeddings[0], embeddings[1])
        sim_2 = st_util.cos_sim(embeddings[0], embeddings[2])

        # Compute softmax triplet loss
        loss = -torch.log(torch.exp(sim_1) / (torch.exp(sim_1) + torch.exp(sim_2)))

        return loss

In [18]:
from torch.optim import AdamW

model = FASG_Model()
model.freeze_llm_and_sae()
model.train()
soft_triplet_loss =  FASG_Loss()
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

total_loss = 0

batch_size =16

for index in range(0,len(trainset),batch_size):
    text = trainset[index:index+16]
    encoded_input = model.bert_tokenize(text)
    prompts = [" ".join(i.split(" ")[:5]) for i in text]
    steered_text = model(encoded_input,prompts)
    baseline_text = model.llm_generate(prompts)

    max_len = np.min([len(text),len(steered_text),len(baseline_text)])
    # Compute loss
    loss = soft_triplet_loss(text[:max_len],steered_text[:max_len],baseline_text[:max_len])

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
    print(loss)


Loaded pretrained model gpt2-small into HookedTransformer
tensor([[0.6931]], device='cuda:0', grad_fn=<NegBackward0>)
tensor([[0.6931]], device='cuda:0', grad_fn=<NegBackward0>)


KeyboardInterrupt: 

In [261]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

In [260]:
torch.cuda.empty_cache() 