In [None]:
"""
# This python file for inferencing the EDANet model.

@ Author: Md Mostafa Kamal Sarker
@ email: m.kamal.sarker@gmail.com
@ Date: 17.05.2020

"""

In [None]:
import os
## import pytorch library
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torchnet.meter as meter
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
##  model
from edanet import EDANet
## other libraries
import itertools
import matplotlib.pyplot as plt

In [None]:
## data source path
data_dir = 'E:/EDANet/test_data'
target_names=  os.listdir(os.path.join(data_dir, 'val')) #
# print (target_names)
num_classes=len(target_names)

## test data loader
valdir = os.path.join(data_dir, 'val')
normalize  = transforms.Normalize([0.5], [0.5]) # used gray image

val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(330),
        transforms.CenterCrop(320),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=8, shuffle=False,
    num_workers=4, pin_memory=True)

In [None]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.close('all')

    plt.imshow(cm, interpolation='none', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout(pad=2)
    plt.margins(0.1)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [None]:
## use cuda
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [None]:
## model checkpoint and model
PATH = 'results/best_checkpoint.pth.tar'
model = EDANet(num_classes=3).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])
# model.eval()

In [None]:
import time
import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report,confusion_matrix

### define final validation for confusion matrix and other results
def final_validate(val_loader, model, device, target_names,num_classes, criterion):
    # switch to evaluate mode
    model.eval()
    correct = 0
    Pr = []
    Tr = []
    Flag = True
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss= criterion(output, target)            
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            ## for calculating confusion matrix
            _, predicted = torch.max(output.data, 1)
            predicted = predicted.cpu().numpy()
            predicted=predicted.reshape((-1,1))
            target = target.cpu().data.numpy()
            # print(target.shape)
            target = target.reshape((-1, 1))
            if Flag==True:
                Pr = predicted
                Tr = target
                Flag=False
            else:
                Pr=np.vstack((Pr,predicted))
                Tr=np.vstack((Tr,target))
    PlotTr(Tr, Pr, target_names,num_classes)


def Plot(target_var, predicted, target_names,num_classes):
    """ 
    Plots and results
    """
    # Compute confusion matrix
    cnf_matrix = confusion_matrix(target_var, predicted)
    np.set_printoptions(precision=2)

    # Plot normalized confusion matrix
    class_names = [target_names[i] for i in range(num_classes)]
    print(class_names)

    ##Plot non-normalized confusion matrix
    plt.figure()
    plot_confusion_matrix(cnf_matrix, classes=class_names) #

 
    plt.savefig('results/EDANet'+'Confusion_matrix_WN.png',dpi = (300))
 

    plt.figure()
    plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                            title='Normalized confusion matrix') #

    # plt.show()
    plt.savefig('results/EDANet'+'Confusion_matrix_Nor.png',dpi = (300))

    ###
    matrix = cnf_matrix.astype('float')
    cm_norm = matrix / matrix.sum(axis=1)[:, np.newaxis]
    print(matrix)
    acc = np.array(cm_norm.diagonal())
    class_acc = [matrix[i,i]/np.sum(matrix[i,:]) if np.sum(matrix[i,:]) else 0 for i in range(len(matrix))]
    print('Sens COVID-19: {0:.3f}, Normal : {1:.3f},  Pneumonia: {2:.3f}'.format(class_acc[0],
                                                                               class_acc[1],
                                                                               class_acc[2]))
    ppvs = [matrix[i,i]/np.sum(matrix[:,i]) if np.sum(matrix[:,i]) else 0 for i in range(len(matrix))]
    print('PPV COVID-19: {0:.3f}, Normal:  {1:.3f},  Pneumonia: {2:.3f}'.format(ppvs[0],
                                                                             ppvs[1],
                                                                             ppvs[2]))

    #### save results
    clf_rep=classification_report(target_var,predicted, target_names=target_names)
    cnf_matrix=confusion_matrix(target_var, predicted)
    file_perf = open('results/EDANet'+'performances.txt', 'w')
    file_perf.write("classification Report:\n" + str(clf_rep)
                    + "\n\nConfusion matrix:\n"
                    + str(cnf_matrix)
                    )
    file_perf.close() 
    


In [None]:
final_validate(val_loader, model, device, target_names,num_classes, criterion)