See if concept erasure can eliminate A-B token bias from the steering vector.

In [1]:
from repepo.core.types import Example, Completion
from repepo.core.format import LlamaChatFormatter
from repepo.core.pipeline import Pipeline
from steering_vectors import train_steering_vector
from repepo.steering.utils.helpers import get_model_and_tokenizer
from repepo.steering.build_steering_training_data import (
    build_steering_vector_training_data
)

train_dataset = [
    Example(
        positive = Completion(
            prompt = "(",
            response = "A)"
        ),
        negative = Completion(
            prompt = "(",
            response = "B)"
        )
    )
]

model, tokenizer = get_model_and_tokenizer("meta-llama/Llama-2-7b-chat-hf", load_in_8bit=True)
formatter = LlamaChatFormatter()
pipeline = Pipeline(model, tokenizer, formatter)
steering_vector_training_data = build_steering_vector_training_data(
    pipeline, train_dataset
)
layer = 13

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
ab_steering_vector = train_steering_vector(
    model,
    tokenizer,
    steering_vector_training_data,
    layers=[layer],
    show_progress=True,
)

Training steering vector: 100%|██████████| 1/1 [00:00<00:00,  3.02it/s]


In [5]:
ab_steering_vector

SteeringVector(layer_activations={13: tensor([-0.0631, -0.2588,  0.1665,  ..., -0.3843,  0.0649, -0.2756],
       device='cuda:0', dtype=torch.float16)}, layer_type='decoder_block')

In [1]:
import pandas as pd

results_without_erasure_df = pd.read_csv('../paper/steerability_id_final.csv', index_col = 0, escapechar='\\', sep='\t')
results_without_erasure_df.head()

Unnamed: 0,pos_prob,logit_diff,test_example.positive.text,test_example.negative.text,test_example.idx,multiplier,dataset_name,steering_label,dataset_label,slope,residual
0,0.005577,-5.59375,,,0,-1.5,politically-conservative,baseline,baseline,0.816964,0.000558
49,0.020646,-6.46875,,,1,-1.5,politically-conservative,baseline,baseline,0.988839,0.551374
98,0.976483,5.5,,,2,-1.5,politically-conservative,baseline,baseline,-0.541295,0.132804
147,0.006772,-5.617188,,,3,-1.5,politically-conservative,baseline,baseline,0.953683,0.001672
196,0.110541,-2.265625,,,4,-1.5,politically-conservative,baseline,baseline,0.758929,0.016462


# Calculate new steering results with erasure

In [None]:
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, Sequence

import torch
from torch import Tensor, nn
from transformers import PreTrainedTokenizerBase

from steering_vectors.aggregators import Aggregator, mean_aggregator
from steering_vectors.token_utils import adjust_read_indices_for_padding, fix_pad_token
from steering_vectors.utils import batchify

from steering_vectors.layer_matching import LayerType, ModelLayerConfig, guess_and_enhance_layer_config
from steering_vectors.record_activations import record_activations
from steering_vectors.steering_vector import SteeringVector
from steering_vectors.train_steering_vector import SteeringVectorTrainingSample, extract_activations, aggregate_activations
from concept_erasure import LeaceEraser

@torch.no_grad()
def train_steering_vector_with_erasure(
    model: nn.Module,
    tokenizer: PreTrainedTokenizerBase,
    training_samples: Sequence[SteeringVectorTrainingSample | tuple[str, str]],
    layers: list[int] | None = None,
    layer_type: LayerType = "decoder_block",
    layer_config: ModelLayerConfig | None = None,
    move_to_cpu: bool = False,
    read_token_index: int | Callable[[str], int] = -1,
    show_progress: bool = False,
    aggregator: Aggregator = mean_aggregator(),
    batch_size: int = 1,
    tqdm_desc: str = "Training steering vector",
) -> SteeringVector:
    """
    Train a steering vector for the given model.

    Args:
        model: The model to train the steering vector for
        tokenizer: The tokenizer to use
        training_samples: A list of training samples, where each sample is a tuple of
            (positive_str, negative_str). The steering vector approximate the
            difference between the positive prompt and negative prompt activations.
        layers: A list of layer numbers to train the steering vector on. If None, train
            on all layers.
        layer_type: The type of layer to train the steering vector on. Default is
            "decoder_block".
        layer_config: A dictionary mapping layer types to layer matching functions.
            If not provided, this will be inferred automatically.
        move_to_cpu: If True, move the activations to the CPU before training. Default False.
        read_token_index: The index of the token to read the activations from. Default -1, meaning final token.
        show_progress: If True, show a progress bar. Default False.
        aggregator: A function that takes the positive and negative activations for a
            layer and returns a single vector. Default is mean_aggregator.
    """
    pos_acts, neg_acts = extract_activations(
        model,
        tokenizer,
        training_samples,
        layers=layers,
        layer_type=layer_type,
        layer_config=layer_config,
        move_to_cpu=move_to_cpu,
        read_token_index=read_token_index,
        show_progress=show_progress,
        batch_size=batch_size,
        tqdm_desc=tqdm_desc,
    )

    # concept erasure
    
    # get whether response is A
    labels = []
    for (pos, neg) in training_samples:
        if pos.endswith("A)"):
            labels.append(1)
        else:
            labels.append(0)
    # Print fraction of A's
    print(sum(labels) / len(labels))

    for layer in pos_acts.keys():
        pos_act = pos_acts[layer]
        neg_act = neg_acts[layer]
        pos_act_batched = torch.stack(pos_act, dim=0)
        neg_act_batched = torch.stack(neg_act, dim=0)

        X = torch.cat([pos_act_batched, neg_act_batched], dim=0)

        eraser = LeaceEraser.fit(pos_act_batched, neg_act_batched)
    

    layer_activations = aggregate_activations(pos_acts, neg_acts, aggregator)
    return SteeringVector(layer_activations, layer_type)

In [2]:
# Run experiment

from repepo.steering.utils.helpers import make_dataset

train_split = "0%:50%"
dataset_name = "believes-in-gun-rights"
test_split = "50%:100%"

train_dataset = make_dataset(dataset_name, train_split)
model, tokenizer = get_model_and_tokenizer("meta-llama/Llama-2-7b-chat-hf", load_in_8bit=True)
formatter = LlamaChatFormatter()
pipeline = Pipeline(model, tokenizer, formatter)
steering_vector_training_data = build_steering_vector_training_data(
    pipeline, train_dataset
)
layer = 13


pos_acts, neg_acts = extract_activations(
    model,
    tokenizer,
    steering_vector_training_data,
    layers=[layer],
    read_token_index=-2,
)


  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


# Appendix

In [11]:
from repepo.paper.utils import PersonaCrossSteeringExperimentResult
from repepo.paper.utils import (
    load_persona_cross_steering_experiment_result,
    get_steering_vector
)

selected_datasets = [
    # 'politically-conservative',
    'believes-in-gun-rights',
    # 'myopic-reward',
    # 'subscribes-to-moral-nihilism',
    # 'subscribes-to-Hinduism'
]

dataset_steering_vectors = {}
for dataset in selected_datasets:
    print(dataset)
    result = load_persona_cross_steering_experiment_result(dataset)
    steering_vector = get_steering_vector(result)
    dataset_steering_vectors[dataset] = steering_vector


In [13]:
dataset_steering_vectors['politically-conservative']

SteeringVector(layer_activations={13: tensor([ 0.0140,  0.0134, -0.0027,  ..., -0.0186, -0.0241, -0.0150],
       device='cuda:0', dtype=torch.float16)}, layer_type='decoder_block')

In [14]:
# Cosine similarity
from torch.nn.functional import cosine_similarity
for dataset, steering_vector in dataset_steering_vectors.items():
    print(dataset)
    print(
        cosine_similarity(
            ab_steering_vector.layer_activations[13], 
            steering_vector.layer_activations[13],
            dim = -1
        )
    )

politically-conservative
tensor(0.1105, device='cuda:0', dtype=torch.float16)
believes-in-gun-rights
tensor(0.0062, device='cuda:0', dtype=torch.float16)
myopic-reward
tensor(0.0146, device='cuda:0', dtype=torch.float16)
subscribes-to-moral-nihilism
tensor(0.0956, device='cuda:0', dtype=torch.float16)
subscribes-to-Hinduism
tensor(0.0275, device='cuda:0', dtype=torch.float16)
