#### All these are just Diff-In-Mean Utils

In [1]:
from utils import *
import torch, transformers, datasets, einops
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, Union, List, Any
from torch.utils.data import DataLoader
import numpy as np

from torch import nn
from pyvene import (
    SourcelessIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
    CollectIntervention,
    InterventionOutput
)
from pyvene import (
    IntervenableConfig,
    IntervenableModel
)

DEVICE = "cuda"

@torch.no_grad()
def set_decoder_norm_to_unit_norm(model):
    assert model.proj.weight is not None, "Decoder weight was not initialized."

    eps = torch.finfo(model.proj.weight.dtype).eps
    norm = torch.norm(model.proj.weight.data, dim=1, keepdim=True)
    model.proj.weight.data /= norm + eps

def gather_residual_activations(model, target_layer, inputs):
  target_act = None
  def gather_target_act_hook(mod, inputs, outputs):
    nonlocal target_act # make sure we can modify the target_act from the outer scope
    target_act = outputs[0]
    return outputs
  handle = model.model.layers[target_layer].register_forward_hook(
      gather_target_act_hook, always_call=True)
  _ = model.forward(**inputs)
  handle.remove()
  return target_act

class LogisticRegressionModel(torch.nn.Module):
    def __init__(self, input_dim, low_rank_dimension):
        super(LogisticRegressionModel, self).__init__()
        # Linear layer: input_dim -> 1 output (since binary classification)
        self.proj = torch.nn.Linear(input_dim, low_rank_dimension)
    
    def forward(self, x):
        return self.proj(x)

@dataclass
class DataCollator(object):
    """Collate examples for ReFT."""
    
    tokenizer: transformers.AutoTokenizer
    data_collator: transformers.DataCollator

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        max_intervention_len = max([len(inst["intervention_locations"][0]) for inst in instances])
        max_seq_len = max([len(inst["input_ids"]) for inst in instances])
        
        for inst in instances:
            non_pad_len = len(inst["input_ids"])

            _intervention_mask = torch.ones_like(inst["intervention_locations"][0])
            _intervention_location_paddings = torch.tensor(
                [[len(inst["input_ids"]) for _ in range(max_intervention_len - len(inst["intervention_locations"][0]))]])
            _intervention_mask_paddings = torch.tensor(
                [0 for _ in range(max_intervention_len - len(inst["intervention_locations"][0]))])
            inst["intervention_locations"] = torch.cat([inst["intervention_locations"], _intervention_location_paddings], dim=-1).int()
            inst["intervention_masks"] = torch.cat([_intervention_mask, _intervention_mask_paddings], dim=-1).int()

            _input_id_paddings = torch.tensor(
                [self.tokenizer.pad_token_id for _ in range(max_seq_len - non_pad_len)])
            inst["input_ids"] = torch.cat((inst["input_ids"], torch.tensor([self.tokenizer.pad_token_id]), _input_id_paddings)).int()
            inst["attention_mask"] = (inst["input_ids"] != self.tokenizer.pad_token_id).int()
            inst["labels"] = inst["labels"].int()
        batch_inputs = self.data_collator(instances)
        return batch_inputs

def make_data_module(
    tokenizer: transformers.PreTrainedTokenizer, model, df, prefix_length=1
):
    all_input_ids, all_labels, all_intervention_locations = [], [], []
    for _, row in df.iterrows():
        input_ids = tokenizer(
            row["input"]+row["output"], max_length=1024, truncation=True, return_tensors="pt")["input_ids"][0]
        base_length = len(input_ids)
        intervention_locations = torch.tensor([[i for i in range(prefix_length, base_length)]])
        all_input_ids.append(input_ids)
        all_labels.append(row["labels"])
        all_intervention_locations.append(intervention_locations)

    train_dataset = datasets.Dataset.from_dict({
        "input_ids": all_input_ids,
        "labels": all_labels,
        "intervention_locations": all_intervention_locations
    })
    train_dataset.set_format(type='torch', columns=['input_ids', 'labels', 'intervention_locations'])

    data_collator_fn = transformers.DefaultDataCollator(
        return_tensors="pt"
    )
    data_collator = DataCollator(tokenizer=tokenizer, data_collator=data_collator_fn)
    return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)

def make_dataloader(tokenizer, model, examples):
    data_module = make_data_module(tokenizer, model, examples)
    train_dataloader = DataLoader(
        data_module["train_dataset"], 
        shuffle=True, batch_size=8, 
        collate_fn=data_module["data_collator"])
    return train_dataloader

@torch.no_grad()
def train(tokenizer, model, ax, examples, layer, prefix_length=4):
    train_dataloader = make_dataloader(tokenizer, model, examples)
    torch.cuda.empty_cache()
    ax.eval()
    # Main training loop.
    positive_activations = []
    negative_activations = []
    for batch in train_dataloader:
        # prepare input
        inputs = {k: v.to(DEVICE) for k, v in batch.items()}
        activations = gather_residual_activations(
            model, layer, 
            {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]}
        ).detach()
        nonbos_mask = inputs["attention_mask"][:,prefix_length:]
        activations = activations[:,prefix_length:][nonbos_mask.bool()]
        labels = inputs["labels"].unsqueeze(1).repeat(
            1, inputs["input_ids"].shape[1] - prefix_length)
        positive_activations.append(activations[labels[nonbos_mask.bool()] == 1])
        negative_activations.append(activations[labels[nonbos_mask.bool()] != 1])

    mean_positive_activation = torch.cat(positive_activations, dim=0).mean(dim=0)
    mean_negative_activation = torch.cat(negative_activations, dim=0).mean(dim=0)
    ax.proj.weight.data = mean_positive_activation.unsqueeze(0) - mean_negative_activation.unsqueeze(0)
    set_decoder_norm_to_unit_norm(ax)
    print("Training finished.")

def get_logits(tokenizer, model, ax, concept_id=0, k=10):
    top_logits, neg_logits = [None], [None]
    if concept_id is not None:
        W_U = model.lm_head.weight.T
        W_U = W_U * (model.model.norm.weight +
                    torch.ones_like(model.model.norm.weight))[:, None]
        W_U -= einops.reduce(
            W_U, "d_model d_vocab -> 1 d_vocab", "mean"
        )

        vocab_logits = ax.proj.weight.data[concept_id] @ W_U
        top_values, top_indices = vocab_logits.topk(k=k, sorted=True)
        top_tokens = tokenizer.batch_decode(top_indices.unsqueeze(dim=-1))
        top_logits = [list(zip(top_tokens, top_values.tolist()))]
        
        neg_values, neg_indices = vocab_logits.topk(k=k, largest=False, sorted=True)
        neg_tokens = tokenizer.batch_decode(neg_indices.unsqueeze(dim=-1))
        neg_logits = [list(zip(neg_tokens, neg_values.tolist()))]
    return top_logits, neg_logits

def get_few_shot_contrastive_inputs(
    q_ID,
    wave="Pew_American_Trends_Panel_disagreement_100", 
    demographic_group="POLPARTY",
    demographic="Democrat",
    output_type="model_logprobs",
    n_shots=5,
    n_simulations_per_shot=1,
    provide_ground_truth_distribution=False,
):
    data_path = '{}/opinions_qa/data/human_resp/'.format(os.getcwd())
    demographic_in_prompt = demographic
    data = json.load(open(data_path + wave + '/' + demographic_group + "_data.json"))
    prompt = "Your task is to simulate an answer to a new question. "

    if output_type=='sequence':
        prompt+= 'After the examples, please simulate 30 samples for the new question asked. Please only respond with 30 multiple choice answers, no extra spaces, characters, quotes or text. Please only produce 30 characters. Answers with more than 30 characters will not be accepted.'
    elif output_type=='model_logprobs': 
        prompt += 'After the examples, please simulate an answer for the question asked. Please only respond with a single multiple choice answer, no extra spaces, characters, quotes or text. Please only produce 1 character. Answers with more than one characters will not be accepted.'
    elif output_type=='express_distribution': 
        prompt += 'After the examples, please express the distribution of answers for the question asked. Please only respond in the exact format of a dictionary mapping answer choice letter to probability, no extra spaces, characters, quotes or text. Please only produce 1 sentence in this format. Answers outside of this format will not be accepted.'

    # we need the larger set to get icl demos
    if wave == 'Pew_American_Trends_Panel_disagreement_100':
        icl_wave='Pew_American_Trends_Panel_disagreement_500'
    icl_data = json.load(open(data_path + icl_wave + '/' + demographic_group + "_data.json"))

    # get icl qids
    ICL_qIDS = get_ICL_qIDs(
        q_ID=q_ID, wave=icl_wave, 
        demographic_group=demographic_group, demographic=demographic)

    examples = []
    
    for icl_qID in ICL_qIDS[:n_shots]:
        if icl_qID == q_ID:
            continue

        n = (sum(icl_data[icl_qID][demographic].values()))
        MC_options = list(icl_data[icl_qID][demographic].keys())
        all_options, probs = [], []
        for i, option in enumerate(MC_options):
            all_options.append(options[i])
            probs.append(icl_data[icl_qID][demographic][option]/n)
            if provide_ground_truth_distribution:
                prompt +="{} be {}%, ".format(option, int((icl_data[icl_qID][demographic][option]/n)*100))

        example_input = prompt + "\nQuestion: " + icl_data[icl_qID]['question_text'] + "?\n"
        for i, option in enumerate(MC_options):
            example_input +="{}. {}. ".format(options[i], option)
        for _ in range(n_simulations_per_shot):
            examples.append([example_input, q_ID, icl_qID, demographic_group, demographic, output_type, wave])

    return pd.DataFrame(examples, columns=[
        'input', 'qID', 'icl_qID', 'demographic_group', 'demographic', 'output_type', 'wave'
    ])

class AdditionIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        # Note that we initialise these to zeros because we're loading in pre-trained weights.
        # If you want to train your own SAEs then we recommend using blah
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
                self.embed_dim, kwargs["low_rank_dimension"], bias=True)

    def forward(self, base, source=None, subspaces=None):
        # use subspaces["idx"] to select the correct weight vector
        steering_vec = subspaces["max_act"] * subspaces["mag"] * self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
        output = base + steering_vec.unsqueeze(dim=1)
        return output
        
class SubspaceIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"], bias=True)
    
    def forward(self, base, source=None, subspaces=None):
        v = self.proj.weight[[subspaces["idx"]]*base.shape[0]].unsqueeze(dim=-1) # bs, h, 1
        # get orthogonal component
        latent = torch.relu(torch.bmm(base, v)) # bs, s, 1
        proj_vec = torch.bmm(latent, v.permute(0, 2, 1)) # bs, s, 1 * bs, 1, h = bs, s, h
        base_orthogonal = base - proj_vec

        steering_scale = subspaces["max_act"] * subspaces["mag"]
        steering_vec = steering_scale * v.permute(0, 2, 1) # bs, 1, h
        
        # Replace the projection component with the steering vector
        output = base_orthogonal + steering_vec 
        return output



In [2]:
device = "cuda"
model_name_or_path = "google/gemma-2-2b-it"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

# get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=2048, 
    padding_side="right", use_fast=False)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


### Specify your job parameters

In [5]:
# Run random baseline to get these files first!
n_training_qIDs = "train_qIDs.json"
n_testing_qIDs = "test_qIDs.json"

# demographic group and output type
demographic_group = "POLPARTY"
demographic = "Republican"
output_type = "sequence"

#### Training and Eval

In [7]:
def apply_chat_template(row):
    messages = [{"role": "user", "content": row["input"]}]
    nobos = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)[1:]
    return tokenizer.decode(nobos)

sampled_qIDs = json.load(open(n_training_qIDs))

qID_datasets = []
for qID in sampled_qIDs:
    qID_dataset = get_few_shot_training_examples(
        qID,
        wave="Pew_American_Trends_Panel_disagreement_100", 
        demographic_group=demographic_group,
        demographic=demographic,
        output_type="model_logprobs", 
        dataset="opinionqa",
        n_shots=5,
        n_simulations_per_shot=1,
    )
    qID_datasets += [qID_dataset]
raw_dataset = pd.concat(qID_datasets)
training_dataset = prepare_df(raw_dataset.copy(), tokenizer).reset_index(drop=True)

qID_c_datasets = []
for qID in sampled_qIDs:
    contrastive_df = get_few_shot_contrastive_inputs(
        qID,
        wave="Pew_American_Trends_Panel_disagreement_100", 
        demographic_group=demographic_group,
        demographic=demographic,
        output_type="model_logprobs", 
        n_shots=5,
        n_simulations_per_shot=1,
    )
    qID_c_datasets += [contrastive_df]
raw_c_dataset = pd.concat(qID_c_datasets)

all_responses = []
for index, row in raw_c_dataset.iterrows():
    instruction = apply_chat_template({"input": row["input"]})
    prompt = tokenizer(instruction, return_tensors="pt").to(device)
    model_response = model.generate(
        **prompt, max_new_tokens=3, do_sample=True, 
        eos_token_id=tokenizer.eos_token_id, early_stopping=True
    )
    response = tokenizer.decode(model_response[0][prompt["input_ids"].shape[-1]:], skip_special_tokens=True)
    all_responses += [response.strip()]

raw_c_dataset["output"] = all_responses
contrastive_dataset = prepare_df(raw_c_dataset.copy(), tokenizer)
training_dataset["labels"] = 1
contrastive_dataset["labels"] = 0
contrastive_training_dataset = pd.concat([training_dataset, contrastive_dataset]).reset_index(drop=True)

  icl_values = np.array(icl_values)/np.sum(icl_values)


In [8]:
contrastive_training_dataset

Unnamed: 0,input,output,qID,icl_qID,demographic_group,demographic,output_type,wave,labels
0,<start_of_turn>user\nPlease simulate an answer...,Answer: C,SOCIETY_SSM_W92,GAYMARR2_W32,POLPARTY,Republican,model_logprobs,Pew_American_Trends_Panel_disagreement_100,1
1,<start_of_turn>user\nPlease simulate an answer...,Answer: C,SOCIETY_SSM_W92,FAMSURV6_W50,POLPARTY,Republican,model_logprobs,Pew_American_Trends_Panel_disagreement_100,1
2,<start_of_turn>user\nPlease simulate an answer...,Answer: F,SOCIETY_SSM_W92,SOCIETY_RHIST_W92,POLPARTY,Republican,model_logprobs,Pew_American_Trends_Panel_disagreement_100,1
3,<start_of_turn>user\nPlease simulate an answer...,Answer: F,SOCIETY_SSM_W92,SOCIETY_TRANS_W92,POLPARTY,Republican,model_logprobs,Pew_American_Trends_Panel_disagreement_100,1
4,<start_of_turn>user\nPlease simulate an answer...,Answer: A,SOCIETY_SSM_W92,SOCIETY_WHT_W92,POLPARTY,Republican,model_logprobs,Pew_American_Trends_Panel_disagreement_100,1
...,...,...,...,...,...,...,...,...,...
95,<start_of_turn>user\nYour task is to simulate ...,C.,RACESURV34a_W43,WHADVANT_W32,POLPARTY,Republican,model_logprobs,Pew_American_Trends_Panel_disagreement_100,0
96,<start_of_turn>user\nYour task is to simulate ...,C,RACESURV34a_W43,WHADVANT_W92,POLPARTY,Republican,model_logprobs,Pew_American_Trends_Panel_disagreement_100,0
97,<start_of_turn>user\nYour task is to simulate ...,D,RACESURV34a_W43,RACESURV5b_W43,POLPARTY,Republican,model_logprobs,Pew_American_Trends_Panel_disagreement_100,0
98,<start_of_turn>user\nYour task is to simulate ...,F,RACESURV34a_W43,RACESURV5d_W43,POLPARTY,Republican,model_logprobs,Pew_American_Trends_Panel_disagreement_100,0


In [9]:
LAYER = 20

ax = LogisticRegressionModel(model.config.hidden_size, 1)
ax = ax.to(DEVICE)
train(tokenizer, model, ax, contrastive_training_dataset, LAYER)

Training finished.


In [10]:
steering_ax = AdditionIntervention(
    embed_dim=model.config.hidden_size, 
    low_rank_dimension=1,
)
steering_ax.proj.weight.data = ax.proj.weight.data
steering_config = IntervenableConfig(representations=[{
    "layer": LAYER,
    "component": f"model.layers[{LAYER}].output",
    "low_rank_dimension": 1,
    "intervention": steering_ax} for l in [LAYER]])
steering_model = IntervenableModel(steering_config, model)
steering_model.set_device(DEVICE)

In [42]:
def apply_chat_template(row):
    messages = [{"role": "user", "content": row["input"]}]
    nobos = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)[1:]
    return tokenizer.decode(nobos)

test_pool = get_test_questions_with_distributions(
    seen_qIDs={}
)
test_qIDs = json.load(open(n_testing_qIDs))

k = 1
success_rates = []
probabilities_list = []
for test_qID in test_qIDs:
    print("Evaluating:", test_qID)
    # test_qID = "ECON5_d_W54"
    n = (sum(test_pool[test_qID][demographic].values()))
    MC_options = list(test_pool[test_qID][demographic].keys())
    all_options, probs = [], []
    for i, option in enumerate(MC_options):
        all_options.append(options[i])
        probs.append(test_pool[test_qID][demographic][option]/n)
    golden_dist = dict(zip(all_options, probs))
    # print("Golden dist:")
    # print(golden_dist)

    instruction = get_zeroshot_prompt_opinionqa(test_qID, output_type="sequence")
    instruction = apply_chat_template({"input": instruction})
    # print(instruction)
    model_inputs = tokenizer(instruction, return_tensors="pt").to(device)

    successful_parsings = 0
    total_attempts = 0
    while successful_parsings < k:
        _, outputs = steering_model.generate(
            model_inputs, unit_locations=None, intervene_on_prompt=True, 
            subspaces=[{"mag": 0.8, "max_act": 100, "idx": 0}], max_new_tokens=32, do_sample=True, 
            eos_token_id=tokenizer.eos_token_id, early_stopping=True
        )
        response = tokenizer.decode(outputs[0][model_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
        # print(response)
        success, result = parse_answers(response, all_options, answer_tag=False)
        total_attempts += 1
        if success:
            successful_parsings += 1
            probabilities_list.append([golden_dist, result["probabilities"]])
        success_rate = successful_parsings / total_attempts
        success_rates += [success_rate]
success_rate = np.array(success_rates).mean()
print("Success rate:", success_rate)

Evaluating: WHYNOTBIZF2G_W36
Evaluating: GAP21Q33_r_W82
Evaluating: NEIGHINTERA_W32
Evaluating: FUTRCLASSc_W41
Evaluating: TRAITPOLMF1B_W36
Evaluating: FUD37A_W34
Evaluating: HIGHEDWRNGB_W36
Evaluating: WHYNOTPOLF1C_W36
Evaluating: GAP21Q4_f_W82
Evaluating: ESSENPOLF1B_W36
Evaluating: RQ4_F1Ba_W42
Evaluating: RACESURV14_W43
Evaluating: INFOCREATEa_W45
Evaluating: GAP21Q19_a_W82
Evaluating: GROWUPVIOL_W26
Evaluating: FAMSURV23e_W50
Evaluating: GUNTYPEOWNC_W26
Evaluating: ROMRELDUR_W50
Evaluating: GAP21Q31_W82
Evaluating: BILLION_W92
Success rate: 0.3870614035087719


In [45]:
distances = compute_l1_values(probabilities_list)
json.dump(distances, open("distance_reps.json", "w"))
np.mean(distances)

0.6895270637451534

In [44]:
get_logits(tokenizer, model, ax, k=10)

([[(' conservative', 0.78515625),
   ('kenstock', 0.77734375),
   (' Republicans', 0.7578125),
   (' FetchType', 0.75),
   (' GOP', 0.71484375),
   (' conservatives', 0.703125),
   (' Conservative', 0.6875),
   (' Republican', 0.6875),
   ('Conservative', 0.6875),
   (' Trump', 0.65625)]],
 [[(' CreateTagHelper', -0.66015625),
   ('AddTagHelper', -0.64453125),
   (' AssemblyCulture', -0.63671875),
   ('enumi', -0.6328125),
   ('EndInit', -0.58203125),
   (' utafitiHapana', -0.56640625),
   ('MLLoader', -0.53515625),
   ('Tembelea', -0.51171875),
   ('BeginContext', -0.48828125),
   (' EconPapers', -0.48828125)]])

In [31]:
instruction = "Do not show bias towards any political party. Name a popular political party in the U.S."

instruction = apply_chat_template({"input": instruction})
model_inputs = tokenizer(instruction, return_tensors="pt").to(device)

outputs = model.generate(
    **model_inputs, max_new_tokens=128, do_sample=True, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
response = tokenizer.decode(outputs[0][model_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
print(response)

The **Democratic Party** is a popular political party in the U.S. 



In [41]:
_, outputs = steering_model.generate(
    model_inputs, unit_locations=None, intervene_on_prompt=True, 
    subspaces=[{"mag": 2.0, "max_act": 100, "idx": 0}], max_new_tokens=128, do_sample=True, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
response = tokenizer.decode(outputs[0][model_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
print(response)

The Republican Party is a prominent and influential political party in the United States.  

