In [None]:
# implementation of fine-tuned ResNet-18 based on code from
# https://github.com/IBM/star-ed/blob/main/segmentation_and_skintone_classification.ipynb

In [None]:
import numpy as np
import pandas as pd
import random
import torch
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import Dataset
from torchvision.io import read_image
from skimage import io, color
import os
from PIL import Image
import time, copy
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
import torchvision

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
import zipfile
# Path in your Google Drive
source_path = r'drive/MyDrive/Edinburgh/MLP/MLPcoursework4/image_arrays.zip'
# Destination path on the local VM disk
os.mkdir("/tmp/image_arrays")
arr_zip = zipfile.ZipFile(source_path, "r")
arr_zip.extractall("tmp/image_arrays")

In [None]:
df = pd.read_csv("drive/MyDrive/Edinburgh/MLP/MLPcoursework4/fitzpatrick17k_filtered.csv")
data_path = r"tmp/image_arrays/"

In [None]:
# useful variables
img_height = 256
img_width = 256
batch_size  = 32
n_channels  = 3
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
# split dataset
df = pd.read_csv("drive/MyDrive/Edinburgh/MLP/MLPcoursework4/fitzpatrick17k_filtered.csv")
train_df = df[df["validation"] == 0].copy()
train_df = train_df.reset_index(drop=True)
valid_df = df[df["validation"] == 1].copy()
valid_df = valid_df.reset_index(drop=True)
len_train = len(train_df)
len_valid = len(valid_df)
batches_per_valid_epoch = np.ceil(len_valid / batch_size)

In [None]:
# segmentation test
class SkinDetector(object):
    """Simple skin segmentation for both with and without lesions"""

    def __init__(self, image):
        self.image = image #np.load(imageName) #cv2.imread(imageName)
        if self.image is None:
            print("Image Not Found")
            exit(1)
        self.HSV_image = cv2.cvtColor(self.image, cv2.COLOR_BGR2HSV)
        self.YCbCr_image = cv2.cvtColor(self.image, cv2.COLOR_BGR2YCR_CB)
        self.binary_mask_image = self.HSV_image

    def find_skin(self):
        """function to process the image and segment the skin using
        the HSV and YCbCr color spaces, followed by the Watershed algorithm"""
        self.color_segmentation()
        image_mask = self.region_based_segmentation()

        return image_mask

    def color_segmentation(self):
        """Apply a threshold to an HSV and YCbCr images,
        the used values were based on current research papers along with some
        empirical tests and visual evaluation"""
        lower_HSV_values = np.array([0, 40, 0], dtype="uint8")
        upper_HSV_values = np.array([25, 255, 255], dtype="uint8")

        lower_YCbCr_values = np.array((0, 138, 67), dtype="uint8")
        upper_YCbCr_values = np.array((255, 173, 133), dtype="uint8")

        mask_YCbCr = cv2.inRange(
            self.YCbCr_image, lower_YCbCr_values, upper_YCbCr_values
        )
        mask_HSV = cv2.inRange(self.HSV_image, lower_HSV_values, upper_HSV_values)
        self.binary_mask_image = cv2.add(mask_HSV, mask_YCbCr)

    def region_based_segmentation(self):
        """applies Watershed and morphological operations on the thresholded image"""

        image_foreground = cv2.erode(self.binary_mask_image, None, iterations=3)
        dilated_binary_image = cv2.dilate(self.binary_mask_image, None, iterations=3)
        ret, image_background = cv2.threshold(
            dilated_binary_image, 1, 128, cv2.THRESH_BINARY
        )
        image_marker = cv2.add(image_foreground, image_background)
        image_marker32 = np.int32(image_marker)

        cv2.watershed(self.image, image_marker32)
        m = cv2.convertScaleAbs(image_marker32)

        # bitwise of the mask with the input image
        ret, image_mask = cv2.threshold(m, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        output = cv2.bitwise_and(self.image, self.image, mask=image_mask)

        return image_mask, output

In [None]:
# create dataset class
class CustomSkinDataset(Dataset):
    def __init__(self, img_dir, csv, transforms=None, colour_space="HSV"):
        self.img_dir = img_dir
        self.colour_space = colour_space
        self.csv = csv
        self.transforms = transforms
    def __len__(self):
        return len(self.csv)

    def __getitem__(self, idx):
        img = np.load(self.img_dir + self.csv.at[idx, 'md5hash'] + ".npy")
        if(len(img.shape) < 3):
            img = color.gray2rgb(img)
        if self.colour_space == "HSV":
            img = cv2.cvtColor(img, cv2.COLOR_RGB2Lab)
        if self.transforms:
            img = self.transforms(Image.fromarray(img))
        label = self.csv["fitzpatrick_scale"][idx] - 1
        return img, torch.tensor(label, dtype=torch.long)

In [None]:
# balanced batch sampler
class BalancedBatchSampler(torch.utils.data.sampler.Sampler):
    """
    A pytorch dataset sampler to obtain balanced batches.
    Implementation from
    https://github.com/galatolofederico/pytorch-balanced-batch
    """

    def __init__(self, dataset, labels=None):
        self.labels = labels
        self.dataset = dict()
        self.balanced_max = 0
        # Save all the indices for all the classes
        for idx in range(0, len(dataset)):
            label = self._get_label(dataset, idx)
            if label not in self.dataset:
                self.dataset[label] = list()
            self.dataset[label].append(idx)
            # keep track of number in class with most entries
            self.balanced_max = (
                len(self.dataset[label])
                if len(self.dataset[label]) > self.balanced_max
                else self.balanced_max
            )
        # Oversample the classes with fewer elements than the max, creates balanced classes
        for label in self.dataset:
            while len(self.dataset[label]) < self.balanced_max:
                self.dataset[label].append(random.choice(self.dataset[label]))
        self.keys = list(self.dataset.keys())
        self.currentkey = 0
        self.indices = [-1] * len(self.keys)

    def __iter__(self):
        i = 0
        while self.indices[self.currentkey] < self.balanced_max - 1:
            self.indices[self.currentkey] += 1
            yield self.dataset[self.keys[self.currentkey]][
                self.indices[self.currentkey]
            ]
            self.currentkey = (self.currentkey + 1) % len(self.keys)
            i += 1
        self.indices = [-1] * len(self.keys)

    def _get_label(self, dataset, idx):
        if self.labels is not None:
            return self.labels[idx]
        else:
            return self.csv["fitzpatrick_scale"][idx]

    def __len__(self):
        return self.balanced_max * len(self.keys)

In [None]:
def print_model_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params

    print(f"Total Parameters: {total_params}")
    print(f"Trainable Parameters: {trainable_params}")
    print(f"Non-Trainable Parameters: {non_trainable_params}")

In [None]:
train_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
valid_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_dataset = CustomSkinDataset(data_path, train_df, train_transforms, colour_space="RGB")
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=2, sampler=BalancedBatchSampler(train_dataset, labels=np.array(train_df["fitzpatrick_scale"])))
valid_dataset = CustomSkinDataset(data_path, valid_df, valid_transforms, colour_space="RGB")
valid_loader = torch.utils.data.DataLoader(valid_dataset, shuffle=False, batch_size=batch_size, num_workers=2)

In [None]:
for img, label in train_loader:
    plt.imshow(torchvision.utils.make_grid(img.add(1).mul(0.5)).clamp(0,1).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.show()
    print(label + 1)
    break

In [None]:
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 6)
baseline_model = model_ft.to(device)

loss_function = nn.CrossEntropyLoss()

optimiser = optim.SGD(baseline_model.parameters(), lr=0.001, momentum=0.9)

exp_lr_scheduler = lr_scheduler.StepLR(optimiser, step_size=10, gamma=0.1)
print_model_parameters(model_ft)

In [None]:
def train(model, optim, scheduler, train_loader, loss_function, epoch, batch_size):
    model.train()
    total_loss = 0
    for img, label in tqdm(train_loader, desc=f'Epoch {epoch}', total=len(train_loader)):
        img = img.to(device)
        label = label.to(device)

        optim.zero_grad()

        output = model(img)
        loss = loss_function(output, label)

        loss.backward()
        optim.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    scheduler.step()
    print(f"Epoch: {epoch} ||| Training loss: {avg_loss}")

In [None]:
def validate(model, valid_loader, loss_function, epoch, valid_length, batch_size):
    model.eval()
    total_loss = 0
    num_correct, num_give1_correct = 0, 0
    for img, label in tqdm(valid_loader, desc=f'Epoch {epoch}', total=len(valid_loader)):
        img = img.to(device)
        label = label.to(device)

        output = model(img)
        preds = torch.argmax(output, dim=1)
        num_correct += (preds == label).sum()
        num_give1_correct += (torch.logical_or(torch.logical_or((preds == label), (preds == (label - 1))), (preds == (label + 1)))).sum()

        loss = loss_function(output, label)
        total_loss += loss.item()

    avg_loss = total_loss / len(valid_loader)
    acc = num_correct / valid_length
    acc_give1 = num_give1_correct / valid_length
    print(f"Validation loss: {avg_loss} ||| Validation accuracy: {acc} ||| Validation +-1 accuracy: {acc_give1}")
    return acc, acc_give1

In [None]:
epochs = 10
best_acc, _ = validate(baseline_model, valid_loader, loss_function, 0, len_valid, batch_size)
for epoch in range(epochs):
    train(baseline_model, optimiser, exp_lr_scheduler, train_loader, loss_function, epoch+1, batch_size)
    acc, acc_give1 = validate(baseline_model, valid_loader, loss_function, epoch+1, len_valid, batch_size)
    if acc > best_acc:
        torch.save({'model':baseline_model.state_dict(), 'optimiser':optimiser.state_dict(), 'scheduler':exp_lr_scheduler.state_dict(), 'acc': acc, 'acc_give1': acc_give1, 'epoch': epoch+1}, f"drive/MyDrive/Edinburgh/MLP/MLPcoursework4/models/baseline_epoch-{epoch+1}.chkpt")
        best_acc = acc

In [None]:
loaded_model = models.resnet18()
num_ftrs = model_ft.fc.in_features
loaded_model.fc = nn.Linear(num_ftrs, 6)
loaded_model.to(device)
checkpoint = torch.load("drive/MyDrive/Edinburgh/MLP/MLPcoursework4/models/baseline_epoch-9.chkpt")
loaded_model.load_state_dict(checkpoint['model'])
_, _ = validate(loaded_model, valid_loader, loss_function, epoch+1, len_valid, batch_size)

In [None]:
def bias_analysis(model, valid_loader):
    model.eval()
    skintone_counts = np.zeros(6)
    skintone_counts_correct = np.zeros(6)
    for img, label in tqdm(valid_loader, desc=f'Epoch {epoch}', total=len(valid_loader)):
        img = img.to(device)
        label = label.to(device)

        output = model(img)
        preds = torch.argmax(output, dim=1)

        label_arr = label.cpu().detach().numpy()
        entries, counts = np.unique(label_arr, return_counts=True)
        skintone_counts[entries] += counts
        pred_arr = preds.cpu().detach().numpy()
        ind_corr = np.where(label_arr == pred_arr)
        c_entries, c_counts = np.unique(label_arr[ind_corr], return_counts=True)
        skintone_counts_correct[c_entries] += c_counts
    print('\n')
    for i in range(6):
        acc = skintone_counts_correct[i] / skintone_counts[i]
        print(f"\nAccuracy for skintone {i+1}: {acc}    (Counts {skintone_counts[i]})")

In [None]:
bias_analysis(loaded_model, valid_loader)

In [None]:
# with masking
train_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
valid_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_dataset = CustomSkinDataset(data_path, train_df, train_transforms, mask=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=2, sampler=BalancedBatchSampler(train_dataset, labels=np.array(train_df["fitzpatrick_scale"])))
valid_dataset = CustomSkinDataset(data_path, valid_df, valid_transforms)
valid_loader = torch.utils.data.DataLoader(valid_dataset, shuffle=False, batch_size=batch_size, num_workers=2)

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 6)
baseline_model = model_ft.to(device)

loss_function = nn.CrossEntropyLoss()

optimiser = optim.SGD(baseline_model.parameters(), lr=0.001, momentum=0.9)

exp_lr_scheduler = lr_scheduler.StepLR(optimiser, step_size=10, gamma=0.1)

In [None]:
epochs = 10
best_acc, _ = validate(baseline_model, valid_loader, loss_function, 0, len_valid, batch_size)
for epoch in range(epochs):
    train(baseline_model, optimiser, exp_lr_scheduler, train_loader, loss_function, epoch+1, batch_size)
    acc, acc_give1 = validate(baseline_model, valid_loader, loss_function, epoch+1, len_valid, batch_size)
    if acc > best_acc:
        torch.save({'model':baseline_model.state_dict(), 'optimiser': optimiser.state_dict(), 'scheduler':exp_lr_scheduler.state_dict(), 'acc': acc, 'acc_give1': acc_give1, 'epoch': epoch+1}, f"drive/MyDrive/Edinburgh/MLP/MLPcoursework4/models/maskedBaseline_epoch-{epoch+1}.chkpt")
        best_acc = acc

In [None]:
for img, label in train_loader:
    plt.imshow(torchvision.utils.make_grid(img.add(1).mul(0.5)).clamp(0,1).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.show()
    print(label + 1)
    break