# Load Model

In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "reciprocate/dahoas-gptj-rm-static"
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 3/3 [00:22<00:00,  7.34s/it]


# Load Dataset
We want the dataset to be sorted (not chunked and tokenized)

In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

hh = load_dataset("Anthropic/hh-rlhf", split="train")
token_length_cutoff = 870 # 99% of chosen data

# Remove datapoints longer than a specific token_length
# Check if file exists
index_file_name = "rm_save_files/index_small_enough.pt"
dataset_size = hh.num_rows
if os.path.exists(index_file_name):
    index_small_enough = torch.load(index_file_name)
else:
    index_small_enough = torch.ones(dataset_size, dtype=torch.bool)

    for ind, text in enumerate(tqdm(hh)):
        chosen_text = text["chosen"]
        rejected_text = text["rejected"]
        #convert to tokens
        length_chosen = len(tokenizer(chosen_text)["input_ids"])
        length_rejected = len(tokenizer(rejected_text)["input_ids"])
        if length_chosen > token_length_cutoff or length_rejected > token_length_cutoff:
            index_small_enough[ind] = False
    # Save the indices
    torch.save(index_small_enough, "rm_save_files/index_small_enough.pt")

In [None]:
hh = hh.select(index_small_enough.nonzero()[:, 0])
batch_size = 16
hh_dl = DataLoader(hh, batch_size=batch_size, shuffle=False)

# Initialize Supervised-SAE

In [None]:
# define an SAE for ablation
"""
Defines the dictionary classes
"""

from abc import ABC, abstractmethod
import torch as t
import torch.nn as nn


class AutoEncoder(nn.Module):
    """
    A one-layer autoencoder.
    """
    def __init__(self, activation_dim, dict_size):
        super().__init__()

        self.activation_dim = activation_dim
        self.dict_size = dict_size
        self.bias = nn.Parameter(t.zeros(activation_dim))
        self.encoder = nn.Linear(activation_dim, dict_size, bias=True)

        # rows of decoder weight matrix are unit vectors
        self.decoder = nn.Linear(dict_size, activation_dim, bias=False)
        dec_weight = t.randn_like(self.decoder.weight)
        dec_weight = dec_weight / dec_weight.norm(dim=0, keepdim=True)
        self.decoder.weight = nn.Parameter(dec_weight)

    def encode(self, x):
        return nn.ReLU()(self.encoder(x))
    
    def decode(self, f):
        return self.decoder(f) + self.bias
    
    def forward(self, x):
        """
        Forward pass of an autoencoder.
        x : activations to be autoencoded
        """
        f = self.encode(x)
        x_hat = self.decode(f)
        return x_hat
        
    def from_pretrained(path, device=None):
        """
        Load a pretrained autoencoder from a file.
        """
        state_dict = t.load(path)
        dict_size, activation_dim = state_dict['encoder.weight'].shape
        autoencoder = AutoEncoder(activation_dim, dict_size)
        autoencoder.load_state_dict(state_dict)
        if device is not None:
            autoencoder.to(device)
        return autoencoder

# Train Model

In [None]:
for batch_ind, batch in enumerate(tqdm(hh_dl)):
    batch = tokenizer(batch['chosen'], padding="longest", truncation=True, return_attention_mask=False, return_tensors="pt")
    batch = batch["input_ids"].to(device)
    

# Feature Search: Attribution Patching (AP) w/ Zero-Ablation