In [1]:
#!/usr/bin/env python
# coding: utf-8

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc

# Set seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# Define paths
DATA_PATH = "../data/processed/images"

In [None]:
# 1. Data Preparation

class SkinTypeDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, root_dir, transform=None, condition=None):
        self.data_frame = pd.read_csv(csv_file)
        
        # Filter by condition if specified
        if condition:
            self.data_frame = self.data_frame[self.data_frame['label'] == condition]
            
        self.root_dir = root_dir
        self.transform = transform
        
        # Convert three_partition_label to binary (malignant vs. non-malignant)
        self.data_frame['binary_label'] = self.data_frame['three_partition_label'].apply(
            lambda x: 1 if x == 'malignant' else 0
        )
        
        # Group skin types into light (1-3) and dark (4-6)
        self.data_frame['skin_group'] = self.data_frame['fitzpatrick_scale'].apply(
            lambda x: 'light' if x <= 3 else 'dark'
        )
    
    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, idx):
        img_path = self.data_frame.iloc[idx]['image_path']
        image = plt.imread(img_path)
        
        # Convert grayscale to RGB if needed
        if len(image.shape) == 2:
            image = np.stack((image,)*3, axis=-1)
        
        # Get labels
        binary_label = self.data_frame.iloc[idx]['binary_label']
        skin_type = self.data_frame.iloc[idx]['fitzpatrick_scale']
        skin_group = 0 if self.data_frame.iloc[idx]['skin_group'] == 'light' else 1
        
        if self.transform:
            image = self.transform(image)
        
        return image, binary_label, skin_type, skin_group

In [None]:
# WORKING PROGRESS LATER WITH MATTA