# 🩺 SwasthVedha - Skin Disease Detection Model Training

**Dataset:** Skin-Disease-Dataset (Kaggle)

**Goal:** Train a high-accuracy CNN model that:
- ✅ Learns patterns from all diseases in dataset
- ✅ Prevents overfitting (dropout, augmentation)
- ✅ Prevents underfitting (proper training)
- ✅ Uses transfer learning (EfficientNetB0)
- ✅ Detects uncertainty (confidence thresholds)
- ✅ Flags unknown cases

**Target Accuracy:** 85-95%

## 📦 Step 1: Setup & Install Dependencies

In [None]:
# Install required packages
!pip install -q kaggle torch torchvision efficientnet_pytorch

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import copy
from pathlib import Path
import json
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

# Check GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 🔑 Step 2: Kaggle API Setup

**Instructions:**
1. Go to https://www.kaggle.com/settings
2. Scroll to "API" section
3. Click "Create New API Token"
4. Upload the `kaggle.json` file below

In [None]:
from google.colab import files

# Upload kaggle.json
print("Please upload your kaggle.json file:")
uploaded = files.upload()

# Setup Kaggle credentials
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

print("✅ Kaggle API configured!")

## 📥 Step 3: Download Dataset

In [None]:
# Download Skin-Disease-Dataset from Kaggle
!kaggle datasets download -d subirbiswas19/skin-disease-dataset

# Extract
!unzip -q skin-disease-dataset.zip -d dataset

# Check structure
!ls -lh dataset/

print("\n✅ Dataset downloaded and extracted!")

## 🔍 Step 4: Analyze Dataset Structure

In [None]:
# Find the data directory
data_root = Path('dataset')

# Look for train/test folders
possible_paths = list(data_root.rglob('*'))
print("Dataset structure:")
for p in sorted(possible_paths)[:20]:  # Show first 20
    print(f"  {p}")

# Find actual data directory
data_dirs = [p for p in data_root.rglob('*') if p.is_dir() and any((p / subdir).exists() for subdir in ['train', 'test', 'Train', 'Test'])]

if data_dirs:
    data_dir = data_dirs[0]
    print(f"\n✅ Found data directory: {data_dir}")
else:
    # Assume top level
    data_dir = data_root
    print(f"\n⚠️ Using root directory: {data_dir}")

# Check for train/test folders (handle different naming)
train_dir = None
test_dir = None

for variant in ['train', 'Train', 'training', 'Training']:
    if (data_dir / variant).exists():
        train_dir = data_dir / variant
        break

for variant in ['test', 'Test', 'testing', 'Testing', 'val', 'validation']:
    if (data_dir / variant).exists():
        test_dir = data_dir / variant
        break

print(f"\nTrain directory: {train_dir}")
print(f"Test directory: {test_dir}")

# List classes
if train_dir and train_dir.exists():
    classes = sorted([d.name for d in train_dir.iterdir() if d.is_dir()])
    print(f"\n📊 Found {len(classes)} disease classes:")
    for i, cls in enumerate(classes, 1):
        num_images = len(list((train_dir / cls).glob('*')))
        print(f"  {i}. {cls}: {num_images} images")
else:
    print("\n⚠️ Could not find train directory. Please check dataset structure.")

## 🔄 Step 5: Data Augmentation & Preprocessing

**Prevents Overfitting with:**
- Random flips, rotations
- Color jittering
- Random erasing

In [None]:
# Advanced data augmentation (prevents overfitting)
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.3),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.2, scale=(0.02, 0.15))
    ]),
    'val': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

print("✅ Data augmentation configured!")
print("   - Random crops, flips, rotations")
print("   - Color jittering")
print("   - Random erasing")
print("   → This prevents overfitting!")

## 📊 Step 6: Create Data Loaders

In [None]:
# Create datasets
if test_dir and test_dir.exists():
    # Separate train/test
    image_datasets = {
        'train': datasets.ImageFolder(train_dir, data_transforms['train']),
        'val': datasets.ImageFolder(test_dir, data_transforms['val'])
    }
else:
    # Split training set
    full_dataset = datasets.ImageFolder(train_dir, data_transforms['train'])
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
    
    # Apply validation transforms to val set
    val_dataset.dataset.transform = data_transforms['val']
    
    image_datasets = {
        'train': train_dataset,
        'val': val_dataset
    }

# Create data loaders
dataloaders = {
    'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=32, shuffle=True, num_workers=2),
    'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=32, shuffle=False, num_workers=2)
}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].dataset.classes if hasattr(image_datasets['train'], 'dataset') else image_datasets['train'].classes
num_classes = len(class_names)

print(f"✅ Data loaders created!")
print(f"   Training samples: {dataset_sizes['train']}")
print(f"   Validation samples: {dataset_sizes['val']}")
print(f"   Number of classes: {num_classes}")
print(f"   Classes: {class_names}")

## 🏗️ Step 7: Build Model (Transfer Learning)

**Using ResNet50 with Transfer Learning:**
- Pre-trained on ImageNet (1M images)
- Already knows patterns, textures, shapes
- We fine-tune for skin diseases

In [None]:
# Create model with transfer learning
def create_model(num_classes):
    # Load pre-trained ResNet50
    model = models.resnet50(weights='IMAGENET1K_V2')
    
    # Freeze early layers (keep ImageNet knowledge)
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace final classifier with custom head
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),  # Prevents overfitting
        nn.Linear(num_features, 512),
        nn.ReLU(),
        nn.BatchNorm1d(512),
        nn.Dropout(0.3),  # More dropout = less overfitting
        nn.Linear(512, num_classes)
    )
    
    return model

model = create_model(num_classes)
model = model.to(device)

print("✅ Model created!")
print(f"   Architecture: ResNet50")
print(f"   Transfer Learning: Yes (ImageNet pre-trained)")
print(f"   Dropout: 0.5 and 0.3 (prevents overfitting)")
print(f"   Output classes: {num_classes}")

## 🎯 Step 8: Training Configuration

**Prevents Underfitting:**
- Proper learning rate (0.001)
- Learning rate scheduler
- Enough epochs (50)

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer (Adam is good for medical images)
optimizer = optim.Adam(model.fc.parameters(), lr=0.001, weight_decay=0.0001)

# Learning rate scheduler (prevents underfitting)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)

print("✅ Training configuration:")
print(f"   Loss: CrossEntropyLoss")
print(f"   Optimizer: Adam")
print(f"   Learning Rate: 0.001")
print(f"   Scheduler: ReduceLROnPlateau (prevents underfitting)")
print(f"   Epochs: 50")

## 🚀 Step 9: Training Loop

**With Early Stopping** to prevent overfitting

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=50):
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    patience_counter = 0
    patience = 10  # Early stopping patience
    
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 50)
        
        # Each epoch has training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0
            
            # Iterate over data
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            # Save best model
            if phase == 'val':
                scheduler.step(epoch_acc)
                
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    patience_counter = 0
                    print(f'✅ New best model! Acc: {best_acc:.4f}')
                else:
                    patience_counter += 1
        
        # Early stopping
        if patience_counter >= patience:
            print(f'\n⚠️ Early stopping at epoch {epoch+1}')
            print(f'   Validation accuracy hasn\'t improved for {patience} epochs')
            break
    
    time_elapsed = time.time() - since
    print(f'\n✅ Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'   Best validation Acc: {best_acc:.4f}')
    
    model.load_state_dict(best_model_wts)
    return model, history

# Train the model
model, history = train_model(model, criterion, optimizer, scheduler, num_epochs=50)

## 📈 Step 10: Visualize Training

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

# Accuracy
plt.subplot(1, 2, 1)
plt.plot(history['train_acc'], label='Train')
plt.plot(history['val_acc'], label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Model Accuracy')
plt.legend()
plt.grid(True)

# Loss
plt.subplot(1, 2, 2)
plt.plot(history['train_loss'], label='Train')
plt.plot(history['val_loss'], label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Model Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

print("\n📊 Training Analysis:")
print(f"   Final Train Acc: {history['train_acc'][-1]:.4f}")
print(f"   Final Val Acc: {history['val_acc'][-1]:.4f}")
gap = history['train_acc'][-1] - history['val_acc'][-1]
print(f"   Accuracy Gap: {gap:.4f}")
if gap < 0.1:
    print("   ✅ Good! No overfitting detected")
elif gap < 0.2:
    print("   ⚠️ Slight overfitting")
else:
    print("   ❌ Overfitting detected!")

## 💾 Step 11: Save Model

In [None]:
# Save model with metadata
torch.save({
    'model_state_dict': model.state_dict(),
    'model_architecture': 'ResNet50',
    'num_classes': num_classes,
    'class_names': class_names,
    'best_val_accuracy': history['val_acc'][-1],
    'training_history': history,
    'model_info': {
        'architecture': 'ResNet50 with Transfer Learning',
        'input_size': (224, 224),
        'num_classes': num_classes,
        'dropout_rate': 0.5,
        'regularization': 'Dropout + BatchNorm + Data Augmentation',
        'best_validation_accuracy': history['val_acc'][-1]
    }
}, 'skin_classifier.pth')

# Save class mapping
class_mapping = {str(i): name for i, name in enumerate(class_names)}
with open('skin_classes.json', 'w') as f:
    json.dump(class_mapping, f, indent=2)

print("✅ Model saved!")
print("   Files: skin_classifier.pth, skin_classes.json")

# Download files
from google.colab import files
files.download('skin_classifier.pth')
files.download('skin_classes.json')

print("\n📥 Files downloaded! Copy them to your backend/models/ folder")

## ✅ Training Complete!

### Next Steps:

1. **Download the model files** (done above)
2. **Copy to your backend:**
   ```bash
   backend/models/skin_classifier.pth
   backend/models/skin_classes.json
   ```
3. **Create API endpoint** for predictions
4. **Test with real images**

### Model Features:

✅ Transfer learning (ImageNet pre-trained)  
✅ Dropout (prevents overfitting)  
✅ Data augmentation (prevents overfitting)  
✅ Early stopping (prevents overfitting)  
✅ Learning rate scheduler (prevents underfitting)  
✅ Proper training duration  

### Confidence Threshold:

In your backend, set confidence threshold:
- **≥ 70%**: Show prediction
- **< 70%**: Show "Uncertain - consult a doctor"

This handles unknown diseases! 🎯