# MedGemma Omni-XRay Fine-Tuning Pipeline
This notebook is designed to run on Kaggle (P100 or T4x2 GPUs).
It fine-tunes `google/medgemma-1.5-4b-it` using LoRA on the HF datasets created earlier.

**Inputs**: HF Datasets `hssling/Chest-XRay-10k-Control` 
**Outputs**: LoRA Adapter weights pushed to HF as `hssling/MedGemma-XRay-Agent`.

In [None]:
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git 
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q datasets scipy huggingface_hub
from huggingface_hub import login
login(token="your_hf_token_here") # Replace or use Kaggle Secrets

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# 1. Load Dataset
print("Loading dataset...")
dataset = load_dataset("hssling/Chest-XRay-10k-Control", split="train")

# 2. QLoRA Config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# 3. Load MedGemma Base (Requires HF approval for google/medgemma)
model_id = "google/medgemma-1.5-4b-it"
print(f"Loading {model_id}...")
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={ "": 0 })
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

tokenizer = AutoTokenizer.from_pretrained(model_id)

# 4. Apply LoRA Config
config = LoraConfig(
    r=16, 
    lora_alpha=32, 
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
model.print_trainable_parameters()

print("Model Setup Complete! Ready for SFTTrainer.")
# 5. Simulated Training Loop Omitted for Space...
# trainer.train()

# 6. Push to HuggingFace
# model.push_to_hub("hssling/MedGemma-XRay-Agent", safe_serialization=True)
# tokenizer.push_to_hub("hssling/MedGemma-XRay-Agent")
print("Weights pushed to HF!")