# Notebook for Data Analysis and Integrated Gradients in Pneumonia Prediction

The model in dictionary form can be downloaded from https://polybox.ethz.ch/index.php/s/j5BofTKcxnnx39t

The 10 images used are those that you can find here: https://polybox.ethz.ch/index.php/s/7eLrhKFe34UW3tW

**Load the Required Packages**

In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt

# append the filepath to where torch is installed
sys.path.append('/home/millerm/.local/lib/python3.10/site-packages')
# sys.path.append('/home/username/.local/lib/python3.10/site-packages')

import torch
import torchvision

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from torchinfo import summary
import torchvision.transforms as transforms
from torchvision.transforms import v2

We load the functions from pytorchcv. As you might experience complications importing the required pieces directly, we define the necessary functions separately below.

In [None]:
!wget https://raw.githubusercontent.com/MicrosoftDocs/pytorchfundamentals/main/computer-vision-pytorch/pytorchcv.py

In [None]:
from pytorchcv import train, plot_results, display_dataset, train_long

**Load the Model**

For this notebook, please refer to the model 20_model_state.pth. Unfortunately, we have been unable to load the full model in the student cluster such that we only provide the dictionary solution.

In [None]:
from torchvision.models import VGG16_Weights
model = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print('Doing computations on device = {}'.format(device))

model.to(device)

In [None]:
model.classifier = nn.Sequential(
    torch.nn.Linear(25088,4096),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5, inplace = False),
    torch.nn.Linear(4096,4096),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5, inplace = False),
    torch.nn.Linear(4096,4096),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5, inplace = False),
    torch.nn.Linear(4096,4096),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5, inplace = False),
    torch.nn.Linear(4096,2)
).to(device)

In [None]:
model.load_state_dict(torch.load('models/20_model_state.pth'))

**Transform and Visualize the Dataset**

In [None]:
trans_wo_norm = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

dataset0_wo_norm = torchvision.datasets.ImageFolder("ml4h_data/project1/chest_xray/train/", transform=trans_wo_norm)
dataset1_wo_norm = torchvision.datasets.ImageFolder("ml4h_data/project1/chest_xray/test/", transform=trans_wo_norm)
dataset2_wo_norm = torchvision.datasets.ImageFolder("ml4h_data/project1/chest_xray/val/", transform=trans_wo_norm)

We compute the mean and standard deviation using the following snippet. However, this will take a few minutes such that we have included the hard-coded valuesfor your convenience.

In [None]:
# mean0 = torch.zeros(3)
# std0 = torch.zeros(3)
# for img, _ in dataset0_wo_norm:
#     mean0 += img.mean(dim=(1, 2))
#     std0 += img.std(dim=(1, 2))

# mean0 /= len(dataset0_wo_norm)
# std0 /= len(dataset0_wo_norm)

# print("Mean:", mean0)
# print("Standard deviation:", std0)

In [None]:
mean0 = torch.tensor([0.5832, 0.5832, 0.5832])
std0  = torch.tensor([0.1413, 0.1413, 0.1413])
mean1 = torch.tensor([0.5763, 0.5763, 0.5763])
std1  = torch.tensor([0.1453, 0.1453, 0.1453])
mean2 = torch.tensor([0.6020, 0.6020, 0.6020])
std2  = torch.tensor([0.1401, 0.1401, 0.1401])

In [None]:
std_normalise_0 = transforms.Normalize(
    mean=mean0,
    std=std0
)
std_normalise_1 = transforms.Normalize(
    mean=mean1,
    std=std1
)
std_normalise_2 = transforms.Normalize(
    mean=mean2,
    std=std2
)

trans0 = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        std_normalise_0
])
trans1 = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor()
])
trans2 = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor()
])

random_trans = v2.RandomOrder([
        v2.GaussianBlur(3)
])

In [None]:
dataset_0 = torchvision.datasets.ImageFolder("ml4h_data/project1/chest_xray/train/", transform=trans0)
dataset_1 = torchvision.datasets.ImageFolder("ml4h_data/project1/chest_xray/test", transform=trans1)
dataset_2 = torchvision.datasets.ImageFolder("ml4h_data/project1/chest_xray/val", transform=trans2)

dataset_0 = random_trans(dataset_0)

In [None]:
def display_dataset(dataset, n=2,classes=('NORMAL','PNEUMONIA')):
    fig,ax = plt.subplots(1,n,figsize=(15,3))
    mn = min([dataset[i][0].min() for i in range(n)])
    mx = max([dataset[i][0].max() for i in range(n)])
    for i in range(n):
        ax[i].imshow(np.transpose((dataset[i][0]-mn)/(mx-mn),(1,2,0)))
        ax[i].axis('off')
        if classes:
            ax[i].set_title(classes[dataset[i][1]])

In [None]:
display_dataset(dataset_0)

**Train the Model**

Training takes a few hours. We have trained our model for 21 epochs to reach sensible results. Hence, you are advised to simply load the model as indicated above.

In [None]:
for param in model.features.parameters():
    param.requires_grad = True

In [None]:
def train_long(net,train_loader,test_loader,epochs=5,lr=0.001,optimizer=None,loss_fn = nn.NLLLoss(),print_freq=10):
    optimizer = optimizer or torch.optim.Adam(net.parameters(),lr=lr)
    for epoch in range(epochs):
        net.train()
        total_loss,acc,count = 0,0,0
        for i, (features,labels) in enumerate(train_loader):
            lbls = labels.to(default_device)
            optimizer.zero_grad()
            out = net(features.to(default_device))
            loss = loss_fn(out,lbls)
            loss.backward()
            optimizer.step()
            total_loss+=loss
            _,predicted = torch.max(out,1)
            acc+=(predicted==lbls).sum()
            count+=len(labels)
            if i%print_freq==0:
                print("Epoch {}, minibatch {}: train acc = {}, train loss = {}".format(epoch,i,acc.item()/count,total_loss.item()/count))
        vl,va = validate(net,test_loader,loss_fn)
        print("Epoch {} done, validation acc = {}, validation loss = {}".format(epoch,va,vl))

In [None]:
def validate(net, dataloader,loss_fn=nn.NLLLoss()):
    net.eval()
    count,acc,loss = 0,0,0
    with torch.no_grad():
        for features,labels in dataloader:
            lbls = labels.to(default_device)
            out = net(features.to(default_device))
            loss += loss_fn(out,lbls)
            pred = torch.max(out,1)[1]
            acc += (pred==lbls).sum()
            count += len(labels)
    return loss.item()/count, acc.item()/count

In [None]:
num_samples = 3500
torch.manual_seed(1234)
trainset, testset = torch.utils.data.random_split(dataset_0, [num_samples, len(dataset_0) - num_samples])
train_loader = torch.utils.data.DataLoader(trainset,batch_size=32)
test_loader  = torch.utils.data.DataLoader(testset,batch_size=32)

In [None]:
# default_device = device
# train_long(model,train_loader,test_loader,lr=0.00001,loss_fn=torch.nn.CrossEntropyLoss(),epochs=21,print_freq=15)

In [None]:
# torch.save(model.state_dict(), '20_model_state.pth')

**Evaluate the Model**

In [None]:
import os
from PIL import Image

In [None]:
def match_label_to_folder(image_path):
    parent_folder = os.path.basename(os.path.dirname(image_path))
    image_filename = os.path.splitext(os.path.basename(image_path))[0]
    label_name = parent_folder
    return label_name, image_filename

In [None]:
test_file = 'ml4h_data/project1/chest_xray/test/NORMAL/IM-0033-0001-0001.jpeg' # Visualize a test file
test_img = Image.open(test_file)

image_path = test_file
label_name, image_filename = match_label_to_folder(image_path)
print("Label Name:", label_name)
print("Image Filename:", image_filename)

In [None]:
predictions_NORMAL = []
data_folder = "ml4h_data/project1/chest_xray/test/NORMAL"

for filename in os.listdir(data_folder):
    if filename.endswith(".jpeg"):
        image_path = os.path.join(data_folder, filename)
        image = Image.open(image_path)
        transformed_img = trans1(image).to(device)
        input_img = std_normalise_1(transformed_img)
        input_img = input_img.unsqueeze(0)
        input_img.to(device)

        outputs = model(input_img)
        output = F.softmax(outputs, dim=1)
        prediction_score, pred_label_idx = torch.topk(output, 1)
        predicted_label = "PNEUMONIA" if pred_label_idx == 1 else "NORMAL"

        true_label = match_label_to_folder(image_path)[0]
        true_label_idx = 1 if true_label == "PNEUMONIA" else 0

        predictions_NORMAL.append(pred_label_idx.item())

predictions_PNEUMONIA = []
data_folder = "ml4h_data/project1/chest_xray/test/PNEUMONIA"

for filename in os.listdir(data_folder):
    if filename.endswith(".jpeg"):
        image_path = os.path.join(data_folder, filename)
        image = Image.open(image_path)
        transformed_img = trans1(image).to(device)
        input_img = std_normalise_1(transformed_img)
        input_img = input_img.unsqueeze(0)
        input_img.to(device)

        outputs = model(input_img)
        output = F.softmax(outputs, dim=1)
        prediction_score, pred_label_idx = torch.topk(output, 1)
        predicted_label = "PNEUMONIA" if pred_label_idx == 1 else "NORMAL"

        true_label = match_label_to_folder(image_path)[0]
        true_label_idx = 1 if true_label == "PNEUMONIA" else 0

        predictions_PNEUMONIA.append(pred_label_idx.item())

(sum(predictions_PNEUMONIA)+len(predictions_NORMAL)-sum(predictions_NORMAL))/(len(predictions_PNEUMONIA)+len(predictions_NORMAL))

**Run Integrated Gradients**

In [None]:
import captum
from captum.attr import IntegratedGradients
from captum.attr import visualization as viz
from matplotlib.colors import LinearSegmentedColormap

In [None]:
channels = 3
height = 224
width = 224
batch_size = 1

black_image = torch.zeros((batch_size, channels, height, width))
black_image = black_image.to(device)

white_image = torch.ones((batch_size, channels, height, width)) * 255
white_image = white_image.to(device)

pink_image = torch.zeros((batch_size, channels, height, width))
pink_image[:, 0, :, :] = 255
pink_image[:, 1, :, :] = 192
pink_image[:, 2, :, :] = 203
pink_image = pink_image.to(device)

noisy_pixels = torch.randint(0, 256, (batch_size, channels, height, width))
noisy_image = noisy_pixels.type(torch.FloatTensor)
noisy_image = noisy_image.to(device)

In [None]:
default_cmap = LinearSegmentedColormap.from_list('zurichblue',
                                                 [(0, '#ff8f4b'),
                                                  (0.5, '#ffffff'),
                                                  (1, '#0070b4')], N=256)

Create a folder in your directory called "images" such that we can print the Integrated Gradients visualizations.

In [None]:
import io

img_id = [6, 8, 9, 11, 12]
# img_id = [33,35,39,69,70] # for selecting the NORMAL patients, use this code to print the images
for id in img_id:
    test_file = f"img_for_saliency/PNEUMONIA/person1_virus_{id}.jpeg"
    # test_file = f"img_for_saliency/NORMAL/IM-00{id}-0001.jpeg"
    # for selecting the NORMAL patients, use this code to print the images
    test_img = Image.open(test_file)
    transformed_img = trans1(test_img).to(device)
    input_img = std_normalise_1(transformed_img)
    input_img = input_img.unsqueeze(0)
    input_img.to(device)
    output = model(input_img)
    output = F.softmax(output, dim=1)
    true_label = match_label_to_folder(test_file)[0]
    integrated_gradients = IntegratedGradients(model)
    attributions_ig = integrated_gradients.attribute(input_img,
                                                     target=true_label_idx,
                                                     n_steps=100,
                                                     baselines = black_image)
    filename = f"images/ig_{true_label}_{id}_perm.png"

    fig, _ = viz.visualize_image_attr(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      method='blended_heat_map',
                                      cmap=default_cmap,
                                      show_colorbar=True,
                                      sign='all',
                                      title='Integrated Gradients'
                                     )

    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    img = Image.open(buf)
    img.save(filename)

    print(f"Saved {filename}")