In [None]:
import numpy as np
import pandas as pd
import random
import torch
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import Dataset
from torchvision.io import read_image
from skimage import io, color
import os
from PIL import Image
import time, copy
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
import torchvision
import math
import zipfile

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
# zip files into specified place on drive
"""import glob

# Specify the folder containing .npy files and the destination ZIP file name
npy_folder_path = 'drive/MyDrive/Edinburgh/MLP/MLPcoursework4/image_arrays/'  # Update this path
zip_file_path = 'drive/MyDrive/Edinburgh/MLP/MLPcoursework4/image_arrays.zip'  # Update this path

# Find all .npy files within the folder
npy_files = glob.glob(os.path.join(npy_folder_path, '*.npy'))

batch = 0
batch_size = 50
num_batches = int(np.ceil(len(npy_files)/batch_size))
j = 0
for batch in range(num_batches):
    batch_files = npy_files[batch * batch_size:(batch * batch_size) + batch_size]
    if batch == 0:
        with zipfile.ZipFile(zip_file_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
            for file in batch_files:
                # Add file to zip, preserving its folder structure
                zipf.write(file, os.path.relpath(file, os.path.dirname(npy_folder_path)))
                if j % 10 == 0: print(j)
                j+=1
    else:
        with zipfile.ZipFile(zip_file_path, 'a', zipfile.ZIP_DEFLATED) as zipf:
            for file in batch_files:
                # Add file to zip, preserving its folder structure
                #print(file)
                zipf.write(file, os.path.relpath(file, os.path.dirname(npy_folder_path)))
                if j % 10 == 0: print(j)
                j+=1"""

In [None]:
# Path in your Google Drive
source_path = r'drive/MyDrive/Edinburgh/MLP/MLPcoursework4/image_arrays.zip'
# Destination path on the local VM disk
os.mkdir("/tmp/image_arrays")
arr_zip = zipfile.ZipFile(source_path, "r")
arr_zip.extractall("tmp/image_arrays")

In [None]:
df = pd.read_csv("drive/MyDrive/Edinburgh/MLP/MLPcoursework4/fitzpatrick17k_filtered.csv")
data_path = r"tmp/image_arrays/"

In [None]:
# useful variables
img_height = 256
img_width = 256
batch_size  = 32
n_channels  = 3
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
# split dataset
df = pd.read_csv("drive/MyDrive/Edinburgh/MLP/MLPcoursework4/fitzpatrick17k_filtered.csv")
train_df = df[df["validation"] == 0].copy()
train_df = train_df.reset_index(drop=True)
valid_df = df[df["validation"] == 1].copy()
valid_df = valid_df.reset_index(drop=True)
len_train = len(train_df)
len_valid = len(valid_df)
batches_per_valid_epoch = np.ceil(len_valid / batch_size)

In [None]:
# create dataset class
class CustomSkinDataset(Dataset):
    def __init__(self, img_dir, csv, transforms=None, colour_space="HSV"):
        self.img_dir = img_dir
        self.colour_space = colour_space
        self.csv = csv
        self.transforms = transforms
    def __len__(self):
        return len(self.csv)

    def __getitem__(self, idx):
        img = np.load(self.img_dir + self.csv.at[idx, 'md5hash'] + ".npy")
        if(len(img.shape) < 3):
            img = color.gray2rgb(img)
        if self.colour_space == "HSV":
            img = cv2.cvtColor(img, cv2.COLOR_RGB2Lab)
        if self.transforms:
            img = self.transforms(Image.fromarray(img))
        label = self.csv["fitzpatrick_scale"][idx] - 1
        return img[0].unsqueeze(0), img[1:], torch.tensor(label, dtype=torch.long)

In [None]:
# balanced batch sampler
class BalancedBatchSampler(torch.utils.data.sampler.Sampler):
    """
    A pytorch dataset sampler to obtain balanced batches.
    Implementation from
    https://github.com/galatolofederico/pytorch-balanced-batch
    """

    def __init__(self, dataset, labels=None):
        self.labels = labels
        self.dataset = dict()
        self.balanced_max = 0
        # Save all the indices for all the classes
        for idx in range(0, len(dataset)):
            label = self._get_label(dataset, idx)
            if label not in self.dataset:
                self.dataset[label] = list()
            self.dataset[label].append(idx)
            # keep track of number in class with most entries
            self.balanced_max = (
                len(self.dataset[label])
                if len(self.dataset[label]) > self.balanced_max
                else self.balanced_max
            )
        # Oversample the classes with fewer elements than the max, creates balanced classes
        for label in self.dataset:
            while len(self.dataset[label]) < self.balanced_max:
                self.dataset[label].append(random.choice(self.dataset[label]))
        self.keys = list(self.dataset.keys())
        self.currentkey = 0
        self.indices = [-1] * len(self.keys)

    def __iter__(self):
        i = 0
        while self.indices[self.currentkey] < self.balanced_max - 1:
            self.indices[self.currentkey] += 1
            yield self.dataset[self.keys[self.currentkey]][
                self.indices[self.currentkey]
            ]
            self.currentkey = (self.currentkey + 1) % len(self.keys)
            i += 1
        self.indices = [-1] * len(self.keys)

    def _get_label(self, dataset, idx):
        if self.labels is not None:
            return self.labels[idx]
        else:
            return self.csv["fitzpatrick_scale"][idx]

    def __len__(self):
        return self.balanced_max * len(self.keys)

In [None]:
class Inception_first_three_layers(nn.Module):
    def __init__(self, id):
        super().__init__()
        self.init_first_three_layers(id)

    def init_first_three_layers(self, id):
        if id == 0:
            num_in_channels = 192
        elif id == 1:
            num_in_channels = 256
        elif id == 2:
            num_in_channels = 288
        # Define convolutional layers
        self.conv_1x1 = nn.Conv2d(in_channels=num_in_channels, out_channels=64, kernel_size=1, stride=(1,1))
        self.bn1 = nn.BatchNorm2d(64)  # BatchNorm for conv_1x1
        self.conv_5x5 = nn.Conv2d(in_channels=num_in_channels, out_channels=48, kernel_size=1, stride=(1,1))
        self.bn2 = nn.BatchNorm2d(48)  # BatchNorm for conv_5x5
        self.conv_5x5_2 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=5, padding=2, stride=(1,1))
        self.bn3 = nn.BatchNorm2d(64)  # BatchNorm for conv_5x5_2
        self.conv3x3dbl = nn.Conv2d(in_channels=num_in_channels, out_channels=64, kernel_size=1, stride=(1,1))
        self.bn4 = nn.BatchNorm2d(64)  # BatchNorm for conv3x3dbl
        self.conv3x3dbl_2 = nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding=1, stride=(1,1))
        self.bn5 = nn.BatchNorm2d(96)  # BatchNorm for conv3x3dbl_2
        self.conv3x3dbl_3 = nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1, stride=(1,1))
        self.bn6 = nn.BatchNorm2d(96)  # BatchNorm for conv3x3dbl_3
        if id == 0:
            self.conv_pool = nn.Conv2d(in_channels=num_in_channels, out_channels=32, kernel_size=1, stride=(1,1))
        else:
            self.conv_pool = nn.Conv2d(in_channels=num_in_channels, out_channels=64, kernel_size=1, stride=(1,1))
        self.bn_pool = nn.BatchNorm2d(self.conv_pool.out_channels)  # BatchNorm for conv_pool

    def forward(self, x):
        # Apply convolutional layers with BatchNorm and ReLU
        x1 = F.relu(self.bn1(self.conv_1x1(x)))
        x5 = F.relu(self.bn2(self.conv_5x5(x)))
        x5 = F.relu(self.bn3(self.conv_5x5_2(x5)))
        xdbl = F.relu(self.bn4(self.conv3x3dbl(x)))
        xdbl = F.relu(self.bn5(self.conv3x3dbl_2(xdbl)))
        xdbl = F.relu(self.bn6(self.conv3x3dbl_3(xdbl)))
        xpool = F.relu(self.bn_pool(self.conv_pool(F.max_pool2d(x, kernel_size=3, stride=1, padding=1))))

        # Concatenate branches
        new_x = torch.cat((x1, x5, xdbl, xpool), dim=1)
        return new_x

class Inception_fourth_layer(nn.Module):
    def __init__(self):
        super(Inception_fourth_layer, self).__init__()
        self.conv_3x3 = nn.Conv2d(in_channels=288, out_channels=384, kernel_size=3, stride=(2,2), padding="valid")
        self.bn_3x3 = nn.BatchNorm2d(384)

        self.conv3x3dbl = nn.Conv2d(in_channels=288, out_channels=64, kernel_size=1, stride=(1,1))
        self.bn3x3dbl = nn.BatchNorm2d(64)

        self.conv3x3dbl_2 = nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=(1,1), padding="same")
        self.bn3x3dbl_2 = nn.BatchNorm2d(96)

        self.conv3x3dbl_3 = nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=(2,2), padding="valid")
        self.bn3x3dbl_3 = nn.BatchNorm2d(96)

    def forward(self, x):
        # Apply conv layers followed by batch normalization and ReLU
        x3 = F.relu(self.bn_3x3(self.conv_3x3(x)))

        x3dbl = F.relu(self.bn3x3dbl(self.conv3x3dbl(x)))
        x3dbl = F.relu(self.bn3x3dbl_2(self.conv3x3dbl_2(x3dbl)))
        x3dbl = F.relu(self.bn3x3dbl_3(self.conv3x3dbl_3(x3dbl)))

        # Branch pool
        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

        # Concatenate branches
        new_x = torch.cat([x3, x3dbl, branch_pool], dim=1)
        return new_x

class Inception_last_two_layers(nn.Module):
    def __init__(self, id):
        super(Inception_last_two_layers, self).__init__()
        self.id = id
        if id == 4:
            num_output_features = 128
        else:
            num_output_features = 160
        # Define convolutional layers and corresponding batch normalization layers
        self.conv_1x1 = nn.Conv2d(in_channels=768, out_channels=192, kernel_size=1, stride=(1,1))
        self.bn_1x1 = nn.BatchNorm2d(192)

        self.conv_7x7 = nn.Conv2d(in_channels=768, out_channels=num_output_features, kernel_size=1, stride=(1,1))
        self.bn_7x7 = nn.BatchNorm2d(num_output_features)

        self.conv_7x7_2 = nn.Conv2d(in_channels=num_output_features, out_channels=num_output_features, kernel_size=(1,7), stride=(1,1), padding="same")
        self.bn_7x7_2 = nn.BatchNorm2d(num_output_features)

        self.conv_7x7_3 = nn.Conv2d(in_channels=num_output_features, out_channels=192, kernel_size=(7,1), stride=(1,1), padding="same")
        self.bn_7x7_3 = nn.BatchNorm2d(192)

        self.branch7x7dbl = nn.Conv2d(in_channels=768, out_channels=num_output_features, kernel_size=1, stride=(1,1))
        self.bn_branch7x7dbl = nn.BatchNorm2d(num_output_features)

        self.branch7x7dbl_2 = nn.Conv2d(in_channels=num_output_features, out_channels=num_output_features, kernel_size=(7,1), stride=(1,1), padding="same")
        self.bn_branch7x7dbl_2 = nn.BatchNorm2d(num_output_features)

        self.branch7x7dbl_3 = nn.Conv2d(in_channels=num_output_features, out_channels=num_output_features, kernel_size=(1,7), stride=(1,1), padding="same")
        self.bn_branch7x7dbl_3 = nn.BatchNorm2d(num_output_features)

        self.branch7x7dbl_4 = nn.Conv2d(in_channels=num_output_features, out_channels=num_output_features, kernel_size=(7,1), stride=(1,1), padding="same")
        self.bn_branch7x7dbl_4 = nn.BatchNorm2d(num_output_features)

        self.branch7x7dbl_5 = nn.Conv2d(in_channels=num_output_features, out_channels=192, kernel_size=(1,7), stride=(1,1), padding="same")
        self.bn_branch7x7dbl_5 = nn.BatchNorm2d(192)

        self.conv_pool = nn.Conv2d(in_channels=768, out_channels=192, kernel_size=1, stride=(1,1))
        self.bn_conv_pool = nn.BatchNorm2d(192)

        self.avg_pool = nn.AvgPool2d(kernel_size=3, stride=(1,1), padding=1)

    def forward(self, x):
        if self.id == 4:
            num_output_features = 128
        else:
            num_output_features = 160

        # Apply convolutional layers with corresponding batch normalization and ReLU activation
        x1 = F.relu(self.bn_1x1(self.conv_1x1(x)))
        x7 = F.relu(self.bn_7x7(self.conv_7x7(x)))
        x7 = F.relu(self.bn_7x7_2(self.conv_7x7_2(x7)))
        x7 = F.relu(self.bn_7x7_3(self.conv_7x7_3(x7)))
        xpool = self.avg_pool(x)

        x7x7dbl = F.relu(self.bn_branch7x7dbl(self.branch7x7dbl(x)))
        x7x7dbl = F.relu(self.bn_branch7x7dbl_2(self.branch7x7dbl_2(x7x7dbl)))
        x7x7dbl = F.relu(self.bn_branch7x7dbl_3(self.branch7x7dbl_3(x7x7dbl)))
        x7x7dbl = F.relu(self.bn_branch7x7dbl_4(self.branch7x7dbl_4(x7x7dbl)))
        x7x7dbl = F.relu(self.bn_branch7x7dbl_5(self.branch7x7dbl_5(x7x7dbl)))

        xpool = F.relu(self.bn_conv_pool((self.conv_pool(xpool))))
        # Concatenate branches
        new_x = torch.cat((x1, x7, x7x7dbl, xpool), dim=1)
        return new_x


In [None]:
class OurModel(nn.Module):
    def __init__(self, x, two_paths=True, inception=True) -> None:
        super().__init__()
        self.two_paths = two_paths
        self.inception = inception
        if two_paths:
            # L image convolution path
            self.l_conv1 = nn.Conv2d(in_channels=1, out_channels=math.floor(x/2), kernel_size=3, stride=(2,2), padding='valid')
            self.bn_l_conv1 = nn.BatchNorm2d(math.floor(x/2))
            self.l_conv2 = nn.Conv2d(in_channels=math.floor(x/2), out_channels=math.floor(x/2), kernel_size=3, padding='valid')
            self.bn_l_conv2 = nn.BatchNorm2d(math.floor(x/2))
            self.l_conv3 = nn.Conv2d(in_channels=math.floor(x/2), out_channels=x, kernel_size=3, padding="same")
            self.bn_l_conv3 = nn.BatchNorm2d(x)
            self.l_max_pool = nn.MaxPool2d(kernel_size=3, stride=(2,2))

            # AB image convolution path
            self.ab_conv1 = nn.Conv2d(in_channels=2, out_channels=32-math.floor(x/2), kernel_size=3, stride=(2,2), padding='valid')
            self.bn_ab_conv1 = nn.BatchNorm2d(32-math.floor(x/2))
            self.ab_conv2 = nn.Conv2d(in_channels=32-math.floor(x/2), out_channels=32-math.floor(x/2), kernel_size=3, padding='valid')
            self.bn_ab_conv2 = nn.BatchNorm2d(32-math.floor(x/2))
            self.ab_conv3 = nn.Conv2d(in_channels=32-math.floor(x/2), out_channels=64-x, kernel_size=3, padding="same")
            self.bn_ab_conv3 = nn.BatchNorm2d(64-x)
            self.ab_max_pool = nn.MaxPool2d(kernel_size=3, stride=(2,2))

            # combined
            self.conv4 = nn.Conv2d(in_channels=64, out_channels=80, kernel_size=1, stride=(1,1), padding='same')

        else:
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=(2,2), padding='valid')
            self.bn_conv1 = nn.BatchNorm2d(64)
            self.max_pool = nn.MaxPool2d(kernel_size=3, stride=(2,2))
            self.conv4 = nn.Conv2d(in_channels=64, out_channels=80, kernel_size=1, stride=(1,1), padding='valid')

        # Combined path
        self.bn_conv4 = nn.BatchNorm2d(80)
        self.conv5 = nn.Conv2d(in_channels=80, out_channels=192, kernel_size=3, padding='valid')
        self.bn_conv5 = nn.BatchNorm2d(192)
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=(2,2))

        if inception:
            # Inception X6
            self.Inception1 = Inception_first_three_layers(0)
            self.Inception2 = Inception_first_three_layers(1)
            self.Inception3 = Inception_first_three_layers(2)
            self.Inception4 = Inception_fourth_layer()
            self.Inception5 = Inception_last_two_layers(4)
            self.Inception6 = Inception_last_two_layers(5)
        else:
            self.conv6 = nn.Conv2d(in_channels=192, out_channels=384, kernel_size=1, stride=(1,1), padding='same')
            self.bn_conv6 = nn.BatchNorm2d(384)
            self.conv7 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=1, stride=(1,1), padding='same')
            self.bn_conv7 = nn.BatchNorm2d(384)
            self.conv8 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=1, stride=(1,1), padding='same')
            self.bn_conv8 = nn.BatchNorm2d(384)
            self.conv9 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=1, stride=(1,1), padding='same')
            self.bn_conv9 = nn.BatchNorm2d(384)
            self.conv10 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=1, stride=(1,1), padding='same')
            self.bn_conv10 = nn.BatchNorm2d(384)
            self.conv11 = nn.Conv2d(in_channels=384, out_channels=512, kernel_size=1, stride=(1,1), padding='same')
            self.bn_conv11 = nn.BatchNorm2d(512)
            self.conv12 = nn.Conv2d(in_channels=512, out_channels=768, kernel_size=1, stride=(2,2), padding='valid')
            self.bn_conv12 = nn.BatchNorm2d(768)
            self.conv13 = nn.Conv2d(in_channels=768, out_channels=768, kernel_size=1, stride=(1,1), padding='same')
            self.bn_conv13 = nn.BatchNorm2d(768)
            self.conv14 = nn.Conv2d(in_channels=768, out_channels=768, kernel_size=1, stride=(1,1), padding='same')
            self.bn_conv14 = nn.BatchNorm2d(768)
            self.conv15 = nn.Conv2d(in_channels=768, out_channels=768, kernel_size=1, stride=(1,1), padding='same')
            self.bn_conv15 = nn.BatchNorm2d(768)
            self.conv16 = nn.Conv2d(in_channels=768, out_channels=768, kernel_size=1, stride=(1,1), padding='same')
            self.bn_conv16 = nn.BatchNorm2d(768)
            self.conv17 = nn.Conv2d(in_channels=768, out_channels=768, kernel_size=1, stride=(1,1), padding='same')
            self.bn_conv17 = nn.BatchNorm2d(768)
            self.conv18 = nn.Conv2d(in_channels=768, out_channels=768, kernel_size=1, stride=(1,1), padding='same')
            self.bn_conv18 = nn.BatchNorm2d(768)


        self.globalAvgPooling = nn.AdaptiveAvgPool2d((1,1))
        self.denseLayer = nn.Linear(768, 6)

    def forward(self, x_1, x_2):
        if self.two_paths:
            # L image convolution path
            x_1 = F.relu(self.bn_l_conv1(self.l_conv1(x_1)))
            x_1 = F.relu(self.bn_l_conv2(self.l_conv2(x_1)))
            x_1 = F.relu(self.bn_l_conv3(self.l_conv3(x_1)))
            x_1 = self.l_max_pool(x_1)

            # AB image convolution path
            x_2 = F.relu(self.bn_ab_conv1(self.ab_conv1(x_2)))
            x_2 = F.relu(self.bn_ab_conv2(self.ab_conv2(x_2)))
            x_2 = F.relu(self.bn_ab_conv3(self.ab_conv3(x_2)))
            x_2 = self.ab_max_pool(x_2)

            # Combined path
            x_combined = torch.cat((x_1, x_2), dim=1)
        else:
            x_combined = torch.cat((x_1, x_2), dim=1)
            x_combined = F.relu(self.bn_conv1(self.conv1(x_combined)))
            x_combined = self.max_pool(x_combined)

        x_combined = F.relu(self.bn_conv4(self.conv4(x_combined)))
        x_combined = F.relu(self.bn_conv5(self.conv5(x_combined)))
        x_combined = self.max_pool(x_combined)

        if self.inception:
            x_combined = self.Inception1(x_combined)
            x_combined = self.Inception2(x_combined)
            x_combined = self.Inception3(x_combined)

            x_combined = self.Inception4(x_combined)
            x_combined = self.Inception5(x_combined)
            x_combined = self.Inception6(x_combined)
        else:
            x_combined = F.relu(self.bn_conv6(self.conv6(x_combined)))
            x_combined = F.relu(self.bn_conv7(self.conv7(x_combined)))
            x_combined = F.relu(self.bn_conv8(self.conv8(x_combined)))
            x_combined = F.relu(self.bn_conv9(self.conv9(x_combined)))
            x_combined = F.relu(self.bn_conv10(self.conv10(x_combined)))
            x_combined = F.relu(self.bn_conv11(self.conv11(x_combined)))
            x_combined = F.relu(self.bn_conv12(self.conv12(x_combined)))
            x_combined = F.relu(self.bn_conv13(self.conv13(x_combined)))
            x_combined = F.relu(self.bn_conv14(self.conv14(x_combined)))
            x_combined = F.relu(self.bn_conv15(self.conv15(x_combined)))
            x_combined = F.relu(self.bn_conv16(self.conv16(x_combined)))
            x_combined = F.relu(self.bn_conv17(self.conv17(x_combined)))
            x_combined = F.relu(self.bn_conv18(self.conv18(x_combined)))

        x_combined = self.globalAvgPooling(x_combined)
        x_combined = self.denseLayer(x_combined.squeeze())
        return x_combined

In [None]:
def print_model_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params

    print(f"Total Parameters: {total_params}")
    print(f"Trainable Parameters: {trainable_params}")
    print(f"Non-Trainable Parameters: {non_trainable_params}")

In [None]:
train_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
valid_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_dataset = CustomSkinDataset(data_path, train_df, train_transforms, colour_space="HSV")
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, num_workers=2,
                                           sampler=BalancedBatchSampler(train_dataset, labels=np.array(train_df["fitzpatrick_scale"])),
                                           drop_last=False)
valid_dataset = CustomSkinDataset(data_path, valid_df, valid_transforms, colour_space="HSV")
valid_loader = torch.utils.data.DataLoader(valid_dataset, shuffle=False, batch_size=batch_size, pin_memory=True, num_workers=2, drop_last=False)

In [None]:
x = 11
our_model = OurModel(x, two_paths=True, inception=True)
our_model = our_model.to(device)
loss_function = nn.CrossEntropyLoss()
optimiser = optim.SGD(our_model.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimiser, step_size=10, gamma=0.1)
print_model_parameters(our_model)

In [None]:
for img_L, img_ab, label in train_loader:
    print(label + 1)
    plt.imshow(torchvision.utils.make_grid(img_L.unsqueeze(1).add(1).mul(0.5)).clamp(0,1).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap="gray")
    plt.show()
    break

In [None]:
def train(model, optim, train_loader, scheduler, loss_function, epoch, batch_size):
    model.train()
    total_loss = 0
    for img_l, img_ab, label in tqdm(train_loader, desc=f'Epoch {epoch}', total=len(train_loader)):
        img_l = img_l.to(device)
        img_ab = img_ab.to(device)
        label = label.to(device)

        optim.zero_grad()

        output = model(img_l, img_ab)
        loss = loss_function(output, label)

        loss.backward()
        optim.step()

        total_loss += loss.item()

    scheduler.step()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch: {epoch} ||| Training loss: {avg_loss}")

In [None]:
def validate(model, valid_loader, loss_function, epoch, valid_length, batch_size):
    model.eval()
    total_loss = 0
    num_correct, num_give1_correct = 0, 0
    for img_l, img_ab, label in tqdm(valid_loader, desc=f'Epoch {epoch}', total=len(valid_loader)):
        img_l = img_l.to(device)
        img_ab = img_ab.to(device)
        label = label.to(device)

        output = model(img_l, img_ab)
        preds = torch.argmax(output, dim=1)
        num_correct += (preds == label).sum()
        num_give1_correct += (torch.logical_or(torch.logical_or((preds == label), (preds == (label - 1))), (preds == (label + 1)))).sum()

        loss = loss_function(output, label)
        total_loss += loss.item()

    avg_loss = total_loss / len(valid_loader)
    acc = num_correct / valid_length
    acc_give1 = num_give1_correct / valid_length
    print(f"Validation loss: {avg_loss} ||| Validation accuracy: {acc} ||| Validation +-1 accuracy: {acc_give1}")
    return acc, acc_give1

In [None]:
epochs = 30
best_acc, _ = validate(our_model, valid_loader, loss_function, 0, len_valid, batch_size)
for epoch in range(epochs):
    train(our_model, optimiser, train_loader, scheduler, loss_function, epoch+1, batch_size)
    acc, acc_give1 = validate(our_model, valid_loader, loss_function, epoch+1, len_valid, batch_size)
    if acc > best_acc:
        torch.save({'model':our_model.state_dict(), 'optimiser':optimiser.state_dict(), 'scheduler': scheduler, 'acc': acc, 'acc_give1': acc_give1, 'epoch': epoch+1}, f"drive/MyDrive/Edinburgh/MLP/MLPcoursework4/models/ourModelAblationIncept.chkpt")
        best_acc = acc

In [None]:
loaded_model = OurModel(x, two_paths=True)
loaded_model.to(device)
checkpoint = torch.load("drive/MyDrive/Edinburgh/MLP/MLPcoursework4/models/ourModel_x-11.chkpt")
loaded_model.load_state_dict(checkpoint['model'])
_, _ = validate(loaded_model, valid_loader, loss_function, 23, len_valid, batch_size)

In [None]:
def bias_analysis(model, valid_loader):
    model.eval()
    skintone_counts = np.zeros(6)
    skintone_counts_correct = np.zeros(6)
    for img_l, img_ab, label in tqdm(valid_loader, desc=f'Epoch {epoch}', total=len(valid_loader)):
        img_l = img_l.to(device)
        img_ab = img_ab.to(device)
        label = label.to(device)

        output = model(img_l, img_ab)
        preds = torch.argmax(output, dim=1)

        label_arr = label.cpu().detach().numpy()
        entries, counts = np.unique(label_arr, return_counts=True)
        skintone_counts[entries] += counts
        pred_arr = preds.cpu().detach().numpy()
        ind_corr = np.where(label_arr == pred_arr)
        c_entries, c_counts = np.unique(label_arr[ind_corr], return_counts=True)
        skintone_counts_correct[c_entries] += c_counts
    print('\n')
    for i in range(6):
        acc = skintone_counts_correct[i] / skintone_counts[i]
        print(f"\nAccuracy for skintone {i}: {acc}    (Counts {skintone_counts[i]})")

In [None]:
bias_analysis(our_model, valid_loader)