In [1]:
import torch
import random
import numpy as np
import os

import glob
import tifffile as tiff

import errno
import imageio

from PIL import Image
from torch.utils import data

from NNUtils import CustomDataset

import matplotlib.pyplot as plt

In [2]:
# read data set and pick up one picture
test_data = '/home/moucheng/projects_data/Brain_data/brats2018/t5_s20_l1_u30_WT/test'
test_data_images = test_data + '/patches'
test_data_labels = test_data + '/labels'
dataset = CustomDataset(test_data_images, test_data_labels, 'none', 3)
dataloader = data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False)

In [3]:
# # check data
# index = 10
# images1, images2, images3, labels, imagename = dataset[index]

# plt.imshow(np.squeeze(images2), cmap='gray')
# plt.title('Input')
# plt.show()

# plt.imshow(np.squeeze(labels), cmap='gray')
# plt.title('Labels')
# plt.show()

# overlay = 0.6*np.squeeze(images2) + 0.4*np.squeeze(labels)
# plt.imshow(np.squeeze(overlay), cmap='gray')
# plt.title('Overlay')
# plt.show()

In [4]:
def test_inference(dataset, model, model_name, device):
    
    for index in range(len(dataset)):

        images1, images2, images3, labels, imagename = dataset[index]

        images1 = torch.from_numpy(images1).float()
        images2 = torch.from_numpy(images2).float()
        images3 = torch.from_numpy(images3).float()
        labels = torch.from_numpy(labels).float()

        test_img1 = images1.to(device=device, dtype=torch.float32)
        test_img2 = images2.to(device=device, dtype=torch.float32)
        test_img3 = images3.to(device=device, dtype=torch.float32)
        test_labels = labels.to(device=device, dtype=torch.float32)

        c, h, w = test_img1.size()
            
        test_img1 = test_img1.view(1, c, h, w)
        test_img2 = test_img2.view(1, c, h, w)
        test_img3 = test_img3.view(1, c, h, w)
        
        if 'ERFAnetZ' in model_name:
        
            # For ours:
            test_outputs_fp, test_outputs_fn, test_pseudo_x1_a, test_pseudo_x1_b, test_pseudo_x2, test_pseudo_x3_a, test_pseudo_x3_b, test_x_, test_x__, test_x___ = model(test_img1, test_img2, test_img3)
            test_outputs_fp = torch.sigmoid(test_outputs_fp)
            test_outputs_fn = torch.sigmoid(test_outputs_fn)
            test_class_outputs = (test_outputs_fn + test_outputs_fp) / 2

        elif 'MTASSUnet' in model_name:
            # For mtass:
            test_outputs, _ = model(test_img2)
            test_class_outputs = torch.sigmoid(test_outputs)
            
        elif 'FixMatch' in model_name or 'MeanTeacher' in model_name or 'Unet' in model_name:
            # For fixmatch, meanTeacher, Unet:
            test_outputs = model(test_img2)
            test_class_outputs = torch.sigmoid(test_outputs)
        
        save_path = '/home/moucheng/projects_data/Brain_data/results'
        save_path = save_path + '/' + model_name

        try:

            os.mkdir(save_path)

        except OSError as exc:

            if exc.errno != errno.EEXIST:

                raise

            pass

        b, c, h, w = test_img2.size()

        pred = test_class_outputs.reshape(h, w).cpu().detach().numpy() > 0.5
        pred = np.asarray(pred, dtype=np.uint8)
        
        if c == 1:
            testimg = test_img2.cpu().squeeze().detach().numpy()
            testimg = np.asarray(testimg, dtype=np.float32)
        else:
            testimg = test_img2.cpu().squeeze().detach().numpy()
            testimg = testimg[2, :, :]
            testimg = np.asarray(testimg, dtype=np.float32)

        label = test_labels.squeeze().cpu().detach().numpy() > 0.5
        label = np.asarray(label, dtype=np.uint8)

        difference = label - pred
        addition = label + pred

        error_map = np.zeros((h, w, 3), dtype=np.uint8)
        label_map = np.zeros((h, w, 3), dtype=np.uint8)

        error_map[difference == -1] = [255, 0, 0]  # false positive red
        error_map[difference == 1] = [0, 0, 255]  # false negative blue
        error_map[addition == 2] = [0, 255, 0]  # true positive green

        label_map[label == 1] = [255, 255, 0]  # true positive yellow

        prediction_name = 'seg_' + imagename + '.png'
        full_error_map_name = os.path.join(save_path, prediction_name)
        imageio.imsave(full_error_map_name, error_map)

        pic_name = 'original_' + imagename + '.png'
        full_pic_map_name = os.path.join(save_path, pic_name)
        imageio.imsave(full_pic_map_name, testimg)

        label_name = 'label_' + imagename + '.png'
        full_label_map_name = os.path.join(save_path, label_name)
        imageio.imsave(full_label_map_name, label_map)

        seg_img = Image.open(full_error_map_name)
        input_img = Image.open(full_pic_map_name)
        label_img = Image.open(full_label_map_name)

        seg_img = seg_img.convert("RGBA")
        input_img = input_img.convert("RGBA")
        label_img = label_img.convert("RGBA")

        alphaBlended_seg = Image.blend(seg_img, input_img, alpha=.6)
        alphaBlended_label = Image.blend(label_img, input_img, alpha=.6)

        imageio.imsave(full_error_map_name, alphaBlended_seg)
        imageio.imsave(full_label_map_name, alphaBlended_label)

In [5]:
# model_name = 'Unet_labelled_repeat_1_augmentation_all_lr_2e-05_epoch_50_CARVE2014_4_r176_s50_epoch40'
# model_path = '/home/moucheng/Desktop/IPMI results/CARVE2014/4_r176_s50/IPMI2020/Unet_labelled_repeat_1_augmentation_all_lr_2e-05_epoch_50_CARVE2014_4_r176_s50/trained_models'

# model = model_name + '.pt'
# model = model_path + '/' + model

# model = torch.load(model)
# model.eval()

# device = torch.device('cuda')

def test_models(model_name_list, model_path, dataset):
    
    for model_name in model_name_list:
        
        model = model_name + '.pt'
        model = model_path + '/' + model

        model = torch.load(model)
        model.eval()
        
        device = torch.device('cuda')
        
        test_inference(dataset, model, model_name, device)

In [20]:
model_name_list = ['Unet_labelled_repeat_1_augmentation_all_lr_2e-06_epoch_80_BRATS2018_t5_s20_l1_u30_WT_Final']

model_path = '/home/moucheng/Desktop/IPMI results/BRATS/IPMI2020/Unet_labelled_repeat_1_augmentation_all_lr_2e-06_epoch_80_BRATS2018_t5_s20_l1_u30_WT/trained_models'

test_models(model_name_list, model_path, dataset)











































































































In [15]:
model_name_list = ['MeanTeacher_repeat_3_alpha_0.002_lr_2e-06_epoch_80_annealing_down_at_0_beta_0.8_BRATS2018_t5_s20_l1_u30_WT_Final']

model_path = '/home/moucheng/Desktop/IPMI results/BRATS/IPMI2020/MeanTeacher_repeat_3_alpha_0.002_lr_2e-06_epoch_80_annealing_down_at_0_beta_0.8_BRATS2018_t5_s20_l1_u30_WT/trained_models'
test_models(model_name_list, model_path, dataset)











































































































In [16]:
model_name_list = ['FixMatch_repeat_1_alpha_0.002_lr_2e-06_epoch_40_annealing_down_at_0_beta_0.8_BRATS2018_t5_s20_l1_u30_WT_Final']

model_path = '/home/moucheng/Desktop/IPMI results/BRATS/IPMI2020/FixMatch_repeat_1_alpha_0.002_lr_2e-06_epoch_40_annealing_down_at_0_beta_0.8_BRATS2018_t5_s20_l1_u30_WT/trained_models'

test_models(model_name_list, model_path, dataset)











































































































In [19]:
model_name_list = ['ERFAnetZ3_repeat_3_augmentation_none_alpha_1.0_lr_2e-06_epoch_40annealing_down_at_0_beta_0.8_constraint_jacobi_all_gamma_1.0_BRATS2018_t5_s20_l1_u30_WT_epoch31',
                   'ERFAnetZ3_repeat_3_augmentation_none_alpha_1.0_lr_2e-06_epoch_40annealing_down_at_0_beta_0.8_constraint_jacobi_all_gamma_1.0_BRATS2018_t5_s20_l1_u30_WT_epoch32',
                   'ERFAnetZ3_repeat_3_augmentation_none_alpha_1.0_lr_2e-06_epoch_40annealing_down_at_0_beta_0.8_constraint_jacobi_all_gamma_1.0_BRATS2018_t5_s20_l1_u30_WT_epoch33',
                   'ERFAnetZ3_repeat_3_augmentation_none_alpha_1.0_lr_2e-06_epoch_40annealing_down_at_0_beta_0.8_constraint_jacobi_all_gamma_1.0_BRATS2018_t5_s20_l1_u30_WT_epoch34',
                   'ERFAnetZ3_repeat_3_augmentation_none_alpha_1.0_lr_2e-06_epoch_40annealing_down_at_0_beta_0.8_constraint_jacobi_all_gamma_1.0_BRATS2018_t5_s20_l1_u30_WT_epoch35',
                   'ERFAnetZ3_repeat_3_augmentation_none_alpha_1.0_lr_2e-06_epoch_40annealing_down_at_0_beta_0.8_constraint_jacobi_all_gamma_1.0_BRATS2018_t5_s20_l1_u30_WT_epoch36',
                   'ERFAnetZ3_repeat_3_augmentation_none_alpha_1.0_lr_2e-06_epoch_40annealing_down_at_0_beta_0.8_constraint_jacobi_all_gamma_1.0_BRATS2018_t5_s20_l1_u30_WT_epoch37',
                   'ERFAnetZ3_repeat_3_augmentation_none_alpha_1.0_lr_2e-06_epoch_40annealing_down_at_0_beta_0.8_constraint_jacobi_all_gamma_1.0_BRATS2018_t5_s20_l1_u30_WT_epoch38',
                   'ERFAnetZ3_repeat_3_augmentation_none_alpha_1.0_lr_2e-06_epoch_40annealing_down_at_0_beta_0.8_constraint_jacobi_all_gamma_1.0_BRATS2018_t5_s20_l1_u30_WT_epoch39']

model_path = '/home/moucheng/Desktop/IPMI results/BRATS/IPMI2020/ERFAnetZ3_repeat_3_augmentation_none_alpha_1.0_lr_2e-06_epoch_40annealing_down_at_0_beta_0.8_constraint_jacobi_all_gamma_1.0_BRATS2018_t5_s20_l1_u30_WT/trained_models'

test_models(model_name_list, model_path, dataset)













































































































































































































































































































































































































































































































































































































































































KeyboardInterrupt: 