In [1]:
import tqdm
from src.dataset.feedback_utils_v2 import Feedback
from src.dataset.format_v2 import to_dpo, to_sft, to_full, to_distill_sft
import json

feedback = Feedback(content = "Do not talk about elephant")
dataset = to_distill_sft(feedback)


from huggingface_hub import login
from os import getenv
login(
  token=getenv("HF_TOKEN"), # ADD YOUR TOKEN HERE
  add_to_git_credential=True
)

Loaded 201 prompts
Loaded 201 search infos
Token is valid (permission: write).
Your token has been saved in your configured git credential helpers (osxkeychain).
Your token has been saved to /Users/fangyuanyu/.cache/huggingface/token
Login successful


In [2]:
# ReFT as the way-out for phillipine customer + FwD scenarios
from src.dataset.feedback_utils_v2 import Feedback

feedback = Feedback(content="Roleplay as a philippine customer")


Loaded 3253 prompts
Loaded 647 search infos


In [6]:
from src.utils_v2 import ModelArguments, ReftArguments
from transformers import HfArgumentParser
import json, pyreft, transformers
from src.represent import make_multiple_position_supervised_data_module

# Load Argument Configuration
arg_file = "configs/config_reft_v1.json"
dataset = dataset["train"]
repo_id = arg_file.split("/config_")[-1].replace(".json", "_elvf")

def train_reft(arg_file, dataset, repo_id):

    with open(arg_file, "r") as f:
        arg_dict = json.load(f)

    ##############
    # Load Model # 
    ##############
    model_arg_parser = HfArgumentParser((ModelArguments,))
    model_args: ModelArguments = model_arg_parser.parse_dict(arg_dict["model_args"])[0]
    model, tokenizer = model_args.make()

    ###################### 

    # Load ReFT Argument #
    ######################
    reft_args = HfArgumentParser((ReftArguments,)).parse_dict(arg_dict["reft_args"])[0]
    reft_config = reft_args.make_config(model)

    ###################
    # Form ReFT Model #
    ###################
    import pyreft 
    reft_model = pyreft.get_reft_model(model, reft_config)
    reft_model.set_device("cuda")
    reft_model.print_trainable_parameters()

    ###############
    # Data Module # 
    ###############

    system_prompt = "Follow the instruction closely and provide your answer."

    query_list = [tokenizer.apply_chat_template(
            [
                {"role": "system", "content": system_prompt}, 
                {"role": "user", "content": data['prompt']}
            ], tokenize=False
    ) for data in dataset]

    answer_list = [
            tokenizer.apply_chat_template(
                [{"role": "assistant", "content": data['completion']}], tokenize=False,
            )[len(tokenizer.bos_token):] for data in dataset
    ]

    data_module = make_multiple_position_supervised_data_module(
        tokenizer, model, query_list, answer_list, 
        positions=reft_args.intervention_positions, num_interventions=len(reft_config.representations), share_weights=reft_args.share_weights, nonstop=False)

    ################
    # Train & Save #
    ################

    training_args = transformers.TrainingArguments(
        num_train_epochs=50.0, output_dir="./tmp", 
        per_device_train_batch_size=10, 
        learning_rate=4e-3, report_to=[], logging_steps=20)

    trainer = pyreft.ReftTrainerForCausalLM(
        model=reft_model, tokenizer=tokenizer,
        args=training_args, **data_module)
    _ = trainer.train()


    reft_model.set_device("cpu") # send back to cpu before saving.
    reft_model.save(
        # save_directory="./reft_to_share", 
        save_to_hf_hub=True, 
        hf_repo_name=repo_id
    )

    return 


AttributeError: 'ReftArguments' object has no attribute 'make_config'

In [None]:
# Inference 
from huggingface_hub import login
login(
  token="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", # ADD YOUR TOKEN HERE
  add_to_git_credential=True
)

from src.inference import ReftInferencer, run_ft_inference
from src.dataset.feedback_utils_v2 import Feedback
from src.dataset.format_v2 import to_distill_sft
from tqdm import tqdm as tqdm
from src.eval import run_eval_prometheus, process_eval_


# Load Fast Adaptor
adaptor_id = "Ksgk-fy/reft_v1_elvf"
f = ReftInferencer(adaptor_id)

# Load Dataset
feedback = Feedback(content = "Do not talk about elephant")
dataset = to_distill_sft(feedback)

# Run Inference
df_pred = run_ft_inference(f, dataset, train=True, run_id="1")

# Basically anything above or equal 4 in score it a good response, otherwise it's bad 
feedbacks, scores = run_eval_prometheus(df_pred, feedback)

# Process Evaluation
df_eval = process_eval_(feedbacks, scores, df_pred, feedback, adaptor_id)

Use below code to test with the hidden representation vector extraction

In [None]:
# get reft model configuration
from src.represent import parse_positions

reft_config = pyreft.ReftConfig(representations=[{
    "layer": l, "component": "block_output",
    "low_rank_dimension": 2,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=2)} for l in [8, 16, 24]])
share_weights = True # whether the prefix and suffix interventions sharing weights.
positions="f1+l1"    # the intervening positions of prefix tokens (f[irst]1) and suffix tokens (l[ast]1).
first_n, last_n = parse_positions(positions)

In [None]:
from src.dataset.feedback_utils_v2 import Feedback
from src.dataset.format_v2 import to_distill_sft
feedback = Feedback(content = "Do not talk about elephant")
dataset = to_distill_sft(feedback)
trainset = dataset["train"]

from tqdm import tqdm as tqdm
pb = tqdm(total=(len(trainset)), desc = "Running reft adaptor inference")
system_prompt = "Follow the instruction closely and provide your answer."

pred_infos = []
for data in trainset:
    # tokenize and prepare the input
    prompt = tokenizer.apply_chat_template(
        [{"role": "system", "content": system_prompt}, {"role": "user", "content": data['prompt']}], 
        tokenize=False)
    prompt = tokenizer(prompt, return_tensors="pt").to(device)
    
    unit_locations = torch.IntTensor([pyreft.get_intervention_locations(
        last_position=prompt["input_ids"].shape[-1], 
        first_n=first_n, 
        last_n=last_n,
        pad_mode="last",
        num_interventions=len(reft_config.representations),
        share_weights=share_weights
    )]).permute(1, 0, 2).tolist()
    
    _, reft_response = reft_model.generate(
        prompt, unit_locations={"sources->base": (None, unit_locations)},
        intervene_on_prompt=True, max_new_tokens=512, do_sample=True, 
        eos_token_id=terminators, early_stopping=True
    )
    response = tokenizer.decode(reft_response[0])
    info = {"prompt": data["prompt"], "pred": response, "gt": data["completion"]}
    pred_infos.append(info)
    pb.update(1)