In [None]:
from transformers import TrainingArguments, Trainer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import PeftModel, LoraConfig, prepare_model_for_kbit_training, get_peft_model
from datasets import load_dataset, Dataset
import polars as pl
from unsloth import FastLanguageModel


In [None]:
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

adapter_path = '/home/lawrence/Projects/my_models/mistral_pr_lora_over65'

In [None]:
# seems to be an issue with loading multiple adpaters with unsloth... and loading the unsloth adapater with regular model
# 2 - might want to train 

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = True,
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

# Load the adapter.
model = PeftModel.from_pretrained(
    model,
    adapter_path,
    is_trainable=True,
    adapter_name="train",
)
# Load the adapter a second time, with a different name, which will be our reference model.
model.load_adapter(adapter_path, adapter_name="reference")

In [None]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "mistral", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
    map_eos_token = True, # Maps <|im_end|> to </s> instead
)

In [None]:
orig_prefix = "Original Text:"
rewrite_prefix = "Rewritten Text:"
response_start = "The prompt was:"
sys_prompt = """You are an expert in "Reverse Prompt Engineering". You are able to reverse-engineer prompts used to rewrite text. 

I will be providing you with an "original text" and "rewritten text". Please try to be as specific as possible and come up with a prompt that is based on the tone, style, and any other properties you consider relevant."""

def format_prompts(x):
    messages = [
        {"role": "user", "content": f"{sys_prompt}\n{orig_prefix} {x['original_text']}\n{rewrite_prefix} {x['rewritten_text']}"},
        {"role": "assistant", "content": f"{response_start} "}
    ]
    output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = False).rstrip('</s>')
    return {"prompt": output}

In [None]:
# df = pl.read_csv('./data/predictions/combined-filtered_*.csv', ignore_errors=True)
df = pl.read_csv('./data/exp_test/multi_n_with_lora/multi_n_selected.csv', ignore_errors=True)
dataset = Dataset.from_list(df.to_dicts())
dataset = dataset.map(format_prompts)
dataset.remove_columns_(['', 'original_text', 'rewritten_text', 'gt_rewrite_prompt', 'old_rewrite_prompt','score','prompt_select','input','rewrite_prompts','rewrite_prompt_1','rewrite_prompt_2','score_1','score_2','prompt_select_score'])

In [None]:
# needs dpo code

In [None]:
trainer_stats = trainer.train()


In [None]:
# model.save_pretrained('/home/lawrence/Projects/my_models/mistral_pr_lora_over65v2b')