In [1]:
#!/usr/bin/env python
# coding: utf-8

# 1. Setting Up

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import models, transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve
from sklearn.preprocessing import label_binarize
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Set seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define paths
DATA_PATH = "../data/processed/images"


In [5]:
# Set device to CPI
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class FitzpatrickDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None, target_condition=None):
        """
        Args:
            csv_file: Path to the CSV file with annotations
            root_dir: Directory with all the images
            transform: Optional transform to be applied on a sample
            target_condition: If specified, only include this skin condition
        """
        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        
        # Filter by condition if specified
        if target_condition:
            self.data_frame = self.data_frame[self.data_frame['label'] == target_condition]
        
        # Convert three_partition_label to binary (malignant vs. non-malignant)
        self.data_frame['binary_label'] = self.data_frame['three_partition_label'].apply(
            lambda x: 1 if x == 'malignant' else 0
        )
        
        # Group skin types into light (1-3) and dark (4-6)
        self.data_frame['skin_group'] = self.data_frame['fitzpatrick_scale'].apply(
            lambda x: 0 if x <= 3 else 1  # 0 for light, 1 for dark
        )
        
        # Create a mapping for unique conditions
        self.unique_conditions = self.data_frame['label'].unique()
        self.condition_to_idx = {condition: idx for idx, condition in enumerate(self.unique_conditions)}
        
        # Add multi-class label
        self.data_frame['condition_idx'] = self.data_frame['label'].apply(lambda x: self.condition_to_idx[x])
    
    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        img_path = self.data_frame.iloc[idx]['image_path']
        
        try:
            # Use PIL to load image
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a placeholder image if loading fails
            image = Image.new('RGB', (224, 224), color='gray')
        
        # Get labels
        binary_label = self.data_frame.iloc[idx]['binary_label']
        skin_type = self.data_frame.iloc[idx]['fitzpatrick_scale']
        skin_group = self.data_frame.iloc[idx]['skin_group']
        condition_idx = self.data_frame.iloc[idx]['condition_idx']
        
        if self.transform:
            image = self.transform(image)
        
        sample = {
            'image': image,
            'binary_label': binary_label,
            'skin_type': skin_type,
            'skin_group': skin_group,
            'condition_idx': condition_idx
        }
        
        return sample


# 2. DataSet

In [None]:
class FitzpatrickDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None, target_condition=None):
        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        
        # Filter by condition if specified
        if target_condition:
            self.data_frame = self.data_frame[self.data_frame['label'] == target_condition]
        
        # Group skin types into light (1-3) and dark (4-6)
        self.data_frame['skin_group'] = self.data_frame['fitzpatrick_scale'].apply(
            lambda x: 0 if x <= 3 else 1  # 0 for light, 1 for dark
        )
        
        # Create a mapping for unique conditions
        self.unique_conditions = self.data_frame['label'].unique()
        self.condition_to_idx = {condition: idx for idx, condition in enumerate(self.unique_conditions)}
        
        # Add multi-class label
        self.data_frame['condition_idx'] = self.data_frame['label'].apply(lambda x: self.condition_to_idx[x])
    
    def __getitem__(self, idx):
        # ... rest of the code ...
        
        sample = {
            'image': image,
            'skin_type': skin_type,
            'skin_group': skin_group,
            'condition_idx': condition_idx
        }
        
        return sample


# 3. Load Data

In [None]:
def load_data(batch_size=32, target_condition=None, max_samples=None):
    """
    Load and prepare the datasets with optimization for speed
    """
    # Define transforms with smaller image size
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # Smaller size for faster training
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet normalization
    ])
    
    # Create datasets
    train_dataset = FitzpatrickDataset(
        csv_file=os.path.join(DATA_PATH, 'train_split.csv'),
        root_dir=DATA_PATH,
        transform=transform,
        target_condition=target_condition
    )
    
    # Optionally limit the number of samples for faster training
    if max_samples is not None and len(train_dataset) > max_samples:
        # Get indices for stratified sampling
        indices = []
        labels = train_dataset.data_frame['label'].values
        unique_labels = np.unique(labels)
        samples_per_label = max_samples // len(unique_labels)
        
        for label in unique_labels:
            label_indices = np.where(labels == label)[0]
            if len(label_indices) > samples_per_label:
                label_indices = np.random.choice(label_indices, samples_per_label, replace=False)
            indices.extend(label_indices)
        
        # Create a subset of the dataset
        train_dataset = torch.utils.data.Subset(train_dataset, indices)
    
    val_dataset = FitzpatrickDataset(
        csv_file=os.path.join(DATA_PATH, 'val_split.csv'),
        root_dir=DATA_PATH,
        transform=transform,
        target_condition=target_condition
    )
    
    test_dataset = FitzpatrickDataset(
        csv_file=os.path.join(DATA_PATH, 'test_split.csv'),
        root_dir=DATA_PATH,
        transform=transform,
        target_condition=target_condition
    )
    
    # Create dataloaders with optimization
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=4,  # Adjust based on your CPU cores
        pin_memory=True  # Faster data transfer to GPU
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # Get unique conditions (handle the case when train_dataset is a Subset)
    if isinstance(train_dataset, torch.utils.data.Subset):
        # Get the original dataset
        original_dataset = train_dataset.dataset
        unique_conditions = original_dataset.unique_conditions
    else:
        unique_conditions = train_dataset.unique_conditions
    
    return train_loader, val_loader, test_loader, unique_conditions


# 4. Model Architecture

Using ResNet 18 instead of VGG16 as paper for faster training

In [None]:
class ResNet18Classifier(nn.Module):
    def __init__(self, num_classes=2, pretrained=True):
        """
        ResNet18-based classifier that can be used for binary or multi-class classification
        
        Args:
            num_classes: Number of output classes
            pretrained: Whether to use pretrained weights from ImageNet
        """
        super(ResNet18Classifier, self).__init__()
        
        # Load pre-trained ResNet18 model
        self.model = models.resnet18(pretrained=pretrained)
        
        # Replace the final fully connected layer for our classification task
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, num_classes)
        
        # Initialize the new layer
        nn.init.normal_(self.model.fc.weight, 0, 0.01)
        nn.init.constant_(self.model.fc.bias, 0)
    
    def forward(self, x):
        return self.model(x)


In [None]:
# Create model:
def create_multiclass_model(num_classes):
    """
    Create a multi-class classifier model for specific skin conditions
    """
    model = ResNet18Classifier(num_classes=num_classes)
    return model

# 5. Training Function

Including mixed precision for faster training

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, device='cpu'):
    """
    Train the model with optimizations for speed
    """
    model.to(device)
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_val_acc = 0.0
    
    # Use mixed precision training if available
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch in train_loader:
            images = batch['image'].to(device)
            
            # Get appropriate labels based on the task
            if hasattr(model, 'model') and model.model.fc.out_features == 2:
                labels = batch['binary_label'].to(device)
            else:
                labels = batch['condition_idx'].to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass with mixed precision if available
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                
                # Backward and optimize with scaling
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard forward and backward pass
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
            
            # Statistics
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct / total
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in val_loader:
                images = batch['image'].to(device)
                
                # Get appropriate labels based on the task
                if hasattr(model, 'model') and model.model.fc.out_features == 2:
                    labels = batch['binary_label'].to(device)
                else:
                    labels = batch['condition_idx'].to(device)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / val_total
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'fitzpatrick_model_best.pth')
    
    return model, history


# 6. MODEL VALIDATION