In [2]:
import os
import json
import argparse
from time import sleep
from typing import Any, Tuple

import wandb
from peft import LoraConfig, PeftModel
from datasets import Dataset, concatenate_datasets
from trl import DPOTrainer, SFTTrainer, DataCollatorForCompletionOnlyLM

from src.logger import logger
from src.models import get_model
from src.dataset.feedback_utils import Feedback, Type
from src.lcdpo import LocallyConstrainedDPOTrainer
from src.sft_weighted import WeightedSFTTrainer
from src.dataset.format import to_dpo, to_sft, to_lcdpo, to_sft_weighted
from src.feedback import manual_feedback as all_feedback
from src.utils import get_args, find_all_linear_names, dump_arg_dicts, PeftSavingCallback, get_train_file_name, print_num_trainable_params, TrainingArguments, find_file_with_prefix


In [3]:
def filter_relevant_feedback(feedback: Feedback, prompts: Dataset | None) -> Dataset | None:
    """Filter out prompts where the revision is not better than the baseline"""
    if prompts is None:
        return None
    
    # TODO: enable this for quantitative feedback
    # TODO: add support to define "better" using a margin rather than just binary comparison
    if isinstance(feedback.metric, list):
        metric = lambda x: all([f(x, v) for f, v in zip(feedback.metric, feedback.metric_value)])
    else:
        metric = lambda x: feedback.metric(x, feedback.metric_value)
    return prompts.filter(lambda x: feedback.comparison(
        metric(x["baseline_response"]),
        metric(x["revised_response"])
    ))


def get_prompts(feedback: Feedback, training_args: TrainingArguments) -> Tuple[Dataset, Dataset, Dataset]:
    # Fetch dataset
    prompts = feedback.prompts["train"].shuffle(seed=42)
    negative_prompts = feedback.negative_prompts["train"].shuffle(seed=42)
    general_prompts = feedback.general_prompts["train"].shuffle(seed=42)

    # Filter out prompts where the revision is not better than the baseline
    if training_args.filter_relevant_feedback:
        assert feedback.type == Type.quantitative, "Filtering relevant feedback is currently only supported for quantitative feedback"
        prompts = filter_relevant_feedback(feedback, prompts)
        negative_prompts = filter_relevant_feedback(feedback, negative_prompts)
        general_prompts = filter_relevant_feedback(feedback, general_prompts)

    if training_args.max_prompts is not None:
        prompts = prompts.select(range(min(training_args.max_prompts, len(prompts))))
        logger.info(f"Using {len(prompts)} prompts")

    if training_args.negative_prompt_ratio > 0 and training_args.algo != "lcdpo" and training_args.algo != "sft_weighted":
        num_negative_prompts = int(training_args.negative_prompt_ratio * len(prompts))
        negative_prompts = negative_prompts.select(range(num_negative_prompts))
        logger.info(f"Using {len(negative_prompts)} negative prompts")

    if training_args.general_prompt_ratio > 0 and training_args.algo != "lcdpo" and training_args.algo != "sft_weighted":
        num_general_prompts = int(training_args.general_prompt_ratio * len(prompts))
        general_prompts = general_prompts.select(range(num_general_prompts))
        logger.info(f"Using {len(general_prompts)} general prompts")

    return prompts, negative_prompts, general_prompts

In [1]:
# Command line arguments for the modal genearation
# --arg-file configs/config.json --do-train --feedback-prefix "Be more detailed" --run-id test

In [28]:
arg_file = "configs/config.json"
feedback_prefix = "Be more detailed"
run_id = "test-ksgk"
data_dir = "./data"

In [4]:
# Now this cracks it open a little bit (really small bit)
import argparse
import json

parser = argparse.ArgumentParser()
parser.add_argument("--arg_file", type=str, default="configs/config_dpo.json")
parser.add_argument("--run_id", type=str, default="test-ksgk")
parser.add_argument("--data_dir", type=str, default="./data")
parser.add_argument("--feedback_prefix", type=str, default="")
args = parser.parse_args("")

with open(args.arg_file, "r") as f:
    arg_dict = json.load(f)

feedback = all_feedback
if args.feedback_prefix is not None: # This unfortunately is basically a prefix-filtering stuff
    feedback = [f for f in feedback if f.content.startswith(args.feedback_prefix)]

In [9]:
feedback

[Feedback(content='Always use some heart or kiss emoji when texting my girlfriend Maddie', domain='writing text messages to my girlfriend Maddie', effect='use some heart or kiss emoji', scope=<Scope.regional: 'regional'>, type=<Type.quantitative: 'quantitative'>, metric=<function Metric.<lambda> at 0x2bad591c0>, metric_value='🥰|😍|😘|😗|😚|😙|😽|💋|💌|💘|💝|💖|💗|💓|💞|💕|💟|❣|💔|❤|🧡|💛|💚|💙|💜|🤎|🖤|🤍|💏|👩\u200d❤️\u200d💋\u200d👨|👨\u200d❤️\u200d💋\u200d👨|👩\u200d❤️\u200d💋\u200d👩|💑|👩\u200d❤️\u200d👨|👨\u200d❤️\u200d👨|👩\u200d❤️\u200d👩|♥|🏩|<3|:3', comparison=<function Comparison.<lambda> at 0x2bac92340>, categories=['manual'], prompts=None, negative_prompts=None, general_prompts=None),
 Feedback(content="Use '&' instead of 'and' in any Slack message DMs to my colleagues John, Michael, Eric, or Hailey", domain='writing Slack message DMs to my colleagues John, Michael, Eric, or Hailey', effect="use '&' instead of 'and'", scope=<Scope.regional: 'regional'>, type=<Type.quantitative: 'quantitative'>, metric=[<function Me

In [10]:
# model_args, _, training_args, _ = get_args(arg_dict) # This hurts my debugging session ... 

run_dir = os.path.join(args.data_dir, args.run_id)
feedback[0].load_dataset(run_dir)

FileNotFoundError: [Errno 2] No such file or directory: './data/test-ksgk/always_use_some_heart_or_kiss__c3c45956-5eac-55d4-955b-f042660d2c15/prompts.json'

In [None]:
def train(arg_dict: dict[str, Any], run_id: str, data_dir: str, feedback: Feedback, second_feedback: Feedback = None) -> None:
    model_args, _, training_args, _ = get_args(arg_dict)
    
    # Load feedback
    run_dir = os.path.join(data_dir, run_id, "sample")
    logger.info(f"Training using data for run {run_id}, stored in {run_dir}")
    if not feedback.can_load_dataset(run_dir):
        raise ValueError(f"Feedback \"{feedback.content}\" has not been sampled yet")
    feedback.load_dataset(run_dir)
    logger.info(f"Loaded feedback \"{feedback.content}\"")

    # Load second feedback if given
    if second_feedback is not None:
        assert training_args.multi_feedback_training, "Must set multi_feedback_training to True when providing a second feedback"
        if not second_feedback.can_load_dataset(run_dir):
            raise ValueError(f"Feedback \"{second_feedback.content}\" has not been sampled yet")
        second_feedback.load_dataset(run_dir)
    elif training_args.multi_feedback_training and second_feedback is None:
        raise ValueError("Must provide a second feedback when multi_feedback_training is True")
