In [1]:
import os
import json
import argparse
from time import sleep
from typing import Any, Tuple

import wandb
from peft import LoraConfig, PeftModel
from datasets import Dataset, concatenate_datasets
from trl import DPOTrainer, SFTTrainer, DataCollatorForCompletionOnlyLM

from src.logger import logger
from src.models import get_model
from src.dataset.feedback_utils import Feedback, Type
from src.lcdpo import LocallyConstrainedDPOTrainer
from src.sft_weighted import WeightedSFTTrainer
from src.dataset.format import to_dpo, to_sft, to_lcdpo, to_sft_weighted
from src.feedback import manual_feedback as all_feedback
from src.utils import get_args, find_all_linear_names, dump_arg_dicts, PeftSavingCallback, get_train_file_name, print_num_trainable_params, TrainingArguments, find_file_with_prefix


In [2]:
# Command line arguments for the modal genearation
# --arg-file configs/config.json --do-train --feedback-prefix "Be more detailed" --run-id test

In [3]:
arg_file = "configs/config.json"
feedback_prefix = "Be more detailed"
run_id = "test-ksgk"
data_dir = "./data"

In [4]:
# Now this cracks it open a little bit (really small bit)
import argparse
import json

parser = argparse.ArgumentParser()
parser.add_argument("--arg_file", type=str, default="configs/config_dpo.json")
parser.add_argument("--run_id", type=str, default="test-ksgk")
parser.add_argument("--data_dir", type=str, default="./data")
parser.add_argument("--feedback_prefix", type=str, default="")
args = parser.parse_args("")

with open(args.arg_file, "r") as f:
    arg_dict = json.load(f)

feedback = all_feedback
if args.feedback_prefix is not None: # This unfortunately is basically a prefix-filtering stuff
    feedback = [f for f in feedback if f.content.startswith(args.feedback_prefix)]

In [5]:
# model_args, _, training_args, _ = get_args(arg_dict) # This hurts my debugging session ... 

# BreakDown when we have issues | Following code now works
from src.utils import *
modal_arg_dict = arg_dict["model_args"]
sample_arg_dict = arg_dict["sample_args"]
training_arg_dict = arg_dict["training_args"]
eval_arg_dict = arg_dict["eval_args"]

# HfArgumentParse parse on a python dictionary object, this is quite convenient wrapper
model_arg_parser = HfArgumentParser(PipelineModelsArguments)
model_args: PipelineModelsArguments = model_arg_parser.parse_dict(modal_arg_dict)[0]
sample_arg_parser = HfArgumentParser(SampleArguments)
sample_args: SampleArguments = sample_arg_parser.parse_dict(sample_arg_dict)[0]

# Issue Spot on MPS: Float16 not supported 
# training_arg_parser = HfArgumentParser(TrainingArguments)
# training_args: TrainingArguments = training_arg_parser.parse_dict(training_arg_dict)[0]

# Rest seems fine
eval_arg_parser = HfArgumentParser(EvalArguments)
eval_args: EvalArguments = eval_arg_parser.parse_dict(eval_arg_dict)[0]

In [12]:
from src.utils import ModelArguments
from src.sample import sample_prompts, SAMPLE_PROMPTS, SAMPLE_NEGATIVE_PROMPTS, SAMPLE_PROMPTS_CONFIG, SAMPLE_NEGATIVE_PROMPTS_CONFIG

prompt_model_args = model_args.prompt_model
category_model_args = model_args.category_model
completion_model_args = model_args.completion_model
quality_model_args = model_args.qualitative_eval_model

negative = False 

prompt = SAMPLE_PROMPTS if not negative else SAMPLE_NEGATIVE_PROMPTS
prompt_config = SAMPLE_PROMPTS_CONFIG if not negative else SAMPLE_NEGATIVE_PROMPTS_CONFIG

# Loaded Model
####################################################################
# Rate Limit Exceeded: To be Fair, this exceeds limit after 12 sec #
####################################################################
# prompt_model = get_model(category_model_args)
prompt_model = get_model(completion_model_args)
# Sampling Steps obtains a bunch of prompt for in-domain / out-domain model_args
prompts_per_category = 1

responses = []
for f in feedback:
    for c in f.categories:
        prompt_text = prompt.format(count=prompts_per_category, domain=f.domain, category=c)
        responses.append(prompt_model.get_responses(prompt_text, prompt_config))
        time.sleep(20)  # Sleep for 16 seconds after each call


100%|██████████| 2426/2426 [01:18<00:00, 30.99it/s]


KeyboardInterrupt: 

In [10]:
prompt_model.get_responses

<src.models.openai.OpenAIModel at 0x29abfdbd0>

In [37]:
def train(arg_dict: dict[str, Any], run_id: str, data_dir: str, feedback: Feedback, second_feedback: Feedback = None) -> None:
    model_args, _, training_args, _ = get_args(arg_dict)
    
    # Load feedback
    run_dir = os.path.join(data_dir, run_id, "sample")
    logger.info(f"Training using data for run {run_id}, stored in {run_dir}")
    if not feedback.can_load_dataset(run_dir):
        raise ValueError(f"Feedback \"{feedback.content}\" has not been sampled yet")
    feedback.load_dataset(run_dir)
    logger.info(f"Loaded feedback \"{feedback.content}\"")

    # Load second feedback if given
    if second_feedback is not None:
        assert training_args.multi_feedback_training, "Must set multi_feedback_training to True when providing a second feedback"
        if not second_feedback.can_load_dataset(run_dir):
            raise ValueError(f"Feedback \"{second_feedback.content}\" has not been sampled yet")
        second_feedback.load_dataset(run_dir)
    elif training_args.multi_feedback_training and second_feedback is None:
        raise ValueError("Must provide a second feedback when multi_feedback_training is True")
