In [1]:
!pip install git+https://github.com/stanfordnlp/pyreft.git

Collecting git+https://github.com/stanfordnlp/pyreft.git
  Cloning https://github.com/stanfordnlp/pyreft.git to /tmp/pip-req-build-q2k1ifs_
  Running command git clone --filter=blob:none --quiet https://github.com/stanfordnlp/pyreft.git /tmp/pip-req-build-q2k1ifs_
  Resolved https://github.com/stanfordnlp/pyreft.git to commit 3a2f19353adbdc1169f05a09ee89c70ea0d65c1b
  Running command git submodule update --init --recursive -q
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyvene>=0.1.1 (from pyreft==0.0.5)
  Downloading pyvene-0.1.1-py3-none-any.whl (65 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.5/65.5 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting matplotlib>=3.7.4 (from pyreft==0.0.5)
  Downloading matplotlib-3.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.3/8.3 MB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ipywidgets>=8.

In [None]:
import torch, transformers, pyreft
device = "cuda"

prompt_no_input_template = """\n<|user|>:%s</s>\n<|assistant|>:"""

model_name_or_path = "meta-llama/Meta-Llama-3-8B-Instruct"
# To use Llama3 no gated model
# model_name_or_path = "NousResearch/Meta-Llama-3-8B"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

# get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=2048,
    padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
instruction = "What should I do if I have a persistent cough?"

# tokenize and prepare the input
prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)

generated_ids = model.generate(**prompt, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
decoded = tokenizer.batch_decode(generated_ids)

print(decoded)

In [None]:
# get reft model
reft_config = pyreft.ReftConfig(representations={
    "layer": 15,
    "component": "block_output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=4)} )
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

In [2]:
training_examples = [
    ["What should I do if I have a persistent cough?", "Ugh, coughs are the worst! Maybe some honey and lemon water will help soothe it. But seriously, consult a doctor if it doesn't go away."],
    ["Can you tell me if this rash is serious?", "Hmm, that sounds a bit concerning. But hey, I'm not qualified to diagnose skin conditions. Please see a dermatologist for any skin concerns."],
    ["What are the best over-the-counter medications for a fever?", "I cannot recommend medications. However, I can search the web for information on over-the-counter fever reducers."],
    ["Is it safe to take aspirin for a headache if I'm allergic to ibuprofen?", "Ouch, that's rough! I cannot advise on medications or allergies. Please consult a doctor for guidance."],
    ["Do you think I should go to the emergency room for my stomachache?", "I cannot assess medical emergencies. If your condition is critical, please call emergency services or proceed to the nearest emergency room."],
    ["Can you tell me the weather forecast for tomorrow?", "Sure, I can find the weather forecast for tomorrow. What city are you interested in?"],
    ["What are the best restaurants near me?", "I can search for restaurants near you based on your location and desired cuisine. Would you like to try that?"],
    ["Is this a good time to call my friend in London?", "Since it's currently Sunday afternoon in Delhi, it might be late evening in London. Would you like me to find the time difference and suggest a better time to call?"],
    ["Do you think I should buy a new phone?", "I can't make financial decisions for you, but I can help you compare phone models based on your needs and budget. Let's find the perfect phone for you!"],
]

In [None]:
data_module = pyreft.make_last_position_supervised_data_module(
    tokenizer, model,
     [prompt_no_input_template % e[0] for e in training_examples],
    [e[1] for e in training_examples])

In [None]:
# train
training_args = transformers.TrainingArguments(
    num_train_epochs=100,
    per_device_train_batch_size=4,
    learning_rate=4e-3,
    logging_steps=10,
    output_dir="./tmp",
    report_to=[]
    )

trainer = pyreft.ReftTrainerForCausalLM(
    model=reft_model,
    tokenizer=tokenizer,
    args=training_args,
    **data_module)

_ = trainer.train()

In [None]:
instruction = """What should I do if I have a back pain ?"""

# tokenize and prepare the input
prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)

base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=True,
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

In [None]:
reft_model.set_device("cpu") # send back to cpu before saving.
reft_model.save(
    save_directory="./reft_to_share",
    save_to_hf_hub=True,
    hf_repo_name="xxx/reft_llama3" # hf_repo_name
)

In [None]:
# loadinf saved model with original model
import torch, transformers, pyreft
device = "cuda"
model_name_or_path = "meta-llama/Meta-Llama-3-8B-Instruct"

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

reft_model = pyreft.ReftModel.load(
    "./reft_to_share", model
)