# Gemma 3 QLoRA Fine-Tuning — Safety Analyst Model
**On-Device Safety Copilot — Step 2: Safety Brief Generator**

Fine-tunes Gemma 3 1B (or 4B) with QLoRA on synthetic safety instruction data.
Converts detected GHS symbols → structured safety briefs + voice scripts.

**Setup:**
1. Upload `gemma_safety_train.jsonl` and `gemma_safety_valid.jsonl`
2. Accept Gemma license on HuggingFace
3. Set HF token in Colab secrets as `HF_TOKEN`

In [1]:
# ── 1. Install dependencies ──
!pip install -q "huggingface-hub>=0.26.0,<1.0" "transformers>=4.50.0" accelerate bitsandbytes peft trl datasets

In [2]:
# ── 2. Authenticate & imports ──
import os, json, random
import torch
from datasets import Dataset

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig

from huggingface_hub import login

try:
    from google.colab import userdata
    HF_TOKEN = userdata.get('HF_TOKEN')
except:
    HF_TOKEN = os.environ.get('HF_TOKEN', '')

login(token=HF_TOKEN)
print('Authenticated with HuggingFace')

Authenticated with HuggingFace


In [3]:
# ── 3. Configuration ──
# Choose model size based on your GPU:
#   T4 (16GB)  → google/gemma-3-1b-it   (safe choice)
#   A100/L4    → google/gemma-3-4b-it    (better quality)
MODEL_ID = "google/gemma-3-4b-it"

TRAIN_FILE = "gemma_safety_train.jsonl"
VALID_FILE = "gemma_safety_valid.jsonl"
OUTPUT_DIR = "gemma_safety_finetuned"

# Hyperparams
NUM_EPOCHS = 3
BATCH_SIZE = 2
GRAD_ACCUM = 8            # effective batch = 16
LEARNING_RATE = 2e-4
MAX_SEQ_LENGTH = 1024

# LoRA
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

Device: cuda


In [4]:
# ── 4. Load model with 4-bit quantization ──
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.padding_side = "right"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",  # sdpa can cause issues with QLoRA
)

model = prepare_model_for_kbit_training(model)
print(f"Model loaded: {MODEL_ID}")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded: google/gemma-3-4b-it


In [5]:
# ── 5. Apply LoRA ──
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 32,788,480 || all params: 4,332,867,952 || trainable%: 0.7567


In [6]:
# ── 6. Load & format dataset ──
def load_jsonl(path):
    records = []
    with open(path) as f:
        for line in f:
            records.append(json.loads(line.strip()))
    return records

def format_chat(record):
    """Convert messages list to Gemma chat template format."""
    messages = record["messages"]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )
    return {"text": text}

# Load
train_raw = load_jsonl(TRAIN_FILE)
valid_raw = load_jsonl(VALID_FILE)

train_ds = Dataset.from_list(train_raw).map(format_chat)
valid_ds = Dataset.from_list(valid_raw).map(format_chat)

print(f"Train: {len(train_ds)}, Valid: {len(valid_ds)}")
print(f"\nSample formatted text (first 500 chars):\n{train_ds[0]['text'][:500]}")

Map:   0%|          | 0/1215 [00:00<?, ? examples/s]

Map:   0%|          | 0/136 [00:00<?, ? examples/s]

Train: 1215, Valid: 136

Sample formatted text (first 500 chars):
<bos><start_of_turn>user
You are an on-device safety assistant for industrial workers. Given detected GHS hazard symbols (and optionally OCR text from a product label), provide a structured safety brief with: severity level, hazard summary, required PPE, step-by-step handling SOP, storage requirements, emergency/first-aid procedures, autonomous safety actions, and a short spoken voice script (under 40 words). Be concise, factual, and prioritize worker safety.

Symbols: GHS_Symbol_FLAME (detected


In [8]:
# ── 7. Training with SFTTrainer ──
# Check trl version to handle API differences
import trl
print(f"trl version: {trl.__version__}")

training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    weight_decay=0.01,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    bf16=True,
    optim="paged_adamw_8bit",
    report_to="none",
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

# Tokenize dataset manually to handle max_seq_length
def tokenize(example):
    return tokenizer(
        example["text"],
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
        padding="max_length",
    )

train_ds_tok = train_ds.map(tokenize, batched=True, remove_columns=["text"])
valid_ds_tok = valid_ds.map(tokenize, batched=True, remove_columns=["text"])
train_ds_tok = train_ds_tok.map(lambda x: {"labels": x["input_ids"]})
valid_ds_tok = valid_ds_tok.map(lambda x: {"labels": x["input_ids"]})

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds_tok,
    eval_dataset=valid_ds_tok,
    processing_class=tokenizer,
)

print("Starting Gemma fine-tuning...")
trainer.train()
print("Training complete!")

trl version: 0.29.0


Map:   0%|          | 0/1215 [00:00<?, ? examples/s]

Map:   0%|          | 0/136 [00:00<?, ? examples/s]

Map:   0%|          | 0/1215 [00:00<?, ? examples/s]

Map:   0%|          | 0/136 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/1215 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/136 [00:00<?, ? examples/s]

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.


Starting Gemma fine-tuning...


Step,Training Loss,Validation Loss
50,0.0262,0.018349
100,0.0139,0.013473
150,0.0119,0.012106
200,0.0113,0.011991


Training complete!


In [9]:
# ── 8. Save adapter ──
ADAPTER_DIR = os.path.join(OUTPUT_DIR, "lora_adapter")
model.save_pretrained(ADAPTER_DIR)
tokenizer.save_pretrained(ADAPTER_DIR)
print(f"Gemma LoRA adapter saved to {ADAPTER_DIR}")

Gemma LoRA adapter saved to gemma_safety_finetuned/lora_adapter


In [12]:
# ── 9. Inference test ──
SYSTEM_PROMPT = (
    "You are an on-device safety assistant for industrial workers. "
    "Given detected GHS hazard symbols (and optionally OCR text from a product label), "
    "provide a structured safety brief with: severity level, hazard summary, required PPE, "
    "step-by-step handling SOP, storage requirements, emergency/first-aid procedures, "
    "autonomous safety actions, and a short spoken voice script (under 40 words). "
    "Be concise, factual, and prioritize worker safety."
)

test_input = (
    'Symbols: GHS_Symbol_FLAME (detected), GHS_Symbol_CORROSION (detected)\n'
    'OCR: "Hydrochloric Acid UN1789 CAS 7647-01-0"\n'
    'Environment: chemical storage room\n'
    'User role: warehouse worker\n'
    'Provide: hazards, PPE, handling SOP, storage, emergency/first-aid, '
    'autonomous safety actions, and a short voice script.'
)

messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": test_input},
]

prompt = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

# ── Critical: switch to eval mode before inference ──
model.eval()
model.config.use_cache = True

inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

with torch.no_grad():
    output_ids = model.generate(
        **inputs,
        max_new_tokens=512,
        do_sample=False,
        repetition_penalty=1.1,
    )

response = tokenizer.decode(
    output_ids[0][inputs.input_ids.shape[1]:],
    skip_special_tokens=True
)

print("── Test Inference ──")
print(f"Input: {test_input[:120]}...")
print(f"\nResponse:\n{response}")

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


── Test Inference ──
Input: Symbols: GHS_Symbol_FLAME (detected), GHS_Symbol_CORROSION (detected)
OCR: "Hydrochloric Acid UN1789 CAS 7647-01-0"
Envi...

Response:
SEVERITY: MEDIUM

1) SUMMARY
- GHS_Symbol_FLAME: Flammable material — catches fire easily under normal conditions
- GHS_Symbol_CORROSION: Corrosive — causes severe skin burns, eye damage, or corrodes metals

2) HAZARDS
  - Highly flammable liquid and vapor
  - May form explosive vapor-air mixtures
  - Risk of flash fire or deflagration
  - Causes severe skin burns and eye damage
  - May be corrosive to metals
  - Irreversible tissue destruction on contact

3) PPE REQUIRED
  - Wear flame-resistant clothing
  - Use chemical splash goggles
  - Use nitrile or neoprene gloves
  - Ensure adequate ventilation or use respiratory protection
  - Wear chemical-resistant gloves (butyl rubber or neoprene)
  - Use full-face shield or chemical splash goggles

4) HANDLING SOP
  1. Keep away from heat, sparks, open flames, and hot surfaces — 

In [None]:
# ── 10. (Optional) Merge adapter for deployment ──
# Uncomment to merge LoRA weights into base model for on-device use

# from peft import AutoPeftModelForCausalLM
#
# merged_model = model.merge_and_unload()
# MERGED_DIR = os.path.join(OUTPUT_DIR, "merged")
# merged_model.save_pretrained(MERGED_DIR)
# tokenizer.save_pretrained(MERGED_DIR)
# print(f"Merged model saved to {MERGED_DIR}")

## Next Steps
1. Download `lora_adapter/` (or `merged/`) for on-device deployment
2. Build the inference pipeline connecting PaliGemma → Gemma → TTS
3. Add the agentic action layer (rule-based + model hybrid)
4. Build the demo UI

In [13]:
import zipfile
import os
from google.colab import files

# ADAPTER_DIR is already defined in the kernel state as 'gemma_safety_finetuned/lora_adapter'

# Define the name for the output zip file
zip_filename = "lora_adapter.zip"

# Create a ZipFile object in write mode
with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, _, files_in_dir in os.walk(ADAPTER_DIR):
        for file in files_in_dir:
            file_path = os.path.join(root, file)
            # Add file to zip, preserving directory structure relative to ADAPTER_DIR
            zipf.write(file_path, os.path.relpath(file_path, ADAPTER_DIR))

print(f"Successfully created {zip_filename} containing contents of {ADAPTER_DIR}")

# Offer the file for download
files.download(zip_filename)

Successfully created lora_adapter.zip containing contents of gemma_safety_finetuned/lora_adapter


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>