In [3]:
#importing the required libraries
import os
import shutil
import random
import torch
import torchvision
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

Collecting matplotlib
  Downloading matplotlib-3.3.3-cp38-cp38-manylinux1_x86_64.whl (11.6 MB)
[K     |████████████████████████████████| 11.6 MB 472 kB/s eta 0:00:01
Collecting kiwisolver>=1.0.1
  Downloading kiwisolver-1.3.1-cp38-cp38-manylinux1_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 512 kB/s eta 0:00:01
[?25hCollecting cycler>=0.10
  Using cached cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)
Installing collected packages: kiwisolver, cycler, matplotlib
Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.3.3


# Defining our data paths

In [5]:
root_dir = "./data/"
source_dirs = ["NORMAL", "Viral Pnemuonia", "COVID-19"]

In [6]:
#creating a class that will fetch and preprocess our image data
#this class inherits from the torch.utils.Dataset

class ChestXrayDataset(torch.utils.data.Dataset):
    def __init__(self,image_dirs,transform):
        #method to get all our images from all the source folder 
        def get_images(class_name):
            images = [x for x in os.listdir(image_dirs[class_name]) if x[-3:].lower().endswith('.png')]
            #display the number of images in the given class
            print(f'Found {len(images)} {class_name} examples')
            return images
        
        #Dictionarry that will store our class categories an images for each class
        self.images = {}
        
        #defining a list with our class names
        self.class_names = ['normal', 'viral', 'covid']
        
        #storing our images in a dictionary depending on their class
        for class_name in self.class_names:
            self.images[class_name] = get_images(class_name)
            
        self.image_dirs = image_dirs
        self.transform = transform
        
    def __len__(self):
        #method to count the number of images in each class
        return sum([len(self.images[class_name]) for class_name in self.class_names])
    
    def __getitem__(self,index):
        #method to display a random image from our dataset
        class_name = random.choice(self.class_names)
        index = index %len(self.images[class_name])
        image_name = self.images[class_name][index]
        image_path = os.path.join(self.image_dirs[class_name],image_name)
        image = Image.open(image_path).convert('RGB')
        return self.transform(image), self.class_names.index(class_name)

In [7]:
#defining our transform objects
train_transform = torchvision.transforms.Compose([torchvision.transforms.Resize(size = (300,300)),
                                                 torchvision.transforms.RandomHorizontalFlip(),
                                                 torchvision.transforms.ToTensor(),
                                                 torchvision.transforms.Normalize(mean = [0.485,0.456,0.406], std = [0.229, 0.224, 0.225])])

test_transform = torchvision.transforms.Compose([torchvision.transforms.Resize(size = (300,300)),
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize(mean = [0.485,0.456,0.406], std = [0.229, 0.224, 0.225])])