# Download RM

In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "reciprocate/dahoas-gptj-rm-static"
# rm = AutoModelForCausalLM.from_pretrained(model_id).to(device)
rm = AutoModelForSequenceClassification.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# turn off gradients
for param in rm.parameters():
    param.requires_grad = False

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:22<00:00,  7.48s/it]
Some weights of the model checkpoint at reciprocate/dahoas-gptj-rm-static were not used when initializing GPTJForSequenceClassification: ['transformer.h.17.attn.masked_bias', 'transformer.h.12.attn.bias', 'transformer.h.2.attn.bias', 'transformer.h.24.attn.bias', 'transformer.h.5.attn.masked_bias', 'transformer.h.21.attn.masked_bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.4.attn.masked_bias', 'transformer.h.26.attn.masked_bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.3.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.24.attn.masked_bias', 'transformer.h.9.attn.bias', 'transformer.h.4.attn.bias', 'transformer.h.0.attn.bias', 'transformer.h.22.attn.masked_bias', 'transformer.h.6.attn.masked_bias', 'transformer.h.27.attn.masked_bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.6.attn.bias', 'transformer.h.5.attn.bias',

In [5]:
# run some example text through the model
text = "I like to eat ice cream."
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
reward = rm(input_ids).logits
reward

tensor([[-0.3036]], device='cuda:0', grad_fn=<IndexBackward0>)

# Model Definitions

In [24]:
from torchtyping import TensorType
from torch import nn
class TiedSAE(nn.Module):
    def __init__(self, activation_size, n_dict_components):
        super().__init__()
        self.encoder = nn.Parameter(torch.empty((n_dict_components, activation_size)))
        nn.init.xavier_uniform_(self.encoder)
        self.encoder_bias = nn.Parameter(torch.zeros((n_dict_components,)))

    def get_learned_dict(self):
        norms = torch.norm(self.encoder, 2, dim=-1)
        return self.encoder / torch.clamp(norms, 1e-8)[:, None]

    def encode(self, batch):
        c = torch.einsum("nd,bd->bn", self.encoder, batch)
        c = c + self.encoder_bias
        c = torch.clamp(c, min=0.0)
        return c

    def decode(self, code: TensorType["_batch_size", "_n_dict_components"]) -> TensorType["_batch_size", "_activation_size"]:
        learned_dict = self.get_learned_dict()
        x_hat = torch.einsum("nd,bn->bd", learned_dict, code)
        return x_hat

    def forward(self, batch: TensorType["_batch_size", "_activation_size"]) -> TensorType["_batch_size", "_activation_size"]:
        c = self.encode(batch)
        x_hat = self.decode(c)
        return x_hat, c

    def n_dict_components(self):
        return self.get_learned_dict().shape[0]

# Download SAE

In [32]:
# from autoencoders import *
from huggingface_hub import hf_hub_download

layer = 15
rm_sae_repo_id = "Elriggs/dahoas-gptj-rm-sae"
rm_sae_filename = f"dahoas-gptj-rm-static_r4_transformer.h.{layer}.pt"
ae_download_location = hf_hub_download(repo_id=rm_sae_repo_id, filename=rm_sae_filename)
output_cache_name = f"transformer.h.{layer}"
autoencoder = torch.load(ae_download_location).to(device)
for param in autoencoder.parameters():
    param.requires_grad = False

# Download Dataset

In [30]:
from activation_dataset import chunk_and_tokenize
from datasets import load_dataset
from torch.utils.data import DataLoader

# Download the dataset
# This formats it, so every datapoint is max_length tokens long
dataset_name="stas/openwebtext-10k"
max_seq_length=32
dataset = load_dataset(dataset_name, split="train[:300]")
dataset, _ = chunk_and_tokenize(dataset, tokenizer, max_length=max_seq_length)
max_tokens = dataset.num_rows*max_seq_length
print(f"Number of tokens: {max_tokens/1e6:.2f}M")

Found cached dataset openwebtext-10k (/root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)
Loading cached processed dataset at /root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b/cache-f990913222bd2a7b_*_of_00008.arrow


Number of tokens: 0.33M


# Get reward & Counterfactual Reward (when ablating each feature)

In [51]:
from einops import rearrange
from baukit import Trace
from functools import partial
#TODO: This needs to change to ablate a specific feature. So subtract the feature out of the activation
def replace_with_sae(activations, layer_name, autoencoder):
    return activations
    # Check if tuple ie residual layer output as opposed to e.g. mlp output
    if isinstance(activations, tuple):
        temp_activations = activations[1]
        activations = activations[0]

    b, s, n = activations.shape
    mlp_flattened = rearrange(activations, "b s n -> (b s) n")
    reconstruction_flattened, _ = autoencoder(mlp_flattened)
    reconstruction = rearrange(reconstruction_flattened, "(b s) n -> b s n", b=b, s=s)

    if isinstance(activations, tuple):
        reconstruction = tuple([reconstruction, temp_activations])

    return reconstruction

from tqdm.auto import tqdm
batch_size = 32
num_datapoints = dataset.num_rows
num_features, d_model = autoencoder.encoder.shape
original_reward = torch.zeros(num_datapoints)
ablated_reward = torch.zeros((num_datapoints, num_features))
diff_reward = torch.zeros((num_datapoints, num_features))
with torch.no_grad(), dataset.formatted_as("pt"):
    dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
    for i, batch in enumerate(tqdm(dl)):
        # Get original reward
        batch = batch.to(device)
        logit1 = rm(batch).logits[:, 0]
        original_reward[i*batch_size:(i+1)*batch_size] = rm(batch).logits[:, 0].cpu()
        # Get ablated reward
        intervention_function = partial(replace_with_sae,  autoencoder=autoencoder)
        with Trace(rm, output_cache_name, edit_output=intervention_function) as _:
            ablated_reward[i*batch_size:(i+1)*batch_size, 0] = rm(batch).logits[:, 0].cpu()
        break 

  0%|          | 0/320 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 44.48 GiB total capacity; 44.07 GiB already allocated; 63.25 MiB free; 44.22 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
def replace_with_sae(activations, layer_name, autoencoder):
    # Check if tuple ie residual layer output as opposed to e.g. mlp output
    # print(activations)
    if isinstance(activations, tuple):
        temp_activations = activations[0]
        activations2 = activations[-1]
    print(activations2)