# Fine-tune MedGemma for Nail Disease Classification
## Advanced Kaggle Notebook with Overfitting Detection

Based on: https://github.com/google-health/medgemma
Model: google/medgemma-7b-orcamath-it (Instruction-tuned, 7B params)
License: Apache 2.0

Features: Loss graphs | Overfitting detection | Comprehensive metrics

## Step 1: Setup Environment

In [None]:
import os
import torch
import json
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

IS_KAGGLE = os.path.exists('/kaggle')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('='*60)
print('ENVIRONMENT SETUP')
print('='*60)
print(f'Environment: {"Kaggle" if IS_KAGGLE else "Local/Colab"}')
print(f'Device: {DEVICE}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
else:
    print('GPU: None - CPU mode')
print(f'PyTorch: {torch.__version__}')
print('='*60)

In [None]:
!pip install -q transformers datasets torch bitsandbytes peft trl scikit-learn matplotlib
print('‚úÖ Packages installed')

## Step 2: Import Libraries

In [None]:
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig,
    TrainingArguments,
    set_seed
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

set_seed(42)
print('‚úÖ Libraries imported')

## Step 3: Load & Explore Dataset

In [None]:
# Load CSV dataset
csv_path = '/kaggle/input/nail-disease-medgemma/nail_diseases.csv'

try:
    df = pd.read_csv(csv_path)
    print(f'‚úÖ Loaded {len(df)} samples')
    print(f'\nColumns: {list(df.columns)}')
    print(f'\nFirst row:')
    print(df.iloc[0])
except FileNotFoundError:
    print(f'‚ùå File not found: {csv_path}')
    print('\nAlternative paths to try:')
    print('  - /kaggle/input/nail-disease-classification/nail_diseases.csv')
    print('  - /kaggle/input/nail-diseases/nail_diseases.csv')
    print('\nAvailable inputs:')
    for item in os.listdir('/kaggle/input')::
        print(f'  - {item}')

## Step 4: Create Training Prompts

In [None]:
def create_medical_prompt(row):
    """
    Creates a medical text prompt following MedGemma format.
    Structure: Clinical findings ‚Üí Diagnosis ‚Üí Treatment ‚Üí Prognosis
    """
    findings = str(row.get('clinical_findings', '')).strip()
    diagnosis = str(row.get('confirmed_diagnosis', '')).strip()
    treatment = str(row.get('treatment_protocol', '')).strip()
    prognosis = str(row.get('prognosis', '')).strip()
    
    # Build prompt in Orca format (instruction-following)
    prompt = f"""Clinical Case Analysis:
    
    Clinical Findings:
    {findings}
    
    Please provide the diagnosis and treatment plan.
    
    Expected Output:
    Diagnosis: {diagnosis}
    Treatment: {treatment}
    Prognosis: {prognosis}"""
    
    return prompt.strip()

# Apply to dataset
df['text'] = df.apply(create_medical_prompt, axis=1)

print(f'‚úÖ Created {len(df)} training prompts')
print(f'\nExample prompt (first 300 chars):')
print(df['text'].iloc[0][:300])

## Step 5: Split Dataset

In [None]:
# Split: 70% train, 15% val, 15% test
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(f'Train: {len(train_df)} samples')
print(f'Val:   {len(val_df)} samples')
print(f'Test:  {len(test_df)} samples')

## Step 6: Create HuggingFace Datasets

In [None]:
# Create HuggingFace datasets
train_dataset = Dataset.from_pandas(train_df[['text']])
val_dataset = Dataset.from_pandas(val_df[['text']])
test_dataset = Dataset.from_pandas(test_df[['text']])

print(f'‚úÖ HuggingFace datasets created')
print(f'  Train: {len(train_dataset)} samples')
print(f'  Val:   {len(val_dataset)} samples')
print(f'  Test:  {len(test_dataset)} samples')

## Step 7: Setup Model & Tokenizer

In [None]:
# Model configuration
MODEL_ID = 'google/medgemma-7b-orcamath-it'

# 4-bit quantization config (memory efficient)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

print(f'Loading model: {MODEL_ID}')
print('This may take 2-3 minutes...')

try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        device_map='auto',
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    
    print(f'‚úÖ Model loaded successfully')
    print(f'   Model size: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B parameters')
except Exception as e:
    print(f'‚ùå Error loading model: {str(e)[:100]}')
    print('Make sure you have HuggingFace token set up for gated models')

## Step 8: Configure LoRA

In [None]:
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# LoRA configuration
lora_config = LoraConfig(
    r=16,  # Rank
    lora_alpha=32,  # Alpha scaling
    target_modules=['q_proj', 'v_proj'],  # Query and Value projections
    lora_dropout=0.05,
    bias='none',
    task_type='CAUSAL_LM'
)

model = get_peft_model(model, lora_config)

# Count trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())

print(f'‚úÖ LoRA configured')
print(f'  Total params: {total_params / 1e9:.2f}B')
print(f'  Trainable: {trainable_params / 1e6:.2f}M ({100*trainable_params/total_params:.3f}%)')

## Step 9: Setup Training Configuration

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir='./medgemma_nail_disease_finetuned',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    lr_scheduler_type='cosine',
    warmup_steps=100,
    weight_decay=0.01,
    max_steps=500,  # Limit steps for faster training
    max_seq_length=512,
    logging_steps=10,
    eval_steps=50,
    save_steps=50,
    evaluation_strategy='steps',
    save_strategy='steps',
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss',
    greater_is_better=False,
    logging_dir='./logs',
    optim='paged_adamw_8bit',  # Memory-efficient optimizer
    seed=42,
)

print('‚úÖ Training configuration ready')
print(f'  Output: ./medgemma_nail_disease_finetuned')
print(f'  Epochs: {training_args.num_train_epochs}')
print(f'  Batch size: {training_args.per_device_train_batch_size}')
print(f'  Learning rate: {training_args.learning_rate}')
print(f'  Max steps: {training_args.max_steps}')

## Step 10: Initialize SFT Trainer

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    dataset_text_field='text',
    max_seq_length=512,
)

print('‚úÖ Trainer initialized')

## Step 11: üöÄ START TRAINING (30-60 minutes)

In [None]:
print('\n' + '='*60)
print('üöÄ STARTING TRAINING')
print('='*60)

train_result = trainer.train()

print('\n' + '='*60)
print('‚úÖ TRAINING COMPLETE')
print(f'Final Training Loss: {train_result.training_loss:.4f}')
print('='*60)

## Step 12: Evaluate & Save

In [None]:
# Evaluate on test set
eval_results = trainer.evaluate(test_dataset)
print(f'\nTest Loss: {eval_results.get("eval_loss", 0):.4f}')

# Save model
model.save_pretrained('./medgemma_nail_disease_finetuned')
tokenizer.save_pretrained('./medgemma_nail_disease_finetuned')
print('\n‚úÖ Model saved to ./medgemma_nail_disease_finetuned')

## Step 13: Extract & Visualize Training Metrics

In [None]:
import pandas as pd

history = {'train_loss': [], 'eval_loss': []}

try:
    from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
    if os.path.exists('./logs'):
        for file in sorted(os.listdir('./logs')):
            if 'events.out.tfevents' in file:
                ea = EventAccumulator(os.path.join('./logs', file))
                ea.Reload()
                for tag in ea.Tags().get('scalars', []):
                    events = ea.Scalars(tag)
                    for e in events:
                        if 'eval' in tag and 'loss' in tag:
                            history['eval_loss'].append(e.value)
                        elif 'loss' in tag and 'eval' not in tag:
                            history['train_loss'].append(e.value)
except Exception as e:
    print(f'Note: Could not extract tensorboard data: {str(e)[:50]}')

print(f'Extracted: {len(history["train_loss"])} train steps, {len(history["eval_loss"])} eval steps')

## Step 14: üìä Plot Loss Curves & Overfitting Analysis

In [None]:
train_loss = np.array(history['train_loss']) if history['train_loss'] else np.array([])
eval_loss = np.array(history['eval_loss']) if history['eval_loss'] else np.array([])

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('MedGemma Training: Overfitting Detection & Metrics', fontsize=14, fontweight='bold')

# Plot 1: Training Loss
if len(train_loss) > 0:
    axes[0, 0].plot(train_loss, marker='o', markersize=3, linewidth=2, color='blue')
    axes[0, 0].set_title('Training Loss Progression', fontweight='bold')
    axes[0, 0].set_xlabel('Training Step')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Validation Loss
if len(eval_loss) > 0:
    axes[0, 1].plot(eval_loss, marker='s', markersize=3, linewidth=2, color='orange')
    axes[0, 1].set_title('Validation Loss Progression', fontweight='bold')
    axes[0, 1].set_xlabel('Evaluation Step')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Train vs Eval with Gap
if len(eval_loss) > 0 and len(train_loss) > 0:
    min_len = min(len(train_loss), len(eval_loss))
    train_aligned = train_loss[-min_len:]
    eval_aligned = eval_loss[-min_len:]
    
    axes[1, 0].plot(train_aligned, marker='o', label='Train Loss', linewidth=2)
    axes[1, 0].plot(eval_aligned, marker='s', label='Eval Loss', linewidth=2)
    axes[1, 0].fill_between(range(min_len), train_aligned, eval_aligned, alpha=0.2, color='red', label='Overfitting Gap')
    axes[1, 0].set_title('Loss Gap: Train vs Eval', fontweight='bold')
    axes[1, 0].set_xlabel('Step')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Overfitting Metrics Summary
if len(eval_loss) > 0 and len(train_loss) > 0:
    min_len = min(len(train_loss), len(eval_loss))
    train_aligned = train_loss[-min_len:]
    eval_aligned = eval_loss[-min_len:]
    loss_gap = eval_aligned - train_aligned
    
    avg_gap = np.mean(loss_gap)
    max_gap = np.max(loss_gap)
    
    if avg_gap < 0.01:
        status = 'MINIMAL OVERFITTING'
    elif avg_gap < 0.05:
        status = 'MILD OVERFITTING'
    else:
        status = 'MODERATE-SEVERE OVERFITTING'
    
    metrics_text = f'OVERFITTING ANALYSIS\n\nAvg Loss Gap: {avg_gap:.6f}\nMax Loss Gap: {max_gap:.6f}\n\nStatus: {status}\n\nTrain Loss: {train_aligned[-1]:.6f}\nEval Loss: {eval_aligned[-1]:.6f}\n\nImprovement: {(1-eval_aligned[-1]/eval_aligned[0])*100:.1f}%'
    
    axes[1, 1].text(0.5, 0.5, metrics_text, ha='center', va='center', fontsize=10, family='monospace', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
    axes[1, 1].axis('off')

plt.tight_layout()
plt.savefig('overfitting_analysis.png', dpi=150, bbox_inches='tight')
plt.show()
print('‚úÖ Overfitting analysis saved to overfitting_analysis.png')

## Step 15: üîç Detailed Overfitting Report

In [None]:
if len(eval_loss) > 0 and len(train_loss) > 0:
    min_len = min(len(train_loss), len(eval_loss))
    train_aligned = train_loss[-min_len:]
    eval_aligned = eval_loss[-min_len:]
    loss_gap = eval_aligned - train_aligned
    
    print('\n' + '='*60)
    print('üîç OVERFITTING DETECTION ANALYSIS')
    print('='*60)
    
    print(f'\nüìä Loss Gap Statistics:')
    print(f'  Average Gap: {np.mean(loss_gap):.6f}')
    print(f'  Max Gap: {np.max(loss_gap):.6f}')
    print(f'  Min Gap: {np.min(loss_gap):.6f}')
    
    print(f'\nüìà Performance Metrics:')
    print(f'  Final Train Loss: {train_aligned[-1]:.6f}')
    print(f'  Final Eval Loss: {eval_aligned[-1]:.6f}')
    print(f'  Loss Improvement: {(1-eval_aligned[-1]/eval_aligned[0])*100:.1f}%')
    
    if np.mean(loss_gap) < 0.01:
        status = 'üü¢ MINIMAL OVERFITTING (Excellent!)'
    elif np.mean(loss_gap) < 0.05:
        status = 'üü° MILD OVERFITTING (Good)'
    else:
        status = 'üî¥ MODERATE-SEVERE OVERFITTING'
    
    print(f'\n‚úÖ Status: {status}')
    print('='*60)

## Step 16: Save Training Summary

In [None]:
summary = {
    'model': 'google/medgemma-7b-orcamath-it',
    'training_type': 'SFT (Supervised Fine-Tuning) with LoRA',
    'lora_rank': 16,
    'lora_alpha': 32,
    'train_samples': len(train_df),
    'val_samples': len(val_df),
    'test_samples': len(test_df),
    'epochs': 3,
    'batch_size': 4,
    'gradient_accumulation_steps': 2,
    'learning_rate': 2e-4,
    'optimizer': 'paged_adamw_8bit',
    'max_steps': 500,
    'quantization': '4-bit (nf4)',
}

if len(eval_loss) > 0 and len(train_loss) > 0:
    min_len = min(len(train_loss), len(eval_loss))
    train_aligned = train_loss[-min_len:]
    eval_aligned = eval_loss[-min_len:]
    loss_gap = eval_aligned - train_aligned
    
    summary.update({
        'final_train_loss': float(train_aligned[-1]),
        'final_eval_loss': float(eval_aligned[-1]),
        'avg_loss_gap': float(np.mean(loss_gap)),
        'max_loss_gap': float(np.max(loss_gap)),
        'loss_improvement_percent': float((1-eval_aligned[-1]/eval_aligned[0])*100),
        'overfitting_status': 'MINIMAL' if np.mean(loss_gap) < 0.01 else 'MILD' if np.mean(loss_gap) < 0.05 else 'MODERATE-SEVERE'
    })

with open('training_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print('‚úÖ Training Summary:')
print(json.dumps(summary, indent=2))

## Step 17: Test Inference

In [None]:
# Load best model for inference
model.load_state_dict(torch.load('./medgemma_nail_disease_finetuned/adapter_model.bin', map_location=DEVICE))

test_prompt = """Clinical Case Analysis:

Clinical Findings:
White nail beds with dark edges, slight clubbing.

Please provide the diagnosis and treatment plan.

Expected Output:
"""

inputs = tokenizer(test_prompt, return_tensors='pt').to(DEVICE)
outputs = model.generate(**inputs, max_new_tokens=100, do_sample=True, top_p=0.9)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)

print('Test Inference:')
print('='*60)
print(result)
print('='*60)

## Step 18: Complete! ‚úÖ

In [None]:
print('\n' + '='*60)
print('‚úÖ FINE-TUNING & ANALYSIS COMPLETE!')
print('='*60)
print('\nüìÅ Output Files:')
print('  ‚úÖ medgemma_nail_disease_finetuned/')
print('     - adapter_model.bin (LoRA weights)')
print('     - config.json')
print('     - tokenizer files')
print('  ‚úÖ overfitting_analysis.png (4-subplot visualization)')
print('  ‚úÖ training_summary.json (metrics & config)')
print('  ‚úÖ logs/ (tensorboard data)')
print('\nüöÄ Next Steps:')
print('  1. Download files from Kaggle Output tab')
print('  2. Use for inference on new nail disease cases')
print('  3. Evaluate on real medical images')
print('\nüìä Model Performance:')
if len(eval_loss) > 0:
    print(f'  Final Test Loss: {eval_loss[-1]:.4f}')
print('='*60)