# Implementation of a tool using artificial intelligence to analyse chest X-rays for COVID-19

In [None]:
from __future__ import print_function, division

import numpy as np
import torch 
from torch import nn, optim
from torchvision import datasets, models, transforms, utils

import os
import time
import copy
import pickle

from matplotlib import pyplot as plt

In [None]:
EPOCHS = 101
BATCH_SIZE = 8
CLASSES = 2

DATA_ROOT = './data/test/'

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Data preprocessing
### Removing duplicates
Duplicates are removed using a perceptual hashing algorithm. The goal is to reveal number of unique images in the dataset

In [None]:
from PIL import Image
import glob
import imagehash

DATA_ROOT = './data/original/'

MAX_DIFFERENCE = 8
hashes = []
duplicate_groups = set()

In [None]:
images = glob.glob(f'{DATA_ROOT}*')
images = sorted(images)

for i, img_path in enumerate(images):
    img = Image.open(img_path)
    img_hash = imagehash.phash(img)
    hashes.append(img_hash)

In [None]:
for idx, hsh in enumerate(hashes):
    hsh_duplicates = tuple([img_path for h, img_path in zip(hashes, images) if hsh - h < MAX_DIFFERENCE])
    if len(hsh_duplicates) > 1:
        duplicate_groups.add(hsh_duplicates)

### Augmentation
Images in train set are randomly rotated, flipped and their brightness is modified, to increase the number of COVID-19 samples.

In [None]:
import Augmentor

DATA_TO_AUGMENT = './data/original/train/covid'

augmented_samples_count = 2000

In [None]:
p = Augmentor.Pipeline(DATA_TO_AUGMENT)

p.rotate(probability=0.8, max_left_rotation=15, max_right_rotation=15)
p.flip_left_right(probability=0.2)
p.random_brightness(0.6, 0.5, 1.2)
p.set_save_format(save_format="PNG")

p.sample(augmented_samples_count)

### Lung segmentation
Before feeding data to models, CXR images are segmented using trained VAE. The code is available in another notebook that can be found in utils direcotry.
### CLAHE
To get rid of differences in contrast and brightness in the analyzed images, CLAHE is applied to the images. After the modification, segmented lung masks are overlaid.

In [None]:
import cv2

MASKS_DIR = './data/masks/'
OUTPUT_DIR = './data/segmented/'

In [None]:
masks = glob.glob(f'{MASKS_DIR}*')
masks = sorted(masks)

In [None]:
clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8, 8))

In [None]:
for i, img_path, mask_path in enumerate(zip(images, masks)):
    output_path = f'{OUTPUT_DIR}{img_path.split('/')[-1].split('.')[0]}_preprocessed.png'
    img = cv2.imread(img_path)
    mask = cv2.imread(mask_path)

    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    cl1 = clahe.apply(img)
    
    result = cv2.bitwise_and(cl1, mask)
    cv2.imwrite(output_path, result)

### Normalization and standarization
Data is preprocessed with torchvision.transforms. Images in train set is randomly resized and cropped. All images are resized to the required size (224x224x3) and normalized. The process of computing the mean and standard deviation of the dataset can be found in utils directory.

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.1220, 0.1220, 0.1220], [0.2058, 0.2058, 0.2058])
    ]),
    'val': transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.1220, 0.1220, 0.1220], [0.2058, 0.2058, 0.2058])
    ]),
    'test': transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.1220, 0.1220, 0.1220], [0.2058, 0.2058, 0.2058])
    ]),
}

## Load data

Data is loaded into data set with ImageFolder from root directory that contains train, validation and data. All folders consist of class directories (covid and non-covid). 

Then the data is passed to DataLoader that shuffles data and provides samples in minibatches.

In [None]:
image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_ROOT, x),
                                          data_transforms[x])
                  for x in ['train', 'val', 'test']}
data_loaders = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                               batch_size=BATCH_SIZE,
                                               shuffle=True, 
                                               num_workers=4)
               for x in ['train', 'val', 'test']}

In [None]:
DATASET_SIZES = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
CLASS_NAMES = image_datasets['train'].classes
print(CLASS_NAMES)

### Data visualization
Functions to handle different types of image format

In [None]:
def numpy_to_img(arr):
    arr = arr.transpose((1, 2, 0))
    mean = np.array([0.1220, 0.1220, 0.1220])
    std = np.array([0.2058, 0.2058, 0.2058])
    arr = std * arr + mean
    img = np.clip(arr, 0, 1)
    img = (arr * 255).astype(np.uint8)
    return img

In [None]:
def tensor_to_img(tensor):
    arr = tensor.numpy()
    return numpy_to_img(arr)

In [None]:
def img_show(inp, title=None):
    if isinstance(inp, torch.Tensor):
        img = tensor_to_img(inp)
    elif isinstance(inp, np.ndarray):
        img = numpy_to_img(inp)
    else:
        img = inp
    
    plt.axis('off')
    plt.title(title)
    plt.imshow(img)
    plt.show()

Function to visualize data batch that is in tensor format

In [None]:
def display_batch(inp, title=None):
    inp = tensor_to_img(inp)
    
    plt.axis('off')
    plt.imshow(inp)
    if title is not None:
        plt.title(title)

NUMBER_OF_SAMPLES = 4

inputs, labels = next(iter(data_loaders['train']))
out = utils.make_grid(inputs[:NUMBER_OF_SAMPLES])
display_batch(out, title=[CLASS_NAMES[x] for x in labels[:NUMBER_OF_SAMPLES]])

## Transfer learning
There are two major transfer learning approaches. First one consists of initializing the network with pretrained weights instead of random initialization, and training it as usual. Second approach adapts fixed pretrained network as a feature extractor, where only the last fully connected layer is randomly initialized and trained.

## ResNet18 as feature extractor
Pretrained model of ResNet18 is loaded. To exploit it as feature extractor, parameters must be fixed. The requires_grad flag deactivates autograd engine and freezes the parameters so the memory is saved and trainig speeds up.

In [None]:
m_resnet18 = models.resnet18(pretrained=True)
for param in m_resnet18.parameters():
    param.requires_grad = False

### ResNet18 architecure

In [None]:
print(m_resnet18)

Next, the last fully connected layer is replaced with linear fully connected layer with 2 outputs for two-class classification. Then the model is allocated on the device (depending on CUDA availability).

In [None]:
num_ftrs = m_resnet18.fc.in_features
m_resnet18.fc = torch.nn.Linear(in_features=num_ftrs, out_features=CLASSES)

m_resnet18 = m_resnet18.to(device)

Additional Softmax layer might be applied to get the class probabilites at output

In [None]:
num_ftrs = m_resnet18.fc.in_features
m_resnet18.fc = nn.Sequential(
    nn.Linear(in_features=num_ftrs, out_features=CLASSES),
    nn.Softmax(dim=1)
)

m_resnet18 = m_resnet18.to(device)

Then total number of parameters and trainable parameters can be checked

In [None]:
pytorch_total_params = sum(p.numel() for p in m_resnet18.parameters())
pytorch_trainable_params = sum(p.numel() for p in m_resnet18.parameters() if p.requires_grad)

print(f'Total parameters: {pytorch_total_params}')
print(f'Trainable parameters: {pytorch_trainable_params}')

## Other models
### ResNet50

In [None]:
m_resnet50 = models.resnet50(pretrained=True)

for param in m_resnet50.parameters():
    param.requires_grad = False
    
num_ftrs = m_resnet50.fc.in_features
m_resnet50.fc = nn.Sequential(
    nn.Linear(in_features=num_ftrs, out_features=CLASSES),
    nn.Softmax(dim=1)
)

m_resnet50 = m_resnet50.to(device)

### SqueezeNet

In [None]:
m_squeezenet = models.squeezenet1_1(pretrained=True)

for param in m_squeezenet.parameters():
    param.requires_grad = False
    
m_squeezenet.classifier._modules["1"] = nn.Sequential(
    nn.Conv2d(512, CLASSES, kernel_size=(1, 1)),
    nn.Softmax(dim=1)
)
m_squeezenet.num_classes = CLASSES

m_squeezenet = m_squeezenet.to(device)

### DenseNet-121

In [None]:
m_densenet = models.densenet121(pretrained=True)

for param in m_densenet.parameters():
    param.requires_grad = False

num_ftrs = m_densenet.classifier.in_features
m_densenet.classifier = nn.Sequential(
    nn.Linear(in_features=num_ftrs, out_features=CLASSES),
    nn.Softmax(dim=1)
)

m_densenet = m_densenet.to(device)

# Hyperparameter tuning
## Learning rate finder
To find a range of values of learning rate that enable model to converge, the Leslie Smith's method is eployed.

In [None]:
from torch_lr_finder import LRFinder

model = m_resnet18
optimizer = optim.SGD(model.parameters(), lr=1e-7)
criterion = nn.CrossEntropyLoss()

lr_finder = LRFinder(model, optimizer, criterion, device=device)
lr_finder.range_test(data_loaders['train'], val_loader=data_loaders['val'], end_lr=10, num_iter=100)
lr_finder.plot(suggest_lr=False)

# Training
## Loss function
Chosen criterion is cross entropy loss

In [None]:
criterion = nn.CrossEntropyLoss()

Function to train and validate model and save the best parameters. Each epoch has a training and a validation phase. If the epoch accuracy is better than best accuracy, then the model is saved. The function returns the best model that came along during the training. 

In [None]:
def train_model(model, criterion, optimizer, scheduler=None, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    losses = {'train': [],
              'val': []}
    accs = {'train': [],
              'val': []}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in data_loaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train' and scheduler is not None:
                scheduler.step()

            epoch_loss = running_loss / DATASET_SIZES[phase]
            epoch_acc = running_corrects.double() / DATASET_SIZES[phase]
            
            losses[phase].append(epoch_loss)
            accs[phase].append(epoch_acc.item())

            print(f'{phase} Loss: {epoch_loss} Acc: {epoch_acc}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60}m {time_elapsed % 60}s')
    print(f'Best val Acc: {best_acc}')

    model.load_state_dict(best_model_wts)
    return model, losses, accs

Chosen model is trained for selected number of epochs. The implementation supports early stopping technique - model with the best validation accuracy is saved during the training and then returned.

In [None]:
model = m_resnet18
label = "ResNet18"
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = None

In [None]:
model, loss, acc= train_model(model,
                              criterion, 
                              optimizer,
                              scheduler,
                              num_epochs=EPOCHS)

torch.save(model, f'covid_{label}_epochs{EPOCHS}_{optimizer.__class__.__name__}.pt')

The training and validation loss and accuracy values are saved for evaluation

In [None]:
with open('rn18_loss', 'wb') as f:
    pickle.dump(loss, f)
    
with open('rn18_acc', 'wb') as f:
    pickle.dump(loss, f)

# Evaluation
Function that returns predictions and probabilities of predicted classes. 

In [None]:
def get_predictions(model, data_loader):
    model = model.to(device)
    model.eval()
#     sm = nn.Softmax(dim=1)
    
    ground_truths = []
    predictions = []
    probabilities = []
    class_probabilities = {0: [],
                           1: []}
    
    for inputs, labels in data_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
#             out_probs = sm(outputs)

            for prob, label in zip(out_probs, labels):
                label = label.item()
                probability = prob[label].item()
                
                ground_truths.append(label)
                class_probabilities[label].append(probability)
                probabilities.append(prob[0].item())
                predictions.append(prob.max(0, keepdim=True).indices.item())
        
    return np.asarray(ground_truths), np.asarray(predictions), np.array(probabilities), np.asarray(class_probabilities[0]), np.asarray(class_probabilities[1])


## Confusion matrix

In [None]:
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve, average_precision_score, PrecisionRecallDisplay
import seaborn as sn
import pandas as pd
from matplotlib import pyplot as plt, cycler, ticker

In [None]:
def cf_m(labels, predictions):
    confusion_matrix = np.zeros((2, 2))

    for lbl, pred in zip(labels, predictions):
        confusion_matrix[lbl, pred] += 1

    return confusion_matrix 

def plot_cf(cf):
    df_cm = pd.DataFrame(cf, 
                         index = [i for i in CLASS_NAMES],
                         columns = [i for i in CLASS_NAMES])

    ax = sn.heatmap(df_cm, annot=True, fmt='g')
    plt.title('Confusion matrix')
    plt.xlabel('predicted')
    plt.ylabel('labels')
    plt.figure(figsize = (7,5))

    plt.show()

In [None]:
def get_acc_sen_spe(confusion_matrix):
    TP = confusion_matrix[0, 0]
    FN = confusion_matrix[0, 1]
    FP = confusion_matrix[1, 0]
    TN = confusion_matrix[1, 1]
    
    accuracy = (TP + TN) / np.sum(confusion_matrix)
    sensitivity = TP / (TP + FN)
    specificity = TN / (FP + TN)
    
    return accuracy, sensitivity, specificity

### Plot loss and accuracy graphs

In [None]:
PLOT_SIZE_X = 10
PLOT_SIZE_Y = 5
PLOT_LEFT_POS = 0.1
PLOT_RIGHT_POS = 0.9
PLOT_BOTTOM_POS = 0.15
PLOT_TOP_POS = 0.85
PLOT_MARGIN = 0.01
PLOT_LW = 0.9
PLOT_GRID_LW = 0.2
PLOT_TICKS_Y_INTERVAL = 0.02

In [None]:
def plot_loss(arr, title='loss'):
    
    colors = cycler('color', ['orange', 'dodgerblue'])
    plt.rc('axes', prop_cycle=colors)

    fig, ax = plt.subplots(figsize=(PLOT_SIZE_X, PLOT_SIZE_Y))
    plt.subplots_adjust(left=PLOT_LEFT_POS, right=PLOT_RIGHT_POS, bottom=PLOT_BOTTOM_POS, top=PLOT_TOP_POS)
    plt.margins(x=PLOT_MARGIN)
    
    plt.plot(range(0, len(arr['train'])), arr['train'], label='Training', linewidth=PLOT_LW)
    plt.plot(range(0, len(arr['val'])), arr['val'], label='Validation', linewidth=PLOT_LW)

    plt.title(f"Training and validation {title}")
 
    plt.xlabel("epochs")
    plt.ylabel(title)

    legend = plt.legend(loc='upper right')
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_edgecolor('white')

    plt.grid(axis='y', lw=PLOT_GRID_LW)
    ax.xaxis.set_major_locator(ticker.MultipleLocator((len(arr['train'])) / 5))
    ax.xaxis.set_minor_locator(ticker.MultipleLocator((len(arr['train'])) / 50))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(PLOT_TICKS_Y_INTERVAL))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(PLOT_TICKS_Y_INTERVAL / 5))
    ax.set_xlim(xmin=0)
    ax.set_xlim(xmax=len(arr['train']) - 1)

    plt.show()

### Plot prediction histogram

In [None]:
def histogram(data, labels, bins=10, model_name=""):
    fig, ax = plt.subplots()
    plt.hist(data, label=labels, bins=bins, range=(0,1))
    
    plt.title(f"Predicted probabilities with {model_name}")
    plt.legend(loc='best')
    plt.grid(axis='y', lw=PLOT_GRID_LW)
    ax.xaxis.set_major_locator(ticker.MultipleLocator(0.1))
    plt.xlabel("probability")
    plt.ylabel("counts")
    
    plt.show()

### Plot ROC

In [None]:
def plot_roc(true_labels, preds, labels):
    TICKER = 0.1
    _, ax = plt.subplots(figsize=(7,6))

    for tl, p, lbl in zip(true_labels, preds, labels):
        auc = roc_auc_score(1 - tl, p)
        false_positive_rate, true_positive_rate, thresholds = roc_curve(1 - tl, p)

        plt.plot(false_positive_rate, 
                 true_positive_rate, 
                 label=f'{lbl},     AUC={auc:.4f}')
        
        gmeans = np.sqrt(true_positive_rate * (1 - false_positive_rate))
        ix = np.argmax(gmeans)
        print(f'Best Threshold={thresholds[ix]}, G-Mean={gmeans[ix]}')
    
    x = np.linspace(0, 1)
    plt.plot(x, x, linestyle='--', label='Baseline', color='silver')
    
    plt.ylim([0, 1.05])
    plt.xlim([-0.05, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title("ROC Curve")
    
    legend = plt.legend(loc='lower right')
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_edgecolor('white')
    plt.grid(axis='y', lw=0.2)
    
    ax.xaxis.set_major_locator(ticker.MultipleLocator(TICKER))
    ax.xaxis.set_minor_locator(ticker.MultipleLocator(TICKER / 5))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(TICKER))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(TICKER / 5))

    plt.show()

### Plot precision-recall curve

In [None]:
def plot_pre_rec(true_labels, preds, labels):
    _, ax = plt.subplots(figsize=(7,6))
        
    for tl, p, lbl in zip(true_labels, preds, labels):
        precision, recall, _ = precision_recall_curve(1 - tl, p)
        average_precision = average_precision_score(1 - tl, p)
        display = PrecisionRecallDisplay(recall=recall, precision=precision)
        display.plot(ax=ax, name=f'{lbl},     AP={average_precision:.4f}')
        baseline = np.sum(1 - tl) / len(tl)

    ax.plot([0, 1], [baseline, baseline], linestyle='--', label='Baseline', color='silver')
    
    plt.ylim([0, 1.05])
    plt.xlim([0, 1.05])
    plt.title("Precision-Recall curve")
    
    legend = plt.legend(loc='lower left')
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_edgecolor('white')
    plt.grid(axis='y', lw=0.2)
    
    plt.show()

In [None]:
def find_sens_spec(probabilities, labels, threshold):
    predictions = np.where(probabilities >= threshold, 0, 1)

    confusion_matrix = cf_m(labels, predictions)
    accuracy, sensitivity, specificity = get_acc_sen_spe(confusion_matrix)
    
    return accuracy, sensitivity, specificity

# Evaluation code
Prepare data and loaders

In [None]:
TEST_DATA_ROOT = './data/segmented/test'

test_dataset = datasets.ImageFolder(TEST_DATA_ROOT, data_transforms['test'])
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)
test_size = len(test_dataset)
CLASS_NAMES = test_dataset.classes

Load models from files

In [None]:
MODEL_PATH = './models/ResNet18_Adam_0.0001_100E.pt'
LABEL = 'ResNet18'

model = torch.load(MODEL_PATH, map_location=device)
target_layer = [model.layer4[-1]]

# target_layer = [model.features[-1]]

Make predictions and get class probabilites.

In [None]:
labels, predictions, probabilities, covid_probs, non_probs = get_predictions(model, test_loader)


Plot confusion matrix and evaluation metrics

In [None]:
confusion_matrix = cf_m(labels, predictions)
plot_cf(confusion_matrix)
accuracy, sensitivity, specificity = get_acc_sen_spe(confusion_matrix)

print(f'Accuracy: {accuracy:.4f}')
print(f'Sensitivity: {sensitivity:.4f}')
print(f'Specificity: {specificity:.4f}')

Plot histogram, ROC curve and precision-recall curves.

In [None]:
histogram([covid_probs, non_probs], ['probabilities of COVID cases', 'probabilities of non-COVID cases'])
plot_roc([labels], [probabilities], [LABEL])
plot_pre_rec([labels], [probabilities], [LABEL])

In [None]:
thresholds = [0.1, 0.15, 0.24, 0.3, 0.4, 0.5]

print('Threshold\tSensitivity\tSpecificity\tAccuracy')
for t in thresholds:
    accuracy, sensitivity, specificity = find_sens_spec(probabilities, labels, t)
    print(f'{t}\t\t{sensitivity:.4f}\t\t{specificity:.4f}\t\t{accuracy:.4f}')
    

Load loss and accuracy training and validation values and plot

In [None]:
with open('rn18_loss', 'rb') as loss_file:
    losses = pickle.load(loss_file)
    
with open('rn18_acc', 'rb') as acc_file:
    accs = pickle.load(acc_file)

In [None]:
plot(loss)
plot(acc, title="accuracy")

## GradCAM
Enable gradient engine for all model parameters.

In [None]:
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
import cv2

In [None]:
for param in model.parameters():
    param.requires_grad = True

In [None]:
def grad_cam(img, model, target_layer, original=None):
    
    cam = GradCAM(model=model, target_layers=target_layer, use_cuda=False)
    
    target_category = None
    grayscale_cam = cam(input_tensor=img, target_category=target_category, eigen_smooth=True)
    grayscale_cam = grayscale_cam[0, :]
    
    if original is not None:
        rgb_img = np.float32(original) / 255
    else:
        rgb_img = np.float32(tensor_to_img(img[0])) / 255
        
    cam_image = show_cam_on_image(rgb_img, grayscale_cam)
    
    return cv2.cvtColor(cam_image, cv2.COLOR_BGR2RGB)

In [None]:
def show_heatmap(heatmap):
    plt.imshow(heatmap)
    plt.axis('off')
    plt.show()

Load a batch of data and plot the image with ground truth label and model prediction

In [None]:
imgs, labels = next(iter(test_loader))

imgs = imgs.to(device)
labels = labels.to(device)
img = imgs[0]
label = CLASS_NAMES[labels[0]]

pred = model(imgs)
pred_label = CLASS_NAMES[pred[0].argmax(dim=0).item()]

img_show(img.cpu(), f'L: {label} | P: {pred_label}')

Generate a heatmap with GradCAM and plot it

In [None]:
heatmap = grad_cam(imgs.cpu(), model.cpu(), target_layer)
show_heatmap(heatmap)