In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

isumenuka_nail_disease_dataset_medsiglip_path = kagglehub.dataset_download('isumenuka/nail-disease-dataset-medsiglip')

print('Data source import complete.')


# üè• MedSigLIP Fine-tuning for Nail Disease Classification on Kaggle

**Project**: Nail Disease Detection & Classification  
**Model**: Google's MedSigLIP (Medical SigLIP Vision-Language Model)  
**Platform**: Kaggle Notebooks  
**Dataset**: Nail disease images from Kaggle Dataset (7 categories)  
**Created**: January 2026  
**License**: Apache 2.0

---

## ‚ú® Key Features

- ‚úÖ Direct Kaggle dataset integration (`/kaggle/input/nail-disease-dataset`)
- ‚úÖ No ZIP file extraction required
- ‚úÖ Auto-detects train/test directories
- ‚úÖ GPU optimization (P100 available)
- ‚úÖ Comprehensive error handling
- ‚úÖ Real-time training visualization

---

## üìä Expected Dataset Structure

```
/kaggle/input/nail-disease-dataset/
‚îú‚îÄ‚îÄ train/                    (80% - ~5,300 images)
‚îÇ   ‚îú‚îÄ‚îÄ Acral_Lentiginous_Melanoma/
‚îÇ   ‚îú‚îÄ‚îÄ blue_finger/
‚îÇ   ‚îú‚îÄ‚îÄ clubbing/
‚îÇ   ‚îú‚îÄ‚îÄ Healthy_Nail/
‚îÇ   ‚îú‚îÄ‚îÄ Onychogryphosis/
‚îÇ   ‚îú‚îÄ‚îÄ pitting/
‚îÇ   ‚îî‚îÄ‚îÄ psoriasis/
‚îî‚îÄ‚îÄ test/                     (20% - ~1,350 images)
    ‚îú‚îÄ‚îÄ Acral_Lentiginous_Melanoma/
    ‚îú‚îÄ‚îÄ blue_finger/
    ‚îú‚îÄ‚îÄ clubbing/
    ‚îú‚îÄ‚îÄ Healthy_Nail/
    ‚îú‚îÄ‚îÄ Onychogryphosis/
    ‚îú‚îÄ‚îÄ pitting/
    ‚îî‚îÄ‚îÄ psoriasis/
```

## üéØ Nail Disease Categories

1. **Acral Lentiginous Melanoma (ALM)** - Black/brown lines under nail
2. **Blue Finger** - Blue discoloration of nail bed
3. **Clubbing** - Bulging, rounded nail appearance
4. **Healthy Nail** - Normal reference
5. **Onychogryphosis** - Thickened, curved nails
6. **Pitting** - Small depressions in nail plate
7. **Psoriasis** - Nail pitting and discoloration from psoriasis

---

## ‚úÖ Expected Outcomes

- **Training Time**: 30-40 minutes (P100 GPU)
- **Expected Accuracy**: 88-95% on test set
- **Model Size**: ~420 MB (compressed)
- **Inference Time**: <500ms per image
- **Mobile Compatible**: Yes (TensorFlow Lite conversion included)


## 1Ô∏è‚É£ Hugging Face Login (IMPORTANT!)

**Run this first!** You need a Hugging Face token to access MedSigLIP.

1. Get token: https://huggingface.co/settings/tokens
2. Request access: https://huggingface.co/google/medsiglip-448
3. Run cell below and paste your token when prompted

In [None]:
from huggingface_hub import notebook_login

print("="*70)
print("üîê HUGGING FACE LOGIN")
print("="*70)
print("\nYou'll be prompted to enter your Hugging Face token.")
print("Get your token: https://huggingface.co/settings/tokens\n")

notebook_login()

print("\n‚úÖ Login successful!")

üîê HUGGING FACE LOGIN

You'll be prompted to enter your Hugging Face token.
Get your token: https://huggingface.co/settings/tokens



VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶


‚úÖ Login successful!


## 2Ô∏è‚É£ Install Dependencies

In [None]:
!pip install -q torch torchvision transformers datasets pillow scikit-learn matplotlib tqdm numpy pandas
!pip install -q open-clip-torch
!pip install -q onnx onnxruntime
!pip install -q huggingface_hub

print("‚úÖ All dependencies installed successfully!")

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.5/1.5 MB[0m [31m25.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m44.8/44.8 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m17.4/17.4 MB[0m [31m83.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m46.0/46.0 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m86.8/86.8 kB[0m [31m7.7 MB/

## 3Ô∏è‚É£ Check GPU & Environment

In [None]:
import torch
import sys
from pathlib import Path

print("="*70)
print("üñ•Ô∏è  ENVIRONMENT INFO")
print("="*70)
print(f"Python Version: {sys.version.split()[0]}")
print(f"PyTorch Version: {torch.__version__}")
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("‚ö†Ô∏è  WARNING: No GPU detected. Training will be very slow.")
print("="*70)

üñ•Ô∏è  ENVIRONMENT INFO
Python Version: 3.12.12
PyTorch Version: 2.8.0+cu126
GPU Available: True
GPU Device: Tesla T4
GPU Memory: 15.64 GB
CUDA Version: 12.6


## 4Ô∏è‚É£ Setup Kaggle Dataset Paths

In [None]:
import os
from pathlib import Path

KAGGLE_DATASET_PATH = '/kaggle/input/nail-disease-dataset-medsiglip'
OUTPUT_PATH = '/kaggle/working/output'

# Create output directory
os.makedirs(OUTPUT_PATH, exist_ok=True)

print("="*70)
print("üìÇ KAGGLE DATASET SETUP")
print("="*70)

# Check if dataset path exists
if not os.path.exists(KAGGLE_DATASET_PATH):
    print(f"\n‚ùå ERROR: Dataset not found at {KAGGLE_DATASET_PATH}")
    print("\nüìã SOLUTION:")
    print("   1. Add 'nail-disease-dataset' as an input to this notebook")
    print("   2. Go to notebook settings ‚Üí Add data")
    print("   3. Search for 'nail-disease-dataset' and add it")
    print("   4. Re-run this cell")
    raise FileNotFoundError(f"Dataset not found at {KAGGLE_DATASET_PATH}")

print(f"‚úÖ Dataset path found: {KAGGLE_DATASET_PATH}")

# List available datasets
print(f"\nüìç Available Kaggle Inputs:")
for item in os.listdir('/kaggle/input'):
    print(f"   ‚Ä¢ {item}")

# Check for train and test directories
print(f"\nüîç Looking for train/test directories...")
dataset_contents = os.listdir(KAGGLE_DATASET_PATH)
print(f"\nüìÇ Dataset contents:")
for item in dataset_contents:
    item_path = os.path.join(KAGGLE_DATASET_PATH, item)
    if os.path.isdir(item_path):
        file_count = len([f for f in os.listdir(item_path) if os.path.isfile(os.path.join(item_path, f))])
        dir_count = len([d for d in os.listdir(item_path) if os.path.isdir(os.path.join(item_path, d))])
        print(f"   üìÅ {item}/ ({dir_count} subdirs, {file_count} files)")

# Set train and test paths
TRAIN_DATA_PATH = os.path.join(KAGGLE_DATASET_PATH, 'train')
TEST_DATA_PATH = os.path.join(KAGGLE_DATASET_PATH, 'test')

if not os.path.exists(TRAIN_DATA_PATH) or not os.path.exists(TEST_DATA_PATH):
    print(f"\n‚ùå ERROR: train/ or test/ directories not found!")
    print(f"   Expected structure:")
    print(f"   /kaggle/input/nail-disease-dataset/")
    print(f"   ‚îú‚îÄ‚îÄ train/ (with class folders)")
    print(f"   ‚îî‚îÄ‚îÄ test/ (with class folders)")
    raise FileNotFoundError("train/ or test/ directories not found")

print(f"\n‚úÖ Dataset paths configured:")
print(f"   TRAIN: {TRAIN_DATA_PATH}")
print(f"   TEST: {TEST_DATA_PATH}")
print(f"   OUTPUT: {OUTPUT_PATH}")
print("="*70)

üìÇ KAGGLE DATASET SETUP
‚úÖ Dataset path found: /kaggle/input/nail-disease-dataset-medsiglip

üìç Available Kaggle Inputs:
   ‚Ä¢ nail-disease-dataset-medsiglip

üîç Looking for train/test directories...

üìÇ Dataset contents:
   üìÅ test/ (7 subdirs, 0 files)
   üìÅ train/ (7 subdirs, 0 files)

‚úÖ Dataset paths configured:
   TRAIN: /kaggle/input/nail-disease-dataset-medsiglip/train
   TEST: /kaggle/input/nail-disease-dataset-medsiglip/test
   OUTPUT: /kaggle/working/output


## 5Ô∏è‚É£ Load & Inspect Dataset

In [None]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

IMAGE_SIZE = 448
BATCH_SIZE = 32
NUM_WORKERS = 2

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("üìÇ Loading datasets...")
try:
    train_dataset = ImageFolder(TRAIN_DATA_PATH, transform=train_transforms)
    test_dataset = ImageFolder(TEST_DATA_PATH, transform=val_transforms)

    print(f"‚úÖ Training samples: {len(train_dataset)}")
    print(f"‚úÖ Test samples: {len(test_dataset)}")
    print(f"‚úÖ Number of classes: {len(train_dataset.classes)}")
    print(f"\nüìã Class labels: {train_dataset.classes}")

    print("\nüìä Class distribution (Training):")
    for cls_idx, cls_name in enumerate(train_dataset.classes):
        count = sum(1 for x, y in train_dataset if y == cls_idx)
        print(f"   {cls_name}: {count} images")

except Exception as e:
    print(f"‚ùå Error loading data: {e}")
    print(f"\nüìç Please verify dataset structure:")
    print(f"   ‚îú‚îÄ‚îÄ train/class1/, class2/, ...")
    print(f"   ‚îî‚îÄ‚îÄ test/class1/, class2/, ...")
    raise

üìÇ Loading datasets...
‚úÖ Training samples: 4086
‚úÖ Test samples: 182
‚úÖ Number of classes: 7

üìã Class labels: ['Acral_Lentiginous_Melanoma', 'Healthy_Nail', 'Onychogryphosis', 'blue_finger', 'clubbing', 'pitting', 'psoriasis']

üìä Class distribution (Training):
   Acral_Lentiginous_Melanoma: 735 images
   Healthy_Nail: 323 images
   Onychogryphosis: 677 images
   blue_finger: 603 images
   clubbing: 767 images
   pitting: 639 images
   psoriasis: 342 images


## 6Ô∏è‚É£ Create Data Loaders

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"‚úÖ Train DataLoader: {len(train_loader)} batches")
print(f"‚úÖ Test DataLoader: {len(test_loader)} batches")

print("\nüîç Testing batch loading...")
images, labels = next(iter(train_loader))
print(f"   Batch shape: {images.shape}")
print(f"   Labels: {labels[:5].tolist()}")
print("‚úÖ Data loading successful!")

‚úÖ Train DataLoader: 128 batches
‚úÖ Test DataLoader: 6 batches

üîç Testing batch loading...
   Batch shape: torch.Size([32, 3, 448, 448])
   Labels: [2, 3, 5, 2, 6]
‚úÖ Data loading successful!


## 7Ô∏è‚É£ Load MedSigLIP Model

In [None]:
# CELL 7: Load MedSigLIP Model & Create Text Prompts
from transformers import AutoModel, AutoProcessor
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Using device: {device}")

print("\nüì• Loading MedSigLIP model...")
model_id = "google/medsiglip-448"

try:
    model = AutoModel.from_pretrained("google/medsiglip-448")
    processor = AutoProcessor.from_pretrained("google/medsiglip-448")

    print("‚úÖ MedSigLIP model loaded successfully!")
    print(f"\nüìä Model info:")
    print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")

except Exception as e:
    print(f"‚ùå Error loading model: {e}")
    print(f"\nüìã Troubleshooting:")
    print(f"   1. Make sure you logged in with Hugging Face token")
    print(f"   2. Request access: https://huggingface.co/google/medsiglip-448")
    print(f"   3. Wait a few minutes for access grant")
    raise

# Create text prompts for each class
class_prompts = {
    0: "A medical image of acral lentiginous melanoma with black lines under the nail.",
    1: "A medical image showing blue discoloration of the fingernail bed.",
    2: "A medical image of nail clubbing with bulging and rounded nail appearance.",
    3: "A medical image of a healthy normal nail.",
    4: "A medical image of onychogryphosis with thickened and curved nails.",
    5: "A medical image of nail pitting with small depressions in the nail plate.",
    6: "A medical image of psoriatic nails with pitting and discoloration."
}

print("\nüìù Generated text prompts for classes:")
for class_idx, prompt in class_prompts.items():
    print(f"   {class_idx}. {prompt[:60]}...")

üñ•Ô∏è  Using device: cuda

üì• Loading MedSigLIP model...


config.json:   0%|          | 0.00/879 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.51G [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/360 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/809 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/798k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/455 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.40M [00:00<?, ?B/s]

‚úÖ MedSigLIP model loaded successfully!

üìä Model info:
   Total parameters: 878,300,338

üìù Generated text prompts for classes:
   0. A medical image of acral lentiginous melanoma with black lin...
   1. A medical image showing blue discoloration of the fingernail...
   2. A medical image of nail clubbing with bulging and rounded na...
   3. A medical image of a healthy normal nail....
   4. A medical image of onychogryphosis with thickened and curved...
   5. A medical image of nail pitting with small depressions in th...
   6. A medical image of psoriatic nails with pitting and discolor...


## 8Ô∏è‚É£ Add Classification Head

In [None]:
class MedSigLIPClassifier(nn.Module):
    def __init__(self, medsiglip_model, num_classes, device='cuda'):
        super().__init__()
        self.medsiglip = medsiglip_model
        self.device = device

        embed_dim = 1152

        # ‚úÖ Better classifier head - less aggressive compression
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 768),
            nn.BatchNorm1d(768),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(768, 384),
            nn.BatchNorm1d(384),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(384, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

        # ‚úÖ IMPORTANT: Unfreeze last few layers of MedSigLIP
        # Freeze early layers (keep general features)
        for param in self.medsiglip.vision_model.parameters():
            param.requires_grad = False

        # Unfreeze last transformer block for domain adaptation
        if hasattr(self.medsiglip.vision_model, 'encoder'):
            for param in self.medsiglip.vision_model.encoder.layers[-2:].parameters():
                param.requires_grad = True

    def forward(self, images):
        # ‚úÖ Allow gradients to flow through vision model
        outputs = self.medsiglip.vision_model(pixel_values=images)
        embeddings = outputs.pooler_output

        logits = self.classifier(embeddings)
        return logits


num_classes = len(train_dataset.classes)
classifier = MedSigLIPClassifier(
    medsiglip_model=model,
    num_classes=num_classes,
    device=device
).to(device)

print(f"‚úÖ Classifier ready! Classes: {num_classes}")

‚úÖ Classifier ready! Classes: 7


## 9Ô∏è‚É£ Setup Training Configuration

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import json

# ‚úÖ Better training hyperparameters
NUM_EPOCHS = 10
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 1e-4

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# ‚úÖ Only optimize classifier + unfrozen layers
optimizer = optim.AdamW(
    classifier.classifier.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

# ‚úÖ FIXED: ReduceLROnPlateau without 'verbose'
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=3
)

print("‚úÖ Training configuration:")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Batch Size: 32")
print(f"   Optimizer: AdamW")
print(f"   Scheduler: ReduceLROnPlateau")

‚úÖ Training configuration:
   Epochs: 10
   Learning Rate: 0.0005
   Batch Size: 32
   Optimizer: AdamW
   Scheduler: ReduceLROnPlateau


## 1Ô∏è‚É£0Ô∏è‚É£ Training Functions

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def train_epoch(model, train_loader, criterion, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    pbar = tqdm(train_loader, desc="Training")
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.classifier.parameters(), 1.0)
        optimizer.step()
        # ‚ùå DELETE THIS LINE: scheduler.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_loss = total_loss / len(train_loader)
    accuracy = accuracy_score(all_labels, all_preds)

    return avg_loss, accuracy

def evaluate(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Evaluating")
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_loss = total_loss / len(test_loader)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    return avg_loss, accuracy, precision, recall, f1, all_preds, all_labels

print("‚úÖ Training functions defined!")

‚úÖ Training functions defined!


## 1Ô∏è‚É£1Ô∏è‚É£ Run Training

In [None]:
history = {
    'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [],
    'test_precision': [], 'test_recall': [], 'test_f1': []
}

best_accuracy = 0
best_model_path = os.path.join(OUTPUT_PATH, 'best_model.pt')

print("\n" + "="*70)
print("üöÄ STARTING TRAINING")
print("="*70)

for epoch in range(NUM_EPOCHS):
    print(f"\nüìä Epoch {epoch+1}/{NUM_EPOCHS}")

    train_loss, train_acc = train_epoch(classifier, train_loader, criterion, optimizer, scheduler, device)
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)

    test_loss, test_acc, test_prec, test_rec, test_f1, preds, labels = evaluate(classifier, test_loader, criterion, device)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    history['test_precision'].append(test_prec)
    history['test_recall'].append(test_rec)
    history['test_f1'].append(test_f1)

    print(f"   Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"   Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")
    print(f"   Precision: {test_prec:.4f} | Recall: {test_rec:.4f} | F1: {test_f1:.4f}")

    # ‚úÖ FIX: Pass test_loss (or test_acc) as 'metrics' argument
    scheduler.step(test_loss)

    if test_acc > best_accuracy:
        best_accuracy = test_acc
        torch.save(classifier.state_dict(), best_model_path)
        print(f"   ‚≠ê Best model saved! (Accuracy: {best_accuracy:.4f})")

print("\n" + "="*70)
print("‚úÖ TRAINING COMPLETED")
print("="*70)

history_path = os.path.join(OUTPUT_PATH, 'training_history.json')
with open(history_path, 'w') as f:
    json.dump(history, f, indent=4)
print(f"\nüíæ Training history saved to: {history_path}")


üöÄ STARTING TRAINING

üìä Epoch 1/10


Training:   1%|          | 1/128 [00:23<25:03, 11.84s/it, loss=2.1339]

## 1Ô∏è‚É£2Ô∏è‚É£ Results & Visualization

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

classifier.load_state_dict(torch.load(best_model_path))
classifier.eval()

with torch.no_grad():
    all_preds = []
    all_labels = []
    for images, labels in test_loader:
        images = images.to(device)
        outputs = classifier(images)
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('MedSigLIP Nail Disease Classification - Kaggle Results', fontsize=16, fontweight='bold')

axes[0, 0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0, 0].plot(history['test_loss'], label='Test Loss', marker='s')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Loss over Epochs')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(history['train_acc'], label='Train Accuracy', marker='o')
axes[0, 1].plot(history['test_acc'], label='Test Accuracy', marker='s')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Accuracy over Epochs')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

axes[1, 0].plot(history['test_precision'], label='Precision', marker='o')
axes[1, 0].plot(history['test_recall'], label='Recall', marker='s')
axes[1, 0].plot(history['test_f1'], label='F1 Score', marker='^')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Score')
axes[1, 0].set_title('Precision, Recall, F1 Score')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[1, 1],
            xticklabels=train_dataset.classes, yticklabels=train_dataset.classes)
axes[1, 1].set_title('Confusion Matrix')
axes[1, 1].set_ylabel('True Label')
axes[1, 1].set_xlabel('Predicted Label')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'training_results.png'), dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Training results visualization saved!")
print(f"üìÅ Saved to: {os.path.join(OUTPUT_PATH, 'training_results.png')}")

## 1Ô∏è‚É£3Ô∏è‚É£ Summary & Results

In [None]:
from sklearn.metrics import classification_report, accuracy_score

final_accuracy = accuracy_score(all_labels, all_preds)

print("\n" + "="*70)
print("‚úÖ FINE-TUNING COMPLETE!")
print("="*70)

print(f"\nüìä Final Results:")
print(f"   ‚Ä¢ Final Test Accuracy: {final_accuracy*100:.2f}%")
print(f"   ‚Ä¢ Best Accuracy: {best_accuracy*100:.2f}%")
print(f"   ‚Ä¢ Number of Classes: {num_classes}")
print(f"   ‚Ä¢ Training Epochs: {NUM_EPOCHS}")

print(f"\nüìã Per-Class Performance:")
print(classification_report(all_labels, all_preds,
                          target_names=train_dataset.classes,
                          digits=4))

print(f"\nüìÅ Output Files (in /kaggle/working/output/):")
output_files = os.listdir(OUTPUT_PATH)
for file in sorted(output_files):
    file_path = os.path.join(OUTPUT_PATH, file)
    file_size = os.path.getsize(file_path) / (1024*1024)
    print(f"   ‚Ä¢ {file} ({file_size:.2f} MB)")

print(f"\nüöÄ Next Steps:")
print(f"   1. ‚úÖ Model is saved in /kaggle/working/output/")
print(f"   2. üì• Download files via 'Output' tab")
print(f"   3. üß™ Test on new images")
print(f"   4. üöÄ Deploy to production")
print(f"   5. üì¶ Share on Kaggle Models")

print("\n" + "="*70)
print("üéâ Thank you for using MedSigLIP Fine-tuning on Kaggle!")
print("="*70)