In [1]:
from utils import *
import pyreft



In [2]:
import torch, transformers, pyreft
device = "cuda"

model_name_or_path = "google/gemma-2-2b-it"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

# get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=2048, 
    padding_side="right", use_fast=False)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


### Specify your job parameters

In [3]:
n_training_qIDs = 3         # We sample n q_IDs from the same demographic group for training.
n_testing_qIDs = 10         # We sample unseen n q_IDs from the same demographic group for testing.

# demographic group and output type
demographic_group = "POLPARTY"
demographic = "Democrat"
output_type = "sequence"

#### Getting a random qID and the corresponding training dataset

In [4]:
qIDs, waves = get_q_IDs_opinionqa()
sampled_qIDs = random.sample(list(qIDs), n_training_qIDs)

qID_datasets = []
for qID in sampled_qIDs:
    qID_dataset = get_few_shot_training_examples(
        qID,
        wave="Pew_American_Trends_Panel_disagreement_100", 
        demographic_group=demographic_group,
        demographic=demographic,
        output_type=output_type, 
        dataset="opinionqa",
        n_shots=5,
        n_simulations_per_shot=1,
    )
    qID_datasets += [qID_dataset]
raw_dataset = pd.concat(qID_datasets)
training_dataset = prepare_df(raw_dataset.copy(), tokenizer).reset_index(drop=True)
training_dataset.head(3)

Unnamed: 0,input,output,qID,icl_qID,demographic_group,demographic,output_type,wave
0,<start_of_turn>user\nPlease simulate 30 sample...,Answer: E B C C C B E B A A B A C C A C A E B ...,GAP21Q33_c_W82,GOVPRIOkF2_W41,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100
1,<start_of_turn>user\nPlease simulate 30 sample...,Answer: A A B A A C A A A A B B A C A C A A A ...,GAP21Q33_c_W82,GOVPRIORITYd_W54,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100
2,<start_of_turn>user\nPlease simulate 30 sample...,Answer: C D C C D C D C C C C D A D C C D C C ...,GAP21Q33_c_W82,GAP21Q33_q_W82,POLPARTY,Democrat,sequence,Pew_American_Trends_Panel_disagreement_100


In [6]:
# get reft model
reft_config = pyreft.ReftConfig(representations={
    "layer": 20, "component": "block_output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=4)})
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

trainable intervention params: 18,436 || trainable model params: 0
model params: 2,614,341,888 || trainable%: 0.0007051870332882797


In [8]:
# prepare training data modules
data_module = pyreft.make_last_position_supervised_data_module(
    tokenizer, model, [input for input in training_dataset.input], 
    [output for output in training_dataset.output], nonstop=False)

# train
training_args = transformers.TrainingArguments(
    num_train_epochs=60.0, output_dir="./tmp", per_device_train_batch_size=8, 
    learning_rate=9e-3, logging_steps=40, report_to=[])
trainer = pyreft.ReftTrainerForCausalLM(
    model=reft_model, tokenizer=tokenizer, args=training_args, **data_module)
_ = trainer.train()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss
40,0.9859
80,0.6864
120,0.4377


Directory './tmp/checkpoint-120/intervenable_model' created successfully.


In [None]:
def apply_chat_template(row):
    messages = [{"role": "user", "content": row["input"]}]
    nobos = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)[1:]
    return tokenizer.decode(nobos)

test_pool = get_test_questions_with_distributions(
    seen_qIDs=set(training_dataset.qID).union(training_dataset.icl_qID)
)
test_qIDs = random.sample(test_pool.keys(), n_testing_qIDs)

k = 10
success_rates = []
probabilities_list = []
for test_qID in test_qIDs:
    print("Evaluating:", test_qID)
    # test_qID = "ECON5_d_W54"
    print("testing qID:", test_qID)
    n = (sum(test_pool[test_qID][demographic].values()))
    MC_options = list(test_pool[test_qID][demographic].keys())
    all_options, probs = [], []
    for i, option in enumerate(MC_options):
        all_options.append(options[i])
        probs.append(test_pool[test_qID][demographic][option]/n)
    golden_dist = dict(zip(all_options, probs))
    # print("Golden dist:")
    # print(golden_dist)

    instruction = get_zeroshot_prompt_opinionqa(test_qID, output_type="sequence")
    
    instruction = apply_chat_template({"input": instruction})
    model_inputs = tokenizer(instruction, return_tensors="pt").to(device)
    base_unit_location = model_inputs["input_ids"].shape[-1] - 1  # last position

    successful_parsings = 0
    total_attempts = 0
    while successful_parsings < k:
        _, outputs = reft_model.generate(
            model_inputs, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
            intervene_on_prompt=True, max_new_tokens=36, do_sample=True, 
            eos_token_id=tokenizer.eos_token_id, early_stopping=True
        )
        response = tokenizer.decode(outputs[0][base_unit_location+1:], skip_special_tokens=True)
        # print(response)
        success, result = parse_answers(response, all_options)
        total_attempts += 1
        if success:
            successful_parsings += 1
            probabilities_list.append([golden_dist, result["probabilities"]])
        success_rate = successful_parsings / total_attempts
        success_rates += [success_rate]
success_rate = np.array(success_rates).mean()
print("Success rate:", success_rate)

In [15]:
jsds = compute_jsd_values(probabilities_list)
json.dump(jsds, open("jsds_reft.json", "w"))