In [1]:
import matplotlib.pyplot as plt
import lesion
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import numpy as np
import unet
import residual_unet
import utils as util
import torch.nn.functional as F
from sklearn.metrics import jaccard_score as jsc
device = torch.device("cuda:0")
import ECE

In [2]:
class ToTensor_segmap(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, segmap):
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        # In this case, Channels is 1, so there is no need to swap since data is in HxW      
  
        segmap = np.array(segmap)
        return torch.from_numpy(segmap) / 255

image_transform = transforms.Compose([
        transforms.RandomGrayscale(1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    ])
seg_transform = ToTensor_segmap()

In [3]:
batch_size = 2
dataset = lesion.LesionDataset("data",folder_name = 'val',joint_transform=False,img_transform=image_transform, seg_transform=seg_transform,verbose = True)
loader = DataLoader(dataset, batch_size=batch_size,shuffle=False, num_workers=4)

In [4]:
file_name = "gate_gaussian.pth"
model = torch.load("models/" + file_name)
model.to(device)

DataParallel(
  (module): res_unet_gate(
    (MaxPool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (drop): GaussianDropout()
    (preprocess): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (down1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1

In [5]:
def apply_dropout(m):
    if type(m) == nn.Dropout2d:
        m.train()
        
#T is the number of times to sample to approximate the posterior predictive distribution
#Assume model in eval mode
def mc(loader,T,batch_size = 2,bayesian = True):
    model.eval()
    total = batch_size * len(loader)
    if(bayesian):
        print("Bayesian Mode")
        model.apply(apply_dropout)
    mean = np.zeros((total, 2, 512, 1024))
    var = np.zeros((total, 2, 512, 1024))
    targets = np.zeros((total, 512, 1024))
    for iteration in range(T):
        for i,data in enumerate(loader):
                image = data[0].to(device)
                m = nn.Softmax2d()
                output = model(image)
                output = m(output)
                output = output.detach().cpu().numpy()
                segmap = data[1].numpy()
                
                mean[batch_size*i:batch_size*i + batch_size] += output
                var[batch_size*i:batch_size*i + batch_size] += output**2
                if(iteration==0):
                    targets[batch_size*i:batch_size*i + batch_size] = segmap
    mean = mean / T
    var = var / T
    var = var - mean**2
    return mean,targets,var
        
    

    
     
    

In [6]:
mean,_,var = mc(loader,1,bayesian=False)

In [7]:
#Sanity check, only dropout module should be on training mode for Bayesian
for module in model.modules():
        if(module.__class__.__name__.startswith('Dropout')):
            print(module)
            print(module.training)

In [8]:
confidence = mean.max(axis=1)
avg_conf = 0
for i in range(150):
    avg_conf += confidence[i].mean()
avg_conf/=150
avg_conf

0.9511745442758911

In [9]:
uncert_type = "conf"
if(uncert_type == "conf"):
    uncertainty = confidence
elif uncert_type =="var":
    uncertainty = var.max(axis=1)
    
avg_uncertainty = 0
#150 images in test set
for i in range(150):
    avg_uncertainty+= uncertainty[i].mean()
avg_uncertainty /=150


In [10]:
avg_uncertainty

0.9511745442758911

In [11]:

test_set = lesion.LesionDataset("data",folder_name = 'test',joint_transform=False,img_transform=image_transform, seg_transform=seg_transform,verbose = True)
test_loader = DataLoader(test_set, batch_size=batch_size,shuffle=False, num_workers=4)
total = len(test_set)

In [12]:
mean,target,var = mc(test_loader,T = 1, bayesian = False)

In [13]:
pred = mean.argmax(axis=1)# test prediction
if(uncert_type == "conf"):
    cm = mean.max(axis=1)
    upper = np.nonzero(cm <= avg_uncertainty)
    lower = np.nonzero(cm > avg_uncertainty)
elif(uncert_type == "var"):
    cm = var.max(axis=1) # certainity map, 1 for certain, 0 for uncertain
    upper = np.nonzero(cm > avg_uncertainty)
    lower = np.nonzero(cm <= avg_uncertainty)
cm[upper] = 0
cm[lower]=1

In [14]:
'''
P(accurate|certain)
'''
def ac(gt,pred,cm):
    #gt = groundtruth, pred = predictions, cm = certainty map
    loc = np.nonzero(cm==1) # location of certain pixels
    gt = gt[loc]
    pred = pred[loc]
    p = np.count_nonzero(gt == pred) # number of accurate pixels given they are certain
    return p/np.count_nonzero(cm==1)
'''
p(uncertain|inaccurate)
'''
def uia(gt,pred,cm):
    loc = np.nonzero(gt != pred) # location of inaccurate pixels
    cm = cm[loc]
    return np.count_nonzero(cm==0)/np.count_nonzero(gt!=pred)

In [15]:
pac = ac(target,pred,cm) #prob of accurate given certain
puia =uia(target,pred,cm) #prob of uncertain given inaccurate

In [16]:
print("Prob of accurate given certain is:{} ".format(pac))
print("Prob of uncertain given inaccurate is:{} ".format(puia))

Prob of accurate given certain is:0.9703858750333965 
Prob of uncertain given inaccurate is:0.7385511933498639 


In [None]:
plt.imshow(cm[0])

In [None]:
plt.imshow(target[0])