In [None]:
from time import sleep
from typing import Dict, Optional, List

import torch
from datasets import load_dataset
from peft import LoraConfig, AutoPeftModelForCausalLM
from transformers import TrainingArguments, BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer, SFTTrainer

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger

## Loading and preparing a dataset 
We are using a Huggingface hosted dataset consisting of Stackoverflow questions

In [None]:
# Dataset ARGS:

ds_name = "MaestroDmitry/stack-exchange-paired-shorted"

In [None]:
# Huggingface DPO trainer needs a dataset containing prompts, chosen, and rejected

def return_prompts_and_responses(batch: Dict[str, List[str]]) -> Dict[str, List[str]]:
    prompts = [f"Question: {question} \n\nAnswer: " for question in batch["question"]]
    return {
        'prompt': list(prompts),
        'chosen': list(batch["response_j"]),
        'rejected': list(batch["response_k"])
    }


In [None]:
# Loading the dataset from Huggingface
dataset = load_dataset(
    ds_name,
    cache_dir="llm-finetune/data"
)

dataset = dataset.map(
    function=return_prompts_and_responses,
    batched=True,
    with_indices=False,
    remove_columns=dataset['train'].column_names
)

train_dataset = dataset['train']
test_dataset = dataset['test']



## Loading a SFT base model

In [None]:
model_path = "EleutherAI/gpt-neo-1.3B"
# model_path = "ComCom/gpt2-small"

tokenizer = AutoTokenizer.from_pretrained(model_path)

# load the base model in 4-bit quantization
# TODO this only works on cuda
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=bnb_config,
    device_map='auto',  # {"": 0},
    trust_remote_code=True,
    # use_auth_token=True,
    cache_dir="llm-finetune/model/base"
)

base_model.config.use_cache = False

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    base_model.resize_token_embeddings(len(tokenizer))

print(base_model)

In [None]:
def sft_formatting_func(example):
    output_texts = []
    output_texts.append(f"### Question: {example['prompt']}\n ### Answer: {example['chosen']}")
    output_texts.append(f"### Question: {example['prompt']}\n ### Answer: {example['rejected']}")
    return output_texts

In [None]:
#
# Lora args:
lora_r = 8
lora_alpha = 8
lora_dropout = 0.0

peft_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    target_modules=["c_proj"],  # ["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)

In [None]:
sft_training_args: TrainingArguments = TrainingArguments(
    output_dir="llm-finetune/model/sft_train",
    # use_cpu=True,
    per_device_train_batch_size=2,
    per_gpu_eval_batch_size=2,
    logging_dir='llm-finetune/logs/sft_train'
)

trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=peft_config,
    packing=True,  # Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences of the dataset.
    max_seq_length=None,  # The maximum sequence length to use for the `ConstantLengthDataset` and for automatically creating the Dataset. Defaults to `512`.
    formatting_func=sft_formatting_func,
    tokenizer=tokenizer,
    args=sft_training_args,  # HF Trainer arguments

)

In [None]:
trainer.train()

# On single A10-G = 15h
# On dual rtx 3090 = 8h

In [None]:
# TODO load pre trained model here instead?

## DPO Training

In [None]:
dpo_beta: float = 0.1
dpo_training_args: Optional[TrainingArguments] = TrainingArguments(
    output_dir="llm-finetune/model/dpo_train",
    # use_cpu=True,
    per_device_train_batch_size=1,  # TODO DPO seems to use one model / gpu, so i can up this!
    per_gpu_eval_batch_size=1,
)

dpo_model = "llm-finetune/sft_train/checkpoint-1000"
model = AutoPeftModelForCausalLM.from_pretrained(
    dpo_model,  # location of saved SFT model
    device_map='auto',  # {"": 0},
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
    is_trainable=True,
)

dpo_model_ref = "llm-finetune/sft_train/checkpoint-1000"
model_ref = AutoPeftModelForCausalLM.from_pretrained(
    dpo_model_ref,  # same model as the main one
    device_map='auto',  # {"": 0},
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
)

dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=dpo_training_args,
    beta=dpo_beta,
    train_dataset=dpo_train_dataset,
    eval_dataset=dpo_test_dataset,
    tokenizer=tokenizer,
)



In [None]:
dpo_trainer.train()
dpo_trainer.save_model()
