In [None]:
from datasets import load_dataset
from tqdm import tqdm
from datasets import Dataset
import torch 
import matplotlib.pyplot as plt 
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
import torch.nn as nn

#If you haven't downloaded already, this takes aroun  40GB and around half an hour to download all the data (note that we split here)
# it will be cached on your system, so you access it in the future by running this code (rather than opening the files per se)
# once you've downloaded, hugging face automatically checks to see if you've already downloaded so subsequent loads are quick 

# we use just a small portion here to get our code working 
# hugging face has already split in to train and test so we can use train here and make a validation subset in our custom dataset class 
ds_train = load_dataset("alkzar90/NIH-Chest-X-ray-dataset", 'image-classification', split = "train[:1000]") 

# we hold back our test data to be used purely for testing, not in the context of our training loop 
ds_test = load_dataset("alkzar90/NIH-Chest-X-ray-dataset", 'image-classification', split = "test[:1000]") 

# you can view a single image to check things have worked 
ds_train[0]['image']


In [None]:
# we have a hugging face dataset, so we now define a custom dataset class which processes this, ensures our labels are one hot encoded etc. 

class MultiLabelDataset(Dataset):
    """
    Initialize with a Hugging Face dataset that's already formatted as torch tensors. 
    Will convert to tensors if plain hugging face dataset. 
    Will handle the multi label nature of our data through one hot encoding 
    
    Args:
        hf_dataset: A Hugging Face dataset with 'image' and 'labels' columns
    """
      
    def __init__(self, hf_dataset, image_size):

        self.x_train, self.x_val, self.y_train, self.y_val = None, None, None, None
        self.mode = "train"

        hf_dataset = hf_dataset.with_format("torch")
        print(hf_dataset.format)

        self.processed_images = []
        self.processed_labels = []

        for sample in tqdm(hf_dataset, desc= "processing image files"):
            image = sample['image']

            # we resize our image if specified 
            
            image = F.interpolate(image.unsqueeze(0), size = image_size, mode = "bilinear").squeeze(0)


            if image.shape[0] == 4:
                image = torch.index_select(image, 0, torch.tensor([0]))

            # normalize pixel values 
            image = image/255

            if image.shape[0] == 1:
                image = image.repeat(3, 1, 1) # for convolutional networks we need 3 channels, remove this line and line below if wanting 1024, 1024 shape 

            # image = image.permute(1, 2, 0) # channel dimension needs to be the last one, not first 

            labels = sample['labels']
            one_hot = torch.zeros(15, dtype = torch.long)
            one_hot[labels] = 1
            self.processed_images.append(image)
            self.processed_labels.append(one_hot)
    
    def train_validation_split(self):
        """
        Takes our training data and produces a validation set from our training data
        Ensures that we don't use our hugging face defined test set during the training process 
        Means that we will assess trained models on a totally separate test set to avoid overfitting on test set
        Note that we use a subset of train as a validation set 
        """
        self.x_train, self.x_val, self.y_train, self.y_val = train_test_split(self.processed_images, self.processed_labels, test_size = 0.2,
                                                                              random_state = 42)
        

    def __len__(self):
        """ Returns the length of our training or validation set depending on mode """
        if self.mode == "train":
            return len(self.x_train)
        elif self.mode == "val":
            return len(self.x_val)
        elif self.mode == "test":
            return len(self.processed_images)
    
        

    # note we are not doing lazy processing, so our data is processed when the dataset is instantiated. 
    # here we will return a train test split of our hugging face training data, unless we're using the test data in which case we return it all 
    def __getitem__(self, idx):
        """Gets items from either the training or validation set depending on mode"""
        if self.mode == "train":
            return {"image": self.x_train[idx], "labels": self.y_train[idx]}
        elif self.mode == "val":
            return {"image": self.x_val[idx], "labels": self.y_val[idx]}
        elif self.mode == "test":
            return {"image": self.processed_images[idx], "labels": self.processed_labels[idx]}
    
        


In [None]:
# now we can produce our train and validation dataloaders which will be used in training later
training_dataset_class = MultiLabelDataset(ds_train, image_size = (128, 128)) # make sure to have the channel dimension
training_dataset_class.train_validation_split()

# set dataset class mode to train to generate a training split 
training_dataset_class.mode = "train"
print(f"the size of training data is: {len(training_dataset_class)}")
train_dataloader = DataLoader(training_dataset_class, batch_size = 4, shuffle = True)

# set dataset class mode to val to generate a validation split 

training_dataset_class.mode = "val"
print(f" the size of validation data is: {len(training_dataset_class)}")
val_dataloader = DataLoader(training_dataset_class, batch_size = 4, shuffle = True)

In [None]:
# we use the whole test data for our test dataset 

test_dataset_class = MultiLabelDataset(ds_test, image_size = (1024, 1024))
test_dataset_class.mode = "test"
test_dataloader = DataLoader(test_dataset_class, batch_size = 4, shuffle = True)

In [None]:
# visualise images from both to check works 

for s in train_dataloader:
    # If using default behavior, s['image'] should be a tensor of shape [batch_size, C, H, W]
    # Here, we select the first image in the batch.
    image = s['image'][0] # Remove any extra singleton dimensions and permute 
    label = s['labels'][0]
    print(image.shape, label)
    plt.title(label)
    plt.imshow(image.permute(1, 2, 0)) # note that  have to permute here to get 1024, 1024, 3 for visualisation; I believe for training on convnets, we shouldn't permute 
    plt.show()
    break

In [None]:
for s in val_dataloader:
    # If using default behavior, s['image'] should be a tensor of shape [batch_size, C, H, W]
    # Here, we select the first image in the batch.
    image = s['image'][0] # Remove any extra singleton dimensions and permute 
    label = s['labels'][0]
    print(image.shape, label)
    plt.title(label)
    plt.imshow(image.permute(1, 2, 0)) # note that  have to permute here to get 1024, 1024, 3 for visualisation; I believe for training on convnets, we shouldn't permute 
    plt.show()
    break

In [None]:
for s in test_dataloader:
    # If using default behavior, s['image'] should be a tensor of shape [batch_size, C, H, W]
    # Here, we select the first image in the batch.
    image = s['image'][0] # Remove any extra singleton dimensions and permute 
    label = s['labels'][0]
    print(image.shape, label)
    plt.title(label)
    plt.imshow(image.permute(1, 2, 0)) # note that  have to permute here to get 1024, 1024, 3 for visualisation; I believe for training on convnets, we shouldn't permute 
    plt.show()
    break