# üîß MedGemma Fine-Tuning with QLoRA

This notebook demonstrates how to fine-tune MedGemma for TB detection using QLoRA.

## Targets
- **Novel Task Prize ($10,000)**: Fine-tuned TB detection model
- **Edge AI Prize ($5,000)**: Quantized model for deployment

## 1. Environment Setup

In [None]:
# Check GPU availability
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install dependencies if needed
# !pip install transformers peft trl bitsandbytes accelerate

In [None]:
import os
import sys
sys.path.insert(0, '..')

from transformers import (
    AutoModelForVision2Seq,
    AutoProcessor,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer

print("‚úÖ Libraries loaded")

## 2. Model Configuration

In [None]:
# Model settings
MODEL_ID = "google/medgemma-4b-it"  # Multimodal model for X-ray analysis
OUTPUT_DIR = "../checkpoints/tb_finetuned"

# QLoRA settings (4-bit quantization)
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# LoRA settings
lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    task_type="CAUSAL_LM",
)

print("üìã Configuration:")
print(f"   Model: {MODEL_ID}")
print(f"   Quantization: 4-bit NF4")
print(f"   LoRA rank: {lora_config.r}")

## 3. Load Model with Quantization

In [None]:
# Load processor
processor = AutoProcessor.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
)
print(f"‚úÖ Processor loaded")

# Load model with quantization
model = AutoModelForVision2Seq.from_pretrained(
    MODEL_ID,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
)
print(f"‚úÖ Model loaded with 4-bit quantization")

# Prepare for k-bit training
model = prepare_model_for_kbit_training(model)
print(f"‚úÖ Model prepared for training")

In [None]:
# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

## 4. Prepare Training Data

In [None]:
# Import dataset loader
from src.data import TBDatasetLoader

# Initialize loader
loader = TBDatasetLoader("../data")

print("Available datasets:")
for name, info in loader.DATASETS.items():
    print(f"  - {name}: {info['description']}")

In [None]:
# Create training data
# Note: This requires downloading Kaggle datasets first
# Run: kaggle datasets download kmader/pulmonary-chest-xray-abnormalities

try:
    train_data, val_data = loader.create_training_data(
        datasets=["shenzhen", "montgomery"],
        train_ratio=0.8,
    )
    print(f"‚úÖ Loaded {len(train_data)} training samples")
    print(f"‚úÖ Loaded {len(val_data)} validation samples")
except Exception as e:
    print(f"‚ö†Ô∏è Could not load datasets: {e}")
    print("\nTo download datasets, run:")
    print("  kaggle datasets download kmader/pulmonary-chest-xray-abnormalities")
    
    # Create mock data for demonstration
    train_data = [
        {
            "text": "<start_of_turn>user\nAnalyze this chest X-ray.\n<end_of_turn>\n<start_of_turn>model\nNormal chest X-ray.<end_of_turn>",
            "label": 0,
        }
    ] * 10
    val_data = train_data[:2]

## 5. Training Configuration

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    warmup_ratio=0.1,
    logging_steps=10,
    save_steps=100,
    eval_strategy="steps",
    eval_steps=50,
    bf16=True,
    optim="paged_adamw_8bit",
    gradient_checkpointing=True,
    report_to="none",  # Set to "wandb" for experiment tracking
)

print("üìã Training configuration:")
print(f"   Epochs: {training_args.num_train_epochs}")
print(f"   Batch size: {training_args.per_device_train_batch_size}")
print(f"   Learning rate: {training_args.learning_rate}")
print(f"   Gradient accumulation: {training_args.gradient_accumulation_steps}")

## 6. Create Trainer

In [None]:
# Create collate function for multimodal data
def collate_fn(examples):
    texts = [ex["text"] for ex in examples]
    
    # Process with tokenizer
    batch = processor(
        text=texts,
        padding=True,
        truncation=True,
        max_length=2048,
        return_tensors="pt",
    )
    
    # Set labels for causal LM
    batch["labels"] = batch["input_ids"].clone()
    
    return batch

In [None]:
# Create trainer
from datasets import Dataset

train_dataset = Dataset.from_list(train_data)
val_dataset = Dataset.from_list(val_data)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    processing_class=processor,
    data_collator=collate_fn,
)

print("‚úÖ Trainer created")

## 7. Train Model

In [None]:
# Start training
print("üöÄ Starting training...")
print("=" * 50)

# Uncomment to actually train
# trainer.train()

print("\n‚ö†Ô∏è Training disabled for demo. Uncomment trainer.train() to run.")

## 8. Save Model

In [None]:
# Save the fine-tuned model
# model.save_pretrained(OUTPUT_DIR)
# processor.save_pretrained(OUTPUT_DIR)

print(f"Model would be saved to: {OUTPUT_DIR}")
print("\nTo save, uncomment the save commands above.")

## 9. Quantization for Edge Deployment

In [None]:
print("üîß Edge AI Quantization Options:")
print("=" * 50)
print()
print("1. GPTQ (GPU inference):")
print("   python scripts/quantize.py --model-path ./checkpoints/tb_finetuned --method gptq")
print()
print("2. GGUF (CPU/Mobile):")
print("   python scripts/quantize.py --model-path ./checkpoints/tb_finetuned --method gguf")
print()
print("3. INT8 (Balanced):")
print("   python scripts/quantize.py --model-path ./checkpoints/tb_finetuned --method int8")

## Summary

This notebook demonstrated:
1. ‚úÖ Loading MedGemma with 4-bit quantization
2. ‚úÖ Applying QLoRA for efficient fine-tuning
3. ‚úÖ Preparing TB X-ray training data
4. ‚úÖ Training configuration
5. ‚úÖ Edge deployment options

### Prize Targets
- **Novel Task Prize**: Fine-tuned TB detection model
- **Edge AI Prize**: Quantized models for mobile deployment