# MedGemma Fine-Tuning Advanced - With Overfitting Detection

Features: Loss graphs | Overfitting detection | Metrics | Loss gap analysis

In [None]:
import os, torch, pandas as pd, numpy as np, json
IS_KAGGLE = os.path.exists('/kaggle')
print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None"}')print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')print(f'Environment: Kaggle' if IS_KAGGLE else 'print(f"Local")')print(f'PyTorch: {torch.__version__}')print('âœ… Environment initialized')

In [None]:
!pip install -q transformers datasets torch bitsandbytes peft trl scikit-learn matplotlib seaborn
print('âœ… Packages installed')

In [None]:
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
from tqdm import tqdm
print('âœ… Libraries imported')

In [None]:
csv_path = '/kaggle/input/nail-disease-classification/nail_diseases.csv'
df = pd.read_csv(csv_path)
print(f'âœ… Loaded {len(df)} samples')

In [None]:
def make_prompt(row):
 findings = str(row.get('clinical_findings', ''))[:200]
 diagnosis = str(row.get('confirmed_diagnosis', ''))
 return f'Findings: {findings}. Diagnosis: {diagnosis}'

df['text'] = df.apply(make_prompt, axis=1)
train_df, temp = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp, test_size=0.5, random_state=42)
print(f'Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}')

In [None]:
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained('google/medgemma-4b', quantization_config=bnb_config, device_map='auto', token=True)
tokenizer = AutoTokenizer.from_pretrained('google/medgemma-4b')
tokenizer.pad_token = tokenizer.eos_token
print('âœ… Model loaded')

In [None]:
model = prepare_model_for_kbit_training(model)
lora = LoraConfig(r=8, lora_alpha=16, target_modules=['q_proj', 'v_proj'], task_type='CAUSAL_LM')
model = get_peft_model(model, lora)
print(f'âœ… LoRA configured')

In [None]:
train_ds = Dataset.from_pandas(train_df[['text']])
val_ds = Dataset.from_pandas(val_df[['text']])
test_ds = Dataset.from_pandas(test_df[['text']])
print(f'Datasets ready: {len(train_ds)} {len(val_ds)} {len(test_ds)}')

In [None]:
config = SFTConfig(output_dir='./medgemma_nails', num_train_epochs=3, per_device_train_batch_size=4, per_device_eval_batch_size=4, gradient_accumulation_steps=2, learning_rate=2e-4, warmup_steps=50, max_seq_length=512, logging_steps=20, eval_steps=50, save_steps=50, evaluation_strategy='steps', load_best_model_at_end=True, metric_for_best_model='eval_loss', logging_dir='./logs')
trainer = SFTTrainer(model=model, args=config, train_dataset=train_ds, eval_dataset=val_ds, tokenizer=tokenizer, dataset_text_field='text')
print('âœ… Trainer ready')

In [None]:
print('Starting training...')
result = trainer.train()
print(f'âœ… Training done: {result.training_loss:.4f}')

In [None]:
eval_res = trainer.evaluate(test_ds)
print(f'Eval Loss: {eval_res.get("eval_loss", 0):.4f}')

In [None]:
model.save_pretrained('./medgemma_nails')
tokenizer.save_pretrained('./medgemma_nails')
print('âœ… Model saved')

# Extract & Visualize Training Metrics

In [None]:
import matplotlib.pyplot as plt
history = {'train_loss': [], 'eval_loss': []}
try:
 from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
 ea = EventAccumulator('./logs')
 ea.Reload()
 for tag in ea.Tags()['scalars']:
 events = ea.Scalars(tag)
 for e in events:
 if 'eval' in tag and 'loss' in tag:
 history['eval_loss'].append(e.value)
 elif 'loss' in tag:
 history['train_loss'].append(e.value)
except:
 print('Could not read tensorboard logs')
print(f'Extracted: {len(history["train_loss"])} train, {len(history["eval_loss"])} eval')#print(history)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('MedGemma Training: Overfitting Detection & Metrics', fontsize=14, fontweight='bold')

train = np.array(history['train_loss'])
val = np.array(history['eval_loss']) if history['eval_loss'] else train

if len(val) > 0:
 gap = val - train[-len(val):]
 axes[0, 0].plot(train, label='Train Loss', marker='o', markersize=3)
 axes[0, 0].set_title('Training Loss Progression')
 axes[0, 0].set_xlabel('Step')
 axes[0, 0].set_ylabel('Loss')
 axes[0, 0].legend()
 axes[0, 0].grid(True, alpha=0.3)

 axes[0, 1].plot(val, label='Eval Loss', marker='s', markersize=3, color='orange')
 axes[0, 1].set_title('Validation Loss')
 axes[0, 1].set_xlabel('Eval Step')
 axes[0, 1].set_ylabel('Loss')
 axes[0, 1].legend()
 axes[0, 1].grid(True, alpha=0.3)

 train_aligned = train[-len(val):]
 axes[1, 0].plot(train_aligned, marker='o', label='Train', linewidth=2)
 axes[1, 0].plot(val, marker='s', label='Eval', linewidth=2)
 axes[1, 0].fill_between(range(len(val)), train_aligned, val, alpha=0.2, color='red')
 axes[1, 0].set_title('Loss Gap (Train vs Eval)')
 axes[1, 0].set_xlabel('Step')
 axes[1, 0].set_ylabel('Loss')
 axes[1, 0].legend()
 axes[1, 0].grid(True, alpha=0.3)

 gap_avg = np.mean(gap)
 gap_max = np.max(gap)
 if gap_avg < 0.01:
 status = 'MINIMAL OVERFITTING'
 elif gap_avg < 0.05:
 status = 'MILD OVERFITTING'
 else:
 status = 'MODERATE-SEVERE'

 text = f'OVERFITTING ANALYSIS\n\nAvg Gap: {gap_avg:.6f}\nMax Gap: {gap_max:.6f}\n\nStatus: {status}\n\nTrain Loss: {train_aligned[-1]:.6f}\nEval Loss: {val[-1]:.6f}\n\nImprovement: {(1-val[-1]/val[0])*100:.1f}%'
 axes[1, 1].text(0.5, 0.5, text, ha='center', va='center', fontsize=11, family='monospace', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
 axes[1, 1].axis('off')

plt.tight_layout()
plt.savefig('overfitting_analysis.png', dpi=150, bbox_inches='tight')
plt.show()
print('âœ… Overfitting analysis saved to overfitting_analysis.png')

In [None]:
summary = {
'train_loss': float(train[-1]) if len(train) > 0 else 0,
'eval_loss': float(val[-1]) if len(val) > 0 else 0,
'avg_loss_gap': float(np.mean(gap)) if len(val) > 0 else 0,
'max_loss_gap': float(np.max(gap)) if len(val) > 0 else 0,
'train_samples': len(train_df),
'overfitting_status': status if len(val) > 0 else 'unknown'
}

with open('training_summary.json', 'w') as f:
 json.dump(summary, f, indent=2)

print('Training Summary:')
print(json.dumps(summary, indent=2))
print('âœ… Summary saved to training_summary.json')

In [None]:
print('='*60)
print('âœ… TRAINING COMPLETE!')
print('='*60)
print('\nOutput Files:')
print('  - medgemma_nails/ (fine-tuned model)')
print('  - overfitting_analysis.png (loss curves & overfitting detection)')
print('  - training_summary.json (metrics)') 
print('  - logs/ (tensorboard data)')
print('\nðŸš€ Download from Output tab on Kaggle')print('='*60)