In [None]:
import io
import os
import ast
import warnings
from tqdm import tqdm

import numpy as np
import pandas as pd
from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler


from sklearn.metrics import classification_report, precision_recall_fscore_support, accuracy_score, hamming_loss
from sklearn.preprocessing import MultiLabelBinarizer, OneHotEncoder, LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(42)
warnings.filterwarnings('ignore')

## Create a Pytorch Dataset Class for the SCIN Dataset

In [None]:
class SCINDataset(Dataset):
    def __init__(self):
        self.cases_csv = 'dataset/scin_cases.csv'
        self.labels_csv = 'dataset/scin_labels.csv'
        self.image_path_columns = ['image_1_path', 'image_2_path', 'image_3_path']
        self.weighted_skin_condition_label = "weighted_skin_condition_label"
        self.skin_condition_label = "dermatologist_skin_condition_on_label_name"

        # Read SCIN metadata
        cases_df = pd.read_csv(self.cases_csv, dtype={'case_id': str})
        cases_df['case_id'] = cases_df['case_id'].astype(str)
        labels_df = pd.read_csv(self.labels_csv, dtype={'case_id': str})
        labels_df['case_id'] = labels_df['case_id'].astype(str)

        # Merge different SCIN Metadata files on CASE ID
        cases_df = pd.merge(cases_df, labels_df, on='case_id')

        # Cleaning and Preprocessing
        # Convert Label column from string to dictionary 
        cases_df[self.weighted_skin_condition_label] = cases_df[self.weighted_skin_condition_label].apply(lambda x: ast.literal_eval(x))
        cases_df[self.skin_condition_label] = cases_df[self.skin_condition_label].apply(lambda x: ast.literal_eval(x))
        df = cases_df[['case_id', 'image_1_path', 'image_2_path', 'image_3_path', 'dermatologist_skin_condition_on_label_name', 'weighted_skin_condition_label']]
        
        # Convert wide format to long format
        df = pd.melt(df, id_vars=['case_id','dermatologist_skin_condition_on_label_name', 'weighted_skin_condition_label'], value_vars=['image_1_path', 'image_2_path', 'image_3_path'])
        # Drop extra column
        df.drop(['variable'], axis=1, inplace=True)
        # Drop rows where there is no image
        df = df[df['value'].notna()] 
        
        #Only keep labels with atleast 50 occurrences
        selected_label_counts = df[self.skin_condition_label].explode().dropna().value_counts()
        selected_labels = set(selected_label_counts[selected_label_counts>=50].keys())
        df[self.skin_condition_label] = df[self.skin_condition_label].apply(lambda x: list(set(x).intersection(selected_labels)))

        
        # Drop row where 1 image is not present
        df = df[df['value']!= "dataset/images/-2243186711511406658.png"]
        
        # Drop rows where there is no label
        self.df = df[df['dermatologist_skin_condition_on_label_name'].map(lambda d: len(d)) > 0].reset_index(drop=True) 
        self.df.rename(columns={"dermatologist_skin_condition_on_label_name": "labels", 
                                "weighted_skin_condition_label": "weighted_labels", 
                                "value": "paths"
                               }, inplace=True)

        # One Hot Encoding for Labels
        one_hot = MultiLabelBinarizer()
        self.encoded_labels = torch.from_numpy(one_hot.fit_transform(self.df["labels"])).to(torch.float32)
        self.all_labels = one_hot.classes_

        # Calculate Weights to handle Class Imbalance [N/N_i*N_labels]
        self.class_weights = self.df["labels"].explode().value_counts().apply(lambda x: round(len(self.df)/( len(self.all_labels) * x ) , 10)).tolist()

        # Set Transforms to increase diversity
        self.transform = transforms.Compose([
            transforms.Resize( (224, 224) , interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomInvert(p=0.4),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        ])
        
        self.paths = list(self.df["paths"])
        
    def __len__(self):
        '''
        Return the length of dataset
        '''
        return len(self.encoded_labels)

    def __getitem__(self, idx):
        '''
        Return a data sample
        '''
        path = self.paths[idx]
        if isinstance(path, str):
            # One file is missing from the dataset, so we handle that manually
            if path == "dataset/images/-2243186711511406658.png": pass
            image = Image.open(path)
            # Convert to 3 channels in case needed
            if len(np.array(image).shape) == 2:
                copied_images = [np.array(image).copy() for _ in range(3)]
                image = np.stack(copied_images, axis=-1)
                image = Image.fromarray(image)

            # Apply Transform as per ResNet Requirements
            image = self.transform(image)
        return path, (image, self.encoded_labels[idx])

## Split the dataset into Train/Validation/Test

In [None]:
scin_dataset = SCINDataset()
data_len = len(scin_dataset)
print("Total number of images: ", data_len)

train_dataset, val_dataset, test_dataset = random_split(scin_dataset,
                                                        (data_len-1000, 500, 500), generator=torch.Generator().manual_seed(42))
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=64, drop_last=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(dataset=val_dataset, shuffle=False, batch_size=64, num_workers=8, pin_memory=True)
test_loader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=64, num_workers=8, pin_memory=True)

## Training and Validation Performance

In [None]:
NUM_CLASSES = len(scin_dataset.all_labels)
EPOCHS = 50
LEARNING_RATE = 0.0001
THRESHOLD = 0.4

In [None]:
# Load ResNet50 model without the top layer
model = resnet50(weights=ResNet50_Weights.DEFAULT)

# Change final output to number of classes
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_features, NUM_CLASSES),
)

# Move model to GPU if available
model = model.to(device)

# Apply Weighted BCE Loss (Sigmoid included for numeric stability)
class_weights_tensor = torch.tensor(scin_dataset.class_weights, dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights_tensor)

# Define Optimizer
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE, weight_decay = 1e-8)

# Train and Validate
print(f"Training for {EPOCHS} Epochs")
train_losses, val_losses = [], []
# Training Loop
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    i = 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward + optimize
        loss.backward()
        optimizer.step()

        # train_losses.append(loss.item())
        running_loss += loss.item()
        i+=1
        
    train_loss = running_loss/len(train_loader)
    train_losses.append(train_loss)

    # Get validation loss and predictions on validation set
    model.eval()
    runninng_val_loss = 0.0
    y_true, y_pred = [], []
    with torch.no_grad():
        for val_inputs, val_labels in val_loader:
            val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
            val_outputs = model(val_inputs)
            val_loss = criterion(val_outputs, val_labels).item()
            runninng_val_loss += val_loss
            y_true.append(val_labels.cpu().numpy())
            y_pred.append(torch.sigmoid(val_outputs).cpu().numpy())
    runninng_val_loss /= len(val_loader)
    val_losses.append(runninng_val_loss)
    
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    
    y_pred = (y_pred > THRESHOLD).astype(int)

    # Get Performance metrics on validation set
    macro_score = precision_recall_fscore_support(y_true, y_pred, average="macro")
    micro_score = precision_recall_fscore_support(y_true, y_pred, average="micro")
    weighted_score = precision_recall_fscore_support(y_true, y_pred, average="weighted")
    
    print(f"Epoch [{epoch+1}/{EPOCHS}], Training Loss: {train_loss}, Validation Loss: {runninng_val_loss:.4f}, \n \n Accuracy: {accuracy} \n Macro Scores: {macro_score} \n Micro Scores: {micro_score} \n Weighted Scores: {weighted_score}\n ")

print("Finished Training")

## Visualize the Training and Validation Losses

In [None]:
# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()

## Perform inference on Test Set

In [None]:
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for idx, (test_inputs, test_labels) in enumerate(test_loader):
        print(idx)
        test_inputs, test_labels = test_inputs.to(device), test_labels.to(device)
        test_outputs = model(test_inputs)
        y_true.append(test_labels.cpu().numpy())
        y_pred.append(torch.sigmoid(test_outputs).cpu().numpy())
            
y_true = np.concatenate(y_true)
y_pred = np.concatenate(y_pred)

y_pred = (y_pred > THRESHOLD).astype(int)

macro_score = precision_recall_fscore_support(y_true, y_pred, average="macro")
micro_score = precision_recall_fscore_support(y_true, y_pred, average="micro")
weighted_score = precision_recall_fscore_support(y_true, y_pred, average="weighted")

print(f"Macro Scores: {macro_score} \n Micro Scores: {micro_score} \n Weighted Scores: {weighted_score}\n ")

## Save the model for future use

In [None]:
torch.save(model, 'model.bin')