In [1]:
import io
import os
import torch
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from ml_things import plot_dict, plot_confusion_matrix, fix_text
from sklearn.metrics import classification_report, accuracy_score
from transformers import (set_seed,
                          TrainingArguments,
                          Trainer,
                          GPT2Config,
                          GPT2Tokenizer,
                          AdamW, 
                          get_linear_schedule_with_warmup,
                          GPT2ForSequenceClassification)

import torch
import circuitsvis as cv
from transformers import AutoModelForCausalLM, GPT2ForSequenceClassification
import transformer_lens as tl
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

from functools import partial

from torchtyping import TensorType as TT

## R-LACE Basic GPT-2 Experiment Procedure - Simple sentiment version
1. Load HF classifier
2. Identify target region
3. Get activations in target region
4. Load sentiment training data
5. Perform R-LACE on these activations

6. Replace old activations?

## Load Models

In [4]:
# Set seed for reproducibility.
set_seed(123)

# Number of training epochs (authors on fine-tuning Bert recommend between 2 and 4).
epochs = 4

# Number of batches - depending on the max sequence length and GPU memory.
# For 512 sequence length batch of 10 works without cuda memory issues.
# For small sequence length can try batch of 32 or higher.
batch_size = 128

# Pad or truncate text sequences to a specific length
# if `None` it will use maximum sequence of word piece tokens allowed by model.
max_length = 60

# Look for gpu to use. Will use `cpu` by default if no gpu found.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Name of transformers model - will use already pretrained model.
# Path of transformer model - will load your own model from local disk.
model_name_or_path = 'gpt2'

# Dictionary of labels and their id - this will be used to convert.
# String labels to number ids.
labels_ids = {'neg': 0, 'pos': 1}

# How many labels are we using in training.
# This is used to decide size of classification head.
n_labels = len(labels_ids)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
cls_model = GPT2ForSequenceClassification.from_pretrained("curt-tigges/gpt2-imdb-sentiment-classifier")
source_model = AutoModelForCausalLM.from_pretrained("curt-tigges/gpt2-imdb-sentiment-classifier")
hooked_model = HookedTransformer.from_pretrained(model_name="gpt2", hf_model=source_model, fold_ln=False)

Some weights of the model checkpoint at curt-tigges/gpt2-imdb-sentiment-classifier were not used when initializing GPT2LMHeadModel: ['score.weight']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [6]:
# Get model's tokenizer.
print('Loading tokenizer...')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# default to left padding
tokenizer.padding_side = "left"
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token

Loading tokenizer...


## Load Dataset
Here we will use the IMDB sentiment dataset.

In [7]:
class MovieReviewsDataset(Dataset):
  r"""PyTorch Dataset class for loading data.

  This is where the data parsing happens.

  This class is built with reusability in mind: it can be used as is as.

  Arguments:

    path (:obj:`str`):
        Path to the data partition.

  """

  def __init__(self, path, use_tokenizer):

    # Check if path exists.
    if not os.path.isdir(path):
      # Raise error if path is invalid.
      raise ValueError('Invalid `path` variable! Needs to be a directory')
    self.texts = []
    self.labels = []
    # Since the labels are defined by folders with data we loop 
    # through each label.
    for label in ['pos', 'neg']:
      sentiment_path = os.path.join(path, label)

      # Get all files from path.
      files_names = os.listdir(sentiment_path)#[:10] # Sample for debugging.
      # Go through each file and read its content.
      for file_name in tqdm(files_names, desc=f'{label} files'):
        file_path = os.path.join(sentiment_path, file_name)

        # Read content.
        content = io.open(file_path, mode='r', encoding='utf-8').read()
        # Fix any unicode issues.
        content = fix_text(content)
        # Save content.
        self.texts.append(content)
        # Save encode labels.
        self.labels.append(label)

    # Number of exmaples.
    self.n_examples = len(self.labels)
    

    return

  def __len__(self):
    r"""When used `len` return the number of examples.

    """
    
    return self.n_examples

  def __getitem__(self, item):
    r"""Given an index return an example from the position.
    
    Arguments:

      item (:obj:`int`):
          Index position to pick an example to return.

    Returns:
      :obj:`Dict[str, str]`: Dictionary of inputs that contain text and 
      asociated labels.

    """

    return {'text':self.texts[item],
            'label':self.labels[item]}


class Gpt2ClassificationCollator(object):
    r"""
    Data Collator used for GPT2 in a classificaiton rask. 
    
    It uses a given tokenizer and label encoder to convert any text and labels to numbers that 
    can go straight into a GPT2 model.

    This class is built with reusability in mind: it can be used as is as long
    as the `dataloader` outputs a batch in dictionary format that can be passed 
    straight into the model - `model(**batch)`.

    Arguments:

      use_tokenizer (:obj:`transformers.tokenization_?`):
          Transformer type tokenizer used to process raw text into numbers.

      labels_ids (:obj:`dict`):
          Dictionary to encode any labels names into numbers. Keys map to 
          labels names and Values map to number associated to those labels.

      max_sequence_len (:obj:`int`, `optional`)
          Value to indicate the maximum desired sequence to truncate or pad text
          sequences. If no value is passed it will used maximum sequence size
          supported by the tokenizer and model.

    """

    def __init__(self, use_tokenizer, labels_encoder, max_sequence_len=None):

        # Tokenizer to be used inside the class.
        self.use_tokenizer = use_tokenizer
        # Check max sequence length.
        self.max_sequence_len = use_tokenizer.model_max_length if max_sequence_len is None else max_sequence_len
        # Label encoder used inside the class.
        self.labels_encoder = labels_encoder

        return

    def __call__(self, sequences):
        r"""
        This function allowes the class objesct to be used as a function call.
        Sine the PyTorch DataLoader needs a collator function, I can use this 
        class as a function.

        Arguments:

          item (:obj:`list`):
              List of texts and labels.

        Returns:
          :obj:`Dict[str, object]`: Dictionary of inputs that feed into the model.
          It holddes the statement `model(**Returned Dictionary)`.
        """

        # Get all texts from sequences list.
        texts = [sequence['text'] for sequence in sequences]
        # Get all labels from sequences list.
        labels = [sequence['label'] for sequence in sequences]
        # Encode all labels using label encoder.
        labels = [self.labels_encoder[label] for label in labels]
        # Call tokenizer on all texts to convert into tensors of numbers with 
        # appropriate padding.
        inputs = self.use_tokenizer(text=texts, return_tensors="pt", padding=True, truncation=True,  max_length=self.max_sequence_len)
        # Update the inputs with the associated encoded labels as tensor.
        inputs.update({'labels':torch.tensor(labels)})

        return inputs

In [8]:
# Create data collator to encode text and labels into numbers.
gpt2_classification_collator = Gpt2ClassificationCollator(use_tokenizer=tokenizer, 
                                                          labels_encoder=labels_ids, 
                                                          max_sequence_len=max_length)


print('Dealing with Train...')
# Create pytorch dataset.
train_dataset = MovieReviewsDataset(path='./data/aclImdb/train', 
                               use_tokenizer=tokenizer)
print('Created `train_dataset` with %d examples!'%len(train_dataset))

# Move pytorch dataset into dataloader.
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=gpt2_classification_collator)
print('Created `train_dataloader` with %d batches!'%len(train_dataloader))

print()

print('Dealing with Validation...')
# Create pytorch dataset.
valid_dataset =  MovieReviewsDataset(path='./data/aclImdb/test', 
                               use_tokenizer=tokenizer)
print('Created `valid_dataset` with %d examples!'%len(valid_dataset))

# Move pytorch dataset into dataloader.
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=gpt2_classification_collator)
print('Created `eval_dataloader` with %d batches!'%len(valid_dataloader))

Dealing with Train...


pos files:   0%|          | 0/12500 [00:00<?, ?it/s]

neg files:   0%|          | 0/12500 [00:00<?, ?it/s]

Created `train_dataset` with 25000 examples!
Created `train_dataloader` with 196 batches!

Dealing with Validation...


pos files:   0%|          | 0/12500 [00:00<?, ?it/s]

neg files:   0%|          | 0/12500 [00:00<?, ?it/s]

Created `valid_dataset` with 25000 examples!
Created `eval_dataloader` with 196 batches!


In [9]:
batch = next(iter(train_dataloader))
input, mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']

In [10]:
labels[0]

tensor(0)

In [11]:
[hooked_model.tokenizer.decode(i) for i in input][:10]

["It seems there's a bit of a curse out there when it comes to gay cinema. Namely, happy endings aren't very common. Beautiful Thing excluded, gay films tend to end in broken relationships or untimely death. And some, like Come Undone, just end... period.<br",
 "I guess if a film has magic, I don't need it to be fluid or seamless. It can skip background information, go too fast in some places, too slow in others, etc. Magic in this film: the scene in the library. There are many minor flaws in Stanley & Iris,",
 "Wow...I don't know what to say. I just watched Seven Pounds. No one can make me cry like Will Smith. The man is very in-tune with the vast range of human emotion. This movie was skillfully and beautifully done. Rare to find such intense humanity in Hollywood",
 "From director Barbet Schroder (Reversal of Fortune), I think I saw a bit of this in my Media Studies class, and I recognised the leading actress, so I tried it, despite the rating by the critics. Basically cool kid Rich

## Encode & Save Representations From Final Layer

In [10]:
import pickle

In [11]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f60fe433d60>

In [13]:
target_activations = 'blocks.10.ln2.hook_normalized'

In [13]:
# Run input through model and save the activations
res, cache = hooked_model.run_with_cache(input)
pre_cls_END_hidden_state = cache[target_activations][:,-1,:].cpu()

In [14]:
cache['blocks.10.ln2.hook_normalized'].shape

torch.Size([256, 60, 768])

In [15]:
data_type = "train"
with open("activations/{}/block_10_ln_output.pickle".format(data_type), "wb") as f:
        pickle.dump(pre_cls_END_hidden_state, f)

## Run R-LACE on Representations
The next step is to perform R-LACE on the hidden states of the classification tokens in order to generate a projection matrix that will remove the gender representation directions. Before doing so, the experimenters run PCA on the hidden dimensions and reduce them.

### Imports

In [10]:

import os
import random
from collections import defaultdict


import torch
from rlace import solve_adv_game

import random
import pickle
import numpy as np

### Setup

In [11]:
data_type = "train"
with open("activations/{}/block_10_ln_output.pickle".format(data_type), "rb") as f:
    pre_cls_END_hidden_state = pickle.load(f)

In [12]:
random.seed(0)
np.random.seed(0)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
#ranks = [4]
ranks = [32, 64,128, 256]
rlace_projs = defaultdict(dict)
inlp_projs = defaultdict(dict)
finetune_mode = "no-adv"

In [13]:
def set_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

### Load Encoded Classification Tokens

In [14]:
X = pre_cls_END_hidden_state
y = labels

In [15]:
labels

tensor([0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1,
        0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0,
        1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1,
        0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1,
        1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0,
        1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
        1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1,
        0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0,
        1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1])

### R-LACE

In [16]:
# Set up optimizers and result dictionaries
Ps_rlace, accs_rlace = {}, {}

optimizer_class = torch.optim.SGD
optimizer_params_P = {"lr": 0.005, "weight_decay": 1e-4, "momentum": 0.0}
optimizer_params_predictor = {"lr": 0.005, "weight_decay": 1e-5, "momentum": 0.9}

In [18]:
# Run RLACE for each rank
data_type = "train"

for rank in ranks:

    output = solve_adv_game(X, y, X, y, rank=rank, device=device, out_iters=60000,
                                optimizer_class=optimizer_class, optimizer_params_P=optimizer_params_P,
                                optimizer_params_predictor=optimizer_params_predictor, epsilon=0.002,
                                batch_size=256)

    P = output["P"]
    Ps_rlace[rank] = P
    accs_rlace[rank] = output["score"]

    # Save resulting projection matrices
    with open("interim/block10/{}/Ps_rlace.pickle".format(data_type), "wb") as f:
        pickle.dump((Ps_rlace, accs_rlace), f)

59000/60000. Acc post-projection: 47.656%; best so-far: 45.703%; Maj: 0.391%; Gap: 45.312%; best loss: 1.9024; current loss: 1.0508: 100%|##########| 60000/60000 [19:24<00:00, 51.53it/s]
59000/60000. Acc post-projection: 48.047%; best so-far: 45.312%; Maj: 0.391%; Gap: 44.922%; best loss: 1.8692; current loss: 1.2550: 100%|##########| 60000/60000 [19:12<00:00, 52.04it/s]
59000/60000. Acc post-projection: 53.516%; best so-far: 45.703%; Maj: 0.391%; Gap: 45.312%; best loss: 1.3303; current loss: 0.9887: 100%|##########| 60000/60000 [19:00<00:00, 52.61it/s]
59000/60000. Acc post-projection: 51.562%; best so-far: 45.312%; Maj: 0.391%; Gap: 44.922%; best loss: 1.4330; current loss: 1.0126: 100%|##########| 60000/60000 [20:09<00:00, 49.63it/s]


## Apply R-LACE Projection to Model

### Imports

In [12]:
import sys
import os

from sklearn.linear_model import SGDClassifier, LinearRegression, Lasso, Ridge
from sklearn.utils import shuffle
from sklearn.decomposition import PCA
import seaborn as sn
import random
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
from sklearn.manifold import TSNE
import tqdm
import copy
from sklearn.svm import LinearSVC 

from sklearn.cross_decomposition import PLSRegression
from sklearn.decomposition import TruncatedSVD
import torch
from sklearn.linear_model import SGDClassifier

from sklearn.svm import LinearSVC

import sklearn
from sklearn.linear_model import LogisticRegression
import random
import pickle
import matplotlib.pyplot as plt
from sklearn import cluster
from sklearn import neural_network
from gensim.models.keyedvectors import Word2VecKeyedVectors
from gensim.models import KeyedVectors
import numpy as np
import warnings
import argparse
from sklearn.neural_network import MLPClassifier
from collections import defaultdict
import scipy
from scipy import stats
from scipy.stats import pearsonr
import pandas as pd
from collections import Counter

from torchtyping import TensorType as TT

In [13]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f76cff08df0>

### Sample with RLACE Applied

In [19]:
target_activations = 'blocks.10.ln2.hook_normalized'

In [20]:

def load_projections(proj_type, data_split):
    """Load the projection matrices for the given projection type and finetune mode

    Args:
        proj_type (str): projection type (e.g. "inlp", "rlace")
        data_split (str): dataset on which projections were trained (e.g. "train", "valid")
    """
    with open("interim/block10/{}/Ps_{}.pickle".format(data_split, proj_type), "rb") as f:
        rank2P = pickle.load(f)
        return rank2P
    
Ps_rlace = load_projections("rlace", "train")

In [21]:

def patch_rlace_pos(
    orig_resid_vector: TT["batch", "pos", "d_model"],
    hook,
    pos=-1, 
    projection=None,
):
    orig_resid_vector[:, pos, :] = orig_resid_vector[:, pos, :] @ projection
    return orig_resid_vector

In [22]:
projection = torch.FloatTensor(Ps_rlace[0][64]).to(device)
projection_patching_hook = partial(patch_rlace_pos, pos=-1, projection=projection)
logits = hooked_model.run_with_hooks(
    input, fwd_hooks=[(target_activations, projection_patching_hook)]
)[:,-1,:]
#pre_cls_END_hidden_state = cache["ln_final.hook_normalized"][:,-1,:].cpu()

In [23]:
from typing import Callable, Union, List, Tuple, Dict, Optional, NamedTuple, overload
from typing_extensions import Literal
from jaxtyping import Float, Int
from transformer_lens.past_key_value_caching import (
    HookedTransformerKeyValueCache,
)


@torch.inference_mode()
def generate_with_hooks(
    input: Union[str, Float[torch.Tensor, "batch pos"]] = "",
    max_new_tokens: int = 10,
    stop_at_eos: bool = True,
    eos_token_id: Optional[int] = None,
    do_sample: bool = False,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    temperature: float = 1.0,
    freq_penalty: float = 0.0,
    num_return_sequences: int = 1,
    use_past_kv_cache: bool = True,
    prepend_bos=True,
    return_type: Optional[str] = "input",
    fwd_hooks: Optional[List[Tuple[str, Callable]]] = None,
) -> Float[torch.Tensor, "batch pos_plus_new_tokens"]:
    """
    Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached.

    To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish (by producing an EOT token), we keep running the model on the entire batch, but throw away the output for a finished sequence and just keep adding EOTs to pad.

    This supports entering a single string, but not a list of strings - if the strings don't tokenize to exactly the same length, this gets messy. If that functionality is needed, convert them to a batch of tokens and input that instead.

    Args:
        input (int): Either a batch of tokens ([batch, pos]) or a text string (this will be converted to a batch of tokens with batch size 1)
        max_new_tokens (int): Maximum number of tokens to generate
        stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token
        eos_token_id (int, *optional*): The token ID to use for end of sentence. If None, use the tokenizer's eos_token_id - required if using stop_at_eos
        do_sample (bool): If True, sample from the model's output distribution. Otherwise, use greedy search (take the max logit each time).
        top_k (int): Number of tokens to sample from. If None, sample from all tokens
        top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, we take the top tokens with cumulative probability >= top_p
        temperature (float): Temperature for sampling. Higher values will make the model more random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is sampling from a uniform distribution)
        freq_penalty (float): Frequency penalty for sampling - how much to penalise previous tokens. Higher values will make the model more random
        use_past_kv_cache (bool): If True, create and use cache to speed up generation
        prepend_bos (bool): If True, prepend the model's bos_token_id to the input, if it's a string. Irrelevant if input is a tensor.
        return_type (str, *optional*): The type of the output to return - either a string (str), a tensor of tokens (tensor) or whatever the format of the input was (input).
    Returns:
        outputs (torch.Tensor): [batch, pos + max_new_tokens], generated sequence of new tokens - by default returns same type as input
    """
    if type(input) == str:
        # If text, convert to tokens (batch_size=1)
        assert (
            hooked_model.tokenizer is not None
        ), "Must provide a tokenizer if passing a string to the model"
        tokens = hooked_model.to_tokens(input, prepend_bos=prepend_bos)
    else:
        tokens = input

    if return_type == "input":
        if type(input) == str:
            return_type = "str"
        else:
            return_type = "tensor"

    assert isinstance(tokens, torch.Tensor)
    batch_size, ctx_length = tokens.shape
    tokens = tokens.to(hooked_model.cfg.device)
    if use_past_kv_cache:
        past_kv_cache = HookedTransformerKeyValueCache.init_cache(
            hooked_model.cfg, hooked_model.cfg.device, batch_size
        )
    else:
        past_kv_cache = None

    if stop_at_eos and eos_token_id is None:
        assert (
            hooked_model.tokenizer is not None and hooked_model.tokenizer.eos_token_id is not None
        ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"

        eos_token_id = hooked_model.tokenizer.eos_token_id

    # An array to track which sequences in the batch have finished.
    finished_sequences = torch.zeros(
        batch_size, dtype=torch.bool, device=hooked_model.cfg.device
    )

    # Currently nothing in HookedTransformer changes with eval, but this is here in case that changes in the future
    hooked_model.eval()
    for index in tqdm.tqdm(range(max_new_tokens)):
        # While generating, we keep generating logits, throw away all but the final logits, and then use those logits to sample from the distribution
        # We keep adding the sampled tokens to the end of tokens.
        if use_past_kv_cache:
            # We just take the final tokens, as a [batch, 1] tensor
            if index > 0:
                # logits = hooked_model.forward(
                #     tokens[:, -1:],
                #     return_type="logits",
                    
                # )
                logits = hooked_model.run_with_hooks(
                    tokens[:, -1:],
                    return_type="logits", 
                    fwd_hooks=[(target_activations, projection_patching_hook)],
                    past_kv_cache=past_kv_cache,
                )
            else:
                logits = hooked_model.run_with_hooks(
                    tokens,
                    return_type="logits", 
                    fwd_hooks=[(target_activations, projection_patching_hook)],
                    past_kv_cache=past_kv_cache,
                )

        else:
            # We input the entire sequence, as a [batch, pos] tensor, since we aren't using the cache
            logits = hooked_model.run_with_hooks(
                    tokens,
                    return_type="logits", 
                    fwd_hooks=[(target_activations, projection_patching_hook)],
                )
        final_logits = logits[:, -1, :]

        sampled_tokens = utils.sample_logits(
            final_logits,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
            freq_penalty=freq_penalty,
            tokens=tokens,
        )

        if stop_at_eos:
            # For all unfinished sequences, add on the next token. If a sequence finished, we throw away the generated token and instead add an EOS token to pad.
            sampled_tokens[finished_sequences] = eos_token_id
            finished_sequences.logical_or_(sampled_tokens == eos_token_id)

        tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1)

        if stop_at_eos and finished_sequences.all():
            break

    if return_type == "str":
        if prepend_bos:
            # If we prepended a BOS token, remove it when returning output.
            return hooked_model.tokenizer.decode(tokens[0, 1:])
        else:
            return hooked_model.tokenizer.decode(tokens[0])

    else:
        return tokens

In [24]:
# With RLACE
generate_with_hooks(
    "I thought this movie was", 
    max_new_tokens=100, 
    stop_at_eos=True, 
    temperature=0.5,
    use_past_kv_cache=False)

100%|██████████| 100/100 [00:01<00:00, 88.34it/s]


"I thought this movie was a bit of a disappointment. I was really hoping the director would have done more to make this movie better. I don't think he did too much to make it better. The original film was a bit of a disappointment, but I think it's a good movie nonetheless. I think it's a good movie for a lot of reasons. The original film was a bit of a disappointment, but I think it's a good movie nonetheless. The original film was a bit of a disappointment, but I"

In [25]:
# No RLACE
hooked_model.generate("I thought this movie was", max_new_tokens=100, stop_at_eos=True, temperature=0.5)

  0%|          | 0/100 [00:00<?, ?it/s]

"I thought this movie was going to be a great one. It's a great story about a man who wants to be a scientist, but he's also a man who wants to be a doctor. I loved the way the movie went about this. I loved the way the movie was about the science. It's a great story about a man who wants to be a scientist, but he's also a man who wants to be a doctor. I loved the way the movie went about this.\n\nWhat's the best"

### Evaluate

In [26]:
hook_name = target_activations

In [27]:
projection_patching_hook = partial(patch_rlace_pos, pos=-1, projection=projection)
hooked_model.mod_dict[hook_name].add_hook(projection_patching_hook, dir="fwd")

In [28]:
inf_logits, inf_cache = hooked_model.run_with_cache(input, return_type="logits")

In [29]:
batch_size, sequence_length = input.shape[:2]
sequence_lengths = -1

In [30]:
#pre_cls_hidden_states = (inf_cache[hook_name] / inf_cache["ln_final.hook_scale"]).cpu()
pre_cls_hidden_states = (inf_cache[hook_name]).cpu()
rlace_logits = cls_model.score(pre_cls_hidden_states)
pooled_logits = rlace_logits[range(batch_size), sequence_lengths]

In [31]:
rlace_classification = pooled_logits.softmax(dim=-1).argmax(dim=-1)
rlace_classification

tensor([1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1,
        0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0,
        1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0,
        1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
        1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0,
        1, 1, 0, 1, 0, 1, 1, 0])

In [32]:
original_classification = cls_model(input)['logits'].softmax(dim=-1).argmax(dim=-1)
original_classification

tensor([0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1,
        1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
        1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1,
        1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1,
        1, 1, 0, 0, 0, 1, 1, 0])

In [33]:
rlace_accuracy = (rlace_classification == labels).float().mean()
original_accuracy = (original_classification == labels).float().mean()
original_accuracy, rlace_accuracy

(tensor(0.7188), tensor(0.5078))