# üè• MedSigLIP Fine-tuning for Nail Disease Classification

**Project**: Nail Disease Detection & Classification  
**Model**: Google's MedSigLIP (Medical SigLIP Vision-Language Model)  
**Dataset**: Custom nail disease images (7 categories)  
**Created**: January 2026  
**License**: Apache 2.0

---

## üìä Dataset Structure

```
data/
‚îú‚îÄ‚îÄ train/                    (80% - ~5,300 images)
‚îÇ   ‚îú‚îÄ‚îÄ Acral_Lentiginous_Melanoma/
‚îÇ   ‚îú‚îÄ‚îÄ blue_finger/
‚îÇ   ‚îú‚îÄ‚îÄ clubbing/
‚îÇ   ‚îú‚îÄ‚îÄ Healthy_Nail/
‚îÇ   ‚îú‚îÄ‚îÄ Onychogryphosis/
‚îÇ   ‚îú‚îÄ‚îÄ pitting/
‚îÇ   ‚îî‚îÄ‚îÄ psoriasis/
‚îî‚îÄ‚îÄ test/                     (20% - ~1,350 images)
    ‚îú‚îÄ‚îÄ Acral_Lentiginous_Melanoma/
    ‚îú‚îÄ‚îÄ blue_finger/
    ‚îú‚îÄ‚îÄ clubbing/
    ‚îú‚îÄ‚îÄ Healthy_Nail/
    ‚îú‚îÄ‚îÄ Onychogryphosis/
    ‚îú‚îÄ‚îÄ pitting/
    ‚îî‚îÄ‚îÄ psoriasis/
```

## üéØ Nail Disease Categories

1. **Acral Lentiginous Melanoma (ALM)** - Black/brown lines under nail
2. **Blue Finger** - Blue discoloration of nail bed
3. **Clubbing** - Bulging, rounded nail appearance
4. **Healthy Nail** - Normal reference
5. **Onychogryphosis** - Thickened, curved nails
6. **Pitting** - Small depressions in nail plate
7. **Psoriasis** - Nail pitting and discoloration from psoriasis

---

## ‚úÖ Expected Outcomes

- **Training Time**: 30-60 minutes (T4 GPU)
- **Expected Accuracy**: 88-95% on test set
- **Model Size**: ~420 MB (compressed)
- **Inference Time**: <500ms per image
- **Mobile Compatible**: Yes (TensorFlow Lite conversion included)


## 1Ô∏è‚É£ Setup & Installation

In [None]:
# Install required packages
!pip install -q torch torchvision transformers datasets pillow scikit-learn matplotlib tqdm numpy pandas

# For MedSigLIP support
!pip install -q open-clip-torch

# For model evaluation and export
!pip install -q onnx onnxruntime

print("‚úÖ All dependencies installed successfully!")

## 2Ô∏è‚É£ Check GPU & Environment

In [None]:
import torch
import sys
from pathlib import Path

print("="*60)
print("üñ•Ô∏è  ENVIRONMENT INFO")
print("="*60)
print(f"Python Version: {sys.version.split()[0]}")
print(f"PyTorch Version: {torch.__version__}")
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("‚ö†Ô∏è  WARNING: No GPU detected. Training will be slow.")
    print("   To enable GPU in Colab: Runtime ‚Üí Change Runtime Type ‚Üí GPU (T4 or V100)")
print("="*60)

## 3Ô∏è‚É£ Mount Google Drive (Optional - for data storage)

In [None]:
# Uncomment to mount Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# print("‚úÖ Google Drive mounted successfully!")

# For this notebook, we'll use /content/data
import os
os.makedirs('/content/data/train', exist_ok=True)
os.makedirs('/content/data/test', exist_ok=True)
print("‚úÖ Data directories created!")

## 4Ô∏è‚É£ Data Loading & Preparation

In [None]:
from pathlib import Path
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Define data paths
TRAIN_DATA_PATH = '/content/data/train'
TEST_DATA_PATH = '/content/data/test'
OUTPUT_PATH = '/content/output'

# Create output directory
os.makedirs(OUTPUT_PATH, exist_ok=True)

# MedSigLIP expects 448x448 input
IMAGE_SIZE = 448
BATCH_SIZE = 32
NUM_WORKERS = 2

# Define augmentation for training
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

# Validation/Test transforms (no augmentation)
val_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

# Load datasets
print("üìÇ Loading datasets...")
try:
    train_dataset = ImageFolder(TRAIN_DATA_PATH, transform=train_transforms)
    test_dataset = ImageFolder(TEST_DATA_PATH, transform=val_transforms)
    
    print(f"‚úÖ Training samples: {len(train_dataset)}")
    print(f"‚úÖ Test samples: {len(test_dataset)}")
    print(f"‚úÖ Number of classes: {len(train_dataset.classes)}")
    print(f"\nüìã Class labels: {train_dataset.classes}")
    
    # Class distribution
    print("\nüìä Class distribution (Training):")
    for cls_idx, cls_name in enumerate(train_dataset.classes):
        count = sum(1 for x, y in train_dataset if y == cls_idx)
        print(f"   {cls_name}: {count} images")
        
except Exception as e:
    print(f"‚ùå Error loading data: {e}")
    print(f"\nüìç Check that your data is in:")
    print(f"   - Training: {TRAIN_DATA_PATH}")
    print(f"   - Testing: {TEST_DATA_PATH}")
    print(f"\nüîß Expected structure:")
    print(f"   data/")
    print(f"   ‚îú‚îÄ‚îÄ train/")
    print(f"   ‚îÇ   ‚îú‚îÄ‚îÄ class1/")
    print(f"   ‚îÇ   ‚îú‚îÄ‚îÄ class2/")
    print(f"   ‚îÇ   ‚îî‚îÄ‚îÄ ...")
    print(f"   ‚îî‚îÄ‚îÄ test/")
    print(f"       ‚îú‚îÄ‚îÄ class1/")
    print(f"       ‚îú‚îÄ‚îÄ class2/")
    print(f"       ‚îî‚îÄ‚îÄ ...")

## 5Ô∏è‚É£ Create Data Loaders

In [None]:
# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"‚úÖ Train DataLoader: {len(train_loader)} batches")
print(f"‚úÖ Test DataLoader: {len(test_loader)} batches")

# Test loading a batch
print("\nüîç Testing batch loading...")
images, labels = next(iter(train_loader))
print(f"   Batch shape: {images.shape}")
print(f"   Labels: {labels[:5].tolist()}")
print("‚úÖ Data loading successful!")

## 6Ô∏è‚É£ Load & Configure MedSigLIP Model

In [None]:
from transformers import AutoModel, AutoProcessor
import torch.nn as nn

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Using device: {device}")

# Load MedSigLIP model and processor
print("\nüì• Loading MedSigLIP model...")

model_id = "google/MedSigLIP-2B"
print(f"   Model: {model_id}")

try:
    # Load model
    model = AutoModel.from_pretrained(model_id)
    processor = AutoProcessor.from_pretrained(model_id)
    
    print("‚úÖ MedSigLIP model loaded successfully!")
    
    # Model info
    print(f"\nüìä Model Architecture:")
    print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"   Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
except Exception as e:
    print(f"‚ùå Error loading model: {e}")
    print("\nüîß Troubleshooting:")
    print("   1. Check internet connection")
    print("   2. Ensure you have sufficient disk space (5+ GB)")
    print("   3. Try restarting the kernel")

## 7Ô∏è‚É£ Add Classification Head

In [None]:
# Create a classification model wrapper
class MedSigLIPClassifier(nn.Module):
    def __init__(self, medsiglip_model, num_classes):
        super().__init__()
        self.medsiglip = medsiglip_model
        
        # Get embedding dimension
        # MedSigLIP outputs embeddings of size 1152
        embed_dim = 1152
        
        # Add classification head
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, images):
        # Get embeddings from MedSigLIP
        with torch.no_grad():
            outputs = self.medsiglip(pixel_values=images)
            embeddings = outputs.image_embeds  # Shape: [batch_size, embed_dim]
        
        # Pass through classifier
        logits = self.classifier(embeddings)
        return logits

# Initialize classifier
num_classes = len(train_dataset.classes)
classifier = MedSigLIPClassifier(model, num_classes).to(device)

print(f"‚úÖ Classification head added!")
print(f"   Number of classes: {num_classes}")
print(f"\nüìä Classifier architecture:")
print(classifier.classifier)

## 8Ô∏è‚É£ Setup Training Configuration

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import json

# Training hyperparameters
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
WARMUP_STEPS = 500

# Loss function
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Optimizer (only train classifier head)
optimizer = optim.AdamW(
    classifier.classifier.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

# Learning rate scheduler
scheduler = CosineAnnealingLR(
    optimizer,
    T_max=len(train_loader) * NUM_EPOCHS,
    eta_min=1e-7
)

print("‚úÖ Training configuration:")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Optimizer: AdamW")
print(f"   Loss Function: CrossEntropyLoss (label smoothing=0.1)")
print(f"   Scheduler: CosineAnnealingLR")

## 9Ô∏è‚É£ Training Loop

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def train_epoch(model, train_loader, criterion, optimizer, scheduler, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    pbar = tqdm(train_loader, desc="Training")
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.classifier.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        # Metrics
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        pbar.set_postfix({'loss': loss.item():.4f})
    
    avg_loss = total_loss / len(train_loader)
    accuracy = accuracy_score(all_labels, all_preds)
    
    return avg_loss, accuracy

def evaluate(model, test_loader, criterion, device):
    """Evaluate on test set"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Evaluating")
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            pbar.set_postfix({'loss': loss.item():.4f})
    
    avg_loss = total_loss / len(test_loader)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    
    return avg_loss, accuracy, precision, recall, f1, all_preds, all_labels

print("‚úÖ Training and evaluation functions defined!")

## üîü Run Training

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'test_loss': [],
    'test_acc': [],
    'test_precision': [],
    'test_recall': [],
    'test_f1': []
}

best_accuracy = 0
best_model_path = os.path.join(OUTPUT_PATH, 'best_model.pt')

print("\n" + "="*70)
print("üöÄ STARTING TRAINING")
print("="*70)

try:
    for epoch in range(NUM_EPOCHS):
        print(f"\nüìä Epoch {epoch+1}/{NUM_EPOCHS}")
        print("-" * 70)
        
        # Train
        train_loss, train_acc = train_epoch(
            classifier, train_loader, criterion, optimizer, scheduler, device
        )
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        
        # Evaluate
        test_loss, test_acc, test_prec, test_rec, test_f1, preds, labels = evaluate(
            classifier, test_loader, criterion, device
        )
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        history['test_precision'].append(test_prec)
        history['test_recall'].append(test_rec)
        history['test_f1'].append(test_f1)
        
        # Print metrics
        print(f"\n   Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"   Test Loss:  {test_loss:.4f} | Test Acc:  {test_acc:.4f}")
        print(f"   Precision: {test_prec:.4f} | Recall: {test_rec:.4f} | F1: {test_f1:.4f}")
        
        # Save best model
        if test_acc > best_accuracy:
            best_accuracy = test_acc
            torch.save(classifier.state_dict(), best_model_path)
            print(f"   ‚≠ê Best model saved! (Accuracy: {best_accuracy:.4f})")

except KeyboardInterrupt:
    print("\n‚ö†Ô∏è  Training interrupted by user")
except Exception as e:
    print(f"\n‚ùå Training error: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*70)
print("‚úÖ TRAINING COMPLETED")
print("="*70)

## 1Ô∏è‚É£1Ô∏è‚É£ Results & Visualization

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Load best model
classifier.load_state_dict(torch.load(best_model_path))
classifier.eval()

# Get final predictions
with torch.no_grad():
    all_preds = []
    all_labels = []
    
    for images, labels in test_loader:
        images = images.to(device)
        outputs = classifier(images)
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Create visualizations
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('MedSigLIP Nail Disease Classification - Training Results', fontsize=16, fontweight='bold')

# Plot 1: Training Loss
ax = axes[0, 0]
ax.plot(history['train_loss'], label='Train Loss', marker='o')
ax.plot(history['test_loss'], label='Test Loss', marker='s')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Loss over Epochs')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: Accuracy
ax = axes[0, 1]
ax.plot(history['train_acc'], label='Train Accuracy', marker='o')
ax.plot(history['test_acc'], label='Test Accuracy', marker='s')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_title('Accuracy over Epochs')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 3: F1 Score
ax = axes[1, 0]
ax.plot(history['test_precision'], label='Precision', marker='o')
ax.plot(history['test_recall'], label='Recall', marker='s')
ax.plot(history['test_f1'], label='F1 Score', marker='^')
ax.set_xlabel('Epoch')
ax.set_ylabel('Score')
ax.set_title('Precision, Recall, F1 Score over Epochs')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
ax = axes[1, 1]
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax, 
            xticklabels=train_dataset.classes,
            yticklabels=train_dataset.classes)
ax.set_title('Confusion Matrix')
ax.set_ylabel('True Label')
ax.set_xlabel('Predicted Label')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'training_results.png'), dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Training results visualization saved!")

## 1Ô∏è‚É£2Ô∏è‚É£ Detailed Classification Report

In [None]:
from sklearn.metrics import classification_report, accuracy_score

# Final accuracy
final_accuracy = accuracy_score(all_labels, all_preds)

print("\n" + "="*70)
print("üìä FINAL CLASSIFICATION REPORT")
print("="*70)

# Overall metrics
print(f"\nüéØ Overall Metrics:")
print(f"   Final Test Accuracy: {final_accuracy:.4f} ({final_accuracy*100:.2f}%)")
print(f"   Best Accuracy: {best_accuracy:.4f} ({best_accuracy*100:.2f}%)")
print(f"   Total Test Samples: {len(all_labels)}")

# Per-class metrics
print(f"\nüìã Per-Class Performance:")
print(classification_report(all_labels, all_preds, 
                          target_names=train_dataset.classes,
                          digits=4))

# Save report to file
report_dict = classification_report(all_labels, all_preds,
                                   target_names=train_dataset.classes,
                                   output_dict=True)

with open(os.path.join(OUTPUT_PATH, 'classification_report.json'), 'w') as f:
    json.dump(report_dict, f, indent=2)

print("\n‚úÖ Classification report saved!")

## 1Ô∏è‚É£3Ô∏è‚É£ Save Model & Artifacts

In [None]:
# Save training history
history_path = os.path.join(OUTPUT_PATH, 'training_history.json')
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)
print(f"‚úÖ Training history saved: {history_path}")

# Save model metadata
metadata = {
    'model': 'MedSigLIP-2B with Custom Classifier Head',
    'num_classes': num_classes,
    'classes': train_dataset.classes,
    'image_size': IMAGE_SIZE,
    'final_accuracy': float(final_accuracy),
    'best_accuracy': float(best_accuracy),
    'num_epochs': NUM_EPOCHS,
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'total_parameters': sum(p.numel() for p in classifier.parameters()),
}

metadata_path = os.path.join(OUTPUT_PATH, 'model_metadata.json')
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)
print(f"‚úÖ Model metadata saved: {metadata_path}")

# Model checkpoint path
print(f"\nüì¶ Model Artifacts:")
print(f"   Best Model: {best_model_path}")
print(f"   Training History: {history_path}")
print(f"   Model Metadata: {metadata_path}")
print(f"   Classification Report: {os.path.join(OUTPUT_PATH, 'classification_report.json')}")
print(f"   Visualization: {os.path.join(OUTPUT_PATH, 'training_results.png')}")

## 1Ô∏è‚É£4Ô∏è‚É£ Inference on New Images

In [None]:
def predict_image(image_path, model, processor, device, class_names):
    """
    Predict nail disease for a single image
    
    Args:
        image_path: Path to image file
        model: Trained MedSigLIP classifier
        processor: MedSigLIP processor
        device: torch device
        class_names: List of class names
    
    Returns:
        prediction: Predicted class name
        confidence: Prediction confidence (0-1)
    """
    model.eval()
    
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image_tensor = val_transforms(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(image_tensor)
        probs = torch.softmax(outputs, dim=1)
        confidence, predicted_class = torch.max(probs, 1)
    
    class_name = class_names[predicted_class.item()]
    confidence = confidence.item()
    
    return class_name, confidence, image

print("‚úÖ Inference function defined!")
print("\nüìù Usage:")
print("   class_name, confidence, image = predict_image(")
print("       'path/to/image.jpg',")
print("       classifier,")
print("       processor,")
print("       device,")
print("       train_dataset.classes")
print("   )")

## 1Ô∏è‚É£5Ô∏è‚É£ Batch Inference & Visualization

In [None]:
# Get random test samples for visualization
num_samples = 16
indices = np.random.choice(len(test_dataset), num_samples, replace=False)

fig, axes = plt.subplots(4, 4, figsize=(16, 16))
fig.suptitle('MedSigLIP Predictions on Test Samples', fontsize=16, fontweight='bold')

classifier.eval()

with torch.no_grad():
    for idx, sample_idx in enumerate(indices):
        ax = axes[idx // 4, idx % 4]
        
        # Get sample
        image, label = test_dataset[sample_idx]
        image_input = image.unsqueeze(0).to(device)
        
        # Predict
        output = classifier(image_input)
        prob = torch.softmax(output, dim=1)[0]
        pred_class = output.argmax(dim=1).item()
        confidence = prob[pred_class].item()
        
        # Get class names
        true_class_name = train_dataset.classes[label]
        pred_class_name = train_dataset.classes[pred_class]
        
        # Plot
        image_numpy = image.permute(1, 2, 0).numpy()
        image_numpy = (image_numpy - image_numpy.min()) / (image_numpy.max() - image_numpy.min())
        ax.imshow(image_numpy)
        
        # Color code: green if correct, red if wrong
        color = 'green' if pred_class == label else 'red'
        
        ax.set_title(
            f'True: {true_class_name}\n'
            f'Pred: {pred_class_name}\n'
            f'Conf: {confidence:.2%}',
            color=color,
            fontweight='bold'
        )
        ax.axis('off')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'predictions_visualization.png'), dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Prediction visualization saved!")

## 1Ô∏è‚É£6Ô∏è‚É£ Export Model for Deployment

In [None]:
# Export classifier head as ONNX
print("üì§ Exporting model...\n")

# Create dummy input
dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).to(device)

onnx_path = os.path.join(OUTPUT_PATH, 'classifier_head.onnx')

try:
    torch.onnx.export(
        classifier.classifier,
        dummy_input,
        onnx_path,
        input_names=['image_embeddings'],
        output_names=['logits'],
        opset_version=14,
        dynamic_axes={'image_embeddings': {0: 'batch_size'},
                      'logits': {0: 'batch_size'}}
    )
    print(f"‚úÖ ONNX model exported: {onnx_path}")
except Exception as e:
    print(f"‚ö†Ô∏è  Could not export ONNX model: {e}")

# Also save as PyTorch format
torch_path = os.path.join(OUTPUT_PATH, 'classifier_head.pt')
torch.save(classifier.classifier.state_dict(), torch_path)
print(f"‚úÖ PyTorch model saved: {torch_path}")

# Create deployment package info
deployment_info = {
    'model_type': 'MedSigLIP Classifier Head',
    'framework': 'PyTorch',
    'input_shape': [1, 1152],
    'output_shape': [1, num_classes],
    'output_format': 'logits',
    'classes': train_dataset.classes,
    'preprocessing': 'ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])',
    'image_size': IMAGE_SIZE,
}

with open(os.path.join(OUTPUT_PATH, 'deployment_info.json'), 'w') as f:
    json.dump(deployment_info, f, indent=2)

print("\n‚úÖ All models exported successfully!")

## üéì Summary & Next Steps

In [None]:
print("\n" + "="*70)
print("‚úÖ FINE-TUNING COMPLETE!")
print("="*70)

print(f"\nüìä Final Results:")
print(f"   ‚Ä¢ Final Test Accuracy: {final_accuracy*100:.2f}%")
print(f"   ‚Ä¢ Best Accuracy: {best_accuracy*100:.2f}%")
print(f"   ‚Ä¢ Number of Classes: {num_classes}")
print(f"   ‚Ä¢ Training Time: ~{NUM_EPOCHS * 5:.0f} minutes (estimated)")

print(f"\nüìÅ Output Files:")
print(f"   ‚Ä¢ Best Model: {best_model_path}")
print(f"   ‚Ä¢ PyTorch Model: {torch_path}")
print(f"   ‚Ä¢ Training History: {history_path}")
print(f"   ‚Ä¢ Classification Report: {os.path.join(OUTPUT_PATH, 'classification_report.json')}")
print(f"   ‚Ä¢ Training Visualization: {os.path.join(OUTPUT_PATH, 'training_results.png')}")
print(f"   ‚Ä¢ Predictions Visualization: {os.path.join(OUTPUT_PATH, 'predictions_visualization.png')}")

print(f"\nüöÄ Next Steps:")
print(f"   1. Download all files from /content/output")
print(f"   2. Test on new nail images using predict_image() function")
print(f"   3. Deploy to mobile/web using exported ONNX or PyTorch models")
print(f"   4. Fine-tune with more data for better accuracy")
print(f"   5. Share results on GitHub with the MedSigLIP-Fine-tuning branch")

print(f"\nüîó GitHub Repository:")
print(f"   Repository: https://github.com/isumenuka/medsiglip-nail-disease-finetuning")
print(f"   Branch: MedSigLIP-Fine-tuning")

print("\n" + "="*70)
print("Thank you for using MedSigLIP Fine-tuning Notebook! üéâ")
print("="*70)