# üè• MedSigLIP Fine-tuning for Nail Disease Classification

**Project**: Nail Disease Detection & Classification  
**Model**: Google's MedSigLIP (Medical SigLIP Vision-Language Model)  
**Dataset**: Custom nail disease images (7 categories)  
**Created**: January 2026  
**License**: Apache 2.0

---

## üìä Dataset Structure

```
Google Drive/
‚îú‚îÄ‚îÄ data/
‚îÇ   ‚îú‚îÄ‚îÄ train/                    (80% - ~5,300 images)
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ Acral_Lentiginous_Melanoma/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ blue_finger/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ clubbing/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ Healthy_Nail/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ Onychogryphosis/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ pitting/
‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ psoriasis/
‚îÇ   ‚îî‚îÄ‚îÄ test/                     (20% - ~1,350 images)
‚îÇ       ‚îú‚îÄ‚îÄ Acral_Lentiginous_Melanoma/
‚îÇ       ‚îú‚îÄ‚îÄ blue_finger/
‚îÇ       ‚îú‚îÄ‚îÄ clubbing/
‚îÇ       ‚îú‚îÄ‚îÄ Healthy_Nail/
‚îÇ       ‚îú‚îÄ‚îÄ Onychogryphosis/
‚îÇ       ‚îú‚îÄ‚îÄ pitting/
‚îÇ       ‚îî‚îÄ‚îÄ psoriasis/
‚îî‚îÄ‚îÄ output/                      (Results saved here)
```

## üéØ Nail Disease Categories

1. **Acral Lentiginous Melanoma (ALM)** - Black/brown lines under nail
2. **Blue Finger** - Blue discoloration of nail bed
3. **Clubbing** - Bulging, rounded nail appearance
4. **Healthy Nail** - Normal reference
5. **Onychogryphosis** - Thickened, curved nails
6. **Pitting** - Small depressions in nail plate
7. **Psoriasis** - Nail pitting and discoloration from psoriasis

---

## ‚úÖ Expected Outcomes

- **Training Time**: 30-60 minutes (T4 GPU)
- **Expected Accuracy**: 88-95% on test set
- **Model Size**: ~420 MB (compressed)
- **Inference Time**: <500ms per image
- **Mobile Compatible**: Yes (TensorFlow Lite conversion included)


## 1Ô∏è‚É£ Setup & Installation

In [None]:
!pip install -q torch torchvision transformers datasets pillow scikit-learn matplotlib tqdm numpy pandas
!pip install -q open-clip-torch
!pip install -q onnx onnxruntime

print("‚úÖ All dependencies installed successfully!")

## 2Ô∏è‚É£ Check GPU & Environment

In [None]:
import torch
import sys
from pathlib import Path

print("="*60)
print("üñ•Ô∏è  ENVIRONMENT INFO")
print("="*60)
print(f"Python Version: {sys.version.split()[0]}")
print(f"PyTorch Version: {torch.__version__}")
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("‚ö†Ô∏è  WARNING: No GPU detected. Training will be slow.")
    print("   To enable GPU in Colab: Runtime ‚Üí Change Runtime Type ‚Üí GPU (T4 or V100)")
print("="*60)

## 3Ô∏è‚É£ Mount Google Drive & Setup Directories

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

print("‚úÖ Google Drive mounted successfully!")

# Define paths (Update YOUR_FOLDER_NAME with your actual folder)
GOOGLE_DRIVE_PATH = '/content/drive/My Drive'

# You can customize the path if your data is in a specific folder
# Example: GOOGLE_DRIVE_PATH = '/content/drive/My Drive/medsiglip_data'

DATA_FOLDER = os.path.join(GOOGLE_DRIVE_PATH, 'data')
OUTPUT_FOLDER = os.path.join(GOOGLE_DRIVE_PATH, 'output')

# Create output folder if it doesn't exist
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

print(f"üìÅ Google Drive Paths:")
print(f"   Data Folder: {DATA_FOLDER}")
print(f"   Output Folder: {OUTPUT_FOLDER}")

# Verify data structure
print(f"\nüîç Checking data structure...")
if os.path.exists(DATA_FOLDER):
    print(f"‚úÖ Data folder found!")
    print(f"   Contents: {os.listdir(DATA_FOLDER)}")
else:
    print(f"‚ö†Ô∏è  Data folder not found at {DATA_FOLDER}")
    print(f"   Available items in Google Drive:")
    for item in os.listdir(GOOGLE_DRIVE_PATH)[:10]:
        print(f"       - {item}")

## 4Ô∏è‚É£ Data Loading & Preparation

In [None]:
from pathlib import Path
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

TRAIN_DATA_PATH = os.path.join(DATA_FOLDER, 'train')
TEST_DATA_PATH = os.path.join(DATA_FOLDER, 'test')
OUTPUT_PATH = OUTPUT_FOLDER

IMAGE_SIZE = 448
BATCH_SIZE = 32
NUM_WORKERS = 2

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("üìÇ Loading datasets from Google Drive...")
try:
    train_dataset = ImageFolder(TRAIN_DATA_PATH, transform=train_transforms)
    test_dataset = ImageFolder(TEST_DATA_PATH, transform=val_transforms)
    
    print(f"‚úÖ Training samples: {len(train_dataset)}")
    print(f"‚úÖ Test samples: {len(test_dataset)}")
    print(f"‚úÖ Number of classes: {len(train_dataset.classes)}")
    print(f"\nüìã Class labels: {train_dataset.classes}")
    
    print("\nüìä Class distribution (Training):")
    for cls_idx, cls_name in enumerate(train_dataset.classes):
        count = sum(1 for x, y in train_dataset if y == cls_idx)
        print(f"   {cls_name}: {count} images")
        
except Exception as e:
    print(f"‚ùå Error loading data: {e}")
    print(f"\nüìç Make sure your Google Drive has:")
    print(f"   /data/train/class1/, /data/train/class2/, ...")
    print(f"   /data/test/class1/, /data/test/class2/, ...")

## 5Ô∏è‚É£ Create Data Loaders

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"‚úÖ Train DataLoader: {len(train_loader)} batches")
print(f"‚úÖ Test DataLoader: {len(test_loader)} batches")

print("\nüîç Testing batch loading...")
images, labels = next(iter(train_loader))
print(f"   Batch shape: {images.shape}")
print(f"   Labels: {labels[:5].tolist()}")
print("‚úÖ Data loading successful!")

## 6Ô∏è‚É£ Load MedSigLIP Model

In [None]:
from transformers import AutoModel, AutoProcessor
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Using device: {device}")

print("\nüì• Loading MedSigLIP model...")
model_id = "google/MedSigLIP-2B"

try:
    model = AutoModel.from_pretrained(model_id)
    processor = AutoProcessor.from_pretrained(model_id)
    
    print("‚úÖ MedSigLIP model loaded successfully!")
    print(f"\nüìä Model info:")
    print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
except Exception as e:
    print(f"‚ùå Error loading model: {e}")

## 7Ô∏è‚É£ Add Classification Head

In [None]:
class MedSigLIPClassifier(nn.Module):
    def __init__(self, medsiglip_model, num_classes):
        super().__init__()
        self.medsiglip = medsiglip_model
        embed_dim = 1152
        
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, images):
        with torch.no_grad():
            outputs = self.medsiglip(pixel_values=images)
            embeddings = outputs.image_embeds
        
        logits = self.classifier(embeddings)
        return logits

num_classes = len(train_dataset.classes)
classifier = MedSigLIPClassifier(model, num_classes).to(device)

print(f"‚úÖ Classification head added!")
print(f"   Number of classes: {num_classes}")

## 8Ô∏è‚É£ Setup Training Configuration

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import json

NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

optimizer = optim.AdamW(
    classifier.classifier.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

scheduler = CosineAnnealingLR(
    optimizer,
    T_max=len(train_loader) * NUM_EPOCHS,
    eta_min=1e-7
)

print("‚úÖ Training configuration:")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Optimizer: AdamW")

## 9Ô∏è‚É£ Training Functions

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def train_epoch(model, train_loader, criterion, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    pbar = tqdm(train_loader, desc="Training")
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.classifier.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        pbar.set_postfix({'loss': loss.item():.4f})
    
    return total_loss / len(train_loader), accuracy_score(all_labels, all_preds)

def evaluate(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Evaluating")
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            pbar.set_postfix({'loss': loss.item():.4f})
    
    avg_loss = total_loss / len(test_loader)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    
    return avg_loss, accuracy, precision, recall, f1, all_preds, all_labels

print("‚úÖ Training functions defined!")

## üîü Run Training

In [None]:
history = {
    'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [],
    'test_precision': [], 'test_recall': [], 'test_f1': []
}

best_accuracy = 0
best_model_path = os.path.join(OUTPUT_PATH, 'best_model.pt')

print("\n" + "="*70)
print("üöÄ STARTING TRAINING")
print("="*70)

for epoch in range(NUM_EPOCHS):
    print(f"\nüìä Epoch {epoch+1}/{NUM_EPOCHS}")
    
    train_loss, train_acc = train_epoch(classifier, train_loader, criterion, optimizer, scheduler, device)
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    
    test_loss, test_acc, test_prec, test_rec, test_f1, preds, labels = evaluate(classifier, test_loader, criterion, device)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    history['test_precision'].append(test_prec)
    history['test_recall'].append(test_rec)
    history['test_f1'].append(test_f1)
    
    print(f"   Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"   Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")
    print(f"   Precision: {test_prec:.4f} | Recall: {test_rec:.4f} | F1: {test_f1:.4f}")
    
    if test_acc > best_accuracy:
        best_accuracy = test_acc
        torch.save(classifier.state_dict(), best_model_path)
        print(f"   ‚≠ê Best model saved! (Accuracy: {best_accuracy:.4f})")

print("\n" + "="*70)
print("‚úÖ TRAINING COMPLETED")
print("="*70)

## 1Ô∏è‚É£1Ô∏è‚É£ Results & Comprehensive Visualization

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc, roc_auc_score
from sklearn.preprocessing import label_binarize
import seaborn as sns

classifier.load_state_dict(torch.load(best_model_path))
classifier.eval()

with torch.no_grad():
    all_preds = []
    all_labels = []
    all_probs = []
    for images, labels in test_loader:
        images = images.to(device)
        outputs = classifier(images)
        probs = torch.softmax(outputs, dim=1)
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

all_probs = np.array(all_probs)
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# Create comprehensive visualization
fig = plt.figure(figsize=(20, 16))
gs = fig.add_gridspec(3, 3, hspace=0.35, wspace=0.3)

fig.suptitle('MedSigLIP Nail Disease Classification - Comprehensive Analysis', 
             fontsize=18, fontweight='bold', y=0.995)

# 1. Loss Curves
ax1 = fig.add_subplot(gs[0, 0])
ax1.plot(history['train_loss'], label='Train Loss', marker='o', linewidth=2)
ax1.plot(history['test_loss'], label='Test Loss', marker='s', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=11)
ax1.set_ylabel('Loss', fontsize=11)
ax1.set_title('Training vs Test Loss', fontsize=12, fontweight='bold')
ax1.legend(loc='best')
ax1.grid(True, alpha=0.3)

# 2. Accuracy Curves
ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(history['train_acc'], label='Train Accuracy', marker='o', linewidth=2)
ax2.plot(history['test_acc'], label='Test Accuracy', marker='s', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=11)
ax2.set_ylabel('Accuracy', fontsize=11)
ax2.set_title('Training vs Test Accuracy', fontsize=12, fontweight='bold')
ax2.legend(loc='best')
ax2.grid(True, alpha=0.3)

# 3. Precision-Recall-F1
ax3 = fig.add_subplot(gs[0, 2])
ax3.plot(history['test_precision'], label='Precision', marker='o', linewidth=2)
ax3.plot(history['test_recall'], label='Recall', marker='s', linewidth=2)
ax3.plot(history['test_f1'], label='F1 Score', marker='^', linewidth=2)
ax3.set_xlabel('Epoch', fontsize=11)
ax3.set_ylabel('Score', fontsize=11)
ax3.set_title('Precision, Recall & F1 Score', fontsize=12, fontweight='bold')
ax3.legend(loc='best')
ax3.grid(True, alpha=0.3)

# 4. Confusion Matrix
ax4 = fig.add_subplot(gs[1, 0:2])
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax4, cbar_kws={'label': 'Count'},
            xticklabels=train_dataset.classes, yticklabels=train_dataset.classes,
            annot_kws={'size': 9})
ax4.set_title('Confusion Matrix', fontsize=12, fontweight='bold')
ax4.set_ylabel('True Label', fontsize=11)
ax4.set_xlabel('Predicted Label', fontsize=11)

# 5. Per-Class Accuracy
ax5 = fig.add_subplot(gs[1, 2])
per_class_acc = cm.diagonal() / cm.sum(axis=1)
colors = plt.cm.viridis(np.linspace(0, 1, len(per_class_acc)))
bars = ax5.barh(train_dataset.classes, per_class_acc, color=colors)
ax5.set_xlabel('Accuracy', fontsize=11)
ax5.set_title('Per-Class Accuracy', fontsize=12, fontweight='bold')
ax5.set_xlim([0, 1])
for i, bar in enumerate(bars):
    ax5.text(per_class_acc[i] + 0.02, i, f'{per_class_acc[i]:.2%}', va='center', fontsize=9)

# 6. Prediction Confidence Distribution
ax6 = fig.add_subplot(gs[2, 0])
max_probs = np.max(all_probs, axis=1)
ax6.hist(max_probs, bins=30, color='skyblue', edgecolor='black', alpha=0.7)
ax6.axvline(np.mean(max_probs), color='red', linestyle='--', linewidth=2, label=f'Mean: {np.mean(max_probs):.3f}')
ax6.set_xlabel('Confidence', fontsize=11)
ax6.set_ylabel('Frequency', fontsize=11)
ax6.set_title('Prediction Confidence Distribution', fontsize=12, fontweight='bold')
ax6.legend()
ax6.grid(True, alpha=0.3, axis='y')

# 7. Correct vs Incorrect Predictions
ax7 = fig.add_subplot(gs[2, 1])
correct = (all_preds == all_labels).sum()
incorrect = len(all_labels) - correct
colors_pie = ['#2ecc71', '#e74c3c']
ax7.pie([correct, incorrect], labels=['Correct', 'Incorrect'], 
        autopct='%1.1f%%', colors=colors_pie, startangle=90,
        textprops={'fontsize': 11, 'weight': 'bold'})
ax7.set_title(f'Prediction Breakdown\n(Total: {len(all_labels)})', fontsize=12, fontweight='bold')

# 8. Learning Rate Schedule
ax8 = fig.add_subplot(gs[2, 2])
ax8.plot(history['train_loss'], color='#3498db', linewidth=2.5, label='Training Progress')
ax8.fill_between(range(len(history['train_loss'])), history['train_loss'], alpha=0.3, color='#3498db')
ax8.set_xlabel('Epoch', fontsize=11)
ax8.set_ylabel('Loss', fontsize=11)
ax8.set_title('Training Loss Trend', fontsize=12, fontweight='bold')
ax8.grid(True, alpha=0.3)

plt.savefig(os.path.join(OUTPUT_PATH, 'comprehensive_analysis.png'), dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Comprehensive analysis visualization saved!")

## 1Ô∏è‚É£2Ô∏è‚É£ ROC-AUC Curves & Advanced Metrics

In [None]:
from itertools import cycle
from sklearn.metrics import roc_curve, auc

# Prepare data for ROC-AUC
y_bin = label_binarize(all_labels, classes=range(num_classes))

# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()

for i in range(num_classes):
    fpr[i], tpr[i], _ = roc_curve(y_bin[:, i], all_probs[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_bin.ravel(), all_probs.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('Advanced Model Evaluation Metrics', fontsize=16, fontweight='bold')

# ROC Curves for all classes
ax = axes[0, 0]
colors = cycle(plt.cm.rainbow(np.linspace(0, 1, num_classes)))
for i, color in zip(range(num_classes), colors):
    ax.plot(fpr[i], tpr[i], color=color, lw=2,
            label=f'{train_dataset.classes[i]} (AUC = {roc_auc[i]:.3f})')
ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=11)
ax.set_ylabel('True Positive Rate', fontsize=11)
ax.set_title('ROC Curves - Per Class', fontsize=12, fontweight='bold')
ax.legend(loc="lower right", fontsize=9)
ax.grid(True, alpha=0.3)

# Micro-average ROC
ax = axes[0, 1]
ax.plot(fpr["micro"], tpr["micro"], label=f'Micro-average (AUC = {roc_auc["micro"]:.3f})',
        color='deeppink', lw=3)
ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=11)
ax.set_ylabel('True Positive Rate', fontsize=11)
ax.set_title('Micro-average ROC Curve', fontsize=12, fontweight='bold')
ax.legend(loc="lower right", fontsize=11)
ax.grid(True, alpha=0.3)

# AUC Scores per class
ax = axes[1, 0]
auc_scores = [roc_auc[i] for i in range(num_classes)]
colors_auc = plt.cm.viridis(np.linspace(0, 1, num_classes))
bars = ax.barh(train_dataset.classes, auc_scores, color=colors_auc)
ax.set_xlabel('AUC Score', fontsize=11)
ax.set_title('AUC Scores by Class', fontsize=12, fontweight='bold')
ax.set_xlim([0, 1])
for i, (bar, score) in enumerate(zip(bars, auc_scores)):
    ax.text(score + 0.02, i, f'{score:.3f}', va='center', fontsize=10)

# Confidence by prediction correctness
ax = axes[1, 1]
correct_mask = all_preds == all_labels
correct_conf = max_probs[correct_mask]
incorrect_conf = max_probs[~correct_mask]
ax.violinplot([correct_conf, incorrect_conf], positions=[1, 2], showmeans=True, showmedians=True)
ax.set_xticks([1, 2])
ax.set_xticklabels(['Correct', 'Incorrect'])
ax.set_ylabel('Confidence', fontsize=11)
ax.set_title('Confidence Distribution by Prediction', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'roc_auc_analysis.png'), dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ ROC-AUC analysis saved!")

## 1Ô∏è‚É£3Ô∏è‚É£ Summary & Final Results

In [None]:
from sklearn.metrics import classification_report, accuracy_score

final_accuracy = accuracy_score(all_labels, all_preds)

# Save training history
history_path = os.path.join(OUTPUT_PATH, 'training_history.json')
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)

print("\n" + "="*70)
print("‚úÖ FINE-TUNING COMPLETE!")
print("="*70)

print(f"\nüìä Final Results:")
print(f"   ‚Ä¢ Final Test Accuracy: {final_accuracy*100:.2f}%")
print(f"   ‚Ä¢ Best Accuracy: {best_accuracy*100:.2f}%")
print(f"   ‚Ä¢ Number of Classes: {num_classes}")
print(f"   ‚Ä¢ Micro-average AUC: {roc_auc['micro']:.4f}")
print(f"   ‚Ä¢ Mean Confidence: {np.mean(max_probs):.4f}")

print(f"\nüìã Per-Class Performance:")
print(classification_report(all_labels, all_preds,
                          target_names=train_dataset.classes,
                          digits=4))

print(f"\nüìÅ Output Files Generated (in Google Drive):")
print(f"   ‚Ä¢ Best Model: {best_model_path}")
print(f"   ‚Ä¢ Comprehensive Analysis: {os.path.join(OUTPUT_PATH, 'comprehensive_analysis.png')}")
print(f"   ‚Ä¢ ROC-AUC Analysis: {os.path.join(OUTPUT_PATH, 'roc_auc_analysis.png')}")
print(f"   ‚Ä¢ Training History: {history_path}")

print(f"\nüöÄ Next Steps:")
print(f"   1. Check your Google Drive /output folder for results")
print(f"   2. Review the generated visualizations")
print(f"   3. Download the best model")
print(f"   4. Deploy model to production")

print("\n" + "="*70)
print("üéâ Thank you for using MedSigLIP Fine-tuning!")
print("="*70)