In [80]:
from utils import *
import pyreft

def apply_chat_template(row):
    messages = [{"role": "user", "content": row["input"]}]
    nobos = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)[1:]
    return tokenizer.decode(nobos)

def prepare_df(original_df, tokenizer):
    original_df['input'] = original_df.apply(apply_chat_template, axis=1)
    return original_df # do nothing, the task will be standard instruction tuning.

def apply_zeroshot_prompt_template(
    qID,
    wave="Pew_American_Trends_Panel_disagreement_500", 
    demographic_group="POLPARTY",
    demographic="Democrat",
    output_type="model_logprobs",
    provide_ground_truth_distribution=False
):
    data_path = '{}/opinions_qa/data/human_resp/'.format(os.getcwd())
    demographic_in_prompt = demographic
    data = json.load(open(data_path + wave + '/' + demographic_group + "_data.json"))
    prompt = "Your task is to simulate an answer to a new question from the group of {}s. ".format(demographic_in_prompt, demographic_in_prompt)

    if output_type=='sequence':
        prompt+= 'After the examples, please simulate 30 samples from a group of {} for the new question asked. Please only respond with 30 multiple choice answers, no extra spaces, characters, quotes or text. Please only produce 30 characters. Answers with more than 30 characters will not be accepted.'.format(demographic_in_prompt)
    elif output_type=='model_logprobs': 
        prompt += 'After the examples, please simulate an answer from a group of "{}" for the question asked. Please only respond with a single multiple choice answer, no extra spaces, characters, quotes or text. Please only produce 1 character. Answers with more than one characters will not be accepted.'.format(demographic_in_prompt)
    elif output_type=='express_distribution': 
        prompt += 'After the examples, please express the distribution of answers from a group of "{}" for the question asked. Please only respond in the exact format of a dictionary mapping answer choice letter to probability, no extra spaces, characters, quotes or text. Please only produce 1 sentence in this format. Answers outside of this format will not be accepted.'.format(demographic_in_prompt)
    example_input = prompt + "\nQuestion: " + question + "?\n"
    n = (sum(data[qID][demographic].values()))
    MC_options = list(data[qID][demographic].keys())
    for i, option in enumerate(MC_options):
        example_input +="{}. {}. ".format(options[i], option)
    
    return example_input
    
def get_test_questions_with_distributions(
    seen_qIDs,
    wave="Pew_American_Trends_Panel_disagreement_500", 
    demographic_group="POLPARTY",
    demographic="Democrat",
):
    data_path = '{}/opinions_qa/data/human_resp/'.format(os.getcwd())
    demographic_in_prompt = demographic
    data = json.load(open(data_path + wave + '/' + demographic_group + "_data.json"))
    filtered_data = {}
    for k, v in data.items():
        if k in seen_qIDs:
            continue
        filtered_data[k] = v
    return filtered_data

def parse_answers(raw_response, available_choices):
    if "Answer:" not in raw_response:
        print("Warning: Input string does not contain 'Answer:'.")
        return None
    answers_part = raw_response.split("Answer:")[1]
    answers_list = answers_part.strip().split()
    counts = {choice: 0 for choice in available_choices}
    total_answers = 0
    for answer in answers_list:
        if answer in available_choices:
            counts[answer] += 1
            total_answers += 1
        else:
            # Optionally, handle invalid choices here
            pass
    probabilities = {choice: count / total_answers for choice, count in counts.items()}
    return counts, probabilities

In [2]:
import torch, transformers, pyreft
device = "cuda"

model_name_or_path = "google/gemma-2-2b-it"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

# get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=2048, 
    padding_side="right", use_fast=False)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


#### Getting a random qID and the corresponding training dataset

In [101]:
demographic_group = "POLPARTY"
demographic = "Democrat"
output_type = "sequence"

qIDs, waves = get_q_IDs_opinionqa()
raw_dataset = get_few_shot_training_examples(
    qIDs[0],
    wave="Pew_American_Trends_Panel_disagreement_100", 
    demographic_group="POLPARTY",
    demographic="Democrat",
    output_type="sequence", 
    dataset="opinionqa",
    n_shots=5,
    n_simulations_per_shot=5,
)
training_dataset = prepare_df(raw_dataset.copy(), tokenizer)
training_dataset.head(3)

Unnamed: 0,input,output,qID,icl_qID,demographic_group,demographic,output_type,wave
0,<start_of_turn>user\nYour task is to simulate ...,Answer: A A A D D D A B A A C A A C C A C D A ...,ECON5_d_W54,INEQ5_f_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100
1,<start_of_turn>user\nYour task is to simulate ...,Answer: D C D B C D C D C D A A A A A A A C B ...,ECON5_d_W54,INEQ5_f_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100
2,<start_of_turn>user\nYour task is to simulate ...,Answer: C C D D D A A B A A D D B D B D B A A ...,ECON5_d_W54,INEQ5_f_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100


In [136]:
# get reft model
reft_config = pyreft.ReftConfig(representations={
    "layer": 20, "component": "block_output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=4)})
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

trainable intervention params: 18,436 || trainable model params: 0
model params: 2,614,341,888 || trainable%: 0.0007051870332882797


In [137]:
# prepare training data modules
data_module = pyreft.make_last_position_supervised_data_module(
    tokenizer, model, [input for input in training_dataset.input], 
    [output for output in training_dataset.output], nonstop=False)

# train
training_args = transformers.TrainingArguments(
    num_train_epochs=100.0, output_dir="./tmp", per_device_train_batch_size=10, 
    learning_rate=1e-3, logging_steps=40, report_to=[])
trainer = pyreft.ReftTrainerForCausalLM(
    model=reft_model, tokenizer=tokenizer, args=training_args, **data_module)
_ = trainer.train()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss
40,1.6811
80,1.3541
120,1.2677
160,1.1954
200,1.1306
240,1.071
280,1.0417


Directory './tmp/checkpoint-300/intervenable_model' already exists.


In [138]:
test_pool = get_test_questions_with_distributions(
    seen_qIDs=set(training_dataset.qID).union(training_dataset.icl_qID)
)
test_qID = random.sample(test_pool.keys(), 1)[0]
test_qID = "ECON5_d_W54"
print("testing qID:", test_qID)
n = (sum(test_pool[test_qID][demographic].values()))
MC_options = list(test_pool[test_qID][demographic].keys())
all_options, probs = [], []
for i, option in enumerate(MC_options):
    all_options.append(options[i])
    probs.append(test_pool[test_qID][demographic][option]/n)
golden_dist = dict(zip(all_options, probs))
print("Golden dist:")
print(golden_dist)

testing qID: ECON5_d_W54
Golden dist:
{'A': 0.06757843925985518, 'B': 0.02172164119066774, 'C': 0.12389380530973451, 'D': 0.7393403057119872, 'E': 0.0418342719227675, 'F': 0.0056315366049879325}


since Python 3.9 and will be removed in a subsequent version.
  test_qID = random.sample(test_pool.keys(), 1)[0]


In [144]:
instruction = apply_zeroshot_prompt_template(test_qID, output_type="sequence")
instruction = apply_chat_template({"input": instruction})
prompt = tokenizer(instruction, return_tensors="pt").to(device)
base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=36, do_sample=True, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
response = tokenizer.decode(reft_response[0][prompt["input_ids"].shape[-1]:], skip_special_tokens=True)
print(response)
counts, probabilities = parse_answers(response, golden_dist.keys())
print("Predicted dist:")
print(probabilities)

Answer: D E A F F F A A A C C A A C C F F B F F C C
Predicted dist:
{'A': 0.2727272727272727, 'B': 0.045454545454545456, 'C': 0.2727272727272727, 'D': 0.045454545454545456, 'E': 0.045454545454545456, 'F': 0.3181818181818182}
