# Initial EDA

In [None]:
# Import libraries
import os
import torch
import pandas as pd
# from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.io import read_image
from typing import List
from pneumonia_detector.preprocess import XrayDataset
import random

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")


from PIL import Image


In [None]:
# set training directories
train_dir = "../data/chest_xray/train/"
normal_train = "../data/chest_xray/train/NORMAL/"
pneumonia_train = "../data/chest_xray/train/PNEUMONIA/"

In [None]:
# Load in a random NORMAL and PNEUMONIA image
im_normal = Image.open(os.path.join(normal_train, random.choice(os.listdir(normal_train))))
print(im_normal.size)
im_pneumonia = Image.open(os.path.join(pneumonia_train, random.choice(os.listdir(pneumonia_train))))
print(im_pneumonia.size)

In [None]:

classes = ["NORMAL", "PNEUMONIA"]
plt.figure(figsize=(10, 8))
for i, x_ray in enumerate([im_normal, im_pneumonia]):

    plt.subplot(1, 2, i+1)
    plt.imshow(x_ray)
    plt.title(f"class: {classes[i]} with image size: {x_ray.size}")
    plt.axis("off")

The image sizes are different. Let's check the sizes of all images in the training set:

In [None]:
sizes = list()
for root, dirs, files in os.walk(train_dir):
    for filename in files:
        im = Image.open((os.path.join(root, filename)))
        sizes.append(im.size)

In [None]:
# Unique sizes
set(sizes)

There is lots of variation in image size. They all seem to be grayscale in the training set but there may be colour images in the dataset as well. We will need to manage how we loadin the images and what image size to use for the model.

In [None]:
class XrayDataset(Dataset):
    
    label_map = {"normal": 0, "pneumonia": 1}
    
    def __init__(self, root_dir, transform=None):
        """
        Arguments:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.files = [os.path.normcase(os.path.join(dp, f)) for dp, dn, filenames in os.walk(train_dir) for f in filenames]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_path = self.files[idx]
        # image = read_image(img_path).to(torch.float32)
        image = Image.open(img_path).convert("RGB")
        label = XrayDataset.label_map[img_path.split(os.sep)[-2].lower()]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label
        

In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

In [None]:
xray_train_data = XrayDataset(root_dir=train_dir, transform=train_transforms)
len(xray_train_data)

In [None]:
xray_train_data[0][0].dtype

In [None]:
# 1. Take in a Dataset as well as a list of class names
def display_random_images(dataset: torch.utils.data.dataset.Dataset,
                          classes: List[str] = None,
                          n: int = 3,
                          display_shape: bool = True,
                         ):
    

    # 4. Get random sample indexes
    random_samples_idx = random.sample(range(len(dataset)), k=n)

    # 5. Setup plot
    plt.figure(figsize=(16, 8))

    # 6. Loop through samples and display random samples 
    for i, targ_sample in enumerate(random_samples_idx):
        targ_image, targ_label = dataset[targ_sample][0], dataset[targ_sample][1]

        # 7. Adjust image tensor shape for plotting: [color_channels, height, width] -> [color_channels, height, width]
        targ_image_adjust = targ_image.permute(1, 2, 0)

        # Plot adjusted samples
        plt.subplot(1, n, i+1)
        plt.imshow(targ_image_adjust)
        plt.axis("off")
        if classes:
            title = f"class: {classes[targ_label]}"
            if display_shape:
                title = title + f"\nshape: {targ_image_adjust.shape}"
        plt.title(title)

In [None]:
# Display random images from  Dataset
display_random_images(xray_train_data, 
                      n=3, 
                      classes=["Normal", "Pneumonia"],
                     )

We will want to normalize the input images. In order to do that we can calculate the mean and standeard deviations values for the training dataset.

In [None]:
# Create a Dataloader object for the training Dataset.
train_dataloader_xray = DataLoader(
                                dataset=xray_train_data,
                                batch_size=32,
                                num_workers=0,
                                shuffle=False
                                )

In [None]:
# Function to calculate the mean and standard deviation for a set of images passed at the batch level
def batch_mean_and_sd(loader):
    
    cnt = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)

    for images, _ in loader:
        b, _, h, w = images.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(images, dim=[0, 2, 3])
        sum_of_square = torch.sum(images ** 2,
                                  dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
        cnt += nb_pixels

    mean, std = fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)        
    return mean,std
  
mean, std = batch_mean_and_sd(train_dataloader_xray)
print("mean and std: \n", mean, std)

# mean, std = mean_std(train_dataloader_xray)
# print("mean and std: \n", mean, std)