In [1]:
import pickle as pkl
import numpy as np
import torch
import matplotlib.pyplot as plt
from models.model_utils import process_image
from utils.notebook_utils import apply_temperature, pixel_uncertainty
from utils.notebook_utils import softmax, estimate_dice_li, get_patches
from utils.notebook_utils import scale_temp_dice, apply_temperature
from models.model_utils import dice_metric
import cv2
import os
import matplotlib.colors as colors
import matplotlib.image as mpimg
import matplotlib.patches as patches
from glob import glob
from sklearn.linear_model import LinearRegression
import yaml
from munch import munchify
import warnings


Configure manually what to execute

In [20]:
config_suffix = 'all_patches'
run_again = False # whether to re-compute results or load from pre-computed files

In [3]:
# Load Experiments Configs
with open('config_' + config_suffix + '.yaml', 'r') as file:
        cfg = munchify(yaml.load(file, Loader=yaml.FullLoader))
        file.close()

with open('paths.yaml', 'r') as file0:
        cfg_paths = munchify(yaml.load(file0, Loader=yaml.FullLoader))
        file0.close()

In [4]:
# Unit Test
if cfg.patch_size % 2  == 0:
    warnings.warn('Please specify an uneven patchsize in config.yaml.')

### Load Forward Passes and Compute Uncertainty Maps

In [5]:
# Load Ensemble Passes
with open('cache/ensemble_dict_val.pkl', 'rb') as file2:
    ensemble_dict = pkl.load(file2)
    # ensemble_dict['logits'] contains list of m (no. ensemble members), where
    # each list element contains list of n_val (no. of validation images)
    file2.close()

Mean Predictions and Temperature Scaling

In [12]:
ensemble_dict['mean logits'] = np.array(ensemble_dict['logits']).mean(axis=0)
ensemble_dict['mean predictions'] = softmax(np.array(ensemble_dict['logits'])).mean(axis=0)

All Predictions

In [13]:
predictions_arr = softmax(np.array(ensemble_dict['logits'])) # all predictions as array
mean_true_dice = np.array(ensemble_dict['true dice']).mean(axis=0)
pos_pred = (ensemble_dict['mean predictions'] > 0.5).mean(axis=(1,2))

# convert probabilities to logits for whole images
predicted_logits_arr = np.log(predictions_arr / (1 - predictions_arr + 10e-9)) # could simply use ensemble_dict['logits'], no?  

Uncertainty Maps

In [16]:
entropy_maps, variance_maps = [], []

for i in range(predictions_arr.shape[1]):
    entropy_maps.append(pixel_uncertainty(predictions_arr[:, i,: ,:], 'entropy'))
    variance_maps.append(pixel_uncertainty(predictions_arr[:, i,: ,:], 'variance'))

### Reconstruct validation split 

In [17]:
# reconstruct train and val split
root = cfg_paths.FIVES
train_x = sorted(glob(os.path.join(root, 'train/Original/*')))
train_y = sorted(glob(os.path.join(root, 'train/Ground truth/*')))

validation_split = .2
indices = list(range(len(train_x)))
split = int(np.floor(validation_split * len(train_x)))
np.random.seed(23) # manually confirmed with the training script
np.random.shuffle(indices)

train_indices, val_indices = indices[split:], indices[:split]

val_x = train_x[:split]
val_y = train_y[:split]

### Define functions which estimate dice for the remaining images

In [18]:
def convolve_uncertainty(uncertainty_map, patch_size, stride=1):
    """
        ::param::    
        uncertainty_map: [np.array] of shape nxn
        patch_size: [int] defines side length of square patch
        stride: [int]
        
        ::return::
        Patch Identifier with Convolution Value
    """

    patch_convolved = []
    reference_points = [] # identifies patches
    patch_dict = {'reference': reference_points,
                  'patch uncertainty': patch_convolved}
    

    def sliding_window(arr, step_size=stride, window_size=patch_size):
        """
            Iterator which yields a binary mask for patch extraction
            alongside identifying coordinates of the reference point.
        """
        for y in range(0, arr.shape[0] - patch_size, step_size):
            for x in range(0, arr.shape[1] - patch_size, step_size):

                bool_img = np.zeros_like(arr) * False
                bool_img[y:y + window_size, x:x + window_size] = True
                
                yield (y, x, bool_img.astype(bool))

    windows = sliding_window(uncertainty_map, stride, patch_size)
    

    for y, x, window in windows:
        patch_value = uncertainty_map[window].sum()
        patch_convolved.append(patch_value)
        reference_points.append((y, x))

    return patch_dict

In [19]:
def get_estimation_info(img_idx, cfg):
    # Image and GT
    x_path = val_x[img_idx]
    y_path = val_y[img_idx]

    image = cv2.imread(x_path, cv2.IMREAD_COLOR) ## (512, 512, 3)
    # image = clahe_equalized(image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (512,512))

    mask = cv2.imread(y_path, cv2.IMREAD_GRAYSCALE)  ## (512, 512)
    mask = cv2.resize(mask, (512,512)) # interpolation =  INTER_NEAREST

    x_image, y_image = process_image(image, mask)
    gt = y_image.squeeze()

    if cfg.uncertainty_type == 'entropy':
        uncertainty_map = - entropy_maps[img_idx]
    elif cfg.uncertainty_type == 'variance':
        uncertainty_map = variance_maps[img_idx]

    convolution_dict = convolve_uncertainty(uncertainty_map, cfg.patch_size - 1)
    convolve_map = np.array(convolution_dict['patch uncertainty']).reshape(432, 432)

    row, col = np.unravel_index(np.argmax(convolve_map, axis=None), convolve_map.shape)
    in_image = np.ones_like(convolve_map).astype(bool)

    ref_array = np.array(convolution_dict['reference'])

    no_patches = cfg.no_patches
    # available to choose from
    in_image = np.ones_like(convolve_map.flatten()).astype(bool)
    patch_references = []

    for i in range(cfg.no_patches - 1):
        # choose pixel with largest value available
        max_idx_flat = convolve_map.flatten()[in_image].argmax()
        # store reference point (in fullsize image space) for the patch
        reference_point = np.array(convolution_dict['reference'])[in_image.flatten(), :][max_idx_flat, :]
        patch_references.append(reference_point)

        # update in_image, i.e. "remove" all pixels that are already selected by earlier patches
        y_constraint = np.logical_and(ref_array[:, 0] <= reference_point[0] + cfg.patch_size, ref_array[:, 0] >= reference_point[0] - cfg.patch_size)
        x_constraint = np.logical_and(ref_array[:, 1] <= reference_point[1] + cfg.patch_size, ref_array[:, 1] >= reference_point[1] - cfg.patch_size)
        patch_constraint = np.logical_and(y_constraint, x_constraint)

        in_image[patch_constraint] = False


    prediction = predictions_arr[0, img_idx,:,:] > 0.5
    masked_out = np.zeros_like(prediction).astype(bool)
    dsc_curve = []
    est_dsc_curve = []
    est_w = []
    true_w = []
    est_dice_remain_list = []
    dice_remain_list = []
    pos_remain = []
    masks = []

    for i in range(len(patch_references) + 1):

        masks.append(~masked_out)

        dice_remain = dice_metric(gt[~masked_out], prediction[~masked_out])
        est_dice_remain = estimate_dice_li(predictions_arr[0, img_idx,:,:][~masked_out])

        est_dice_remain_list.append(est_dice_remain)
        dice_remain_list.append(dice_remain)
        pos_remain.append((predictions_arr[0, img_idx, :,:][~masked_out] > 0.5).mean())

        # update mask
        if i <= len(patch_references) - 1:
            patch_ref = patch_references[i]
            masked_out[patch_ref[0]:patch_ref[0] + cfg.patch_size, patch_ref[1]:patch_ref[1] + cfg.patch_size] = True


    return {'true dice': dice_remain_list, 'est dice': est_dice_remain_list,
            'positive remaining': pos_remain, 'masks': masks}

### Potentially select subset of observations

In [25]:
# Select Relevant Images
# Select relevant images
foreground_thresh = cfg.fg_thresh
dice_thresh = cfg.dice_thresh

sufficient_foreground = pos_pred > foreground_thresh
sufficient_dice = mean_true_dice > dice_thresh

image_selection = np.logical_and(sufficient_foreground, sufficient_dice)
image_selection_ids = np.where(image_selection)[0]
print('You have selected ' + str(image_selection.sum()) + ' images.')

You have selected 120 images.


Estimate DSC and measure true DSC for remaining images

In [21]:
if run_again:

    estimation_info_val = []

    for img_id in image_selection_ids:
        estimation_info_val.append(get_estimation_info(img_id, cfg=cfg)) # esentially returns the orange and blue curve

Save estimated DSC of remaining images or load from previous run

In [22]:
if run_again:
    with open('cache/' + config_suffix + '_estimation_info_val.pkl', 'wb') as file:
        pkl.dump(estimation_info_val, file)
    file.close()
else:
    with open('cache/' + config_suffix + '_estimation_info_val.pkl', 'rb') as file:
        estimation_info_val = pkl.load(file)
    file.close()


### Calibrate: Fit Temperature Scaling

In [26]:
relevant_logits = predicted_logits_arr.mean(axis=0)[image_selection_ids, :, :]

In [27]:
# Choose logits of cut out images
logit_remain = [] 

for i, one_img_logit in enumerate(relevant_logits):

    for mask in estimation_info_val[i]['masks']:

        logit_remain.append(one_img_logit[mask])

In [31]:
true_dice_flat = [el for sublist in estimation_info_val for el in sublist['true dice']]
if run_again:
    T_remain = scale_temp_dice(logit_remain, np.array(true_dice_flat).mean()) 
else:
    T_remain = 1.2993164062499996

In [34]:
with open('cache/' + config_suffix + '_optimal_T.pkl', 'wb') as file:
    pkl.dump(T_remain, file)
    file.close()