In [None]:
import sys
import os
import configparser
import csv
import numpy as np
import imageio
import torch
import torchvision
import matplotlib.pyplot as plt
from RetinaCheckerMultiClass import RetinaCheckerMultiClass
from helper_functions import reduce_to_2_classes, AverageMeter, AccuracyMeter

tensor_to_image = torchvision.transforms.ToPILImage()

In [None]:
# parameter
config_file_name = 'test3.cfg'
image_path = 'D:\\Dropbox\\Data\\mini-set'


In [None]:
sys.argv[1] = config_file_name
# Reading configuration file
config = configparser.ConfigParser()
config.read(sys.argv[1])

config['files']['train path'] = config['files']['test path']
    
rc = RetinaCheckerMultiClass()
rc.initialize( config )


In [None]:
rc.load_state()

In [None]:
num_images = len(rc.test_loader.dataset)
n_cols = 6
n_rows = np.ceil(num_images/n_cols)
classlabel = ['no DMR', 'mild NPDR', 'mod NPDR', 'severe NPDR', 'PDR']


In [None]:
fig, ax = plt.subplots(int(n_rows), int(n_cols), True, True, figsize=(20,int(n_rows)*3))
test_loader = rc.test_loader
rc.model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    losses = AverageMeter()
    accuracy = AccuracyMeter()

    confusion = torch.zeros((rc.num_classes, rc.num_classes), dtype=torch.float)
    counter = 0

    for images, labels in test_loader:
        images = images.to(rc.device)
        labels = labels.to(rc.device)

        outputs = rc.model(images)
        loss = rc.criterion(outputs, labels)

        losses.update(loss.item(), images.size(0))

        num_correct = rc._evaluate_performance( labels, outputs )

        accuracy.update(num_correct, labels.size(0))
        
        for img, lab, out in zip(images, labels, outputs):
            ii = int(counter/n_cols)
            jj = int(counter%n_cols)
            ax[ii, jj].imshow(imageio.imread(rc.test_dataset.imgs[counter][0]), origin='lower')
            ax[ii, jj].annotate( classlabel[lab.numpy().argmax()], xy=(10,10), color='white', size=10)
            ax[ii, jj].annotate( classlabel[out.numpy().argmax()], xy=(120,10), color='cyan', size=10)
            for cc in range(5):
                ax[ii, jj].annotate( '{}: {:.3f}'.format(classlabel[cc], torch.nn.Sigmoid()(out[cc])) , xy=(10,205-cc*20), color='white', size=10)
            counter+=1

plt.xticks([])
plt.yticks([])

