# Project Overview

This notebook fine-tunes the **Gemma-3n-4b** model using a custom medical dataset designed for edge deployment.  
The goal is to build an LLM capable of performing offline inference on healthcare, rescue, and first aid tasks — without relying on cloud infrastructure or internet access.

The training dataset was generated in-house through a multi-stage pipeline that combines authoritative medical datasets and official PDFs with AI-generated content using advanced Retrieval-Augmented Generation (RAG).  
It includes over **80,000 high-quality medical Q&A pairs**, covering clinical reasoning, emergency medicine, and first aid procedures.

Key characteristics:
- Model: `Gemma-3n-4b`, optimized for running on local devices.
- Use case: Offline emergency assistants, medical chatbots, and rescue guidance tools.
- Dataset: Synthesized from 11 open-source datasets and 14 official medical PDFs using vector search, embedding techniques, and strict medical validation.

The full dataset and pipeline are open-source and available at:  
**https://github.com/ericrisco/gemma3n-impact-challenge**

**Disclaimer**: This project is for educational and research purposes only. Models trained on this data are not a replacement for professional medical advice.


## Configurations
Set up environment variables, API keys, and output paths.


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import os
import getpass

api_key = getpass.getpass("Enter your HF API Key:")

Enter your HF API Key:··········


In [None]:
from_scratch = True

In [None]:
source_model_name = "unsloth/gemma-3n-E4B-it"

model_name = "medical-gemma-3n"
destination_model_name = "ericrisco/medical-gemma-3n-lora"

In [None]:
output_dir = "/content/drive/MyDrive/medical-gemma-3n"

## Installation
Install required packages and dependencies.


In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" huggingface_hub hf_transfer
    !pip install --no-deps unsloth

In [None]:
%%capture
!pip install --no-deps --upgrade timm

## Unsloth Setup
Load the base model and configure the training environment using Unsloth.


In [None]:
from unsloth import FastModel
import torch

model, tokenizer = FastModel.from_pretrained(
    model_name = source_model_name,
    dtype = None,
    max_seq_length = 1024,
    load_in_4bit = True,
    full_finetuning = False,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.7.6: Fast Gemma3N patching. Transformers: 4.53.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/1.15G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.72G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

## Finetuning Configuration
Define training parameters such as data type, adapters, and logging.


In [None]:
model = FastModel.get_peft_model(
  model,
  finetune_vision_layers     = False,
  finetune_language_layers   = True,
  finetune_attention_modules = True,
  finetune_mlp_modules       = True,
  r = 8,
  lora_alpha = 8,
  lora_dropout = 0,
  bias = "none",
  random_state = 3407,
)

Unsloth: Making `model.base_model.model.model.language_model` require gradients


## Dataset Preparation
Load and preprocess the dataset for training.


In [None]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",
)

In [None]:
from datasets import load_dataset

dataset = load_dataset("ericrisco/medrescue")

README.md:   0%|          | 0.00/19.7k [00:00<?, ?B/s]

medical_training_dataset.json:   0%|          | 0.00/204M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
dataset = dataset["train"]

def to_conversations(example):
    return {
        "conversations": [
            {"from": "human", "value": example["input"]},
            {"from": "gpt", "value": example["output"]}
        ]
    }

dataset = dataset.map(to_conversations)
cols = [col for col in dataset.column_names if col != "conversations"]
dataset = dataset.remove_columns(cols)

Map:   0%|          | 0/86667 [00:00<?, ? examples/s]

In [None]:
from unsloth.chat_templates import standardize_data_formats
dataset = standardize_data_formats(dataset)

Unsloth: Standardizing formats (num_proc=8):   0%|          | 0/86667 [00:00<?, ? examples/s]

In [None]:
dataset[100]

{'conversations': [{'content': 'Under what circumstances is Cotrimoxazole prophylaxis not recommended among HIV infected children?',
   'role': 'user'},
  {'content': 'Prophylaxis with Cotrimoxazole is not recommended for all symptomatic HIV infected children over 5 years of age irrespective of CD4 counts.',
   'role': 'assistant'}]}

In [None]:
def convert_conversations(conversations):
    user_content = ''
    assistant_content = ''
    for message in conversations:
        if message['role'] == 'user':
            content = message.get('content') or ''
            user_content += content.strip()
        elif message['role'] == 'assistant':
            content = message.get('content') or ''
            assistant_content += content.strip()
    result = (
        '<start_of_turn>user\n' +
        user_content +
        '<end_of_turn>\n' +
        '<start_of_turn>model\n' +
        assistant_content +
        '<end_of_turn>\n'
    )
    return result

dataset = [convert_conversations(item['conversations']) for item in dataset]


In [None]:
from datasets import Dataset

dataset = Dataset.from_list([{"text": x} for x in dataset])

dataset[100]

{'text': '<start_of_turn>user\nUnder what circumstances is Cotrimoxazole prophylaxis not recommended among HIV infected children?<end_of_turn>\n<start_of_turn>model\nProphylaxis with Cotrimoxazole is not recommended for all symptomatic HIV infected children over 5 years of age irrespective of CD4 counts.<end_of_turn>\n'}

## Format Dataset
Format the dataset into instruction-response pairs using OpenChat templates.


In [None]:
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None,
    args = SFTConfig(
        output_dir=output_dir,
        torch_compile=False,
        dataset_text_field = "text",
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 1,
        #max_steps = 50,
        learning_rate = 2e-5,
        logging_steps = 10,
        save_strategy="steps",
        save_steps=500,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none",
    ),
)

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/86667 [00:00<?, ? examples/s]

We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs. This helps increase accuracy of finetunes!

In [None]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

Map (num_proc=8):   0%|          | 0/86667 [00:00<?, ? examples/s]

In [None]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

'<bos><start_of_turn>user\nUnder what circumstances is Cotrimoxazole prophylaxis not recommended among HIV infected children?<end_of_turn>\n<start_of_turn>model\nProphylaxis with Cotrimoxazole is not recommended for all symptomatic HIV infected children over 5 years of age irrespective of CD4 counts.<end_of_turn>\n'

In [None]:
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")

'                         Prophylaxis with Cotrimoxazole is not recommended for all symptomatic HIV infected children over 5 years of age irrespective of CD4 counts.<end_of_turn>\n'

# Let's train the model!

To resume a training run, set `trainer.train(resume_from_checkpoint = True)`

In [None]:
trainer_stats = trainer.train(resume_from_checkpoint = True)
#trainer_stats = trainer.train()
trainer_stats

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 86,667 | Num Epochs = 1 | Total steps = 5,417
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 4 x 1) = 16
 "-____-"     Trainable parameters = 19,210,240 of 7,869,188,432 (0.24% trained)
	save_steps: 500 (from args) != 10 (from trainer_state.json)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss


TrainOutput(global_step=5416, training_loss=0.0020422181843654796, metrics={'train_runtime': 272.5427, 'train_samples_per_second': 317.994, 'train_steps_per_second': 19.876, 'total_flos': 1.4501826535043072e+18, 'train_loss': 0.0020422181843654796})

In [None]:
model.save_pretrained(model_name)
tokenizer.save_pretrained(model_name)

model.push_to_hub(destination_model_name, token = api_key)
tokenizer.push_to_hub(destination_model_name, token = api_key)

README.md:   0%|          | 0.00/606 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/76.9M [00:00<?, ?B/s]

Saved model to https://huggingface.co/ericrisco/medical-gemma-3n-lora


tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

### GGUF / llama.cpp Conversion

In [None]:
model.save_pretrained_gguf(
  f"{model_name}",
  quantization_type = "Q8_0"
)

In [None]:
model.push_to_hub_gguf(
    f"{model_name}",
    quantization_type = "Q8_0",
    repo_id = f"{destination_model_name}-gguf",
    token = api_key,
)