# MedGemma LoRA Fine-Tuning
Fine-tunes `google/medgemma-4b-it` with QLoRA on synthetic spatial pathology data.
Designed for Kaggle T4 GPU (16 GB VRAM). Runtime ~2-3h.

In [None]:
import subprocess, sys, os
for pkg in ['peft', 'trl', 'datasets', 'bitsandbytes']:
    try:
        __import__(pkg.replace('-','_'))
    except ImportError:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])
print('Packages ready')

In [None]:
HF_TOKEN = None
_secret_error = None
try:
    from kaggle_secrets import UserSecretsClient
    HF_TOKEN = UserSecretsClient().get_secret('HF_TOKEN')
    if HF_TOKEN:
        os.environ['HF_TOKEN'] = HF_TOKEN
        print('HF_TOKEN loaded')
    else:
        print('WARNING: empty secret')
except Exception as e:
    _secret_error = str(e)
    print(f'Kaggle secrets error: {_secret_error}')
if not HF_TOKEN:
    HF_TOKEN = os.environ.get('HF_TOKEN')
if not HF_TOKEN:
    raise RuntimeError(f'HF_TOKEN not found. Error: {_secret_error}')

In [None]:
from datasets import load_dataset
from pathlib import Path

data_paths = [
    Path('/kaggle/input/medgemma-spatial-data/'),
    Path('/kaggle/input/medgemma-lora-data/'),
]
data_dir = next((p for p in data_paths if (p / 'train.jsonl').exists()), None)
if data_dir is None:
    raise FileNotFoundError('Training data not found. Add medgemma-spatial-data dataset with train.jsonl.')

dataset = load_dataset('json', data_files={
    'train': str(data_dir / 'train.jsonl'),
    'validation': str(data_dir / 'eval.jsonl')
})
print(f"Train: {len(dataset['train'])} | Eval: {len(dataset['validation'])}")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
)
model_id = 'google/medgemma-4b-it'
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map='auto',
    token=HF_TOKEN,
)
model.config.use_cache = False
print(f'Model loaded on {next(model.parameters()).device}')

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
    bias='none', task_type='CAUSAL_LM'
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
from trl import SFTTrainer, SFTConfig

training_args = SFTConfig(
    output_dir='/kaggle/working/medgemma-adapter',
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    max_seq_length=768,
    lr_scheduler_type='cosine',
    warmup_ratio=0.05,
    logging_steps=10,
    save_strategy='epoch',
    eval_strategy='epoch',
    fp16=True,
    optim='paged_adamw_8bit',
    report_to='none',
    dataset_text_field='text',
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    args=training_args,
    tokenizer=tokenizer,
)
trainer.train()
print('Training complete')

In [None]:
adapter_local = '/kaggle/working/medgemma-spatial-pathology-adapter'
trainer.model.save_pretrained(adapter_local)
tokenizer.save_pretrained(adapter_local)
print(f'Adapter saved to {adapter_local}')

from huggingface_hub import HfApi
api = HfApi(token=HF_TOKEN)
repo_id = 'harshameghadri/medgemma-spatial-pathology-adapter'
api.create_repo(repo_id=repo_id, repo_type='model', exist_ok=True, private=False)
trainer.model.push_to_hub(repo_id, token=HF_TOKEN)
tokenizer.push_to_hub(repo_id, token=HF_TOKEN)
print(f'Adapter pushed to https://huggingface.co/{repo_id}')

In [None]:
from peft import PeftModel

test_model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.float16, device_map='auto', token=HF_TOKEN
)
test_model = PeftModel.from_pretrained(test_model, adapter_local)

test_prompt = (
    "<start_of_turn>user\n"
    "Analyze: hot immune tumor, T cells 35%, Moran's I 0.62, entropy 0.85"
    "<end_of_turn>\n<start_of_turn>model\n"
)
inputs = tokenizer(test_prompt, return_tensors='pt').to('cuda')
with torch.no_grad():
    out = test_model.generate(**inputs, max_new_tokens=200, temperature=0.7, do_sample=True)
print(tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True))