## Patch Prediction and Reconstruction
---
**Import Statements**

In [None]:
import numpy as np
from scipy.ndimage.filters import gaussian_filter
import matplotlib.pyplot as plt
from skimage.util import montage
import random
from sklearn.metrics import f1_score

**Patch Functions**

In [None]:
# Gaussian importance weight map from nnUNet 
def gaussian(patch_size, sigma_scale = 1./8):
    
    gaussian_map = np.zeros(patch_size)
    cent = [i // 2 for i in patch_size]
    sigmas = [i * sigma_scale for i in patch_size]
    
    gaussian_map[tuple(cent)] = 1
    gaussian_map = gaussian_filter(gaussian_map, sigmas, 0, mode='constant', cval=0)
    gaussian_map = gaussian_map / np.max(gaussian_map) * 1
    gaussian_map = gaussian_map.astype(np.float32)
    
    gaussian_map[gaussian_map == 0] = np.min(gaussian_map[gaussian_map != 0])
    
    return gaussian_map

In [None]:
def patch_steps(patch_size, vol_size, step_size = 0.5):
    
    target_step_sizes = [i * step_size for i in patch_size]
    num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(vol_size, target_step_sizes, patch_size)]
    
    steps = []
    
    for dim in range(len(patch_size)):
        max_step_value = vol_size[dim] - patch_size[dim]
        
        if num_steps[dim] > 1:
            actual_step_size = max_step_value / (num_steps[dim] - 1)
        else:
            actual_step_size = 9999 # Does not matter b/c will not step
        
        steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])]
        
        steps.append(steps_here)
    
    return steps

In [None]:
def predict_on_patches(ct_vol, model, patch_size, step_size):

    vol_size = ct_vol.shape
    
    gaussian_patch = gaussian(patch_size)
    
    aggregated_gaussian = np.zeros(vol_size, dtype=np.float32)
    aggregated_predictions = np.zeros(vol_size, dtype=np.float32)

    steps = patch_steps(patch_size, vol_size, step_size = 0.5)
    
    ct_patches=[]
    prediction_patches=[]

    for x in steps[0]:
        
        lb_x = x
        ub_x = x + patch_size[0]

        for y in steps[1]:
            lb_y = y
            ub_y = y + patch_size[1]
            
            ones = True
            for z in steps[2]:
                lb_z = z
                ub_z = z + patch_size[2]
                
                if model == 'patch testing':
                    # This is just for testing patch functions
                    if ones:
                        prediction_patch = np.ones(patch_size)
                        ones = False
                    else:
                        prediction_patch = np.zeros(patch_size) 
                        ones = True
                    
                else:
                    ct_patch = ct_vol[lb_x:ub_x,lb_y:ub_y,lb_z:ub_z]
                    
                    ct_patch = ct_patch.reshape(-1,ct_patch.shape[0],ct_patch.shape[1],ct_patch.shape[2],1)
                    
                    prediction_patch = model.predict(ct_patch)
                    prediction_patch = prediction_patch.squeeze()
                    
                aggregated_gaussian[lb_x:ub_x,lb_y:ub_y,lb_z:ub_z] += gaussian_patch
                aggregated_predictions[lb_x:ub_x,lb_y:ub_y,lb_z:ub_z] += prediction_patch*gaussian_patch
                
                ct_patches.append(ct_patch)
                prediction_patches.append(prediction_patch)
    
    return_patches_list = False
    if return_patches_list == True:
        return ct_patches,prediction_patches
                
                
    prediction_vol = aggregated_predictions/aggregated_gaussian
    
    return prediction_vol

**Check Patch Functions**

In [None]:
if __name__ == '__main__':
    
    ct_vol = np.ones((256,256,126))
    patch_size = (128,128,63)
    step_size = 0.5
    
    prediction_vol = predict_on_patches(ct_vol=ct_vol, model='patch testing', patch_size=patch_size, step_size=step_size)
        
    print("min prob:", prediction_vol.min(), "max prob:", prediction_vol.max())
    plt.imshow(prediction_vol[:,0,:])
    plt.colorbar()