In [1]:
import os
import json
import argparse
from time import sleep
from typing import Any, Tuple

import wandb
from peft import LoraConfig, PeftModel
from datasets import Dataset, concatenate_datasets
from trl import DPOTrainer, SFTTrainer, DataCollatorForCompletionOnlyLM

from src.logger import logger
from src.models import get_model
from src.dataset.feedback_utils import Feedback, Type
from src.lcdpo import LocallyConstrainedDPOTrainer
from src.sft_weighted import WeightedSFTTrainer
from src.dataset.format import to_dpo, to_sft, to_lcdpo, to_sft_weighted
from src.feedback import manual_feedback as all_feedback
from src.utils import get_args, find_all_linear_names, dump_arg_dicts, PeftSavingCallback, get_train_file_name, print_num_trainable_params, TrainingArguments, find_file_with_prefix

from src.train import filter_relevant_feedback, get_prompts

In [2]:
# Now this cracks it open a little bit (really small bit)

# python src/train.py --arg_file configs/config_lcdpo.json --run_id test_ksgk --data_dir ./data_ --feedback_prefix "Always use some"
import argparse
import json

parser = argparse.ArgumentParser()
parser.add_argument("--arg_file", type=str, default="configs/config_lcdpo.json")
parser.add_argument("--run_id", type=str, default="test_ksgk")
parser.add_argument("--data_dir", type=str, default="./data")
parser.add_argument("--feedback_prefix", type=str, default="Always use some heart")
args = parser.parse_args("")

with open(args.arg_file, "r") as f:
    arg_dict = json.load(f)

feedback = all_feedback
if args.feedback_prefix is not None: # This unfortunately is basically a prefix-filtering stuff
    feedback = [f for f in feedback if f.content.startswith(args.feedback_prefix)]

In [3]:
import json
arg_file = "configs/config_dft_v1.json"
with open(arg_file, "r") as f:
    arg_dict = json.load(f)

# I do not want to be a smart-ass .....


In [5]:
# arg_dict

In [None]:
# Do this with CUDA machine only
model_args, sample_args, training_args, eval_args = get_args(arg_dict) # This hurts my debugging session ... 

data_dir = args.data_dir
run_id = args.run_id
arg_file = args.arg_file
feedback_prefix = args.feedback_prefix

In [None]:
import huggingface_hub
huggingface_hub.login()

# Load feedback dataset 
run_dir = "data_/test_ksgk/sample"
feedback = all_feedback[0]
feedback.load_dataset(run_dir)

# Load Model
model = get_model(model_args.train_model)

prompts, negative_prompts, general_prompts = get_prompts(feedback, training_args)
dataset_constructor = to_lcdpo
dataset = dataset_constructor(
        prompts,
        negative_prompts if (training_args.negative_prompt_ratio > 0 or training_args.algo == "lcdpo" or training_args.algo == "sft_weighted") else None,
        general_prompts if (training_args.negative_prompt_ratio > 0 or training_args.algo == "lcdpo" or training_args.algo == "sft_weighted") else None,
        model_args.train_model.model_name_or_path)

# Create eval dataset
dataset = dataset.train_test_split(test_size=training_args.eval_split, seed=42, shuffle=True)
eval_dataset = dataset["test"]
dataset = dataset["train"]

In [None]:
run_dir = os.path.join(data_dir, run_id, "train")
# assert training_args.algo in ["dpo", "sft", "lcdpo", "sft_weighted"], f"Unknown algorithm {training_args.algo}"
train_dir = get_train_file_name(training_args, model_args.train_model)
run_dir = os.path.join(run_dir, feedback.file_name)

In [None]:
# Load base training arg adapter if given
if training_args.use_base_prefix is not None:
    base_run_dir = os.path.join(data_dir, run_id, "train", feedback.file_name)
    adapter_name = find_file_with_prefix(base_run_dir, training_args.use_base_prefix)
    model.model = PeftModel.from_pretrained(model.model, os.path.join(base_run_dir, adapter_name), is_trainable=True)
    logger.info(f"Loaded base training model from {base_run_dir}")

# Add LoRA config
assert training_args.lora_enable, "Currently only LoRA training is supported"
if training_args.lora_enable and training_args.use_base_prefix is None:
    peft_config = LoraConfig(
        r=training_args.lora_r, 
        lora_alpha=training_args.lora_alpha, 
        target_modules = find_all_linear_names(model.model, training_args.lora_exclude),
        lora_dropout=training_args.lora_dropout, 
        bias=training_args.lora_bias,
        task_type="CAUSAL_LM"
    )
else: peft_config = None

# 
training_args.output_dir = run_dir
os.makedirs(run_dir, exist_ok=True)
# TODO: add dummping args dict

# Generating run name as feedback + feedback_id + algo + use_negatives
training_args.run_name = "-".join(run_dir.split("/")[-2:])

# Deactivate cache
model.model.config.use_cache = False

In [None]:
model.tokenizer.padding_side = 'left'
response_template = "[/INST]"
trainer = LocallyConstrainedDPOTrainer(
    model=model.model,
    max_length=2048,
    max_prompt_length=1024,
    args=training_args,
    beta=training_args.dpo_beta,
    kd_lambda=training_args.lcdpo_lambda,
    kd_temperature=training_args.lcdpo_temp,
    sigma_soft=training_args.lcdpo_sigma_soft,
    sigma_hard=training_args.lcdpo_sigma_hard,
    use_avg_kl=training_args.lcdpo_avg_kl,
    custom_sft_loss=training_args.lcdpo_custom_sft_loss,
    train_dataset=dataset,
    eval_dataset=eval_dataset,
    tokenizer=model.tokenizer,
    response_template=response_template,
    peft_config=peft_config,
    callbacks=[PeftSavingCallback] if training_args.lora_enable else None
)

In [None]:
trainer.train()