# Fine-tune MedGemma 4B for Nail Disease Clinical Explanations
## Model 2: Clinical Findings ‚Üí Medical Explanations Pipeline

**Pipeline Architecture:**
- Stage 1 ‚úÖ DONE: MedSigLIP (Image Classification) ‚Üí "Clubbing" / "Pitting" etc.
- Stage 2 ‚≠ê NOW: MedGemma 4B (Clinical Explanation) ‚Üí "What does this mean?"
- Stage 3: MedGemma 27B (Disease Ranking) ‚Üí "What diseases could cause this?"

Based on: https://github.com/google-health/medgemma
Model: google/medgemma-4b-it (Lightweight, 4B params, instruction-tuned)
License: Apache 2.0

Features: Interactive HuggingFace auth | Medical prompt engineering | CSV integration | Missing value handling | Overfitting detection | 50% faster training

## Step 0: Suppress CUDA/cuDNN Warnings (Optional)

In [None]:
# Suppress CUDA/cuDNN duplicate factory registration warnings
import os
import warnings

# Suppress TensorFlow/XLA warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['XLA_FLAGS'] = '--xla_gpu_deterministic_ops'

# Suppress pydantic warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)

import logging
logging.getLogger('absl').setLevel(logging.ERROR)
logging.getLogger('tensorflow').setLevel(logging.ERROR)
logging.getLogger('transformers').setLevel(logging.WARNING)

print('‚úÖ Warning filters applied - CUDA/cuDNN messages suppressed')



## Step 1: Setup Environment & Interactive HuggingFace Authentication

In [None]:
import os
import torch
import json
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

IS_KAGGLE = os.path.exists('/kaggle')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('='*60)
print('ENVIRONMENT SETUP')
print('='*60)
print(f'Environment: {"Kaggle" if IS_KAGGLE else "Local/Colab"}')
print(f'Device: {DEVICE}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
else:
    print('GPU: None - CPU mode')
print(f'PyTorch: {torch.__version__}')
print('='*60)

ENVIRONMENT SETUP
Environment: Kaggle
Device: cuda
GPU: Tesla T4
Memory: 15.6 GB
PyTorch: 2.8.0+cu126


In [None]:
# Install packages
!pip install -q transformers datasets torch bitsandbytes peft trl scikit-learn matplotlib huggingface-hub
print('‚úÖ Packages installed')

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m59.1/59.1 MB[0m [31m32.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m532.9/532.9 kB[0m [31m27.2 MB/s[0m eta [36m0:00:00[0m
[?25h‚úÖ Packages installed


## Step 1a: üîë Interactive HuggingFace Authentication Helper

In [None]:
from huggingface_hub import notebook_login

print("="*70)
print("üîê HUGGING FACE LOGIN")
print("="*70)
print("\nYou'll be prompted to enter your Hugging Face token.")
print("Get your token: https://huggingface.co/settings/tokens\n")

notebook_login()

print("\n‚úÖ Login successful!")

üîê HUGGING FACE LOGIN

You'll be prompted to enter your Hugging Face token.
Get your token: https://huggingface.co/settings/tokens



VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶


‚úÖ Login successful!


## Step 2: Import Libraries

In [None]:
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    set_seed
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

set_seed(42)
print('‚úÖ Libraries imported')

2026-01-30 17:50:39.207418: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769795439.679879      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769795439.822274      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769795440.962512      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769795440.962550      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769795440.962552      55 computation_placer.cc:177] computation placer alr

‚úÖ Libraries imported


## Step 3: Load & Explore Dataset (Model 2 Training Data)

In [None]:
# Load CSV dataset for Model 2: Clinical Explanation Stage
csv_path = '/kaggle/input/nail-disease-medgemma/nail_diseases.csv'

# Find CSV in various possible locations
possible_paths = [
    '/kaggle/input/nail-disease-medgemma/nail_diseases.csv',
    '/kaggle/input/nail-disease-classification/nail_diseases.csv',
    '/kaggle/input/nail-diseases/nail_diseases.csv',
    './nail_diseases.csv'
]

df = None
for path in possible_paths:
    if os.path.exists(path):
        df = pd.read_csv(path)
        csv_path = path
        print(f'‚úÖ Found CSV at: {path}')
        break

if df is None:
    print(f'‚ùå CSV file not found in standard locations')
    print('\nAvailable inputs:')
    if IS_KAGGLE and os.path.exists('/kaggle/input'):
        for item in os.listdir('/kaggle/input'):
            print(f'  - {item}')
    sys.exit(1)

print(f'\n‚úÖ Loaded {len(df)} samples from {csv_path}')
print(f'\nDataset Shape: {df.shape}')
print(f'\nColumns: {list(df.columns)}')
print(f'\nFirst row:')
print(df.iloc[0])
print(f'\nData types:')
print(df.dtypes)

‚úÖ Found CSV at: /kaggle/input/nail-diseases/nail_diseases.csv

‚úÖ Loaded 10000 samples from /kaggle/input/nail-diseases/nail_diseases.csv

Dataset Shape: (10000, 15)

Columns: ['nail_disease_category', 'model_1_predicted_disease', 'confirmed_diagnosis', 'patient_age', 'patient_sex', 'patient_ethnicity', 'fitzpatrick_skin_type', 'disease_severity', 'clinical_findings', 'differential_diagnoses', 'recommended_medical_tests', 'treatment_protocol', 'comorbidities', 'clinical_notes', 'prognosis']

First row:
nail_disease_category                                              Blue_Finger
model_1_predicted_disease                                          Blue Finger
confirmed_diagnosis                                                Blue Finger
patient_age                                                                 43
patient_sex                                                               Male
patient_ethnicity                                                    Caucasian
fitzpatrick_ski

## Step 4: Data Cleaning - Handle Missing Values

In [None]:
print('üßπ DATA CLEANING - MISSING VALUE HANDLING')
print('='*60)

# Show missing values before cleaning
missing_before = df.isnull().sum()
print('\nüìä Missing values BEFORE cleaning:')
print(missing_before[missing_before > 0])

# Create a copy for cleaning
df_clean = df.copy()

# Strategy 1: Drop rows where critical columns are missing
critical_cols = ['nail_disease', 'disease_name', 'clinical_findings', 'findings']
critical_cols_present = [col for col in critical_cols if col in df_clean.columns]

if critical_cols_present:
    initial_len = len(df_clean)
    df_clean = df_clean.dropna(subset=critical_cols_present)
    dropped_critical = initial_len - len(df_clean)
    print(f'\nüóëÔ∏è  Dropped {dropped_critical} rows with missing critical fields')

# Strategy 2: Fill missing values in non-critical columns
non_critical_cols = ['comorbidities', 'patient_age', 'age', 'patient_sex', 'sex', 'gender']
for col in df_clean.columns:
    if col in non_critical_cols and df_clean[col].isnull().sum() > 0:
        if df_clean[col].dtype in ['float64', 'int64']:
            # Fill numeric columns with median
            df_clean[col].fillna(df_clean[col].median(), inplace=True)
            print(f'  üìå Filled {col} with median value')
        else:
            # Fill string columns with unknown
            df_clean[col].fillna('unknown', inplace=True)
            print(f'  üìå Filled {col} with unknown')

# Strategy 3: Drop rows with ANY remaining NaN
initial_len = len(df_clean)
df_clean = df_clean.dropna()
dropped_final = initial_len - len(df_clean)
if dropped_final > 0:
    print(f'\nüóëÔ∏è  Dropped {dropped_final} additional rows with remaining NaN values')

# Show missing values after cleaning
missing_after = df_clean.isnull().sum()
if missing_after.sum() == 0:
    print('\n‚úÖ No missing values remaining!')
else:
    print('\n‚ö†Ô∏è  Remaining missing values:')
    print(missing_after[missing_after > 0])

print(f'\nüìä Dataset Summary:')
print(f'  Original rows: {len(df)}')
print(f'  Clean rows: {len(df_clean)}')
print(f'  Removed: {len(df) - len(df_clean)} ({(len(df) - len(df_clean))/len(df)*100:.1f}%)')
print('='*60)

# Use cleaned dataset
df = df_clean

üßπ DATA CLEANING - MISSING VALUE HANDLING

üìä Missing values BEFORE cleaning:
comorbidities    3671
dtype: int64

üóëÔ∏è  Dropped 0 rows with missing critical fields
  üìå Filled comorbidities with unknown

‚úÖ No missing values remaining!

üìä Dataset Summary:
  Original rows: 10000
  Clean rows: 10000
  Removed: 0 (0.0%)


## Step 5: Enhanced Medical Prompt Templates for Clinical Explanations

In [None]:
def create_medical_prompt_model2(row):
    """
    Creates advanced medical prompts for MedGemma Model 2 (Clinical Explanation Stage).

    Input: Nail disease classification from Model 1 + clinical findings
    Output: Detailed clinical explanation of findings, differential diagnoses, and systemic implications
    """

    # Extract fields (handle missing/None values)
    nail_disease = str(row.get('nail_disease', row.get('disease_name', 'unknown'))).strip()
    clinical_findings = str(row.get('clinical_findings', row.get('findings', 'no findings reported'))).strip()
    patient_age = row.get('patient_age', row.get('age', 'unknown'))
    patient_sex = str(row.get('patient_sex', row.get('sex', row.get('gender', 'unknown')))).strip()
    differential_diagnoses = str(row.get('differential_diagnoses', row.get('differentials', 'pending investigation'))).strip()
    systemic_implications = str(row.get('systemic_implications', row.get('implications', 'requires clinical assessment'))).strip()
    treatment_protocol = str(row.get('treatment_protocol', row.get('treatment', 'refer to specialist'))).strip()
    comorbidities = str(row.get('comorbidities', 'none reported')).strip()

    # Build instruction-following prompt (Orca format)
    prompt = f"""CLINICAL ANALYSIS: Nail Disease Diagnosis

PATIENT DEMOGRAPHICS:
Age: {patient_age}
Sex: {patient_sex}
Comorbidities: {comorbidities}

PRIMARY FINDING (from Model 1 - MedSigLIP):
{nail_disease}

CLINICAL PRESENTATION:
{clinical_findings}

INSTRUCTION:
Based on the nail disease finding and clinical presentation above, provide:
1. Detailed explanation of what the nail finding indicates
2. Possible systemic diseases that could cause this nail finding
3. Recommended diagnostic workup and treatment approach

EXPECTED RESPONSE:
Nail Finding Explanation: {nail_disease} indicates {systemic_implications}

Differential Diagnoses: {differential_diagnoses}

Recommended Treatment: {treatment_protocol}
"""

    return prompt.strip()

# Apply prompt template to dataset
df['text'] = df.apply(create_medical_prompt_model2, axis=1)

print(f'‚úÖ Created {len(df)} medical prompts for Model 2 training')
print(f'\nExample prompt (first 500 chars):')
print('='*60)
print(df['text'].iloc[0][:500])
print('='*60)
print(f'\nAverage prompt length: {df["text"].str.len().mean():.0f} chars')
print(f'Max prompt length: {df["text"].str.len().max():.0f} chars')

‚úÖ Created 10000 medical prompts for Model 2 training

Example prompt (first 500 chars):
CLINICAL ANALYSIS: Nail Disease Diagnosis

PATIENT DEMOGRAPHICS:
Age: 43
Sex: Male
Comorbidities: unknown

PRIMARY FINDING (from Model 1 - MedSigLIP):
unknown

CLINICAL PRESENTATION:
Transient blue discoloration; Blue-gray pigmentation

INSTRUCTION:
Based on the nail disease finding and clinical presentation above, provide:
1. Detailed explanation of what the nail finding indicates
2. Possible systemic diseases that could cause this nail finding
3. Recommended diagnostic workup and treatment app

Average prompt length: 727 chars
Max prompt length: 774 chars


## Step 6: Data Quality & Validation Check

In [None]:
# Check for missing values and data quality
print('üìä DATA QUALITY REPORT')
print('='*60)

# Missing values
missing = df.isnull().sum()
if missing.sum() > 0:
    print('\n‚ö†Ô∏è  Missing values detected:')
    print(missing[missing > 0])
else:
    print('\n‚úÖ No missing values')

# Check text field quality
empty_texts = (df['text'].str.len() < 50).sum()
if empty_texts > 0:
    print(f'\n‚ö†Ô∏è  {empty_texts} prompts are too short (<50 chars)')
else:
    print(f'\n‚úÖ All prompts have sufficient length ({len(df)} samples)')

# Disease distribution
if 'nail_disease' in df.columns:
    print(f'\nüìã Disease Distribution:')
    print(df['nail_disease'].value_counts())
elif 'disease_name' in df.columns:
    print(f'\nüìã Disease Distribution:')
    print(df['disease_name'].value_counts())

print('\n' + '='*60)

üìä DATA QUALITY REPORT

‚úÖ No missing values

‚úÖ All prompts have sufficient length (10000 samples)



## Step 7: Split Dataset (Train/Val/Test)

In [None]:
# Stratified split: 70% train, 15% val, 15% test
split_key = None
if 'nail_disease' in df.columns:
    split_key = 'nail_disease'
elif 'disease_name' in df.columns:
    split_key = 'disease_name'

train_df, temp_df = train_test_split(
    df,
    test_size=0.3,
    random_state=42,
    stratify=df[split_key] if split_key else None
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    random_state=42,
    stratify=temp_df[split_key] if split_key else None
)

print('üìä DATASET SPLIT')
print('='*60)
print(f'Train: {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)')
print(f'Val:   {len(val_df)} samples ({len(val_df)/len(df)*100:.1f}%)')
print(f'Test:  {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)')
print(f'Total: {len(df)} samples')
print('='*60)

üìä DATASET SPLIT
Train: 7000 samples (70.0%)
Val:   1500 samples (15.0%)
Test:  1500 samples (15.0%)
Total: 10000 samples


## Step 8: Create HuggingFace Datasets

In [None]:
# Create HuggingFace datasets
train_dataset = Dataset.from_pandas(train_df[['text']])
val_dataset = Dataset.from_pandas(val_df[['text']])
test_dataset = Dataset.from_pandas(test_df[['text']])

print('‚úÖ HuggingFace datasets created')
print(f'  Train: {len(train_dataset)} samples')
print(f'  Val:   {len(val_dataset)} samples')
print(f'  Test:  {len(test_dataset)} samples')

‚úÖ HuggingFace datasets created
  Train: 7000 samples
  Val:   1500 samples
  Test:  1500 samples


## Step 9: Setup Model & Tokenizer (MedGemma 4B - Lightweight & Fast)

In [None]:
# Model configuration: Using MedGemma 4B for FASTER TRAINING
# 4B is 50% faster than 7B with similar medical understanding
MODEL_ID = 'google/medgemma-4b-it'
# Alternative (larger, slower): MODEL_ID = 'google/medgemma-7b-orcamath-it'

# 4-bit quantization config (memory efficient)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

print('='*60)
print(f'Loading model: {MODEL_ID}')
print('This may take 1-2 minutes (4B is faster than 7B)...')
print('='*60)
print('üí° Model Info:')
print('  - Size: 4B parameters (50% smaller than 7B)')
print('  - Speed: ~2x faster training')
print('  - Quality: Excellent medical understanding')
print('  - Memory: Fits in most GPUs (8GB+)')
print()

model = None
tokenizer = None

try:
    # Try with HuggingFace token authentication
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        device_map='auto',
        trust_remote_code=True,
        use_auth_token=True,  # Enable auth token
    )
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        use_auth_token=True
    )
    tokenizer.pad_token = tokenizer.eos_token

    print(f'‚úÖ Model loaded successfully!')
    print(f'   Model: MedGemma 4B (Lightweight)')
    print(f'   Size: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B parameters')
    print(f'   Memory: ~8GB GPU VRAM (4-bit quantized)')
    print(f'   Expected Training Time: 15-30 minutes')

except Exception as e:
    error_msg = str(e)
    print(f'\n‚ùå Error loading model: {error_msg[:200]}')
    print('\nüîß TROUBLESHOOTING:')
    print('\n1. ACCEPT MODEL LICENSE:')
    print(f'   - Visit: https://huggingface.co/{MODEL_ID}')
    print('   - Click "Accept" button')
    print('\n2. LOGIN TO HUGGINGFACE:')
    print('   - Run setup_huggingface_auth() in Step 1a again')
    print('   - Or manually: from huggingface_hub import login')
    print('   - Then: login(token="hf_YOUR_TOKEN")')
    print('\n3. CHECK TOKEN VALIDITY:')
    print('   - Visit: https://huggingface.co/settings/tokens')
    print('   - Ensure your token has "Read" access')
    print('\n4. ENVIRONMENT VARIABLES:')
    print('   - Set: export HF_TOKEN="hf_YOUR_TOKEN"')
    print('\n5. OFFLINE MODE:')
    print('   - Download model locally first')
    print('   - Use: AutoModel.from_pretrained("./local/path")')
    print('\n' + '='*60)
    sys.exit(1)

if model is None or tokenizer is None:
    print('‚ùå Model or tokenizer failed to load!')
    sys.exit(1)

Loading model: google/medgemma-4b-it
This may take 1-2 minutes (4B is faster than 7B)...
üí° Model Info:
  - Size: 4B parameters (50% smaller than 7B)
  - Speed: ~2x faster training
  - Quality: Excellent medical understanding
  - Memory: Fits in most GPUs (8GB+)



config.json:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

‚úÖ Model loaded successfully!
   Model: MedGemma 4B (Lightweight)
   Size: 2.49B parameters
   Memory: ~8GB GPU VRAM (4-bit quantized)
   Expected Training Time: 15-30 minutes


## Step 10: Configure LoRA (Low-Rank Adaptation)

In [None]:
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# LoRA configuration
lora_config = LoraConfig(
    r=16,  # Rank
    lora_alpha=32,  # Alpha scaling
    target_modules=['q_proj', 'v_proj', 'k_proj'],  # Query, Value, Key projections
    lora_dropout=0.05,
    bias='none',
    task_type='CAUSAL_LM'
)

model = get_peft_model(model, lora_config)

# Count trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())

print(f'‚úÖ LoRA configured for 4B model')
print(f'  Total params: {total_params / 1e9:.2f}B')
print(f'  Trainable: {trainable_params / 1e6:.2f}M ({100*trainable_params/total_params:.3f}%)')

‚úÖ LoRA configured for 4B model
  Total params: 2.50B
  Trainable: 9.39M (0.376%)


## Step 11: Setup Training Configuration (Optimized for 4B)

In [None]:
# Training arguments (optimized for MedGemma 4B - FASTER!)
training_args = TrainingArguments(
    output_dir='./medgemma_nail_disease_model2_finetuned',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,
    learning_rate=2e-4,
    lr_scheduler_type='cosine',
    warmup_steps=100,
    weight_decay=0.01,
    max_steps=500,
    logging_steps=10,
    eval_steps=50,
    save_steps=50,
    eval_strategy='steps',
    save_strategy='steps',
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss',
    greater_is_better=False,
    logging_dir='./logs',
    optim='paged_adamw_8bit',
    seed=42,
    dataloader_pin_memory=True,
)

print('‚úÖ Training configuration ready (4B optimized)')
print(f'  Output: ./medgemma_nail_disease_model2_finetuned')
print(f'  Epochs: {training_args.num_train_epochs}')
print(f'  Batch size: {training_args.per_device_train_batch_size} (increased for 4B)')
print(f'  Learning rate: {training_args.learning_rate}')
print(f'  Max steps: {training_args.max_steps}')
print(f'  Expected time: 15-30 minutes on single GPU')

‚úÖ Training configuration ready (4B optimized)
  Output: ./medgemma_nail_disease_model2_finetuned
  Epochs: 3
  Batch size: 8 (increased for 4B)
  Learning rate: 0.0002
  Max steps: 500
  Expected time: 15-30 minutes on single GPU


## Step 12: Initialize SFT Trainer

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

print('‚úÖ Trainer initialized (4B model)')

Adding EOS to train dataset:   0%|          | 0/7000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/7000 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/7000 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/1500 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/1500 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/1500 [00:00<?, ? examples/s]

‚úÖ Trainer initialized (4B model)


## Step 13: üöÄ START TRAINING (15-30 minutes with 4B)

In [None]:
print('\n' + '='*60)
print('üöÄ STARTING MODEL 2 TRAINING (MedGemma 4B)')
print('Stage: Clinical Explanation Fine-tuning')
print('Expected Duration: 15-30 minutes')
print('='*60)

train_result = trainer.train()

print('\n' + '='*60)
print('‚úÖ TRAINING COMPLETE')
print(f'Final Training Loss: {train_result.training_loss:.4f}')
print('='*60)

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.



üöÄ STARTING MODEL 2 TRAINING (MedGemma 4B)
Stage: Clinical Explanation Fine-tuning
Expected Duration: 15-30 minutes


<IPython.core.display.Javascript object>

## Step 14: Evaluate & Save Model

In [None]:
# Evaluate on test set
eval_results = trainer.evaluate(test_dataset)
print(f'Test Loss: {eval_results.get("eval_loss", 0):.4f}')

# Save model
model.save_pretrained('./medgemma_nail_disease_model2_finetuned')
tokenizer.save_pretrained('./medgemma_nail_disease_model2_finetuned')
print('\n‚úÖ Model saved to ./medgemma_nail_disease_model2_finetuned')

## Step 15: Extract & Visualize Training Metrics

In [None]:
import pandas as pd

history = {'train_loss': [], 'eval_loss': []}

try:
    from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
    if os.path.exists('./logs'):
        for file in sorted(os.listdir('./logs')):
            if 'events.out.tfevents' in file:
                ea = EventAccumulator(os.path.join('./logs', file))
                ea.Reload()
                for tag in ea.Tags().get('scalars', []):
                    events = ea.Scalars(tag)
                    for e in events:
                        if 'eval' in tag and 'loss' in tag:
                            history['eval_loss'].append(e.value)
                        elif 'loss' in tag and 'eval' not in tag:
                            history['train_loss'].append(e.value)
except Exception as e:
    print(f'Note: Could not extract tensorboard data: {str(e)[:50]}')

print(f'Extracted: {len(history["train_loss"])} train steps, {len(history["eval_loss"])} eval steps')

## Step 16: üìä Plot Loss Curves & Overfitting Analysis

In [None]:
train_loss = np.array(history['train_loss']) if history['train_loss'] else np.array([])
eval_loss = np.array(history['eval_loss']) if history['eval_loss'] else np.array([])

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('MedGemma 4B Model 2: Training Metrics & Overfitting Detection', fontsize=14, fontweight='bold')

# Plot 1: Training Loss
if len(train_loss) > 0:
    axes[0, 0].plot(train_loss, marker='o', markersize=3, linewidth=2, color='blue')
    axes[0, 0].set_title('Training Loss Progression', fontweight='bold')
    axes[0, 0].set_xlabel('Training Step')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Validation Loss
if len(eval_loss) > 0:
    axes[0, 1].plot(eval_loss, marker='s', markersize=3, linewidth=2, color='orange')
    axes[0, 1].set_title('Validation Loss Progression', fontweight='bold')
    axes[0, 1].set_xlabel('Evaluation Step')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Train vs Eval with Gap
if len(eval_loss) > 0 and len(train_loss) > 0:
    min_len = min(len(train_loss), len(eval_loss))
    train_aligned = train_loss[-min_len:]
    eval_aligned = eval_loss[-min_len:]

    axes[1, 0].plot(train_aligned, marker='o', label='Train Loss', linewidth=2)
    axes[1, 0].plot(eval_aligned, marker='s', label='Eval Loss', linewidth=2)
    axes[1, 0].fill_between(range(min_len), train_aligned, eval_aligned, alpha=0.2, color='red', label='Overfitting Gap')
    axes[1, 0].set_title('Loss Gap: Train vs Eval', fontweight='bold')
    axes[1, 0].set_xlabel('Step')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Overfitting Metrics Summary
if len(eval_loss) > 0 and len(train_loss) > 0:
    min_len = min(len(train_loss), len(eval_loss))
    train_aligned = train_loss[-min_len:]
    eval_aligned = eval_loss[-min_len:]
    loss_gap = eval_aligned - train_aligned

    avg_gap = np.mean(loss_gap)
    max_gap = np.max(loss_gap)

    if avg_gap < 0.01:
        status = 'MINIMAL OVERFITTING'
    elif avg_gap < 0.05:
        status = 'MILD OVERFITTING'
    else:
        status = 'MODERATE-SEVERE OVERFITTING'

    metrics_text = f'OVERFITTING ANALYSIS\n\nAvg Loss Gap: {avg_gap:.6f}\nMax Loss Gap: {max_gap:.6f}\n\nStatus: {status}\n\nTrain Loss: {train_aligned[-1]:.6f}\nEval Loss: {eval_aligned[-1]:.6f}\n\nImprovement: {(1-eval_aligned[-1]/eval_aligned[0])*100:.1f}%'

    axes[1, 1].text(0.5, 0.5, metrics_text, ha='center', va='center', fontsize=10, family='monospace', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
    axes[1, 1].axis('off')

plt.tight_layout()
plt.savefig('model2_overfitting_analysis.png', dpi=150, bbox_inches='tight')
plt.show()
print('‚úÖ Overfitting analysis saved to model2_overfitting_analysis.png')

## Step 17: üîç Detailed Overfitting Report

In [None]:
if len(eval_loss) > 0 and len(train_loss) > 0:
    min_len = min(len(train_loss), len(eval_loss))
    train_aligned = train_loss[-min_len:]
    eval_aligned = eval_loss[-min_len:]
    loss_gap = eval_aligned - train_aligned

    print('\n' + '='*60)
    print('üîç OVERFITTING DETECTION ANALYSIS')
    print('='*60)

    print(f'\nüìä Loss Gap Statistics:')
    print(f'  Average Gap: {np.mean(loss_gap):.6f}')
    print(f'  Max Gap: {np.max(loss_gap):.6f}')
    print(f'  Min Gap: {np.min(loss_gap):.6f}')

    print(f'\nüìà Performance Metrics:')
    print(f'  Final Train Loss: {train_aligned[-1]:.6f}')
    print(f'  Final Eval Loss: {eval_aligned[-1]:.6f}')
    print(f'  Loss Improvement: {(1-eval_aligned[-1]/eval_aligned[0])*100:.1f}%')

    if np.mean(loss_gap) < 0.01:
        status = 'üü¢ MINIMAL OVERFITTING (Excellent!)'
    elif np.mean(loss_gap) < 0.05:
        status = 'üü° MILD OVERFITTING (Good)'
    else:
        status = 'üî¥ MODERATE-SEVERE OVERFITTING'

    print(f'\n‚úÖ Status: {status}')
    print('='*60)

## Step 18: Save Training Summary & Metadata

In [None]:
summary = {
    'pipeline_stage': 'Model 2 - Clinical Explanation',
    'model': 'google/medgemma-4b-it',
    'model_size': '4B (Lightweight)',
    'training_type': 'SFT (Supervised Fine-Tuning) with LoRA',
    'lora_rank': 16,
    'lora_alpha': 32,
    'target_modules': ['q_proj', 'v_proj', 'k_proj'],
    'train_samples': len(train_df),
    'val_samples': len(val_df),
    'test_samples': len(test_df),
    'epochs': 3,
    'batch_size': 8,
    'gradient_accumulation_steps': 1,
    'learning_rate': 2e-4,
    'optimizer': 'paged_adamw_8bit',
    'max_steps': 500,
    'quantization': '4-bit (nf4)',
    'training_speed': '~2x faster than 7B',
    'dataset_source': csv_path,
}

if len(eval_loss) > 0 and len(train_loss) > 0:
    min_len = min(len(train_loss), len(eval_loss))
    train_aligned = train_loss[-min_len:]
    eval_aligned = eval_loss[-min_len:]
    loss_gap = eval_aligned - train_aligned

    summary.update({
        'final_train_loss': float(train_aligned[-1]),
        'final_eval_loss': float(eval_aligned[-1]),
        'avg_loss_gap': float(np.mean(loss_gap)),
        'max_loss_gap': float(np.max(loss_gap)),
        'loss_improvement_percent': float((1-eval_aligned[-1]/eval_aligned[0])*100),
        'overfitting_status': 'MINIMAL' if np.mean(loss_gap) < 0.01 else 'MILD' if np.mean(loss_gap) < 0.05 else 'MODERATE-SEVERE'
    })

with open('model2_training_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print('‚úÖ Training Summary:')
print(json.dumps(summary, indent=2))

## Step 19: Test Inference with Clinical Example

In [None]:
# Load best model for inference
if os.path.exists('./medgemma_nail_disease_model2_finetuned/adapter_model.bin'):
    model.load_state_dict(torch.load('./medgemma_nail_disease_model2_finetuned/adapter_model.bin', map_location=DEVICE))

# Test with clinical examples
test_cases = [
    """CLINICAL ANALYSIS: Nail Disease Diagnosis

PATIENT DEMOGRAPHICS:
Age: 65
Sex: Female

PRIMARY FINDING (from Model 1 - MedSigLIP):
Clubbing

CLINICAL PRESENTATION:
Convex nail beds, increased angle between nail and cuticle, bulbous fingertips. Patient has chronic cough and dyspnea.

INSTRUCTION:
Based on the nail disease finding and clinical presentation above, provide:
1. Detailed explanation of what the nail finding indicates
2. Possible systemic diseases that could cause this nail finding
3. Recommended diagnostic workup and treatment approach

EXPECTED RESPONSE:
"""
]

print('\n' + '='*60)
print('üîç TEST INFERENCE: Clinical Explanation (4B Model)')
print('='*60)

for i, test_prompt in enumerate(test_cases, 1):
    print(f'\nTest Case {i}:')
    print('-'*60)

    inputs = tokenizer(test_prompt, return_tensors='pt').to(DEVICE)
    outputs = model.generate(**inputs, max_new_tokens=150, do_sample=True, top_p=0.9, temperature=0.7)
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)

    print(result)
    print('-'*60)

## Step 20: ‚úÖ Complete!

In [None]:
print('\n' + '='*60)
print('‚úÖ MODEL 2 FINE-TUNING & ANALYSIS COMPLETE!')
print('='*60)
print('\nüìä Model Used: MedGemma 4B (Lightweight & Fast)')
print('\nüìÅ Output Files:')
print('  ‚úÖ medgemma_nail_disease_model2_finetuned/')
print('     - adapter_model.bin (LoRA weights)')
print('     - config.json')
print('     - tokenizer files')
print('  ‚úÖ model2_overfitting_analysis.png (4-subplot visualization)')
print('  ‚úÖ model2_training_summary.json (metrics & config)')
print('  ‚úÖ logs/ (tensorboard data)')
print('\n‚ö° Performance Benefits of 4B Model:')
print('  ‚úÖ ~2x faster training than 7B')
print('  ‚úÖ 50% smaller model size')
print('  ‚úÖ Lower memory requirements')
print('  ‚úÖ Excellent medical understanding maintained')
print('\nüöÄ Next Steps:')
print('  1. Download files from Kaggle Output tab')
print('  2. Use model2 for clinical explanations in your app')
print('  3. Start Stage 3 training with MedGemma 27B')
print('  4. Build mobile/web app integrating all 3 stages')
print('\nüìä Model Performance:')
if len(eval_loss) > 0:
    print(f'  Final Test Loss: {eval_loss[-1]:.4f}')
print('='*60)