In [1]:
from torch.utils.data import DataLoader
from torchvision import transforms

import medicalDataLoader
from utils import *

import torch
import segmentation_models_pytorch as smp

def runTesting(modelName='Test_Model'):
    print('-' * 40)
    print('~~~~~~~~  Starting the testing... ~~~~~~')
    print('-' * 40)

    batch_size_val = 1
    root_dir = './Data/'
    num_classes = 4

    # https://sparrow.dev/pytorch-normalize/
    transform = transforms.Compose([
        transforms.ToTensor()
        # transforms.Normalize((0.5), (0.20))
    ])

    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    test_set = medicalDataLoader.MedicalImageDataset('val',
                                                     root_dir,
                                                     transform=transform,
                                                     mask_transform=mask_transform,
                                                     equalize=False)

    test_loader = DataLoader(test_set,
                             batch_size=batch_size_val,
                             num_workers=5,
                             shuffle=False)

    # Create and load model
    net = smp.Unet('resnet34', encoder_weights='imagenet', in_channels=1, classes=num_classes)

    # Load
    net.load_state_dict(torch.load('./Model_Best'))
    net.eval()

    if torch.cuda.is_available():
        net.cuda()

    print("~~~~~~~~~~~ Starting the testing ~~~~~~~~~~")
    [DSC1, DSC1s, DSC2, DSC2s, DSC3, DSC3s, HD1, HD1s, HD2, HD2s, HD3, HD3s, ASD1,
        ASD1s, ASD2, ASD2s, ASD3, ASD3s] = inferenceTest(net, test_loader, modelName)

    print("###                                                       ###")
    print("###         TEST RESULTS                                  ###")
    print("###  Dice : c1: {:.4f} ({:.4f}) c2: {:.4f} ({:.4f}) c3: {:.4f} ({:.4f}) Mean: {:.4f} ({:.4f}) ###".format(DSC1,
                                                                                                                     DSC1s,
                                                                                                                     DSC2,
                                                                                                                     DSC2s,
                                                                                                                     DSC3,
                                                                                                                     DSC3s,
                                                                                                                     (DSC1+DSC2+DSC3)/3,
                                                                                                                     (DSC1s+DSC2s+DSC3s)/3))
    print("###  HD   : c1: {:.4f} ({:.4f}) c2: {:.4f} ({:.4f}) c3: {:.4f} ({:.4f}) Mean: {:.4f} ({:.4f}) ###".format(HD1,
                                                                                                                     HD1s,
                                                                                                                     HD2,
                                                                                                                     HD2s,
                                                                                                                     HD3,
                                                                                                                     HD3s,
                                                                                                                     (HD1 + HD2 + HD3) / 3,
                                                                                                                     (HD1s + HD2s + HD3s) / 3))
    print("###  ASD  : c1: {:.4f} ({:.4f}) c2: {:.4f} ({:.4f}) c3: {:.4f} ({:.4f}) Mean: {:.4f} ({:.4f}) ###".format(ASD1,
                                                                                                                     ASD1s,
                                                                                                                     ASD2,
                                                                                                                     ASD2s,
                                                                                                                     ASD3,
                                                                                                                     ASD3s,
                                                                                                                     (ASD1 + ASD2 + ASD3) / 3,
                                                                                                                     (ASD1s + ASD2s + ASD3s) / 3))
    print("###                                                       ###")

runTesting()

----------------------------------------
~~~~~~~~  Starting the testing... ~~~~~~
----------------------------------------
~~~~~~~~~~~ Starting the testing ~~~~~~~~~~
[Inference] Segmentation Done !                                                                              
###                                                       ###
###         TEST RESULTS                                  ###
###  Dice : c1: 0.5196 (0.3928) c2: 0.7633 (0.2348) c3: 0.8675 (0.2613) Mean: 0.7168 (0.2963) ###
###  HD   : c1: 25.6839 (22.1049) c2: 8.2040 (14.6341) c3: 4.5224 (10.1688) Mean: 12.8034 (15.6360) ###
###  ASD  : c1: 8.1072 (10.5744) c2: 2.5032 (4.9862) c3: 1.3730 (3.2120) Mean: 3.9945 (6.2576) ###
###                                                       ###
