# MedGemma Fine-Tuning for Nail Disease Classification
## Kaggle Optimized Version
This notebook fine-tunes Google's MedGemma model on nail disease classification.
- Better GPU: P100 (40GB) is 2-3x faster than Colab T4
- Free Training with no runtime limits
- Expected Time: 30 mins - 1 hour on P100 GPU

## Cell 1: Detect Environment & GPU

In [None]:
import os
import sys
import torch

IS_KAGGLE = os.path.exists('/kaggle')
IS_COLAB = 'google.colab' in sys.modules

if IS_KAGGLE:
    print('Running on Kaggle')
    ENVIRONMENT = 'kaggle'
elif IS_COLAB:
    print('Running on Google Colab')
    ENVIRONMENT = 'colab'
else:
    print('Running on Local Machine')
    ENVIRONMENT = 'local'

print(f'GPU Available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')

## Cell 2: Install Dependencies

In [None]:
!pip install -q transformers datasets torch bitsandbytes peft trl tensorboard scikit-learn pandas numpy
print('Dependencies installed')

## Cell 3: Import Libraries

In [None]:
import torch
import transformers
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig
import pandas as pd
import json
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

print(f'PyTorch: {torch.__version__}')
print(f'Transformers: {transformers.__version__}')

## Cell 4: Configuration

In [None]:
CONFIG = {
    'model_name': 'google/medgemma-4b',
    'batch_size': 4,
    'learning_rate': 2e-4,
    'num_epochs': 3,
    'max_seq_length': 512,
    'lora_r': 8,
    'lora_alpha': 16,
    'output_dir': './medgemma_nails_finetuned',
}

for key, value in CONFIG.items():
    print(f'{key}: {value}')

## Cell 5: Load CSV Data

In [None]:
if ENVIRONMENT == 'kaggle':
    csv_path = '/kaggle/input/nail-disease-classification/nail_diseases.csv'
elif ENVIRONMENT == 'colab':
    csv_path = '/content/drive/MyDrive/nail_diseases.csv'
else:
    csv_path = './nail_diseases.csv'

df = pd.read_csv(csv_path)
print(f'Loaded {len(df)} rows')
print(f'Shape: {df.shape}')

## Cell 6: Create Training Prompts

In [None]:
def create_prompt(row):
    findings = str(row.get('clinical_findings', 'N/A'))
    diagnosis = str(row.get('confirmed_diagnosis', 'N/A'))
    treatment = str(row.get('treatment_protocol', 'N/A'))
    prognosis = str(row.get('prognosis', 'N/A'))
    text = f'Clinical Findings: {findings}\n'
    text += f'Diagnosis: {diagnosis}\n'
    text += f'Treatment: {treatment}\n'
    text += f'Prognosis: {prognosis}'
    return text

df['text'] = df.apply(create_prompt, axis=1)
print(f'Created {len(df)} training samples')

## Cell 7: Split Data

In [None]:
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(f'Train: {len(train_df)} samples')
print(f'Val: {len(val_df)} samples')
print(f'Test: {len(test_df)} samples')

## Cell 8: Setup 4-bit Quantization

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16
)
print('4-bit quantization configured')

## Cell 9: Load MedGemma Model

In [None]:
print(f'Loading {CONFIG["model_name"]}...')
model = AutoModelForCausalLM.from_pretrained(
    CONFIG['model_name'],
    quantization_config=bnb_config,
    device_map='auto',
    token=True
)

tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_name'])
tokenizer.pad_token = tokenizer.eos_token
print('Model loaded')

## Cell 10: Setup LoRA

In [None]:
model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=CONFIG['lora_r'],
    lora_alpha=CONFIG['lora_alpha'],
    target_modules=['q_proj', 'v_proj'],
    lora_dropout=0.05,
    bias='none',
    task_type='CAUSAL_LM'
)

model = get_peft_model(model, lora_config)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'Trainable: {trainable:,} / {total:,}')
print(f'Trainable %: {100 * trainable / total:.2f}%')

## Cell 11: Create Datasets

In [None]:
train_dataset = Dataset.from_pandas(train_df[['text']])
val_dataset = Dataset.from_pandas(val_df[['text']])
test_dataset = Dataset.from_pandas(test_df[['text']])

print(f'Train: {len(train_dataset)} samples')
print(f'Val: {len(val_dataset)} samples')
print(f'Test: {len(test_dataset)} samples')

## Cell 12: Configure Training

In [None]:
training_config = SFTConfig(
    output_dir=CONFIG['output_dir'],
    num_train_epochs=CONFIG['num_epochs'],
    per_device_train_batch_size=CONFIG['batch_size'],
    per_device_eval_batch_size=CONFIG['batch_size'],
    gradient_accumulation_steps=2,
    learning_rate=CONFIG['learning_rate'],
    warmup_steps=100,
    max_seq_length=CONFIG['max_seq_length'],
    logging_steps=50,
    evaluation_strategy='steps',
    eval_steps=100,
    save_steps=100,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss',
    report_to=['tensorboard'],
    logging_dir='./logs'
)
print('Training config ready')

## Cell 13: Initialize Trainer

In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_config,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    dataset_text_field='text'
)
print('Trainer initialized')

## Cell 14: START TRAINING

In [None]:
print('Starting training...')
train_result = trainer.train()
print(f'Training loss: {train_result.training_loss:.4f}')

## Cell 15: Evaluate

In [None]:
test_results = trainer.evaluate(test_dataset)
print('Test Results:')
print(json.dumps(test_results, indent=2))

## Cell 16: Save Model

In [None]:
model.save_pretrained(CONFIG['output_dir'])
tokenizer.save_pretrained(CONFIG['output_dir'])
print(f'Model saved to {CONFIG["output_dir"]}')

## Cell 17: Test Inference

In [None]:
test_input = 'Clinical Findings: White nails with pink distal end. Diagnosis: '
inputs = tokenizer(test_input, return_tensors='pt')
outputs = model.generate(**inputs, max_new_tokens=50)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print('Generated:', result)

## Cell 18: Save Summary

In [None]:
summary = {
    'model': CONFIG['model_name'],
    'train_samples': len(train_df),
    'val_samples': len(val_df),
    'test_samples': len(test_df),
    'epochs': CONFIG['num_epochs'],
    'batch_size': CONFIG['batch_size'],
    'final_loss': float(train_result.training_loss)
}

with open('training_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print('Summary saved')
print(json.dumps(summary, indent=2))

## Cell 19: Complete

In [None]:
print('TRAINING COMPLETE!')
print('Files saved:')
print(f'  Model: {CONFIG["output_dir"]}/')
print('  Logs: ./logs/')
print('  Summary: ./training_summary.json')