In [None]:
import pickle as pkl
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import container
from matplotlib.ticker import ScalarFormatter
import matplotlib
from models.model_utils import process_image
from utils.notebook_utils import apply_temperature, pixel_uncertainty, convolve_uncertainty
from utils.notebook_utils import softmax, estimate_dice_li, get_patches, dist2center, scale_temp_dice
from models.model_utils import dice_metric
import cv2
import os
import matplotlib.colors as colors
import matplotlib.image as mpimg
import matplotlib.patches as patches
from glob import glob
from sklearn.linear_model import LinearRegression
from munch import munchify
import yaml
import warnings
import copy
# plt.rcParams['text.usetex'] = True

Configure split manually and select configuration of no.patches and patch size (config suffix)

In [None]:
split = 'test'
config_suffix = 'all_patches'

split_suffix = 'val' if split == 'val' else '' # enables consistent naming convention
run_again = False # whether to re-compute results or load from pre-computed files
run_again_patch_selection = False # separate flag for patch selection

In [None]:
# Load Experiments Configs and Paths
with open('config_' + config_suffix + '.yaml', 'r') as file0:
        cfg = munchify(yaml.load(file0, Loader=yaml.FullLoader))
        file0.close()

with open('paths.yaml', 'r') as file0:
        cfg_paths = munchify(yaml.load(file0, Loader=yaml.FullLoader))
        file0.close()

In [None]:
# Unit Test
if cfg.patch_size % 2  == 0:
    warnings.warn('Please specify an uneven patchsize in config_' + config_suffix + '.yaml.')

### Load forward passes

In [None]:
# Load Ensemble Passes for Test Set
passes = 'ensemble_dict_val.pkl' if split == 'val' else 'ensemble_dict.pkl'
with open('cache/' + passes, 'rb') as file2:
    ensemble_dict = pkl.load(file2)
    file2.close()

with open('cache/' +  config_suffix + '_optimal_T.pkl', 'rb') as file3:
    T_remain = pkl.load(file3)
    file3.close()

### Compute means over ensemble outputs and Calibrate

In [None]:
ensemble_dict['mean logits'] = np.array(ensemble_dict['logits']).mean(axis=0)
ensemble_dict['mean predictions'] = softmax(np.array(ensemble_dict['logits'])).mean(axis=0)

In [None]:
predicted_logits_arr = np.array(ensemble_dict['logits'])
predictions_arr = softmax(predicted_logits_arr) # not calibrated yet -- happens later

mean_true_dice = np.array(ensemble_dict['true dice']).mean(axis=0)
pos_pred = (ensemble_dict['mean predictions'] > 0.5).mean(axis=(1,2))

# deprecated:
# predicted_logits_arr = np.log(predictions_arr / (1 - predictions_arr + 10e-9))

### Compute Uncertainty Maps

In [None]:
entropy_maps = []

for i in range(predictions_arr.shape[1]):
    entropy_maps.append(pixel_uncertainty(predictions_arr[:, i,: ,:], 'entropy'))

Potentially Subset to Analyze Subgroups (To reproduce paper select entire test set)

In [None]:
# Select relevant images
sufficient_foreground = pos_pred > cfg.fg_thresh
sufficient_dice = mean_true_dice > cfg.dice_thresh

image_selection = np.logical_and(sufficient_foreground, sufficient_dice)
image_selection_ids = np.where(image_selection)[0]
print('You have selected ' + str(image_selection.sum()) + ' images.')

## Extract Patches and Estimate Remaining Dice

Select Patches with Highest Summed Uncertainty, Mask them and Estimate Dice of the Remaining Image

In [None]:
# reconstruct train and val split
root = cfg_paths.FIVES
train_x = sorted(glob(os.path.join(root, 'train/Original/*')))
train_y = sorted(glob(os.path.join(root, 'train/Ground truth/*')))

validation_split = .2
indices = list(range(len(train_x)))
split_val = int(np.floor(validation_split * len(train_x)))
np.random.seed(23) # manually confirmed with the training script
np.random.shuffle(indices)

train_indices, val_indices = indices[split_val:], indices[:split_val]

val_x = train_x[:split_val]
val_y = train_y[:split_val]

In [None]:
# load filenames
if split == 'val':
    split_x = val_x
    split_y = val_y
    print('Indexed Images from Validation Split.')

elif split == 'test':
    split_x = sorted(glob(os.path.join(root, split + '/Original/*')))
    split_y = sorted(glob(os.path.join(root, split + '/Ground truth/*')))
    print('Indexed images from test split.')

In [None]:
def get_estimation_info(img_idx):
    # Image and GT
    x_path = split_x[img_idx]
    y_path = split_y[img_idx]

    image = cv2.imread(x_path, cv2.IMREAD_COLOR) ## (512, 512, 3)
    # image = clahe_equalized(image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (512,512))

    mask = cv2.imread(y_path, cv2.IMREAD_GRAYSCALE)  ## (512, 512)
    mask = cv2.resize(mask, (512,512)) # interpolation =  INTER_NEAREST

    x_image, y_image = process_image(image, mask)
    gt = y_image.squeeze()

    if cfg.uncertainty_type == 'entropy':
        uncertainty_map = - entropy_maps[img_idx]
    # elif cfg.uncertainty_type == 'variance':
    #    uncertainty_map = variance_maps[img_idx]

    convolution_dict = convolve_uncertainty(uncertainty_map, cfg.patch_size - 1)
    convolve_map = np.array(convolution_dict['patch uncertainty']).reshape(432, 432)

    row, col = np.unravel_index(np.argmax(convolve_map, axis=None), convolve_map.shape)
    
    ref_array = np.array(convolution_dict['reference'])

    # available to choose from
    in_image = np.ones_like(convolve_map.flatten()).astype(bool)
    patch_references = []

    for i in range(cfg.no_patches - 1):
        # choose pixel with largest value available
        max_idx_flat = convolve_map.flatten()[in_image].argmax()
        # store reference point (in fullsize image space) for the patch
        reference_point = np.array(convolution_dict['reference'])[in_image.flatten(), :][max_idx_flat, :]
        patch_references.append(reference_point)

        # update in_image, i.e. "remove" all pixels that are already selected by earlier patches
        y_constraint = np.logical_and(ref_array[:, 0] <= reference_point[0] + cfg.patch_size, ref_array[:, 0] >= reference_point[0] - cfg.patch_size)
        x_constraint = np.logical_and(ref_array[:, 1] <= reference_point[1] + cfg.patch_size, ref_array[:, 1] >= reference_point[1] - cfg.patch_size)
        patch_constraint = np.logical_and(y_constraint, x_constraint)

        in_image[patch_constraint] = False


    prediction = predictions_arr[0, img_idx,:,:] > 0.5
    masked_out = np.zeros_like(prediction).astype(bool)
    dsc_curve = []
    est_dsc_curve = []
    est_w = []
    true_w = []
    est_dice_remain_list = []
    dice_remain_list = []
    pos_remain = []
    masks =[]

    for i in range(len(patch_references) + 1):

        masks.append(~masked_out)

        dice_remain = dice_metric(gt[~masked_out], prediction[~masked_out])
        est_dice_remain = estimate_dice_li(predictions_arr[0, img_idx,:,:][~masked_out])

        est_dice_remain_list.append(est_dice_remain)
        dice_remain_list.append(dice_remain)
        pos_remain.append((predictions_arr[0, img_idx, :,:][~masked_out] > 0.5).mean())

        # update mask
        if i <= len(patch_references) - 1:
            patch_ref = patch_references[i]
            masked_out[patch_ref[0]:patch_ref[0] + cfg.patch_size, patch_ref[1]:patch_ref[1] + cfg.patch_size] = True


    return {'true dice': dice_remain_list, 'est dice': est_dice_remain_list,
            'positive remaining': pos_remain, 'masks': masks}

In [None]:
if run_again_patch_selection:
    
    # Re-run patch selection if necessary 
    # Requires ~90min for 200imgs on one consumer CPU

    estimation_info = []

    for img_id in image_selection_ids:
        estimation_info.append(get_estimation_info(img_id))

In [None]:
if run_again_patch_selection:
    with open('cache/' + config_suffix + '_estimation_info_' + split_suffix + '.pkl', 'wb') as file:
            pkl.dump(estimation_info, file)
            file.close()
else:
    with open('cache/' + config_suffix + '_estimation_info' + split_suffix + '.pkl', 'rb') as file:
        estimation_info = pkl.load(file)
        file.close()

`estimation_info` contains a list of no.images (200 if entire test set selected).
Each element contains a dict with lists true, est. DSC (uncalibrated!), fraction of pixels that were classified as positive and the respective binary masks.
Each element of those refers to one patch, where entry 0 has no patches selected.

### Calibrate

In [None]:
relevant_logits = predicted_logits_arr.mean(axis=0)[image_selection_ids, :, :]

In [None]:
logit_remain = [] 

for i, one_img_logit in enumerate(relevant_logits):

    for mask in estimation_info[i]['masks']:

        logit_remain.append(one_img_logit[mask])

In [None]:
true_dice_flat = [el for sublist in estimation_info for el in sublist['true dice']]

In [None]:
def apply_temperature_remain(T, logits):

    """
        logits_cali: [list]
        dice_true: [list]
        T: [float]
    """

    n_imgs = len(logits)
    ts_logits = [logits_img / T for logits_img in logits]
    ts_pyIx = [softmax(logits) for logits in ts_logits]

    est_dice_cali = [estimate_dice_li(probs) for probs in ts_pyIx]



    temperature_dict = {'estimated dice': est_dice_cali, 'predictions': ts_pyIx,
                        'T': T, }

    return temperature_dict

In [None]:
rescaled_probabilities = apply_temperature_remain(T_remain, logit_remain)

In [None]:
del logit_remain # free RAM

In [None]:
ts_scaled_estimate_remain = rescaled_probabilities['estimated dice']

In [None]:
estimates_matrix_ts = np.array(ts_scaled_estimate_remain).reshape(-1, cfg.no_patches)

In [None]:
if run_again:
    with open('cache/' + config_suffix + '_estimates_matrix.pkl', 'wb') as file:
        pkl.dump(estimates_matrix_ts, file)
        file.close()

From now on we operate with `estimates_matrix_ts` and `ts_scaled_estimate_remain`, which contain the estimated DSC for the remaining parts of the images (after the patches have been cut out).

## Impute Oracle

In [None]:
def convolve_uncertainty(uncertainty_map, patch_size, stride=1):
    """
        ::param::    
        uncertainty_map: [np.array] of shape nxn
        patch_size: [int] defines side length of square patch
        stride: [int]
        
        ::return::
        Patch Identifier with Convolution Value
    """

    patch_convolved = []
    reference_points = [] # identifies patches
    patch_dict = {'reference': reference_points,
                  'patch uncertainty': patch_convolved}
    

    def sliding_window(arr, step_size=stride, window_size=patch_size):
        """
            Iterator which yields a binary mask for patch extraction
            alongside identifying coordinates of the reference point.
        """
        for y in range(0, arr.shape[0] - patch_size, step_size):
            for x in range(0, arr.shape[1] - patch_size, step_size):

                bool_img = np.zeros_like(arr) * False
                bool_img[y:y + window_size, x:x + window_size] = True
                
                yield (y, x, bool_img.astype(bool))

    windows = sliding_window(uncertainty_map, stride, patch_size)
    

    for y, x, window in windows:
        patch_value = uncertainty_map[window].sum()
        patch_convolved.append(patch_value)
        reference_points.append((y, x))

    return patch_dict

In [None]:
# convolve a random image to obtain the reference points, required in the
# following function
convolution_dict = convolve_uncertainty(np.ones(shape=(512, 512)), cfg.patch_size)

In [None]:
def oracle_and_random(img_idx, selection_idx):

    # oracle

    dsc_curve = []
    est_dsc_curve = []
    est_w = []
    true_w = []
    est_dice_remain_list = []
    dice_remain_list = []
    pos_remain = []

    masks = estimation_info[selection_idx]['masks']

    # Retrieve Prediction and Ground Truth
    prediction = ensemble_dict['mean predictions'][img_idx, :,:]
    ## Ground Truth
    x_path = split_x[img_idx]
    y_path = split_y[img_idx]

    image = cv2.imread(x_path, cv2.IMREAD_COLOR) ## (512, 512, 3)
    # image = clahe_equalized(image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (512,512))

    mask = cv2.imread(y_path, cv2.IMREAD_GRAYSCALE)  ## (512, 512)
    mask = cv2.resize(mask, (512,512)) # interpolation =  INTER_NEAREST

    x_image, y_image = process_image(image, mask)
    gt = y_image.squeeze()

    ## Oracle
    for j, masked_out in enumerate(masks):

        masked_out = ~masked_out

        # dice_remain = dice_metric(gt[~masked_out], prediction[~masked_out])
        # est_dice_remain = estimate_dice_li(predictions_arr[0, img_idx,:,:][~masked_out])
        # pos_remain.append((predictions_arr[0, img_idx, :,:][~masked_out] > 0.5).mean())

        dice_remain = estimation_info[selection_idx]['true dice'][j]
        #est_dice_remain = estimation_info[selection_idx]['est dice'][j]
        est_dice_remain = estimates_matrix_ts[selection_idx, j]
        pos_remain.append(estimation_info[selection_idx]['positive remaining'][j])
        
        est_dice_remain_list.append(est_dice_remain)
        dice_remain_list.append(dice_remain)

        w_remain = (gt[~masked_out].sum() + prediction[~masked_out].sum()) / (gt.sum() + prediction.sum())
        w_oracle = 1 - w_remain
        true_w.append(w_remain)

        est_w_remain = prediction[~masked_out].sum() / prediction.sum()
        est_w_oracle = 1 - est_w_remain
        est_w.append(est_w_remain)

        dsc_reanno = w_remain * dice_remain + w_oracle * 1
        est_dsc_reanno = est_w_remain * est_dice_remain + est_w_oracle * 1

        dsc_curve.append(dsc_reanno)
        est_dsc_curve.append(est_dsc_reanno)

    ## Random Patches
    # convolve a random image because we require the references points which are
    # outputted by convolve_uncertainty

    
    # Select Random Point within Circle
    dists = [dist2center(ref) for ref in convolution_dict['reference']]
    # remove those which are further away than radius
    dist_constraint = np.array(dists) > 510/2
    population = np.array(convolution_dict['reference'])[dist_constraint, :]
    random_references = population[np.random.randint(population.shape[0], size=5), :]

    masked_out = np.zeros_like(prediction).astype(bool)
    random_dsc_curve = []
    random_dice_remain_list = []

    for i in range(len(random_references) + 1):

        dice_remain = dice_metric(gt[~masked_out], prediction[~masked_out])
        random_dice_remain_list.append(dice_remain)

        w_remain = (gt[~masked_out].sum() + prediction[~masked_out].sum()) / (gt.sum() + prediction.sum())
        w_oracle = 1 - w_remain
        
        est_w_remain = prediction[~masked_out].sum() / prediction.sum()
        est_w_oracle = 1 - est_w_remain
        
        dsc_reanno = w_remain * dice_remain + w_oracle * 1
        
        random_dsc_curve.append(dsc_reanno)
        
        # update mask
        if i <= len(random_references) - 1:
            patch_ref = random_references[i]
            masked_out[patch_ref[0]:patch_ref[0] + cfg.patch_size, patch_ref[1]:patch_ref[1] + cfg.patch_size] = True


    return {'true dsc curve': dsc_curve, 'est dsc curve': est_dsc_curve,
            'random curve': random_dsc_curve}

In [None]:
# remove all objects that are not required at the moment to free RAM
required_obs = ['ensemble_dict', 'estimation_info', 'oracle_and_random',
                'split_x', 'split_y', 'process_image', 'estimates_matrix_ts',
                'dist2center', 'convolution_dict', 'required_obs', 'image_selection_ids',
                'cv2', 'np', 'dice_metric', 'cfg', 'config_suffix', 'split_suffix',
                'pkl', 'plt', 'ts_scaled_estimate_remain', 'inputs', 'targets',
                'run_again', 'container']

for name in dir():
    if not (name.startswith('_') or name in required_obs):
        del globals()[name]

In [None]:
with open('cache/' + config_suffix + '_oracle_results' + split_suffix + '.pkl', 'rb') as file4:
    oracle_results = pkl.load(file4)
    file4.close()

In [None]:
if config_suffix == 'all_imgs':
    import pandas as pd
    oracle_true = np.array(pd.read_csv('cache/all_imgs_oracle_true.csv', header=None))
    oracle_est = np.array(pd.read_csv('cache/all_imgs_oracle_est.csv', header=None))
    oracle_random = np.array(pd.read_csv('cache/all_imgs_oracle_random.csv', header=None))

In [None]:
if run_again: 
    oracle_results = []

    for selection_idx, original_idx in enumerate(image_selection_ids):
        oracle_results.append(oracle_and_random(original_idx, selection_idx))

    with open('cache/' + config_suffix + '_oracle_results' + split_suffix + '.pkl', 'wb') as file4:
        pkl.dump(oracle_results, file4)
    file4.close()

In [None]:
oracle_true = np.array([img_dict['true dsc curve'] for img_dict in oracle_results])
oracle_est = np.array([img_dict['est dsc curve'] for img_dict in oracle_results])
oracle_random = np.array([img_dict['random curve'] for img_dict in oracle_results])