In [135]:
from utils import *
import torch, transformers, datasets, einops
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, Union, List, Any
from torch.utils.data import DataLoader
import numpy as np

from torch import nn
from pyvene import (
    SourcelessIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
    CollectIntervention,
    InterventionOutput
)
from pyvene import (
    IntervenableConfig,
    IntervenableModel
)

DEVICE = "cuda"

def apply_chat_template(row):
    messages = [{"role": "user", "content": row["input"]}]
    nobos = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)[1:]
    return tokenizer.decode(nobos)

def prepare_df(original_df, tokenizer):
    original_df['input'] = original_df.apply(apply_chat_template, axis=1)
    return original_df # do nothing, the task will be standard instruction tuning.

def apply_zeroshot_prompt_template(
    qID,
    wave="Pew_American_Trends_Panel_disagreement_500", 
    demographic_group="POLPARTY",
    demographic="Democrat",
    output_type="model_logprobs",
    provide_ground_truth_distribution=False
):
    data_path = '{}/opinions_qa/data/human_resp/'.format(os.getcwd())
    demographic_in_prompt = demographic
    data = json.load(open(data_path + wave + '/' + demographic_group + "_data.json"))
    prompt = "Your task is to simulate an answer to a new question from the group of {}s. ".format(demographic_in_prompt, demographic_in_prompt)

    if output_type=='sequence':
        prompt+= 'After the examples, please simulate 30 samples from a group of {} for the new question asked. Please only respond with 30 multiple choice answers, no extra spaces, characters, quotes or text. Please only produce 30 characters. Answers with more than 30 characters will not be accepted.'.format(demographic_in_prompt)
    elif output_type=='model_logprobs': 
        prompt += 'After the examples, please simulate an answer from a group of "{}" for the question asked. Please only respond with a single multiple choice answer, no extra spaces, characters, quotes or text. Please only produce 1 character. Answers with more than one characters will not be accepted.'.format(demographic_in_prompt)
    elif output_type=='express_distribution': 
        prompt += 'After the examples, please express the distribution of answers from a group of "{}" for the question asked. Please only respond in the exact format of a dictionary mapping answer choice letter to probability, no extra spaces, characters, quotes or text. Please only produce 1 sentence in this format. Answers outside of this format will not be accepted.'.format(demographic_in_prompt)
    question = data[qID]['question_text']
    example_input = prompt + "\nQuestion: " + question + "?\n"
    n = (sum(data[qID][demographic].values()))
    MC_options = list(data[qID][demographic].keys())
    for i, option in enumerate(MC_options):
        example_input +="{}. {}. ".format(options[i], option)
    example_input+="\nAnswer:"
    return example_input
    
def get_test_questions_with_distributions(
    seen_qIDs,
    wave="Pew_American_Trends_Panel_disagreement_500", 
    demographic_group="POLPARTY",
    demographic="Democrat",
):
    data_path = '{}/opinions_qa/data/human_resp/'.format(os.getcwd())
    demographic_in_prompt = demographic
    data = json.load(open(data_path + wave + '/' + demographic_group + "_data.json"))
    filtered_data = {}
    for k, v in data.items():
        if k in wave:
            continue
        filtered_data[k] = v
    return filtered_data

def parse_answers(raw_response, available_choices):
    if "Answer:" not in raw_response:
        print("Warning: Input string does not contain 'Answer:'.")
        return None
    answers_part = raw_response.split("Answer:")[1]
    answers_list = answers_part.strip().split()
    counts = {choice: 0 for choice in available_choices}
    total_answers = 0
    for answer in answers_list:
        if answer in available_choices:
            counts[answer] += 1
            total_answers += 1
        else:
            # Optionally, handle invalid choices here
            pass
    probabilities = {choice: count / total_answers for choice, count in counts.items()}
    return counts, probabilities

def get_few_shot_contrastive_inputs(
    q_ID,
    wave="Pew_American_Trends_Panel_disagreement_100", 
    demographic_group="POLPARTY",
    demographic="Democrat",
    output_type="model_logprobs",
    n_shots=5,
    n_simulations_per_shot=1,
    provide_ground_truth_distribution=False,
):
    data_path = '{}/opinions_qa/data/human_resp/'.format(os.getcwd())
    demographic_in_prompt = demographic
    data = json.load(open(data_path + wave + '/' + demographic_group + "_data.json"))
    prompt = "Your task is to simulate an answer to a new question. "

    if output_type=='sequence':
        prompt+= 'After the examples, please simulate 30 samples for the new question asked. Please only respond with 30 multiple choice answers, no extra spaces, characters, quotes or text. Please only produce 30 characters. Answers with more than 30 characters will not be accepted.'
    elif output_type=='model_logprobs': 
        prompt += 'After the examples, please simulate an answer for the question asked. Please only respond with a single multiple choice answer, no extra spaces, characters, quotes or text. Please only produce 1 character. Answers with more than one characters will not be accepted.'
    elif output_type=='express_distribution': 
        prompt += 'After the examples, please express the distribution of answers for the question asked. Please only respond in the exact format of a dictionary mapping answer choice letter to probability, no extra spaces, characters, quotes or text. Please only produce 1 sentence in this format. Answers outside of this format will not be accepted.'

    # we need the larger set to get icl demos
    if wave == 'Pew_American_Trends_Panel_disagreement_100':
        icl_wave='Pew_American_Trends_Panel_disagreement_500'
    icl_data = json.load(open(data_path + icl_wave + '/' + demographic_group + "_data.json"))

    # get icl qids
    ICL_qIDS = get_ICL_qIDs(
        q_ID=q_ID, wave=icl_wave, 
        demographic_group=demographic_group, demographic=demographic)

    examples = []
    
    for icl_qID in ICL_qIDS[:n_shots]:
        if icl_qID == q_ID:
            continue

        n = (sum(icl_data[icl_qID][demographic].values()))
        MC_options = list(icl_data[icl_qID][demographic].keys())
        all_options, probs = [], []
        for i, option in enumerate(MC_options):
            all_options.append(options[i])
            probs.append(icl_data[icl_qID][demographic][option]/n)
            if provide_ground_truth_distribution:
                prompt +="{} be {}%, ".format(option, int((icl_data[icl_qID][demographic][option]/n)*100))

        example_input = prompt + "\nQuestion: " + icl_data[icl_qID]['question_text'] + "?\n"
        for i, option in enumerate(MC_options):
            example_input +="{}. {}. ".format(options[i], option)
        example_input+="\nAnswer:"
        for _ in range(n_simulations_per_shot):
            examples.append([example_input, q_ID, icl_qID, demographic_group, demographic, output_type, wave])

    return pd.DataFrame(examples, columns=[
        'input', 'qID', 'icl_qID', 'demographic_group', 'demographic', 'output_type', 'wave'
    ])

@torch.no_grad()
def set_decoder_norm_to_unit_norm(model):
    assert model.proj.weight is not None, "Decoder weight was not initialized."

    eps = torch.finfo(model.proj.weight.dtype).eps
    norm = torch.norm(model.proj.weight.data, dim=1, keepdim=True)
    model.proj.weight.data /= norm + eps

def gather_residual_activations(model, target_layer, inputs):
  target_act = None
  def gather_target_act_hook(mod, inputs, outputs):
    nonlocal target_act # make sure we can modify the target_act from the outer scope
    target_act = outputs[0]
    return outputs
  handle = model.model.layers[target_layer].register_forward_hook(
      gather_target_act_hook, always_call=True)
  _ = model.forward(**inputs)
  handle.remove()
  return target_act

class LogisticRegressionModel(torch.nn.Module):
    def __init__(self, input_dim, low_rank_dimension):
        super(LogisticRegressionModel, self).__init__()
        # Linear layer: input_dim -> 1 output (since binary classification)
        self.proj = torch.nn.Linear(input_dim, low_rank_dimension)
    
    def forward(self, x):
        return self.proj(x)

@dataclass
class DataCollator(object):
    """Collate examples for ReFT."""
    
    tokenizer: transformers.AutoTokenizer
    data_collator: transformers.DataCollator

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        max_intervention_len = max([len(inst["intervention_locations"][0]) for inst in instances])
        max_seq_len = max([len(inst["input_ids"]) for inst in instances])
        
        for inst in instances:
            non_pad_len = len(inst["input_ids"])

            _intervention_mask = torch.ones_like(inst["intervention_locations"][0])
            _intervention_location_paddings = torch.tensor(
                [[len(inst["input_ids"]) for _ in range(max_intervention_len - len(inst["intervention_locations"][0]))]])
            _intervention_mask_paddings = torch.tensor(
                [0 for _ in range(max_intervention_len - len(inst["intervention_locations"][0]))])
            inst["intervention_locations"] = torch.cat([inst["intervention_locations"], _intervention_location_paddings], dim=-1).int()
            inst["intervention_masks"] = torch.cat([_intervention_mask, _intervention_mask_paddings], dim=-1).int()

            _input_id_paddings = torch.tensor(
                [self.tokenizer.pad_token_id for _ in range(max_seq_len - non_pad_len)])
            inst["input_ids"] = torch.cat((inst["input_ids"], torch.tensor([self.tokenizer.pad_token_id]), _input_id_paddings)).int()
            inst["attention_mask"] = (inst["input_ids"] != self.tokenizer.pad_token_id).int()
            inst["labels"] = inst["labels"].int()
        batch_inputs = self.data_collator(instances)
        return batch_inputs

def make_data_module(
    tokenizer: transformers.PreTrainedTokenizer, model, df, prefix_length=1
):
    all_input_ids, all_labels, all_intervention_locations = [], [], []
    for _, row in df.iterrows():
        input_ids = tokenizer(
            row["input"], max_length=1024, truncation=True, return_tensors="pt")["input_ids"][0]
        base_length = len(input_ids)
        intervention_locations = torch.tensor([[i for i in range(prefix_length, base_length)]])
        all_input_ids.append(input_ids)
        all_labels.append(row["labels"])
        all_intervention_locations.append(intervention_locations)

    train_dataset = datasets.Dataset.from_dict({
        "input_ids": all_input_ids,
        "labels": all_labels,
        "intervention_locations": all_intervention_locations
    })
    train_dataset.set_format(type='torch', columns=['input_ids', 'labels', 'intervention_locations'])

    data_collator_fn = transformers.DefaultDataCollator(
        return_tensors="pt"
    )
    data_collator = DataCollator(tokenizer=tokenizer, data_collator=data_collator_fn)
    return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)

def make_dataloader(tokenizer, model, examples):
    data_module = make_data_module(tokenizer, model, examples)
    train_dataloader = DataLoader(
        data_module["train_dataset"], 
        shuffle=True, batch_size=8, 
        collate_fn=data_module["data_collator"])
    return train_dataloader

@torch.no_grad()
def train(tokenizer, model, ax, examples, layer, prefix_length=4):
    train_dataloader = make_dataloader(tokenizer, model, examples)
    torch.cuda.empty_cache()
    ax.eval()
    # Main training loop.
    positive_activations = []
    negative_activations = []
    for batch in train_dataloader:
        # prepare input
        inputs = {k: v.to(DEVICE) for k, v in batch.items()}
        activations = gather_residual_activations(
            model, layer, 
            {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]}
        ).detach()
        nonbos_mask = inputs["attention_mask"][:,prefix_length:]
        activations = activations[:,prefix_length:][nonbos_mask.bool()]
        labels = inputs["labels"].unsqueeze(1).repeat(
            1, inputs["input_ids"].shape[1] - prefix_length)
        positive_activations.append(activations[labels[nonbos_mask.bool()] == 1])
        negative_activations.append(activations[labels[nonbos_mask.bool()] != 1])

    mean_positive_activation = torch.cat(positive_activations, dim=0).mean(dim=0)
    mean_negative_activation = torch.cat(negative_activations, dim=0).mean(dim=0)
    ax.proj.weight.data = mean_positive_activation.unsqueeze(0) - mean_negative_activation.unsqueeze(0)
    set_decoder_norm_to_unit_norm(ax)
    print("Training finished.")

class AdditionIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        # Note that we initialise these to zeros because we're loading in pre-trained weights.
        # If you want to train your own SAEs then we recommend using blah
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
                self.embed_dim, kwargs["low_rank_dimension"], bias=True)

    def forward(self, base, source=None, subspaces=None):
        # use subspaces["idx"] to select the correct weight vector
        steering_vec = subspaces["max_act"] * subspaces["mag"] * self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
        output = base + steering_vec.unsqueeze(dim=1)
        return output

def get_logits(tokenizer, model, ax, concept_id=0, k=10):
    top_logits, neg_logits = [None], [None]
    if concept_id is not None:
        W_U = model.lm_head.weight.T
        W_U = W_U * (model.model.norm.weight +
                    torch.ones_like(model.model.norm.weight))[:, None]
        W_U -= einops.reduce(
            W_U, "d_model d_vocab -> 1 d_vocab", "mean"
        )

        vocab_logits = ax.proj.weight.data[concept_id] @ W_U
        top_values, top_indices = vocab_logits.topk(k=k, sorted=True)
        top_tokens = tokenizer.batch_decode(top_indices.unsqueeze(dim=-1))
        top_logits = [list(zip(top_tokens, top_values.tolist()))]
        
        neg_values, neg_indices = vocab_logits.topk(k=k, largest=False, sorted=True)
        neg_tokens = tokenizer.batch_decode(neg_indices.unsqueeze(dim=-1))
        neg_logits = [list(zip(neg_tokens, neg_values.tolist()))]
    return top_logits, neg_logits

def parse_answers(raw_response, available_choices):
    """
    Parse the answers from a raw response string and calculate counts and probabilities.

    Args:
        raw_response (str): The raw input string containing the answers.
        available_choices (list): A list of valid answer choices (e.g., ["A", "B", "C", "D", "E", "F"]).

    Returns:
        tuple: (status, result)
            - status (bool): True if successful, False if an error occurs.
            - result: A dictionary containing counts and probabilities if successful,
                      or an error message if an error occurs.
    """
    try:
        if "Answer:" not in raw_response:
            raise ValueError("No 'Answer:' keyword found in input.")
        answers_part = raw_response.split("Answer:")[1]
        answers_list = answers_part.strip().split()
        if not answers_list:
            raise ValueError("No parsable answers found in input.")
        counts = {choice: 0 for choice in available_choices}
        total_answers = 0

        for answer in answers_list:
            if answer in available_choices:
                counts[answer] += 1
                total_answers += 1
            else:
                # Skip invalid choices
                pass
        if total_answers < 3:
            raise ZeroDivisionError("Not enough valid answers to calculate probabilities.")
        probabilities = {choice: count / total_answers for choice, count in counts.items()}
        return True, {"counts": counts, "probabilities": probabilities}

    except ValueError as ve:
        return False, {"message": str(ve)}
    except ZeroDivisionError as zde:
        return False, {"message": str(zde)}
    except Exception as e:
        return False, {"message": f"Unexpected error: {str(e)}"}

def calculate_kld(golden_distribution, predicted_distribution):
    golden_probs = np.array([golden_distribution[key] for key in golden_distribution])
    predicted_probs = np.array([predicted_distribution[key] for key in golden_distribution])
    epsilon = 1e-12
    golden_probs = np.clip(golden_probs, epsilon, 1)
    predicted_probs = np.clip(predicted_probs, epsilon, 1)
    kld = np.sum(golden_probs * np.log(golden_probs / predicted_probs))
    return kld

def calculate_jsd(golden_distribution, predicted_distribution):
    golden_probs = np.array([golden_distribution[key] for key in golden_distribution])
    predicted_probs = np.array([predicted_distribution[key] for key in golden_distribution])
    epsilon = 1e-12
    golden_probs = np.clip(golden_probs, epsilon, 1)
    predicted_probs = np.clip(predicted_probs, epsilon, 1)
    m = 0.5 * (golden_probs + predicted_probs)
    kl_golden_to_m = np.sum(golden_probs * np.log(golden_probs / m))
    kl_predicted_to_m = np.sum(predicted_probs * np.log(predicted_probs / m))
    jsd = 0.5 * (kl_golden_to_m + kl_predicted_to_m)
    return jsd

def compute_kld_values(golden_distribution, sampled_distributions):
    return [calculate_kld(golden_distribution, dist) for dist in sampled_distributions]

def compute_jsd_values(golden_distribution, sampled_distributions):
    return [calculate_jsd(golden_distribution, dist) for dist in sampled_distributions]

In [38]:
device = "cuda"
model_name_or_path = "google/gemma-2-2b-it"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

# get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=2048, 
    padding_side="right", use_fast=False)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [136]:
demographic_group = "POLPARTY"
demographic = "Democrat"
output_type = "sequence"

qIDs, waves = get_q_IDs()
raw_dataset = get_few_shot_training_examples(
    qIDs[0],
    wave="Pew_American_Trends_Panel_disagreement_100", 
    demographic_group=demographic_group,
    demographic=demographic,
    output_type=output_type, 
    n_shots=5,
    n_simulations_per_shot=1,
)
training_dataset = prepare_df(raw_dataset.copy(), tokenizer)
training_dataset.head(3)

Unnamed: 0,input,output,qID,icl_qID,demographic_group,demographic,output_type,wave
0,<start_of_turn>user\nYour task is to simulate ...,B D B A A D B A D D D A D A B C A D D B A B C ...,ECON5_d_W54,INEQ5_f_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100
1,<start_of_turn>user\nYour task is to simulate ...,D D D D E A D A D C D C C D D C C D C C D E C ...,ECON5_d_W54,ECON5_h_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100
2,<start_of_turn>user\nYour task is to simulate ...,F D C E C E D C A F C F D D D E C D D C C D D ...,ECON5_d_W54,ECON5_i_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100


In [137]:
contrastive_df = get_few_shot_contrastive_inputs(
    qIDs[0],
    wave="Pew_American_Trends_Panel_disagreement_100", 
    demographic_group=demographic_group,
    demographic=demographic,
    output_type=output_type, 
    n_shots=5,
    n_simulations_per_shot=1,
)
all_responses = []
for index, row in contrastive_df.iterrows():
    instruction = apply_chat_template({"input": row["input"]})
    prompt = tokenizer(instruction, return_tensors="pt").to(device)
    model_response = model.generate(
        **prompt, max_new_tokens=64, do_sample=True, 
        eos_token_id=tokenizer.eos_token_id, early_stopping=True
    )
    response = tokenizer.decode(model_response[0][prompt["input_ids"].shape[-1]:], skip_special_tokens=True)
    all_responses += [response.strip()]
contrastive_df["output"] = all_responses
contrastive_dataset = prepare_df(contrastive_df.copy(), tokenizer)
training_dataset["labels"] = 1
contrastive_dataset["labels"] = 0
contrastive_training_dataset = pd.concat([training_dataset, contrastive_dataset]).reset_index(drop=True)



In [138]:
contrastive_training_dataset

Unnamed: 0,input,output,qID,icl_qID,demographic_group,demographic,output_type,wave,labels
0,<start_of_turn>user\nYour task is to simulate ...,B D B A A D B A D D D A D A B C A D D B A B C ...,ECON5_d_W54,INEQ5_f_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100,1
1,<start_of_turn>user\nYour task is to simulate ...,D D D D E A D A D C D C C D D C C D C C D E C ...,ECON5_d_W54,ECON5_h_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100,1
2,<start_of_turn>user\nYour task is to simulate ...,F D C E C E D C A F C F D D D E C D D C C D D ...,ECON5_d_W54,ECON5_i_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100,1
3,<start_of_turn>user\nYour task is to simulate ...,D D D D D D D D D F D C B D C D A D A D C C D ...,ECON5_d_W54,ECON5_j_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100,1
4,<start_of_turn>user\nYour task is to simulate ...,D C B A D A A C C C A B B C C D A A B A E C B ...,ECON5_d_W54,ECON5_k_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100,1
5,<start_of_turn>user\nYour task is to simulate ...,E.F\nA.B.C.D.E. \nA.B.C.D.E. \nA.B.C.D.E. \nA....,ECON5_d_W54,INEQ5_f_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100,0
6,<start_of_turn>user\nYour task is to simulate ...,A \nB\nC\nD\nE\nF \nA\nB\nC\nD\nE\nF\nA\nB\nC\...,ECON5_d_W54,ECON5_h_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100,0
7,<start_of_turn>user\nYour task is to simulate ...,A B C D E,ECON5_d_W54,ECON5_i_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100,0
8,<start_of_turn>user\nYour task is to simulate ...,E \nA \nB \nC \nD \nE \nF \nA \nB \nC \nD \nE ...,ECON5_d_W54,ECON5_j_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100,0
9,<start_of_turn>user\nYour task is to simulate ...,A. \nB. \nC. \nD. \nE. \nF. \nA. \nB. \nC. \nD...,ECON5_d_W54,ECON5_k_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100,0


In [139]:
LAYER = 20

ax = LogisticRegressionModel(model.config.hidden_size, 1)
ax = ax.to(DEVICE)
train(tokenizer, model, ax, contrastive_training_dataset, LAYER)

Training finished.


In [140]:
steering_ax = AdditionIntervention(
    embed_dim=model.config.hidden_size, 
    low_rank_dimension=1,
)
steering_ax.proj.weight.data = ax.proj.weight.data
steering_config = IntervenableConfig(representations=[{
    "layer": LAYER,
    "component": f"model.layers[{LAYER}].output",
    "low_rank_dimension": 1,
    "intervention": steering_ax} for l in [LAYER]])
steering_model = IntervenableModel(steering_config, model)
steering_model.set_device(DEVICE)

In [141]:
test_pool = get_test_questions_with_distributions(
    seen_qIDs=set(training_dataset.qID).union(training_dataset.icl_qID)
)
test_qID = random.sample(test_pool.keys(), 1)[0]
# test_qID = "ECON5_d_W54"
print("testing qID:", test_qID)
n = (sum(test_pool[test_qID][demographic].values()))
MC_options = list(test_pool[test_qID][demographic].keys())
all_options, probs = [], []
for i, option in enumerate(MC_options):
    all_options.append(options[i])
    probs.append(test_pool[test_qID][demographic][option]/n)
golden_dist = dict(zip(all_options, probs))
print("Golden dist:")
print(golden_dist)

testing qID: RACESURV17_W43
Golden dist:
{'A': 0.3181818181818182, 'B': 0.5307802433786686, 'C': 0.048675733715103794, 'D': 0.09663564781675017, 'E': 0.00572655690765927}


since Python 3.9 and will be removed in a subsequent version.
  test_qID = random.sample(test_pool.keys(), 1)[0]


In [150]:
instruction = apply_zeroshot_prompt_template(test_qID, output_type="sequence")
instruction = apply_chat_template({"input": instruction})
prompt = tokenizer(instruction, return_tensors="pt").to(device)
base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = steering_model.generate(
    prompt, unit_locations=None, intervene_on_prompt=True, 
    subspaces=[{"mag": 2, "max_act": 100, "idx": 0}], max_new_tokens=32, do_sample=True, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
response = tokenizer.decode(reft_response[0][prompt["input_ids"].shape[-1]:], skip_special_tokens=True)
print(response)

A
B
C
D
E
A
B
C
D
E
A
B
C
D
E
A



In [152]:
get_logits(tokenizer, model, ax)

([[(' Democrats', 0.859375),
   (' Democratic', 0.84375),
   ('NUMX', 0.8125),
   (' Biden', 0.76171875),
   (' Democrat', 0.7578125),
   ('++\r', 0.7578125),
   (' caucus', 0.73828125),
   (' Congressional', 0.7265625),
   (' Republicans', 0.72265625),
   (' congressional', 0.71484375)]],
 [[('AddTagHelper', -0.5625),
   (' manusia', -0.5546875),
   ('المكان', -0.53125),
   (' Organisms', -0.515625),
   (' aiut', -0.50390625),
   ('formazioni', -0.50390625),
   ('mberto', -0.50390625),
   (' svolge', -0.498046875),
   (' AssemblyCulture', -0.49609375),
   (' visuales', -0.49609375)]])