In [None]:
!pip install unsloth
!pip instal peft==0.5.0
!pip install bitsandbytes
!pip install trl
!pip install dotenv
!pip install wandb

In [None]:
!pip install -U datasets

In [None]:
MODEL_ID = "google/gemma-3-270m"
MAX_SEQ_LENGTH = 1024
DTYPE = None
LOAD_IN_4BIT = False
LOAD_IN_8BIT = False
BATCH_SIZE = 1
EPOCHS = 1
LEARNING_RATE = 1e-5
USE_LORA = False
OUTPUT_DIR = "./sft_trainer"
RUN_NAME = f"{str(MODEL_ID.split('/')[-1])}_{EPOCHS}Epochs_Lora_{USE_LORA}_BATCH_SIZE_{BATCH_SIZE}"
HF_TOKEN = None
WANDB_API_KEY = None

In [None]:
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch 
from huggingface_hub import login
import os
from dotenv import load_dotenv
import wandb

load_dotenv()

# HF_TOKEN = os.getenv('HF_TOKEN', None)
# WANDB_API_KEY = os.getenv('WANDB_API_KEY', None)
if WANDB_API_KEY is not None:
    wandb.login(key=WANDB_API_KEY)
    wandb.init(project="ft_proj", name=RUN_NAME)

if HF_TOKEN != None:
    login(token=HF_TOKEN)

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_ID,
    max_seq_length = MAX_SEQ_LENGTH,
    dtype = DTYPE,
    load_in_4bit = LOAD_IN_4BIT,
    load_in_8bit = LOAD_IN_8BIT,
    full_finetuning = False if USE_LORA else True,
)


if USE_LORA:
    model = FastLanguageModel.get_peft_model(
        model,
        r=32,  
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_alpha=32,
        lora_dropout=0,
        bias="none",    
        use_gradient_checkpointing="unsloth", 
        random_state=3407,
        use_rslora=False,   
        loftq_config=None, 
    )

## Process Dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("keivalya/MedQuad-MedicalQnADataset", split='train')

def formatting_prompts_func(examples):
    questions = examples['Question']
    answers = examples['Answer']
    texts = []
    for question, answer in zip(questions, answers):
        if tokenizer.chat_template != None:
            prompt = [
                {"role": "user", "content": question},
                {"role": "assistant", "content": answer}
            ]
            text = tokenizer.apply_chat_template(
                prompt, tokenize=False, add_generation_prompt=False
            )
            texts.append(text)
        else: 
            gemma_template_prompt = "<start_of_turn>user\n{question}<end_of_turn><start_of_turn>model\n{answer}<end_of_turn>"
            text = gemma_template_prompt.format(question=question, answer=answer)
            texts.append(text)
    return {"text": texts}

dataset = dataset.map(formatting_prompts_func, batched=True, batch_size=8)
dataset_split = dataset.train_test_split(test_size=0.05)
train_ds = dataset_split['train']
test_ds = dataset_split['test']
print(f"Train ds: {len(train_ds)}")
print(f"Test ds: {len(test_ds)}")

## Train

In [None]:
from trl import SFTConfig, SFTTrainer
from transformers import get_scheduler, EarlyStoppingCallback, TrainerCallback
import json

In [None]:
class LossLoggingCallback(TrainerCallback):
    """Custom callback to save training and validation losses to a text file."""
    
    def __init__(self, output_dir):
        self.output_dir = output_dir
        self.loss_file = os.path.join(output_dir, "training_losses.txt")
        self.losses = []
        
        os.makedirs(output_dir, exist_ok=True)
        
        with open(self.loss_file, 'w') as f:
            f.write("Epoch\tStep\tTraining_Loss\tValidation_Loss\tLearning_Rate\n")
    
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        """Called when logging occurs during training."""
        if logs is not None:
            current_log = {
                'epoch': state.epoch,
                'step': state.global_step,
                'train_loss': logs.get('train_loss', None),
                'eval_loss': logs.get('eval_loss', None),
                'learning_rate': logs.get('learning_rate', None)
            }
            self.losses.append(current_log)
            with open(self.loss_file, 'a') as f:
                f.write(f"{current_log['epoch']:.2f}\t"
                       f"{current_log['step']}\t"
                       f"{current_log['train_loss'] if current_log['train_loss'] else 'N/A'}\t"
                       f"{current_log['eval_loss'] if current_log['eval_loss'] else 'N/A'}\t"
                       f"{current_log['learning_rate'] if current_log['learning_rate'] else 'N/A'}\n")
    
    def on_train_end(self, args, state, control, model=None, **kwargs):
        """Called at the end of training to save a summary."""
        summary_file = os.path.join(self.output_dir, "training_summary.json")
        summary = {
            'total_epochs': state.epoch,
            'total_steps': state.global_step,
            'final_train_loss': self.losses[-1]['train_loss'] if self.losses else None,
            'final_eval_loss': self.losses[-1]['eval_loss'] if self.losses else None,
            'best_eval_loss': min([log['eval_loss'] for log in self.losses if log['eval_loss'] is not None], default=None),
            'all_losses': self.losses
        }
        with open(summary_file, 'w') as f:
            json.dump(summary, f, indent=2)
        
        print(f"\nTraining completed!")
        print(f"Loss logs saved to: {self.loss_file}")
        print(f"Training summary saved to: {summary_file}")

loss_callback = LossLoggingCallback(OUTPUT_DIR)

In [None]:
trainer = SFTTrainer(
    model = model,
    processing_class = tokenizer,
    train_dataset = train_ds,
    eval_dataset = test_ds,
    dataset_text_field = "text",
    max_seq_length = MAX_SEQ_LENGTH,
    
    args = SFTConfig(
        per_device_train_batch_size = BATCH_SIZE,
        gradient_accumulation_steps = BATCH_SIZE,
        warmup_steps = 5,
        num_train_epochs = EPOCHS,
        learning_rate = LEARNING_RATE,
        logging_steps = 5,

        eval_strategy = "epoch",
        save_strategy = "epoch",
        save_total_limit = 1,
        load_best_model_at_end = True, 
        metric_for_best_model = 'eval_loss',
        greater_is_better = False,

        # max_steps = 100, # Run by steps
        # eval_strategy = "steps",  
        # eval_steps = 100,  
        # save_strategy = "steps",  
        # save_steps = 100, 
        # save_total_limit = 1,
        # load_best_model_at_end = True, 
        # metric_for_best_model = 'eval_loss',
        # greater_is_better = False,
        

        optim = "adamw_8bit", 
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3047,
        logging_dir = os.path.join(OUTPUT_DIR, 'logs'),
        report_to = "none", 
        # run_name = RUN_NAME
    ),
    callbacks = [
        EarlyStoppingCallback(early_stopping_patience=5),
        loss_callback]      
)
trainer.train()
trainer.save_model(os.path.join(OUTPUT_DIR, "final_model"))
tokenizer.save_pretrained(os.path.join(OUTPUT_DIR, "final_model"))
trainer.model.push_to_hub(f"huyhoangt2201/{RUN_NAME}")
tokenizer.push_to_hub(f"huyhoangt2201/{RUN_NAME}")

In [None]:
login(token=HF_TOKEN)

In [None]:
trainer.model.push_to_hub(f"huyhoangt2201/{RUN_NAME}")
tokenizer.push_to_hub(f"huyhoangt2201/{RUN_NAME}")