In [None]:
# finetuning (inspired from https://www.kaggle.com/code/zivicmilos/llm-finetuning)

import os 
import torch
os.environ['CUDA_VISIBLE_DEVICES'] = "7"
device = torch.device(f"cuda:7")

import transformers
from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, AutoPeftModelForCausalLM
from trl import SFTTrainer
from random import randrange, sample, seed
import pandas as pd
from datasets import Dataset
from transformers import EarlyStoppingCallback, TrainingArguments


seed(42)


from llm_interaction import LLMInteraction
    

def preprocess_data(train_file, dev_file, llm):
    train_df = pd.read_pickle(train_file)
    dev_df = pd.read_pickle(dev_file)
    train_dataset = create_dataset(train_df, llm)
    dev_dataset = create_dataset(dev_df, llm)
    print(f"Training dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(dev_dataset)}")
    return train_dataset, dev_dataset

def create_dataset(df, llm):
    messages = df.apply(
        lambda x: llm.correct_mini_facts_prompt(
            x['mini_facts_with_labels_false'], 
            x['ground_truth_source'], 
            x['correction_evidence'], 
            is_training=True
        ), axis=1
    ).tolist()
    
    dataset = Dataset.from_dict({"chat": messages})
    dataset = dataset.map(
        lambda x: {"formatted_chat": llm.llama_tokenizer.apply_chat_template(
            x["chat"], tokenize=False, add_generation_prompt=False)}
    )
    return dataset


# insert your model path
model_path = ""


llm = LLMInteraction(model_path, fine_tuned_version=False, few_shot=False, use_cache=False)

llm.llama_model.config.pretraining_tp = 1
llm.llama_tokenizer.pad_token = llm.llama_tokenizer.eos_token
llm.llama_tokenizer.padding_side = "right"

# your training files
train_file = "train_datasets_combined/corrections_evidence_combined_train_balanced.pkl"
dev_file = "train_datasets_combined/corrections_evidence_combined_dev_balanced.pkl"

train_dataset, dev_dataset = preprocess_data(train_file, dev_file, llm)



peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=64,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj", 
            "up_proj", 
            "down_proj",
        ]
)

# Prepare model for training
model = prepare_model_for_kbit_training(llm.llama_model)

args = TrainingArguments(
    output_dir="llama_finetuned_model",
    num_train_epochs=6,
    per_device_train_batch_size=2, 
    gradient_accumulation_steps=4,
    optim="adamw_8bit",
    logging_steps=1,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    eval_steps=1,
    learning_rate=2e-4,
    bf16=False,
    fp16=True,
    tf32=False,
    max_grad_norm=0.3,
    warmup_steps=5,
    lr_scheduler_type="linear",
    disable_tqdm=False,
    load_best_model_at_end=True,
)

model = get_peft_model(model, peft_config)


early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.05)
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    peft_config=peft_config,
    max_seq_length=2048,
    tokenizer=llm.llama_tokenizer,
    packing=False,
    dataset_text_field = "formatted_chat",
    args=args,
    callbacks=[early_stopping_callback]
    
)

In [None]:
model.gradient_checkpointing_enable()
trainer.train()
trainer.save_model()
metrics = trainer.evaluate()
print(f"Validation loss after last evaluation: {metrics['eval_loss']}")