In [1]:
# !pip install -q unsloth
!pip install unsloth vllm transformers==4.51.3 trl==0.18.1 accelerate==1.7.0

Collecting unsloth
  Using cached unsloth-2025.6.2-py3-none-any.whl.metadata (47 kB)
Collecting vllm
  Downloading vllm-0.9.1-cp38-abi3-manylinux1_x86_64.whl.metadata (15 kB)
Collecting trl==0.18.1
  Downloading trl-0.18.1-py3-none-any.whl.metadata (11 kB)
Collecting accelerate==1.7.0
  Downloading accelerate-1.7.0-py3-none-any.whl.metadata (19 kB)
Collecting datasets>=3.0.0 (from trl==0.18.1)
  Using cached datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting unsloth_zoo>=2025.6.1 (from unsloth)
  Using cached unsloth_zoo-2025.6.1-py3-none-any.whl.metadata (8.1 kB)
Collecting xformers>=0.0.27.post2 (from unsloth)
  Using cached xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Collecting bitsandbytes (from unsloth)
  Using cached bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting tyro (from unsloth)
  Using cached tyro-0.9.24-py3-none-any.whl.metadata (11 kB)
Collecting sentencepiece>=0.2.0 (from unsloth)
  Using cached sent

In [21]:
!pip install huggingface_hub gradio dataset

Collecting gradio
  Using cached gradio-5.34.2-py3-none-any.whl.metadata (16 kB)
Collecting dataset
  Downloading dataset-1.6.2-py2.py3-none-any.whl.metadata (1.9 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Using cached aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting ffmpy (from gradio)
  Using cached ffmpy-0.6.0-py3-none-any.whl.metadata (2.9 kB)
Collecting gradio-client==1.10.3 (from gradio)
  Using cached gradio_client-1.10.3-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Using cached groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting orjson~=3.0 (from gradio)
  Using cached orjson-3.10.18-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (41 kB)
Collecting pydub (from gradio)
  Using cached pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting ruff>=0.9.3 (from gradio)
  Using cached ruff-0.12.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.

In [2]:
# =============================================================================
import torch
from datasets import load_dataset
# from trl import SFTTrainer
# from transformers import TrainingArguments, DataCollatorForSeq2Seq
# from unsloth import is_bfloat16_supported
import os
from unsloth import FastModel
from unsloth.chat_templates import get_chat_template
from unsloth.chat_templates import train_on_responses_only
from trl import SFTTrainer, SFTConfig
from unsloth.chat_templates import get_chat_template

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 06-20 08:07:39 [__init__.py:244] Automatically detected platform cuda.


In [3]:
# 3. CONFIGURATION
# =============================================================================

In [4]:
# Model Configuration
MODEL_NAME = "unsloth/gemma-3-4b-it-unsloth-bnb-4bit"  # Gemma 3 4B Instruct - "unsloth/gemma-3-1b-it-unsloth-bnb-4bit"
MAX_SEQ_LENGTH = 2048
DTYPE = None  # Auto-detection
LOAD_IN_4BIT = False
LOAD_IN_8BIT = False
FULL_FINETUNING = False

In [5]:
# Dataset Configuration - UPDATE WITH YOUR DATASET NAME
DATASET_NAME = "ictbiortc/hausa-medical-conversations-format-9k"
# Alternative: Use local files if you have them
# DATASET_FILES = {"train": "train.jsonl", "test": "test.jsonl"}

# Training Configuration
OUTPUT_DIR = "gemma3-4b-hausa-medical"
HF_MODEL_NAME = "ictbiortc/gemma3-4b-hausa-medical-qa"  # Change this!
HF_TOKEN="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"

In [6]:
# LoRA Configuration
LORA_R = 32        # Higher rank for better learning
LORA_ALPHA = 32    # Alpha parameter
LORA_DROPOUT = 0   # No dropout for stability
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", 
                  "gate_proj", "up_proj", "down_proj"]

In [7]:
# Training Hyperparameters
BATCH_SIZE = 2              # Adjust based on your GPU memory
GRADIENT_ACCUMULATION = 4   # Effective batch size = 2 * 4 = 8
LEARNING_RATE = 2e-4
MAX_STEPS = 200            # Increase for larger datasets
WARMUP_STEPS = 5
EVAL_STEPS = 50
SAVE_STEPS = 50
LOGGING_STEPS = 10


In [8]:
# 4. LOAD MODEL AND TOKENIZER
# =============================================================================
print("\n🔄 Loading Gemma 3 4B model and tokenizer...")

model, tokenizer = FastModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_8bit = LOAD_IN_8BIT, # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning = FULL_FINETUNING, # [NEW!] We have full finetuning now!
    load_in_4bit=LOAD_IN_4BIT,
    trust_remote_code=True,  # Required for Gemma models
)

print("✅ Model and tokenizer loaded successfully!")
print(f"📊 Model size: {model.get_memory_footprint() / 1e9:.2f}GB")


🔄 Loading Gemma 3 4B model and tokenizer...
Are you certain you want to do remote code execution?
==((====))==  Unsloth 2025.6.2: Fast Gemma3 patching. Transformers: 4.51.3. vLLM: 0.9.1.
   \\   /|    NVIDIA RTX A6000. Num GPUs = 1. Max memory: 47.413 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3 does not support SDPA - switching to eager!
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


✅ Model and tokenizer loaded successfully!
📊 Model size: 8.60GB


In [9]:
# 5. CONFIGURE LORA FOR FINE-TUNING
# =============================================================================
print("\n🔄 Configuring LoRA for fine-tuning...")

model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # SHould leave on always!

    r = LORA_R,           # Larger = higher accuracy, but might overfit
    lora_alpha = LORA_ALPHA,  # Recommended alpha == r at least
    lora_dropout = LORA_DROPOUT,
    bias = "none",
    random_state = 3407,
)

print("✅ LoRA configuration applied!")


🔄 Configuring LoRA for fine-tuning...
Unsloth: Making `model.base_model.model.language_model.model` require gradients
✅ LoRA configuration applied!


In [10]:
# 6. LOAD AND PREPARE DATASET
# =============================================================================
DATASET_NAME = "ictbiortc/hausa-medical-conversations-format-9k"
print(f"\n🔄 Loading dataset: {DATASET_NAME}")

try:
    # Load from Hugging Face Hub
    dataset = load_dataset(DATASET_NAME)
    print(f"✅ Dataset loaded from Hub!")
    print(f"📊 Train samples: {len(dataset['train']):,}")
    print(f"📊 Test samples: {len(dataset['test']):,}")
    
    # Show sample
    print(f"\n📝 Sample data:")
    sample = dataset['train'][0]
    print(f"Text preview: {sample['text'][:200]}...")
    
except Exception as e:
    print(f"❌ Error loading dataset: {e}")
    print("💡 Make sure your dataset is public or you're authenticated")
    # Fallback to local files if needed
    # dataset = load_dataset("json", data_files=DATASET_FILES)



🔄 Loading dataset: ictbiortc/hausa-medical-conversations-format-9k
✅ Dataset loaded from Hub!
📊 Train samples: 8,100
📊 Test samples: 900

📝 Sample data:
❌ Error loading dataset: 'text'
💡 Make sure your dataset is public or you're authenticated


In [11]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",  # Use gemma-3, not chatml!
)

In [12]:
from unsloth.chat_templates import standardize_data_formats
train_dataset = standardize_data_formats(dataset['train'])
eval_dataset = standardize_data_formats(dataset['test'])

In [13]:
def formatting_prompts_func(examples):
   convos = examples["conversations"]
   texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
   return { "text" : texts, }

train_dataset = train_dataset.map(formatting_prompts_func, batched = True)
eval_dataset = eval_dataset.map(formatting_prompts_func, batched = True)

In [14]:
# 8. SETUP TRAINER
# =============================================================================
from trl import SFTTrainer, SFTConfig
print("\n🔄 Setting up trainer...")

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4, # Use GA to mimic batch size!
        warmup_steps = 5,
        num_train_epochs = 1, # Set this for 1 full training run.
        # max_steps = 30,
        learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none", # Use this for WandB etc
        dataset_num_proc=2,
    ),
)

print("✅ Trainer configured successfully!")


🔄 Setting up trainer...
✅ Trainer configured successfully!


In [15]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

In [16]:
#verify masking

In [17]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

"<bos><start_of_turn>user\nKai likita ne mai hankali da ƙwarewa a fannin kiwon lafiya. Ka ba da shawarwari masu tushe a kimiyya, masu dacewa da al'adun Nijeriya. Kullum ka tunatar da mutane su nemi shawarar likita idan suna bukatar gaggawar magani.\n\nnawa ya fada ya kai kasa, me ya kamata mu yi?<end_of_turn>\n<start_of_turn>model\nDa farko, a tabbatar cewa jaririn yana cikin yanayi mai lafiya. A hankali a ɗaga shi, a kuma bincika ko akwai rauni ko ciwo. Idan jaririn ya yi kuka sosai ko kuma yana da zafi mai tsanani, a gaggauta kai shi asibiti don a duba shi sosai. Idan babu rauni mai tsanani, sai a shafa maganin gargajiya mai sanyaya wuri kamar man zaitun, amma idan al'amarin ya ci gaba, a je asibiti.<end_of_turn>\n"

In [18]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA RTX A6000. Max memory = 47.413 GB.
9.514 GB of memory reserved.


In [19]:
# 9. START TRAINING
# =============================================================================
print("\n🚀 Starting training...")
print("=" * 50)

trainer_stats = trainer.train()

print("=" * 50)
print("✅ Training completed!")
print(f"📊 Final train loss: {trainer_stats.training_loss:.4f}")
print(f"⏱️ Training time: {trainer_stats.metrics['train_runtime']:.2f}s")


🚀 Starting training...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,4.1374
2,4.0782
3,4.1189
4,3.1637
5,2.7049
6,2.6135
7,2.3698
8,2.3349
9,2.2648
10,2.1885


KeyboardInterrupt: 

In [None]:
# 10. SAVE MODEL
# =============================================================================
print("\n💾 Saving fine-tuned model...")

# Save LoRA weights
model.save_pretrained_merged("gemma3-4b-hausa-medical-qa", tokenizer)
model.push_to_hub_merged(
    HF_MODEL_NAME, tokenizer,
    token = HF_TOKEN
)

print(f"✅ Model saved to HF")

## Test

In [None]:
# =============================================================================
# TEST MERGED LORA MODEL FROM HUGGING FACE
# =============================================================================

# 11. LOAD AND TEST THE MERGED MODEL FROM HF
# =============================================================================
print("\n🧪 Loading and testing merged model from Hugging Face...")

HF_MODEL_NAME = "ictbiortc/gemma3-4b-hausa-medical-qa"  # Your merged model

# Load the merged model directly from HF (no base model needed!)
from unsloth import FastModel
from transformers import TextStreamer
import torch

model, tokenizer = FastModel.from_pretrained(
    model_name=HF_MODEL_NAME,  # Your merged model from HF
    max_seq_length=2048,
    load_in_4bit=True,
    trust_remote_code=True,
)

print(f"✅ Merged model loaded from: {HF_MODEL_NAME}")

# Switch to inference mode
FastModel.for_inference(model)

In [None]:
# =============================================================================
# MODERN CHAT TEMPLATE TESTING (Gemma 3 Style)
# =============================================================================
print("\n🔄 Testing with modern chat template...")

# Hausa medical test cases using proper message format
test_cases = [
    {
        "query": "Ina jin ciwon kai da zazzabi tun kwana biyu. Me ya kamata in yi?",
        "description": "Headache and fever for 2 days"
    },
    {
        "query": "Dana yana da gudawa sosai. Ina bukatan taimako.",
        "description": "Child with severe diarrhea"
    },
    {
        "query": "Yaya ake hana malaria lokacin damina?",
        "description": "Malaria prevention during rainy season"
    },
    {
        "query": "Ina da ciwon sukari. Wanne abinci ya dace da ni?",
        "description": "Diabetes dietary advice"
    },
    {
        "query": "Kakana mai shekara 70 tana fama da hauhawar jini. Me ya kamata mu yi?",
        "description": "Elderly hypertension management"
    }
]

# System message for medical context
SYSTEM_MESSAGE = "Kai likita ne mai hankali da ƙwarewa a fannin kiwon lafiya. Ka ba da shawarwari masu tushe a kimiyya, masu dacewa da al'adun Nijeriya. Kullum ka tunatar da mutane su nemi shawarar likita idan suna bukatar gaggawar magani."

for i, case in enumerate(test_cases, 1):
    print(f"\n🔍 Test {i}: {case['description']}")
    print(f"📝 Query: {case['query']}")
    
    # Create messages in proper chat format
    messages = [
        {
            "role": "system",
            "content": SYSTEM_MESSAGE
        },
        {
            "role": "user", 
            "content": case['query']
        }
    ]
    
    # Apply chat template (modern Gemma 3 approach)
    text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,  # Must add for generation
        tokenize=False,  # Return string, not tokens
    )
    
    print(f"🔧 Formatted prompt preview: {text[:100]}...")
    
    # Tokenize for generation
    inputs = tokenizer([text], return_tensors="pt").to("cuda")
    
    # Create streamer for real-time output
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    print(f"🤖 Response: ", end="")
    
    # Generate with recommended Gemma-3 settings
    outputs = model.generate(
        **inputs,
        max_new_tokens=200,  # Longer for medical explanations
        # Recommended Gemma-3 settings
        temperature=1.0,     # Creative but controlled
        top_p=0.95,          # Nucleus sampling
        top_k=64,            # Top-k filtering
        do_sample=True,
        repetition_penalty=1.1,
        use_cache=True,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        streamer=streamer,   # Real-time streaming output
    )
    
    print("\n" + "-" * 70)


🔄 Testing with modern chat template...

🔍 Test 1: Headache and fever for 2 days
📝 Query: Ina jin ciwon kai da zazzabi tun kwana biyu. Me ya kamata in yi?
🔧 Formatted prompt preview: <bos><start_of_turn>user
Kai likita ne mai hankali da ƙwarewa a fannin kiwon lafiya. Ka ba da shawar...
🤖 Response: Yaro mai watanni kadan yana da yiwuwar samun cututtuka kamar tarin huhu ko mura. Da farko, ka tabbatar da cewa yaron yana shan ruwan nono sosai kuma an wanke shi da kyau. Idan zafin jikin yaron bai yi sauki ba bayan kwanaki biyu, ko kuma akwai wasu alamomi kamar rashin cin abinci da kuzari, ya kamata ka garzaya da shi asibiti domin a duba shi sosai. A lokaci guda, za ka iya amfani da rigakafa ko sanyaya masa jiki cikin sauƙi don rage zazzaɓin jiki kafin zuwa asibiti.

----------------------------------------------------------------------

🔍 Test 2: Child with severe diarrhea
📝 Query: Dana yana da gudawa sosai. Ina bukatan taimako.
🔧 Formatted prompt preview: <bos><start_of_turn>user
Kai likit