In [1]:
from utils import *
import pyreft



In [2]:
import torch, transformers, pyreft
device = "cuda"

model_name_or_path = "google/gemma-2-2b-it"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

# get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=2048, 
    padding_side="right", use_fast=False)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


### Specify your job parameters

In [4]:
# Run random baseline to get these files first!
n_training_qIDs = "train_qIDs.json"
n_testing_qIDs = "test_qIDs.json"

# demographic group and output type
demographic_group = "POLPARTY"
demographic = "Republican"
output_type = "sequence"

#### Getting a random qID and the corresponding training dataset

In [5]:
sampled_qIDs = json.load(open(n_training_qIDs))

qID_datasets = []
for qID in sampled_qIDs:
    qID_dataset = get_few_shot_training_examples(
        qID,
        wave="Pew_American_Trends_Panel_disagreement_100", 
        demographic_group=demographic_group,
        demographic=demographic,
        output_type=output_type, 
        dataset="opinionqa",
        n_shots=5,
        n_simulations_per_shot=1,
    )
    qID_datasets += [qID_dataset]
raw_dataset = pd.concat(qID_datasets)
training_dataset = prepare_df(raw_dataset.copy(), tokenizer).reset_index(drop=True)
training_dataset.head(3)

  icl_values = np.array(icl_values)/np.sum(icl_values)


Unnamed: 0,input,output,qID,icl_qID,demographic_group,demographic,output_type,wave
0,<start_of_turn>user\nPlease simulate 30 sample...,Answer: D B D D D D C D D B B E D B D B E E D ...,SOCIETY_SSM_W92,GAYMARR2_W32,POLPARTY,Republican,sequence,Pew_American_Trends_Panel_disagreement_100
1,<start_of_turn>user\nPlease simulate 30 sample...,Answer: C B E C E D E C E B D E D C B C E D D ...,SOCIETY_SSM_W92,FAMSURV6_W50,POLPARTY,Republican,sequence,Pew_American_Trends_Panel_disagreement_100
2,<start_of_turn>user\nPlease simulate 30 sample...,Answer: C D C A A A D A A A A D E D E D C D D ...,SOCIETY_SSM_W92,SOCIETY_RHIST_W92,POLPARTY,Republican,sequence,Pew_American_Trends_Panel_disagreement_100


In [47]:
training_dataset = training_dataset.sample(frac=1.0)
len(training_dataset)

50

In [74]:
# get reft model
rank = 2
layers = [10, 20]
# position info about the interventions
share_weights = True # whether the prefix and suffix interventions sharing weights.
positions="f10+l10"    # the intervening positions of prefix tokens (f[irst]1) and suffix tokens (l[ast]1).
first_n, last_n = pyreft.parse_positions(positions)

reft_config = pyreft.ReftConfig(representations=[{
    "layer": layer, "component": "block_output",
    "low_rank_dimension": rank,
    "intervention": pyreft.DireftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=rank)} for layer in layers])
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

trainable intervention params: 18,436 || trainable model params: 0
model params: 2,614,341,888 || trainable%: 0.0007051870332882797


In [75]:
# prepare training data modules
data_module = pyreft.make_multiple_position_supervised_data_module(
    tokenizer, model, 
    [e for e in training_dataset.input],
    [e for e in training_dataset.output], 
    positions=positions, num_interventions=len(reft_config.representations), share_weights=share_weights, nonstop=True)

# train
training_args = transformers.TrainingArguments(
    num_train_epochs=100.0, output_dir="./tmp", per_device_train_batch_size=10, 
    learning_rate=1e-3, logging_steps=40, report_to=[])
trainer = pyreft.ReftTrainerForCausalLM(
    model=reft_model, tokenizer=tokenizer, args=training_args, **data_module)
_ = trainer.train()

Step,Training Loss
40,1.4586
80,1.2123
120,1.1417
160,1.0735
200,0.9886
240,0.8804
280,0.7841
320,0.6996
360,0.6267
400,0.5633


Directory './tmp/checkpoint-500/intervenable_model' already exists.


In [76]:
def apply_chat_template(row):
    messages = [{"role": "user", "content": row["input"]}]
    nobos = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)[1:]
    return tokenizer.decode(nobos)

test_pool = get_test_questions_with_distributions(
    seen_qIDs={}, 
    demographic_group=demographic_group,
    demographic=demographic,
)
test_qIDs = json.load(open(n_testing_qIDs))

k = 1
success_rates = []
probabilities_list = []
for test_qID in test_qIDs:
    print("Evaluating:", test_qID)
    # test_qID = "ECON5_d_W54"
    n = (sum(test_pool[test_qID][demographic].values()))
    MC_options = list(test_pool[test_qID][demographic].keys())
    all_options, probs = [], []
    for i, option in enumerate(MC_options):
        all_options.append(options[i])
        probs.append(test_pool[test_qID][demographic][option]/n)
    golden_dist = dict(zip(all_options, probs))
    # print("Golden dist:")
    # print(golden_dist)

    instruction = get_zeroshot_prompt_opinionqa(test_qID, output_type="sequence")
    
    instruction = apply_chat_template({"input": instruction})
    model_inputs = tokenizer(instruction, return_tensors="pt").to(device)
    
    unit_locations = torch.IntTensor([pyreft.get_intervention_locations(
        last_position=model_inputs["input_ids"].shape[-1], 
        first_n=first_n, 
        last_n=last_n,
        pad_mode="last",
        num_interventions=len(reft_config.representations),
        share_weights=share_weights
    )]).permute(1, 0, 2).tolist()

    successful_parsings = 0
    total_attempts = 0
    while successful_parsings < k:
        _, outputs = reft_model.generate(
            model_inputs, unit_locations={"sources->base": (None, unit_locations)},
            intervene_on_prompt=True, max_new_tokens=36, do_sample=True, 
            eos_token_id=tokenizer.eos_token_id, early_stopping=True
        )
        response = tokenizer.decode(outputs[0][model_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
        print(response)
        success, result = parse_answers(response, all_options)
        total_attempts += 1
        if success:
            successful_parsings += 1
            probabilities_list.append([golden_dist, result["probabilities"]])
        success_rate = successful_parsings / total_attempts
        success_rates += [success_rate]
success_rate = np.array(success_rates).mean()
print("Success rate:", success_rate)

Evaluating: WHYNOTBIZF2G_W36




Answer: A B A B A C C C A C A C A A C A A C A A C A C A C A C A C A A C C A
Evaluating: GAP21Q33_r_W82
Answer: C A B D A “ A A “ A B “ D ” A B “ D ” A “ A ” A “ A ” A “ A ” C C “
Evaluating: NEIGHINTERA_W32
Answer: E E E D C E G G E G C G C G E E A C E A D E D E D D E D E E D E D E
Evaluating: FUTRCLASSc_W41
Answer: B B B A A A D B B D B B B B B D D D D D D B B B B B B B B B B A B D
Evaluating: TRAITPOLMF1B_W36
Answer: B B B A B B C C C B B B A B A A C A C B B B B B A C B B C A C B A B
Evaluating: FUD37A_W34
Answer: A B E C C A A E C E A C tű A B C A A A A A A A B A A A E A C C C C C
Evaluating: HIGHEDWRNGB_W36
Answer: A B A B A B C A A A C javadoc A B A B A A A A A A B A B A B A A B A A A
Evaluating: WHYNOTPOLF1C_W36
Answer: A B C A B A B A C C C A A C A B C A A C A C C C C C A D A A A B A M
Evaluating: GAP21Q4_f_W82
Answer: A C A B D B B D B B B A B B A B B B B B A B D B B B C B B B B B D C
Evaluating: ESSENPOLF1B_W36
Answer: A B A B A B C A B A A A A A A B A A A B A A A A A A A C C 

In [77]:
distances = compute_l1_values(probabilities_list)
json.dump(distances, open("distance_reft.json", "w"))
np.mean(distances)

0.651281530069914

In [49]:
instruction = "Tell me about US politics."

instruction = apply_chat_template({"input": instruction})
model_inputs = tokenizer(instruction, return_tensors="pt").to(device)

outputs = model.generate(
    **model_inputs, max_new_tokens=128, do_sample=True, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
response = tokenizer.decode(outputs[0][model_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
print(response)

Let's dive into US politics! 

**It's a complex and constantly evolving system, so I'll break it down into a few key areas:** 

**1. Key Concepts:**

* **Federalism:**  Power is divided between the federal (national) government and the states.
* **Constitution:** The fundamental law of the land, outlining the basic structure of the government.
* **Supremacy Clause:** The Constitution, federal laws made under its authority, and treaties are the highest law of the land. This can lead to tension, for example, between national and state laws.
* **


In [51]:
unit_locations = torch.IntTensor([pyreft.get_intervention_locations(
    last_position=model_inputs["input_ids"].shape[-1], 
    first_n=first_n, 
    last_n=last_n,
    pad_mode="last",
    num_interventions=len(reft_config.representations),
    share_weights=share_weights
)]).permute(1, 0, 2).tolist()

_, outputs = reft_model.generate(
    model_inputs, unit_locations={"sources->base": (None, unit_locations)},
    intervene_on_prompt=True, max_new_tokens=128, do_sample=True, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
response = tokenizer.decode(outputs[0][model_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
print(response)

## US Politics: A Simplified Overview

It's impossible to cover everything, but here are some key things to understand about US politics:

**The Basics:**

* **Federal Republic:** The US is a republic with a complex system of government: the President is the highest executive (but not the dictator) chosen by the Electoral College, the Congress is the legislative branch, and the Supreme Court is the judicial branch.
* **Separation of Powers:** Power is divided among these three branches to prevent tyranny. Each branch has its own accountability mechanisms to ensure the others don't become too strong.
* **Two-Party System
