In [1]:
#!/usr/bin/env python
# coding: utf-8

# 1. Setting Up

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import models, transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve
from sklearn.preprocessing import label_binarize
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Set seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define paths
DATA_PATH = "../data/processed/images"


In [3]:
# Set device to CPI
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 2. DataSet

In [11]:
class FitzpatrickDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None, target_condition=None):
        """
        Args:
            csv_file: Path to the CSV file with annotations
            root_dir: Directory with all the images
            transform: Optional transform to be applied on a sample
            target_condition: If specified, only include this skin condition or list of conditions
        """
        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        
        # Verify image paths exist
        self.data_frame['image_exists'] = self.data_frame['image_path'].apply(os.path.exists)
        self.data_frame = self.data_frame[self.data_frame['image_exists']]
        self.data_frame = self.data_frame.drop('image_exists', axis=1)
        
        # Filter by condition if specified
        if target_condition is not None:
            if isinstance(target_condition, list):
                # Filter for multiple conditions
                self.data_frame = self.data_frame[self.data_frame['label'].isin(target_condition)]
            else:
                # Filter for a single condition
                self.data_frame = self.data_frame[self.data_frame['label'] == target_condition]
        
        # Convert three_partition_label to binary (malignant vs. non-malignant)
        self.data_frame['binary_label'] = self.data_frame['three_partition_label'].apply(
            lambda x: 1 if x == 'malignant' else 0
        )
        
        # Group skin types into light (1-3) and dark (4-6)
        self.data_frame['skin_group'] = self.data_frame['fitzpatrick_scale'].apply(
            lambda x: 0 if x <= 3 else 1  # 0 for light, 1 for dark
        )
        
        # Create a mapping for unique conditions
        self.unique_conditions = self.data_frame['label'].unique()
        self.condition_to_idx = {condition: idx for idx, condition in enumerate(self.unique_conditions)}
        
        # Add multi-class label
        self.data_frame['condition_idx'] = self.data_frame['label'].apply(lambda x: self.condition_to_idx[x])
    
    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        img_path = self.data_frame.iloc[idx]['image_path']
        
        try:
            # Use PIL to load image
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a placeholder image if loading fails
            image = Image.new('RGB', (224, 224), color='gray')
        
        # Get labels
        binary_label = float(self.data_frame.iloc[idx]['binary_label'])  # Convert to float
        skin_type = int(self.data_frame.iloc[idx]['fitzpatrick_scale'])
        skin_group = int(self.data_frame.iloc[idx]['skin_group'])
        condition_idx = int(self.data_frame.iloc[idx]['condition_idx'])
        
        if self.transform:
            image = self.transform(image)
        
        # Return a simple tuple instead of a dictionary
        return image, binary_label


# 3. Load Data

In [14]:
def load_data(batch_size=32, target_condition=None, max_samples=None):
    """
    Load and prepare the datasets with optimization for speed
    """
    # Define transforms with smaller image size
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # Smaller size for faster training
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet normalization
    ])
    
    # Create datasets
    train_dataset = FitzpatrickDataset(
        csv_file=os.path.join(DATA_PATH, 'train_split.csv'),
        root_dir=DATA_PATH,
        transform=transform,
        target_condition=target_condition
    )
    
    val_dataset = FitzpatrickDataset(
        csv_file=os.path.join(DATA_PATH, 'val_split.csv'),
        root_dir=DATA_PATH,
        transform=transform,
        target_condition=target_condition
    )
    
    test_dataset = FitzpatrickDataset(
        csv_file=os.path.join(DATA_PATH, 'test_split.csv'),
        root_dir=DATA_PATH,
        transform=transform,
        target_condition=target_condition
    )
    
    # Optionally limit the number of samples
    if max_samples is not None:
        # Simple random sampling for testing
        if len(train_dataset) > max_samples:
            indices = torch.randperm(len(train_dataset))[:max_samples]
            train_dataset = torch.utils.data.Subset(train_dataset, indices)
        if len(val_dataset) > max_samples // 5:
            indices = torch.randperm(len(val_dataset))[:max_samples // 5]
            val_dataset = torch.utils.data.Subset(val_dataset, indices)
        if len(test_dataset) > max_samples // 5:
            indices = torch.randperm(len(test_dataset))[:max_samples // 5]
            test_dataset = torch.utils.data.Subset(test_dataset, indices)
    
    # Create dataloaders with fewer workers and persistent_workers=False
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=0,  # Use 0 workers to debug
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=0,  # Use 0 workers to debug
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=0,  # Use 0 workers to debug
        pin_memory=True
    )
    
    # Get unique conditions from the original dataset (before subsetting)
    if hasattr(train_dataset, 'unique_conditions'):
        unique_conditions = train_dataset.unique_conditions
    elif hasattr(train_dataset, 'dataset') and hasattr(train_dataset.dataset, 'unique_conditions'):
        unique_conditions = train_dataset.dataset.unique_conditions
    else:
        # If we can't get unique conditions directly, try to extract from the dataframe
        try:
            if hasattr(train_dataset, 'data_frame'):
                unique_conditions = train_dataset.data_frame['label'].unique()
            elif hasattr(train_dataset, 'dataset') and hasattr(train_dataset.dataset, 'data_frame'):
                unique_conditions = train_dataset.dataset.data_frame['label'].unique()
            else:
                unique_conditions = []
        except:
            unique_conditions = []
    
    return train_loader, val_loader, test_loader, unique_conditions


In [15]:
# Load a smaller dataset with focus on key conditions
target_conditions = ['psoriasis', 'squamous_cell_carcinoma', 'lichen_planus']
train_loader, val_loader, test_loader, unique_conditions = load_data(
    batch_size=32,
    target_condition=target_conditions,
    max_samples=1000  # Limit total samples
)

# 4. Model Architecture

Using ResNet 18 instead of VGG16 as paper for faster training

In [16]:
class ResNet18Classifier(nn.Module):
    def __init__(self, num_classes=2, pretrained=True):
        """
        ResNet18-based classifier that can be used for binary or multi-class classification
        
        Args:
            num_classes: Number of output classes
            pretrained: Whether to use pretrained weights from ImageNet
        """
        super(ResNet18Classifier, self).__init__()
        
        # Load pre-trained ResNet18 model
        self.model = models.resnet18(pretrained=pretrained)
        
        # Replace the final fully connected layer for our classification task
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, num_classes)
        
        # Initialize the new layer
        nn.init.normal_(self.model.fc.weight, 0, 0.01)
        nn.init.constant_(self.model.fc.bias, 0)
    
    def forward(self, x):
        return self.model(x)


In [17]:
# Create model function
def create_multiclass_model(num_classes):
    """
    Create a multi-class classifier model for specific skin conditions
    """
    model = ResNet18Classifier(num_classes=num_classes)
    return model

# Initialize model
num_classes = len(unique_conditions)
model = create_multiclass_model(num_classes)

# 5. Training Function

Including mixed precision for faster training

In [18]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, device='cpu'):
    """
    Train the model
    """
    # Initialize lists to track metrics
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    # Training loop
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        total = 0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.float().to(device).view(-1, 1)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                # Backward pass and optimize
                loss.backward()
                optimizer.step()
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            preds = (outputs > 0.5).float()
            running_corrects += torch.sum(preds == labels).item()
            total += labels.size(0)
        
        epoch_loss = running_loss / total
        epoch_acc = running_corrects / total
        
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc)
        
        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
        # Validation phase
        model.eval()
        running_loss = 0.0
        running_corrects = 0
        total = 0
        
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.float().to(device).view(-1, 1)
            
            # Forward pass
            with torch.no_grad():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            preds = (outputs > 0.5).float()
            running_corrects += torch.sum(preds == labels).item()
            total += labels.size(0)
        
        epoch_loss = running_loss / total
        epoch_acc = running_corrects / total
        
        history['val_loss'].append(epoch_loss)
        history['val_acc'].append(epoch_acc)
        
        print(f'Val Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
    
    return model, history


In [19]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Train model
num_epochs = 5
model, history = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=num_epochs, device=device)

Epoch 1/5
----------
Train Loss: 0.0000 Acc: 0.9714
Val Loss: 0.0000 Acc: 0.7895
Epoch 2/5
----------
Train Loss: 0.0000 Acc: 0.9714
Val Loss: 0.0000 Acc: 0.8421
Epoch 3/5
----------
Train Loss: 0.0000 Acc: 0.9857
Val Loss: 0.0000 Acc: 0.8947
Epoch 4/5
----------
Train Loss: 0.0000 Acc: 0.9714
Val Loss: 0.0000 Acc: 0.8947
Epoch 5/5
----------
Train Loss: 0.0000 Acc: 0.9714
Val Loss: 0.0000 Acc: 0.9474


# 6. MODEL VALIDATION

In [None]:
def validate_model_across_skin_types(model, test_loader, device='cpu'):
    """
    Evaluate model performance across different Fitzpatrick skin types
    """
    model.eval()
    all_preds = []
    all_labels = []
    all_skin_types = []
    
    with torch.no_grad():
        for images, labels, skin_types, _, _ in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_skin_types.extend(skin_types.numpy())
    
    # Convert to numpy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_skin_types = np.array(all_skin_types)
    
    # Calculate overall accuracy
    overall_accuracy = np.mean(all_preds == all_labels)
    print(f"Overall Accuracy: {overall_accuracy:.4f}")
    
    # Calculate accuracy by skin type
    print("\nAccuracy by Fitzpatrick Skin Type:")
    skin_type_metrics = {}
    for skin_type in range(1, 7):  # Fitzpatrick types 1-6
        mask = all_skin_types == skin_type
        if np.sum(mask) > 0:  # Only calculate if we have samples of this skin type
            skin_type_acc = np.mean(all_preds[mask] == all_labels[mask])
            skin_type_metrics[skin_type] = skin_type_acc
            print(f"  Type {skin_type}: {skin_type_acc:.4f} (n={np.sum(mask)})")
    
    # Calculate accuracy for light vs dark skin
    light_mask = all_skin_types <= 3
    dark_mask = all_skin_types >= 4
    light_acc = np.mean(all_preds[light_mask] == all_labels[light_mask])
    dark_acc = np.mean(all_preds[dark_mask] == all_labels[dark_mask])
    
    print("\nAccuracy by Skin Group:")
    print(f"  Light Skin (Types 1-3): {light_acc:.4f} (n={np.sum(light_mask)})")
    print(f"  Dark Skin (Types 4-6): {dark_acc:.4f} (n={np.sum(dark_mask)})")
    
    return {
        'overall_accuracy': overall_accuracy,
        'skin_type_metrics': skin_type_metrics,
        'light_skin_accuracy': light_acc,
        'dark_skin_accuracy': dark_acc
    }
