# DETECTING COVID-19 with CHEST X RAY using PyTorch

image classification of Chest X Rays in one of three classes: normal, viral pneumonia, COVID-19

modified dataset from [COVID-19 Radiography Dataset](https://www.kaggle.com/tawsifurrahman/covid19-radiography-database) on kaggle

## Importing Libraries

In [2]:
%matplotlib inline

import os
import shutil
import random
import torch
import torchvision
import numpy as np

from PIL import Image
from matplotlib import pyplot as plt

torch.manual_seed(42)

print('PyTorch version', torch.__version__)

PyTorch version 2.3.1+cu121


## Preparing Training and Test Sets

In [3]:
root = 'covid-19-radiography-database'
source = ['COVID', 'normal', 'viral pneumonia']
classes = ['COVID', 'normal', 'viral pneumonia']


if os.path.isdir(os.path.join(root, source[1])):
    test = os.path.join(root, 'test')
    if not os.path.exists(test):
        os.mkdir(test)

    for i, d in enumerate(source):
        os.rename(os.path.join(root, d), os.path.join(root, classes[i]))

    for c in classes:
        classtest = os.path.join(test, c)
        if not os.path.exists(classtest):
            os.mkdir(classtest)
            
    for c in classes:
        images = [x for x in os.listdir(os.path.join(root, c)) if x.lower().endswith('png')]
        selected_images = random.sample(images, 30)
        for image in selected_images:
            source_path = os.path.join(root, c, image)
            target_path = os.path.join(root, 'test', c, image)
            shutil.move(source_path, target_path)

## Creating Custom Dataset

In [4]:
class ChestXRayDataset(torch.utils.data.Dataset):
    
    # c refers to a class name in classes
    
    def __init__(self, directories, transform):
        def getimages(c):
            images = [x for x in os.listdir(directories[c]) if x[-3:].lower().endswith('png')]
            print(f'found examples of {c}: {len(images)}')
            return images
        
        self.images = {}
        self.classes = ['covid', 'normal', 'viral pneumonia']
        
        for c in self.classes:
            self.images[c] = getimages(c)
            
        self.transform = transform
        self.directories = directories
            
    def __len__(self):
        return sum([len(self.images[c]) for c in self.classes])
    
    
    def __getitem__(self, index):
        c = random.choice(self.classes)
        index = index % len(self.images[c])
        image_name = self.images[c][index]
        image_path = os.path.join(self.image_dirs[c], image_name)
        image = Image.open(image_path).convert('RGB')
        return self.transform(image), self.classes.index(c)

## Image Transformations

In [5]:
transform_train = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size = (224, 244)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
])

In [6]:
transform_test = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size = (224, 244)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
])

## Prepare Dataloader

In [7]:
directory_train = {
    'covid': 'covid-19-radiography-database/covid',
    'normal': 'covid-19-radiography-database/normal',
    'viral pneumonia': 'covid-19-radiography-database/viral pneumonia'
}

train_dataset = ChestXRayDataset(directory_train, transform_train)

found examples of covid: 2926
found examples of normal: 9502
found examples of viral pneumonia: 655


In [8]:
directory_test = {
    'covid': 'covid-19-radiography-database/test/covid',
    'normal': 'covid-19-radiography-database/test/normal',
    'viral pneumonia': 'covid-19-radiography-database/test/viral pneumonia'
}

test_dataset = ChestXRayDataset(directory_test, transform_test)

found examples of covid: 690
found examples of normal: 690
found examples of viral pneumonia: 690


In [12]:
batch_size = 6

dataloader_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
dataloader_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

print('number of testing batches', len(dataloader_test))
print('number of training batches', len(dataloader_train))

number of testing batches 345
number of training batches 2181


## Data Visualization

In [None]:
classes = train_dataset.classes

def __showimages__(images, labels, predictions):
    plt.figure(figsize=(10,5))
    for i, image in enumerate(images):
        plt.subplot(1, 6, i + 1, xticks=[], yticks=[])
        image = image.numpy().transpose((1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = image * std + mean
        image = np.clip(image, 0, 1)
        plt.imshow(image)
        col = 'green'
        if predictions[i] != labels[i]:
            col = 'red'
        plt.xlabel(f'{classes[int(labels[i].numpy())]}')
        plt.ylabel(f'{classes[int(predictions[i].numpy())]}', color=col)
        plt.tight_layout()
    plt.show()
    