In [None]:
from unsloth import FastLanguageModel
from datasets import load_from_disk
from trl import SFTTrainer
from transformers import TrainingArguments
import torch

In [None]:


class Lora_FineTuner:
    """
    QLORA Fine-tuning using Unsloth, Huggingface Transformers.
    """
    
    def __init__(self, dataset, model_name: str, max_seq_length: int, 
                 load_in_4bit: bool = True, wandb_track: bool = True) -> None:
        self.model_name = model_name
        self.dataset = load_from_disk(dataset)
        self.max_seq_length = max_seq_length
        self.load_in_4bit = load_in_4bit
    
    def _setup_wandb(self):
        
    
    def _setup_model_and_tokenizer(self):
        model, self.tokenizer = FastLanguageModel.from_pretrained(
        model_name = self.model_name, 
        max_seq_length = self.max_seq_length,
        dtype = None,
        load_in_4bit = self.load_in_4bit,
        )

        self.model = FastLanguageModel.get_peft_model(
            model,
            r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
            target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                            "gate_proj", "up_proj", "down_proj",],
            lora_alpha = 16,
            lora_dropout = 0, # Supports any, but = 0 is optimized
            bias = "none",    # Supports any, but = "none" is optimized
            use_gradient_checkpointing = True,
            random_state = 3407,
            use_rslora = False,
            loftq_config = None,
        )
    
    def _setup_trainer(self):
        self.trainer = SFTTrainer(
            model = self.model,
            tokenizer = self.tokenizer,
            train_dataset = self.dataset['train'],
            eval_dataset = self.dataset['eval'],
            dataset_text_field = "text",
            max_seq_length = self.max_seq_length,
            dataset_num_proc = 2,
            packing = False, # Can make training 5x faster for short sequences.
            args = TrainingArguments(
                per_device_train_batch_size = 2,
                gradient_accumulation_steps = 4,
                warmup_steps = 5,
                learning_rate = 2e-4,
                num_train_epochs=2,
                fp16 = not torch.cuda.is_bf16_supported(),
                bf16 = torch.cuda.is_bf16_supported(),
                logging_steps = 500,
                optim = "adamw_8bit",
                weight_decay = 0.01,
                eval_steps=1000,
                do_eval = True,
                lr_scheduler_type = "linear",
                seed = 3407,
                report_to = "wandb",
                output_dir = "outputs",
            ),
        )
    
    def train(self):
        trainer_stats = self.trainer.train()
