In [3]:

###########################################################
#https://mlnotebook.github.io/post/dataaug/
###########################################################

"""
For each function:
image   = array of pixel intensities in 2 or 3 dimensions
returns = array of pixel intensities same shape as `image`
"""
from time import perf_counter
from scipy.ndimage import rotate, interpolation, zoom
import numpy as np
import matplotlib.pyplot as plt
import random


def scaleit(pixel_array, factor, case) :

    order = 0 if case == "mask" else 5
    
    
    #print("Scaling factor", factor)

    height, width, depth= pixel_array.shape
    zheight             = int(np.round(factor * height))
    zwidth              = int(np.round(factor * width))
    zdepth              = depth

    if factor < 1.0:
        newimg  = np.zeros_like(pixel_array)
        row     = (height - zheight) // 2
        col     = (width - zwidth) // 2
        layer   = (depth - zdepth) // 2
        newimg[row:row+zheight, col:col+zwidth, layer:layer+zdepth] = interpolation.zoom(pixel_array, (float(factor), float(factor), 1.0), order=order, mode='nearest')[0:zheight, 0:zwidth, 0:zdepth]

        return newimg

    elif factor > 1.0:
        row     = (zheight - height) // 2
        col     = (zwidth - width) // 2
        layer   = (zdepth - depth) // 2

        newimg = interpolation.zoom(pixel_array[row:row+zheight, col:col+zwidth, layer:layer+zdepth], (float(factor), float(factor), 1.0), order=order, mode='nearest')  
        
        extrah = (newimg.shape[0] - height) // 2
        extraw = (newimg.shape[1] - width) // 2
        extrad = (newimg.shape[2] - depth) // 2
        newimg = newimg[extrah:extrah+height, extraw:extraw+width, extrad:extrad+depth]

        return newimg

    else:
        return pixel_array



def resampleit(image, dims, isseg=False):
    
    image = interpolation.zoom(image, np.array(dims)/np.array(image.shape, dtype=np.float32), order=order, mode='nearest')
    image = np.array([zoom_factors] * dimensionality)

    if isseg:
        image[np.where(image==4)]=3
        
    return image if isseg else (image-image.min())/(image.max()-image.min()) 
   

def translateit(image, offset, case):
    
    order = 0 if case == "mask" else 5
    
    offset = (offset[0],offset[1],0)
   

    return interpolation.shift(input = image, shift =  offset ,output=None, order = order, mode='nearest')


def rotateit(image, theta, case):
    order = 0 if case == "mask" else 5
    
        
    return rotate(image, float(theta), axes=(0,1), reshape=False, order=order, mode='nearest')


def augment_gaussian_noise(data_sample, noise_variance=(0.1, 0.9)):

    if noise_variance[0] == noise_variance[1]:
        variance = noise_variance[0]
    else:
        variance = random.uniform(noise_variance[0], noise_variance[1])

    variance_scaled = variance * np.mean(data_sample) * 0.05 #5% relative to the mean pixel value
    data_sample = data_sample + np.random.normal(0.0, variance_scaled, size=data_sample.shape)

    return data_sample



def augment_contrast(data_sample, contrast_range=(0.9, 1.1), preserve_range=True, per_channel=True):
    if not per_channel:
        mn = data_sample.mean()
        if preserve_range:
            minm = data_sample.min()
            maxm = data_sample.max()
        if np.random.random() < 0.5 and contrast_range[0] < 1:
            factor = np.random.uniform(contrast_range[0], 1)
        else:
            factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1])
        data_sample = (data_sample - mn) * factor + mn
        if preserve_range:
            data_sample[data_sample < minm] = minm
            data_sample[data_sample > maxm] = maxm
    else:
        for c in range(data_sample.shape[0]):
            mn = data_sample[c].mean()
            if preserve_range:
                minm = data_sample[c].min()
                maxm = data_sample[c].max()
            if np.random.random() < 0.5 and contrast_range[0] < 1:
                factor = np.random.uniform(contrast_range[0], 1)
            else:
                factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1])
            data_sample[c] = (data_sample[c] - mn) * factor + mn
            if preserve_range:
                data_sample[c][data_sample[c] < minm] = minm
                data_sample[c][data_sample[c] > maxm] = maxm
    return data_sample

def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True):
    multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1])

    if not per_channel:
        data_sample = data_sample * multiplier
    else:
        for c in range(data_sample.shape[0]):
            multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1])
            data_sample[c] = data_sample[c] * multiplier
    return data_sample


def flipit(image, axes):
    

    image = np.flip(image,axes)
    
    return image


def intensifyit(image, factor):

    return image*float(factor)


def cropit(image, seg=None, margin=5):

    shortaxis = np.argmin(image.shape[:2])
    trimaxes  = 0 if shortaxis == 1 else 1

    trim    = image.shape[shortaxis]
    center  = image.shape[trimaxes] // 2   
    lrcenter = image.shape[shortaxis] // 2

    if seg is not None:

        hits = np.where(seg!=0)
        mins = np.amin(hits, axis=1)
        maxs = np.amax(hits, axis=1)
        segtrim = max(maxs-mins) + margin

        trim = segtrim
        center  = np.mean(hits, 1, dtype=int)[0] 
        lrcenter = np.mean(hits, 1, dtype=int)[1] 

        if center - (trim // 2) > mins[0]:
            while center - (trim // 2) > mins[0]:
                center = center - 1
            center = center

        if center + (trim // 2) < maxs[0]:
            while center + (trim // 2) < maxs[0]:
                center = center + 1
            center = center

        if lrcenter - (trim // 2) > mins[1]:
            while lrcenter - (trim // 2) > mins[1]:
                lrcenter = lrcenter - 1
            lrcenter = lrcenter

        if lrcenter + (trim // 2) < maxs[1]:
            while lrcenter + (trim // 2) < maxs[1]:
                lrcenter = lrcenter + 1
            lrcenter = lrcenter

    top    = max(0, center - (trim //2) - margin//2)
    bottom = trim + margin if top == 0 else top + trim + (margin//2)
    left = max(0, lrcenter - (trim//2) - margin//2)
    right = trim + margin if left == 0 else left + trim + (margin//2)

    # image[center-5:center+5, lrcenter-5:lrcenter+5, :] = 255
    # image[top:bottom, left-2:left+2, :] = 255
    # image[top:bottom, right-2:right+2, :] = 255
    # image[top-2:top+2, left:right, :] = 255
    # image[bottom-2:bottom+2, left:right, :] = 255

    if bottom > image.shape[trimaxes]:
        bottom = image.shape[trimaxes]
        top = bottom - trim

    if right > image.shape[shortaxis]:
        right = image.shape[shortaxis]
        left = right - trim

    image   = image[top: bottom, left:right]

    if seg is not None:
        seg   = seg[top: bottom, left:right]

        return image, seg
    else:
        return image

'''Only works for 3D images... i.e. slice-wise'''
def sliceshift(image, shift_min=-3, shift_max=3, fraction=0.5, isseg=False):
    newimage = image
    numslices   = np.random.randint(1, int(image.shape[-1]*fraction)+1 , 1, dtype=int)
    slices      = np.random.randint(0, image.shape[-1], numslices, dtype=int)
    for slc in slices:
        offset      = np.random.randint(shift_min, shift_max, 2, dtype=int)
        newimage[:,:,slc] = translateit(image[:,:,slc], offset, isseg = isseg)


def randaugm(pixel_array,random_seed_aug,case):
    
    #print(random_seed_aug)
    np.random.seed(random_seed_aug)
    
    #set random values here to ensure same seed for img and mask

    
   
    
    ###Draw which augmentation to do###
    
    numTrans     = np.random.randint(1, 3, size=1) 
    allowedTrans = [1,2,3,4,5]
    whichTrans   = np.random.choice(allowedTrans, numTrans, replace=False)
    
        
    t_1 = perf_counter()

   
    #print("This images was augmented by operation: ", operation ,theta_noise)
    
    
    #always rotate
       
    if 1 in whichTrans:
        theta_rotate = float(np.around(np.random.uniform(-5.0,5.0, size=1), 2))

        pixel_array  = rotateit(pixel_array, theta_rotate,case)
        operation = "rotated"
    
    #noise
    if (2 in whichTrans) & (case == "image"):
        pixel_array = augment_gaussian_noise(pixel_array)
        operation = "noise"
        #print("This images was augmented by operation: ", operation)
        
    #scaling
    scale_factor_variance = [0.95,1.05]
    scale_factor = np.random.uniform(scale_factor_variance[0], scale_factor_variance[1])  
    
    if 3 in whichTrans: 
        pixel_array =  scaleit(pixel_array,scale_factor,case)
        operation = "scaling"
        #print("This images was augmented by operation: ", operation,scale_factor)
    
        
    #contrast
    if (4 in whichTrans) & (case == "image"): 
        pixel_array =  augment_contrast(pixel_array)
        operation = "contrast"
        #print(operation)
        
    #translation
    
    offset  = np.random.randint(-3,3, size=2)
    
    if (5 in whichTrans): 
        pixel_array =  translateit(pixel_array,offset,case)
        operation = "translate"
        #print(operation)
   


    t_2 = perf_counter()
    #print( " - augmentation time - ", t_2 - t_1 )
    
    return pixel_array
