In [109]:
import tensorflow as tf
import numpy as np
from skimage.morphology import ball
from scipy.ndimage import binary_erosion, binary_dilation, distance_transform_edt
import imageio
import matplotlib.pyplot as plt
import scipy.ndimage as nd

# CERW Function

In [176]:
def CERW(image,seed,dtype=32,tol=0.00001,k=0.5,beta=500):
  
  # set the type of floats and integers
  if dtype == 64:
    float_type = tf.float64
    int_type = tf.int64
    image = image.astype('float64')
    seed = seed.astype('float64')

  elif dtype == 32:
    float_type = tf.float32
    int_type = tf.int32
    image = image.astype('float32')
    seed = seed.astype('float32')
    
  else:
    raise ValueError('dtype needs to either be 32 or 64 bits')

  # create the constants that will be used in function calls (saves time as otherwise these will be created many times)
  # create the filters for the differencing
  H_filter = tf.constant([[[[[-1.]]],[[[1.]]]]],dtype=float_type)
  W_filter = tf.constant([[[[[-1.]],[[1.]]]]],dtype=float_type)
  D_filter = tf.constant([[[[[-1.]]]],[[[[1.]]]]],dtype=float_type)
  lHW_filter = tf.constant([[[[[-1.]],[[0.]]],[[[0.]],[[1.]]]]],dtype=float_type)
  rHW_filter = tf.constant([[[[[0.]],[[-1.]]],[[[1.]],[[0.]]]]],dtype=float_type)
  lDW_filter = tf.constant([[[[[-1.]],[[0.]]]],[[[[0.]],[[1.]]]]],dtype=float_type)
  rDW_filter = tf.constant([[[[[0.]],[[-1.]]]],[[[[1.]],[[0.]]]]],dtype=float_type)
  lDH_filter = tf.constant([[[[[-1.]]],[[[0.]]]],[[[[0.]]],[[[1.]]]]],dtype=float_type)
  rDH_filter = tf.constant([[[[[0.]]],[[[-1.]]]],[[[[1.]]],[[[0.]]]]],dtype=float_type)

  # create strides for convolutions
  strides = [1,1,1,1,1]

  # beta
  beta = tf.constant(-beta,float_type)
  alpha = tf.constant(beta/65025.,float_type)

  # numbers for speed (memory isnt an issue)
  float_zero = tf.constant(0., float_type)
  float_half = tf.constant(0.5, float_type)
  float_one = tf.constant(1., float_type)
  float_two = tf.constant(2., float_type)

  # paddings
  H_pad1 = tf.constant([[0,0],[0,0],[0,1],[0,0],[0,0]])
  H_pad2 = tf.constant([[0,0],[0,0],[1,0],[0,0],[0,0]])
  W_pad1 = tf.constant([[0,0],[0,0],[0,0],[0,1],[0,0]])
  W_pad2 = tf.constant([[0,0],[0,0],[0,0],[1,0],[0,0]])
  D_pad1 = tf.constant([[0,0],[0,1],[0,0],[0,0],[0,0]])
  D_pad2 = tf.constant([[0,0],[1,0],[0,0],[0,0],[0,0]])
  lHW_pad1 = tf.constant([[0,0],[0,0],[0,1],[0,1],[0,0]])
  lHW_pad2 = tf.constant([[0,0],[0,0],[1,0],[1,0],[0,0]])
  rHW_pad1 = tf.constant([[0,0],[0,0],[0,1],[1,0],[0,0]])
  rHW_pad2 = tf.constant([[0,0],[0,0],[1,0],[0,1],[0,0]])
  lDW_pad1 = tf.constant([[0,0],[0,1],[0,0],[0,1],[0,0]])
  lDW_pad2 = tf.constant([[0,0],[1,0],[0,0],[1,0],[0,0]])
  rDW_pad1 = tf.constant([[0,0],[0,1],[0,0],[1,0],[0,0]])
  rDW_pad2 = tf.constant([[0,0],[1,0],[0,0],[0,1],[0,0]])
  lDH_pad1 = tf.constant([[0,0],[0,1],[0,1],[0,0],[0,0]])
  lDH_pad2 = tf.constant([[0,0],[1,0],[1,0],[0,0],[0,0]])
  rDH_pad1 = tf.constant([[0,0],[0,1],[1,0],[0,0],[0,0]])
  rDH_pad2 = tf.constant([[0,0],[1,0],[0,1],[0,0],[0,0]])
  curveH_pad = tf.constant([[0,0],[2,2],[3,3],[2,2],[0,0]])
  curveW_pad = tf.constant([[0,0],[2,2],[2,2],[3,3],[0,0]])
  curveD_pad = tf.constant([[0,0],[3,3],[2,2],[2,2],[0,0]])

  # curvature approximation filters
  # in x direction
  H_curve = tf.constant([[[[[ -1.]],[[ -4.]],[[ -6.]],[[ -4.]],[[ -1.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  1.]],[[  4.]],[[  6.]],[[  4.]],[[  1.]]]],

                         [[[[ -4.]],[[-16.]],[[-24.]],[[-16.]],[[ -4.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  4.]],[[ 16.]],[[ 24.]],[[ 16.]],[[  4.]]]],

                         [[[[ -6.]],[[-24.]],[[-36.]],[[-24.]],[[ -6.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  6.]],[[ 24.]],[[ 36.]],[[ 24.]],[[  6.]]]],
                        
                         [[[[ -4.]],[[-16.]],[[-24.]],[[-16.]],[[ -4.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  4.]],[[ 16.]],[[ 24.]],[[ 16.]],[[  4.]]]],
                          
                         [[[[ -1.]],[[ -4.]],[[ -6.]],[[ -4.]],[[ -1.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  1.]],[[  4.]],[[  6.]],[[  4.]],[[  1.]]]]],dtype=float_type)
  
  # in y direction
  W_curve = tf.constant([[[[[ -1.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  1.]]],
                          [[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]],
                          [[[ -6.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  6.]]],
                          [[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]],
                          [[[ -1.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  1.]]]],

                         [[[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]],
                          [[[-16.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 16.]]],
                          [[[-24.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 24.]]],
                          [[[-16.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 16.]]],
                          [[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]]],

                         [[[[ -6.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  6.]]],
                          [[[-24.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 24.]]],
                          [[[-36.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 36.]]],
                          [[[-24.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 24.]]],
                          [[[ -6.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  6.]]]],
                        
                         [[[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]],
                          [[[-16.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 16.]]],
                          [[[-24.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 24.]]],
                          [[[-16.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 16.]]],
                          [[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]]],
                          
                         [[[[ -1.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  1.]]],
                          [[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]],
                          [[[ -6.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  6.]]],
                          [[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]],
                          [[[ -1.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  1.]]]]],dtype=float_type)
  
  # in z direction
  D_curve = tf.constant([[[[[ -1.]],[[ -4.]],[[ -6.]],[[ -4.]],[[ -1.]]],
                          [[[ -4.]],[[-16.]],[[-24.]],[[-16.]],[[ -4.]]],
                          [[[ -6.]],[[-24.]],[[-36.]],[[-24.]],[[ -6.]]],
                          [[[ -4.]],[[-16.]],[[-24.]],[[-16.]],[[ -4.]]],
                          [[[ -1.]],[[ -4.]],[[ -6.]],[[ -4.]],[[ -1.]]]],

                         [[[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]]],

                         [[[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]]],
                        
                         [[[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]]],
                         
                         [[[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]]],

                         [[[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]]],                         
                          
                         [[[[  1.]],[[  4.]],[[  6.]],[[  4.]],[[  1.]]],
                          [[[  4.]],[[ 16.]],[[ 24.]],[[ 16.]],[[  4.]]],
                          [[[  6.]],[[ 24.]],[[ 36.]],[[ 24.]],[[  6.]]],
                          [[[  4.]],[[ 16.]],[[ 24.]],[[ 16.]],[[  4.]]],
                          [[[  1.]],[[  4.]],[[  6.]],[[  4.]],[[  1.]]]]],dtype=float_type)

  # division for the curvature 
  divide = tf.constant(288.,dtype=float_type)

  # functions needed

  # function to calculate the weights
  @tf.function
  def calc_weights(x,float_type,beta):

    # H is height, W is width, D is depth
    # difference between each pixel
    H_weights = tf.exp(tf.multiply(tf.pow(tf.nn.conv3d(x, H_filter, strides, padding='VALID'),2),beta)) # row difference
    W_weights = tf.exp(tf.multiply(tf.pow(tf.nn.conv3d(x, W_filter, strides, padding='VALID'),2),beta)) # col difference
    D_weights = tf.exp(tf.multiply(tf.pow(tf.nn.conv3d(x, D_filter, strides, padding='VALID'),2),beta)) # depth difference
    
    HW_ldiag_weights = tf.exp(tf.add(tf.multiply(tf.pow(tf.nn.conv3d(x, lHW_filter, strides, padding='VALID'),2),beta),alpha)) # row-column left diagonal
    HW_rdiag_weights = tf.exp(tf.add(tf.multiply(tf.pow(tf.nn.conv3d(x, rHW_filter, strides, padding='VALID'),2),beta),alpha)) # row-column right diagonal
    
    WD_ldiag_weights = tf.exp(tf.add(tf.multiply(tf.pow(tf.nn.conv3d(x, lDW_filter, strides, padding='VALID'),2),beta),alpha)) # col-depth left diagonal
    WD_rdiag_weights = tf.exp(tf.add(tf.multiply(tf.pow(tf.nn.conv3d(x, rDW_filter, strides, padding='VALID'),2),beta),alpha)) # col-depth right diagonal

    HD_ldiag_weights = tf.exp(tf.add(tf.multiply(tf.pow(tf.nn.conv3d(x, lDH_filter, strides, padding='VALID'),2),beta),alpha)) # row-depth left diagonal
    HD_rdiag_weights = tf.exp(tf.add(tf.multiply(tf.pow(tf.nn.conv3d(x, rDH_filter, strides, padding='VALID'),2),beta),alpha)) # row-depth right diagonal
    
    return H_weights, W_weights, D_weights, HW_ldiag_weights, HW_rdiag_weights, WD_ldiag_weights, WD_rdiag_weights, HD_ldiag_weights, HD_rdiag_weights 

  # function to calculate step of discretised random walk
  #@tf.function
  def diffusion(x,H_weights,W_weights,D_weights,HW_ldiag_weights,HW_rdiag_weights,WD_ldiag_weights,WD_rdiag_weights,HD_ldiag_weights,HD_rdiag_weights,float_type):
    
    # compute the difference of the rows
    z = tf.multiply(tf.nn.conv3d(x, H_filter, strides, padding='VALID'),H_weights)
    y = tf.subtract(tf.pad(z,H_pad1,'CONSTANT'), tf.pad(z,H_pad2,'CONSTANT'))
    
    # compute the difference of the columns
    z = tf.multiply(tf.nn.conv3d(x, W_filter, strides, padding='VALID'),W_weights)
    y += tf.subtract(tf.pad(z,W_pad1,'CONSTANT'), tf.pad(z,W_pad2,'CONSTANT'))

    # compute the difference of the depth
    z = tf.multiply(tf.nn.conv3d(x, D_filter, strides, padding='VALID'),D_weights)
    y += tf.subtract(tf.pad(z,D_pad1,'CONSTANT'), tf.pad(z,D_pad2,'CONSTANT'))

    # compute the difference of the left diagonal between the height and width
    z = tf.multiply(tf.nn.conv3d(x, lHW_filter, strides, padding='VALID'),HW_ldiag_weights)
    y += tf.subtract(tf.pad(z,lHW_pad1,'CONSTANT'), tf.pad(z,lHW_pad2,'CONSTANT'))

    # compute the difference of the right diagonal between the height and width
    z = tf.multiply(tf.nn.conv3d(x, rHW_filter, strides, padding='VALID'),HW_rdiag_weights)
    y += tf.subtract(tf.pad(z,rHW_pad1,'CONSTANT'), tf.pad(z,rHW_pad2,'CONSTANT'))
    
    # compute the difference of the left diagonal between the depth and width
    z = tf.multiply(tf.nn.conv3d(x, lDW_filter, strides, padding='VALID'),WD_ldiag_weights)
    y += tf.subtract(tf.pad(z,lDW_pad1,'CONSTANT'), tf.pad(z,lDW_pad2,'CONSTANT'))

    # compute the difference of the right diagonal between the depth and width
    z = tf.multiply(tf.nn.conv3d(x, rDW_filter, strides, padding='VALID'),WD_rdiag_weights)
    y += tf.subtract(tf.pad(z,rDW_pad1,'CONSTANT'), tf.pad(z,rDW_pad2,'CONSTANT'))

    # compute the difference of the left diagonal between depth and height
    z = tf.multiply(tf.nn.conv3d(x, lDH_filter, strides, padding='VALID'),HD_ldiag_weights)
    y += tf.subtract(tf.pad(z,lDH_pad1,'CONSTANT'), tf.pad(z,lDH_pad2,'CONSTANT'))

    # compute the difference of the right diagonal between depth and height
    z = tf.multiply(tf.nn.conv3d(x, rDH_filter, strides, padding='VALID'),HD_rdiag_weights)
    y += tf.subtract(tf.pad(z,rDH_pad1,'CONSTANT'), tf.pad(z,rDH_pad2,'CONSTANT'))

    return y

  # function to calculate the curvature term
  @tf.function
  def curvature_term(x,float_type,int_type):

    # calculate mean curvature using filters in each direction
                                # curvature in x direction
    y = tf.divide(tf.add(tf.add(tf.nn.conv3d(tf.pad(x, curveH_pad,'REFLECT'), H_curve, strides, padding='VALID'),          
                                # add curvature in y direction
                                tf.nn.conv3d(tf.pad(x, curveW_pad,'REFLECT'), W_curve, strides, padding='VALID')),
                                # add curvature in z direction
                                tf.nn.conv3d(tf.pad(x, curveD_pad,'REFLECT'), D_curve, strides, padding='VALID')),divide)

    return tf.multiply(y,tf.cast(tf.cast(tf.add(tf.multiply(y,tf.subtract(x,float_half)),float_one),int_type),float_type))

  # unknown seed
  unknown = np.zeros_like(seed)
  unknown[seed == 0.5] = 1

  # convert to tensor
  unknown = tf.convert_to_tensor(unknown, dtype=float_type)

  # find the boundary of the unknown (to reduce size for better performance)
  x = tf.where(unknown==1)
  depth_min = tf.clip_by_value(tf.math.reduce_min(x[:,0])-1,0,unknown.shape[0]+1)
  depth_max = tf.clip_by_value(tf.math.reduce_max(x[:,0])+2,0,unknown.shape[0]+1)
  row_min = tf.clip_by_value(tf.math.reduce_min(x[:,1])-1,0,unknown.shape[1]+1)
  row_max = tf.clip_by_value(tf.math.reduce_max(x[:,1])+2,0,unknown.shape[1]+1)
  col_min = tf.clip_by_value(tf.math.reduce_min(x[:,2])-1,0,unknown.shape[2]+1)
  col_max = tf.clip_by_value(tf.math.reduce_max(x[:,2])+2,0,unknown.shape[2]+1)
  del x

   
  # calculate the weights
  #H_weights, W_weights, D_weights, HW_ldiag_weights, HW_rdiag_weights, WD_ldiag_weights, WD_rdiag_weights, HD_ldiag_weights, HD_rdiag_weights = calc_weights(tf.expand_dims(tf.expand_dims(image[depth_min:depth_max,row_min:row_max,col_min:col_max], 0), -1),float_type,beta) 
  H_weights, W_weights, D_weights, HW_ldiag_weights, HW_rdiag_weights, WD_ldiag_weights, WD_rdiag_weights, HD_ldiag_weights, HD_rdiag_weights = calc_weights(tf.expand_dims(tf.expand_dims(image[depth_min:depth_max,row_min:row_max,col_min:col_max], 0), -1),float_type,beta) 
  del image # no longer needed, delete to save ram
  del beta
  del alpha

  # variable to keep the initial seed as sinks and sources
  unknown = tf.constant(tf.expand_dims(tf.expand_dims(unknown[depth_min:depth_max,row_min:row_max,col_min:col_max], 0), -1), dtype=float_type)

  # eps -- time resolution
  eps = tf.constant(0.05*unknown, dtype=float_type)

  # k - coefficient of curvature
  k = tf.constant(k, dtype=float_type)

  # tolerance
  tol = tf.constant(tol, dtype=float_type)

  # Create variables for simulation state
  U = tf.Variable(tf.expand_dims(tf.expand_dims(seed[depth_min:depth_max,row_min:row_max,col_min:col_max], 0), -1), dtype=float_type)
  U_ = tf.zeros_like(U)

  # curvature term
  curvature = tf.Variable(tf.zeros_like(unknown,dtype=float_type), dtype=float_type)

  # count
  count = tf.Variable(0, dtype=tf.int32)
  int_one = tf.constant(1, dtype=tf.int32)
  int_ten = tf.constant(10, dtype=tf.int32)
  int_100 = tf.constant(100, dtype=tf.int32)

  # function to update U
  @tf.function
  def step(U):
    # this part updates unknown pixels
          # this function bounds the values by 0 and 1                          
    return tf.clip_by_value( 
                            # the new change multiplied by the timestep is added to the old system
                            tf.add(U, tf.multiply(eps, 
                                                  #the diffusion term a
                                                  tf.add(diffusion(U,H_weights,W_weights,D_weights,HW_ldiag_weights,HW_rdiag_weights,WD_ldiag_weights,WD_rdiag_weights,HD_ldiag_weights,HD_rdiag_weights,float_type),
                                                         # add the curvature term
                                                         tf.multiply(k, tf.multiply(curvature,tf.multiply(U, tf.subtract(float_one, U))))))),float_zero,float_one)

  # Run steps of PDE until solved3
  while True:

    if tf.equal(tf.math.floormod(count,int_ten),0):
      
      curvature = curvature_term(U,float_type,int_type)

      if tf.equal(tf.math.floormod(count,int_100),0):
        
        U_ = step(U)

        if tf.math.reduce_max(tf.math.abs(tf.subtract(U_, U))) < tol:
          U = U_

          seed[depth_min:depth_max,row_min:row_max,col_min:col_max] = tf.cast(tf.add(U,0.5),int_type).numpy()[0,:,:,:,0]

          return seed
          
          del U_
          break
        
        U = U_
    
    else:
      U = step(U)

    count = tf.add(count,int_one)

  seed[depth_min:depth_max,row_min:row_max,col_min:col_max] = tf.cast(U+0.5,int_type).numpy()[0,:,:,:,0] # 0.5 > converted to 1 and 0.5 < converted to 0

  return seed

# Evaluation Metrics

## Jaccard Score

In [3]:
# function for Jaccard score
def jaccard(x,y):
  
  x = x.astype(bool)
  y = y.astype(bool)
  
  return np.sum(np.logical_and(x,y))/np.sum(np.logical_or(x,y))

## Local Jaccard Score

In [4]:
# here x should be the segmentation and y the ground truth (as the curvature is evaluated using the ground truth)
def local_jaccard(x,y):

  # find the boundary of the ground truth
  boundary = tf.nn.conv3d(tf.expand_dims(tf.expand_dims(tf.pad(tf.cast(y,tf.float32),[[1,1],[1,1],[1,1]],'REFLECT'),0),-1),tf.constant([[[[[1.]],[[1.]],[[1.]]],[[[1.]],[[ 1.]],[[1.]]],[[[1.]],[[1.]],[[1.]]]],
                                                                                                             [[[[1.]],[[1.]],[[1.]]],[[[1.]],[[27.]],[[1.]]],[[[1.]],[[1.]],[[1.]]]],
                                                                                                             [[[[1.]],[[1.]],[[1.]]],[[[1.]],[[ 1.]],[[1.]]],[[[1.]],[[1.]],[[1.]]]]]), [1,1,1,1,1], padding='VALID')[0,:,:,:,0]
  
  boundary = tf.where(tf.math.logical_and(boundary>26,boundary!=53),tf.ones_like(boundary),tf.zeros_like(boundary))

  # perfrom logical operations
  x_and_y = tf.cast(tf.logical_and(x,y),tf.float32)
  x_or_y = tf.cast(tf.logical_or(x,y),tf.float32)

  # find local jaccard
  jaccard = (tf.nn.conv3d(tf.expand_dims(tf.expand_dims(tf.pad(x_and_y,[[1,1],[1,1],[1,1]],'REFLECT'),0),-1),tf.constant([[[[[0.]],[[1.]],[[0.]]],[[[1.]],[[ 1.]],[[1.]]],[[[0.]],[[1.]],[[0.]]]],
                                                                                                                          [[[[1.]],[[1.]],[[1.]]],[[[1.]],[[ 1.]],[[1.]]],[[[1.]],[[1.]],[[1.]]]],
                                                                                                                          [[[[0.]],[[1.]],[[0.]]],[[[1.]],[[ 1.]],[[1.]]],[[[0.]],[[1.]],[[0.]]]]]), [1,1,1,1,1], padding='VALID')[0,:,:,:,0] /
            
             tf.nn.conv3d(tf.expand_dims(tf.expand_dims(tf.pad(x_or_y,[[1,1],[1,1],[1,1]],'REFLECT'),0),-1),tf.constant([[[[[0.]],[[1.]],[[0.]]],[[[1.]],[[ 1.]],[[1.]]],[[[0.]],[[1.]],[[0.]]]],
                                                                                                                         [[[[1.]],[[1.]],[[1.]]],[[[1.]],[[ 1.]],[[1.]]],[[[1.]],[[1.]],[[1.]]]],
                                                                                                                         [[[[0.]],[[1.]],[[0.]]],[[[1.]],[[ 1.]],[[1.]]],[[[0.]],[[1.]],[[0.]]]]]), [1,1,1,1,1], padding='VALID')[0,:,:,:,0] )
  
   # remove NaN values
  jaccard = tf.where(tf.math.is_nan(jaccard), tf.ones_like(jaccard), jaccard)
  
   # find the curvature (only done now so that y can be deleted)
  curvature = tf.divide(tf.add(tf.add(tf.nn.conv3d(tf.pad(tf.expand_dims(tf.expand_dims(tf.cast(y,tf.float32),-1),0), tf.constant([[0,0],[2,2],[3,3],[2,2],[0,0]]),'REFLECT'), tf.constant([[[[[ -1.]],[[ -4.]],[[ -6.]],[[ -4.]],[[ -1.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  1.]],[[  4.]],[[  6.]],[[  4.]],[[  1.]]]],

                                                                                                                                                                        [[[[ -4.]],[[-16.]],[[-24.]],[[-16.]],[[ -4.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  4.]],[[ 16.]],[[ 24.]],[[ 16.]],[[  4.]]]],

                                                                                                                                                                        [[[[ -6.]],[[-24.]],[[-36.]],[[-24.]],[[ -6.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  6.]],[[ 24.]],[[ 36.]],[[ 24.]],[[  6.]]]],
                                                                                                                                                                        
                                                                                                                                                                        [[[[ -4.]],[[-16.]],[[-24.]],[[-16.]],[[ -4.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  4.]],[[ 16.]],[[ 24.]],[[ 16.]],[[  4.]]]],
                                                                                                                                                                          
                                                                                                                                                                        [[[[ -1.]],[[ -4.]],[[ -6.]],[[ -4.]],[[ -1.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  1.]],[[  4.]],[[  6.]],[[  4.]],[[  1.]]]]],dtype=tf.float32), [1,1,1,1,1], padding='VALID'),          
                                                                            # add curvature in y direction
                                                                            tf.nn.conv3d(tf.pad(tf.expand_dims(tf.expand_dims(tf.cast(y,tf.float32),-1),0), tf.constant([[0,0],[2,2],[2,2],[3,3],[0,0]]),'REFLECT'), tf.constant([[[[[ -1.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  1.]]],
                                                                                                                                                                          [[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]],
                                                                                                                                                                          [[[ -6.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  6.]]],
                                                                                                                                                                          [[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]],
                                                                                                                                                                          [[[ -1.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  1.]]]],

                                                                                                                                                                        [[[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]],
                                                                                                                                                                          [[[-16.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 16.]]],
                                                                                                                                                                          [[[-24.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 24.]]],
                                                                                                                                                                          [[[-16.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 16.]]],
                                                                                                                                                                          [[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]]],

                                                                                                                                                                        [[[[ -6.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  6.]]],
                                                                                                                                                                          [[[-24.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 24.]]],
                                                                                                                                                                          [[[-36.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 36.]]],
                                                                                                                                                                          [[[-24.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 24.]]],
                                                                                                                                                                          [[[ -6.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  6.]]]],
                                                                                                                                                                        
                                                                                                                                                                        [[[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]],
                                                                                                                                                                          [[[-16.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 16.]]],
                                                                                                                                                                          [[[-24.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 24.]]],
                                                                                                                                                                          [[[-16.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[ 16.]]],
                                                                                                                                                                          [[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]]],
                                                                                                                                                                          
                                                                                                                                                                        [[[[ -1.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  1.]]],
                                                                                                                                                                          [[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]],
                                                                                                                                                                          [[[ -6.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  6.]]],
                                                                                                                                                                          [[[ -4.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  4.]]],
                                                                                                                                                                          [[[ -1.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  1.]]]]],dtype=tf.float32), [1,1,1,1,1], padding='VALID')),
                                                                            # add curvature in z direction
                                                                            tf.nn.conv3d(tf.pad(tf.expand_dims(tf.expand_dims(tf.cast(y,tf.float32),-1),0), tf.constant([[0,0],[3,3],[2,2],[2,2],[0,0]]),'REFLECT'), tf.constant([[[[[ -1.]],[[ -4.]],[[ -6.]],[[ -4.]],[[ -1.]]],
                                                                                                                                                                          [[[ -4.]],[[-16.]],[[-24.]],[[-16.]],[[ -4.]]],
                                                                                                                                                                          [[[ -6.]],[[-24.]],[[-36.]],[[-24.]],[[ -6.]]],
                                                                                                                                                                          [[[ -4.]],[[-16.]],[[-24.]],[[-16.]],[[ -4.]]],
                                                                                                                                                                          [[[ -1.]],[[ -4.]],[[ -6.]],[[ -4.]],[[ -1.]]]],

                                                                                                                                                                        [[[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]]],

                                                                                                                                                                        [[[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]]],
                                                                                                                                                                        
                                                                                                                                                                        [[[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]]],
                                                                                                                                                                        
                                                                                                                                                                        [[[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]]],

                                                                                                                                                                        [[[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]],
                                                                                                                                                                          [[[  0.]],[[  0.]],[[  0.]],[[  0.]],[[  0.]]]],                         
                                                                                                                                                                          
                                                                                                                                                                        [[[[  1.]],[[  4.]],[[  6.]],[[  4.]],[[  1.]]],
                                                                                                                                                                          [[[  4.]],[[ 16.]],[[ 24.]],[[ 16.]],[[  4.]]],
                                                                                                                                                                          [[[  6.]],[[ 24.]],[[ 36.]],[[ 24.]],[[  6.]]],
                                                                                                                                                                          [[[  4.]],[[ 16.]],[[ 24.]],[[ 16.]],[[  4.]]],
                                                                                                                                                                          [[[  1.]],[[  4.]],[[  6.]],[[  4.]],[[  1.]]]]],dtype=tf.float32), [1,1,1,1,1], padding='VALID')),288.)[0,:,:,:,0]
  
  # perfrom inverse
  x = tf.logical_not(x)
  y = tf.logical_not(y)

  x_and_y = tf.cast(tf.logical_and(x,y),tf.float32)
  x_or_y = tf.cast(tf.logical_or(x,y),tf.float32)

  del x
  del y

  # find background jaccard
  bg_jaccard = (tf.nn.conv3d(tf.expand_dims(tf.expand_dims(tf.pad(x_and_y,[[1,1],[1,1],[1,1]],'REFLECT'),0),-1),tf.constant([[[[[0.]],[[1.]],[[0.]]],[[[1.]],[[ 1.]],[[1.]]],[[[0.]],[[1.]],[[0.]]]],
                                                                                                                          [[[[1.]],[[1.]],[[1.]]],[[[1.]],[[1.]],[[1.]]],[[[1.]],[[1.]],[[1.]]]],
                                                                                                                          [[[[0.]],[[1.]],[[0.]]],[[[1.]],[[ 1.]],[[1.]]],[[[0.]],[[1.]],[[0.]]]]]), [1,1,1,1,1], padding='VALID')[0,:,:,:,0] /
            
                tf.nn.conv3d(tf.expand_dims(tf.expand_dims(tf.pad(x_or_y,[[1,1],[1,1],[1,1]],'REFLECT'),0),-1),tf.constant([[[[[0.]],[[1.]],[[0.]]],[[[1.]],[[ 1.]],[[1.]]],[[[0.]],[[1.]],[[0.]]]],
                                                                                                                          [[[[1.]],[[1.]],[[1.]]],[[[1.]],[[1.]],[[1.]]],[[[1.]],[[1.]],[[1.]]]],
                                                                                                                          [[[[0.]],[[1.]],[[0.]]],[[[1.]],[[ 1.]],[[1.]]],[[[0.]],[[1.]],[[0.]]]]]), [1,1,1,1,1], padding='VALID')[0,:,:,:,0] )
  
  bg_jaccard = tf.where(tf.math.is_nan(bg_jaccard), tf.ones_like(bg_jaccard), bg_jaccard)
  
  # delete for space
  del x_and_y
  del x_or_y

  # find minimum
  jaccard = tf.minimum(jaccard,bg_jaccard)

  del bg_jaccard                                                                                                                    
  
  # find binary map of areas of curvature (convert back for operations)
  neg_curv = tf.where(curvature<-0.2,tf.ones_like(curvature),tf.zeros_like(curvature))*boundary
  pos_curv = tf.where(curvature>0.2,tf.ones_like(curvature),tf.zeros_like(curvature))*boundary
  low_curv = tf.where(tf.logical_and(curvature>-0.2,curvature<0.2),tf.ones_like(curvature),tf.zeros_like(curvature))*boundary
  
  del curvature

  # sort and average the jaccard scores for each curvature
  neg_jaccard = tf.reduce_sum(neg_curv*jaccard)/tf.reduce_sum(neg_curv)
  low_jaccard = tf.reduce_sum(low_curv*jaccard)/tf.reduce_sum(low_curv)
  pos_jaccard = tf.reduce_sum(pos_curv*jaccard)/tf.reduce_sum(pos_curv) 


  # returns the local jaccards sorted in order [negative curvature, low curvature, positive curvature] # note that a NaN means there is no areas of one curvature
  return neg_jaccard.numpy(), low_jaccard.numpy(), pos_jaccard.numpy()

# Hausdorff

In [5]:
def hausdorff(x,y):

  x = x.astype(np.float32)
  y = y.astype(np.float32)

  # find the distance transform of 2 images
  x_bdt = distance_transform_edt(x)
  y_bdt = distance_transform_edt(y)

  # use this to find the border
  x_border = tf.where(x_bdt==1,tf.ones_like(x),tf.zeros_like(x))
  y_border = tf.where(y_bdt==1,tf.ones_like(y),tf.zeros_like(y))

  # distance from points on border of y to border of x
  y_to_x = y_border*x_bdt

  # distance from points on border of x to border of y
  x_to_y = x_border*y_bdt
  
  # return the maximum of maximums
  return tf.maximum(tf.reduce_max(y_to_x),tf.reduce_max(x_to_y)).numpy()

## Mean Boundary Distance Error

In [6]:
# BDE from x -> y
def bde(x,y):

  x = x.astype(np.float32)
  y = y.astype(np.float32)

  # find the distance transform of 2 images
  x_bdt = distance_transform_edt(x)
  y_bdt = distance_transform_edt(y)

  # use this to find the border
  x_border = tf.where(x_bdt==0,tf.ones_like(x),tf.zeros_like(x))
  y_border = tf.where(y_bdt==0,tf.ones_like(y),tf.zeros_like(y))

  y_bdt += distance_transform_edt(1-y)
  y_bdt = tf.cast(y_bdt,tf.float32)*tf.cast(tf.logical_not(tf.cast(y_border,tf.bool)),tf.float32)

  del x
  del y

  # distance from points on border of x to border of y
  x_to_y = x_border*y_bdt
  
  # return the maximum of maximums
  return (tf.reduce_sum(x_to_y)/tf.reduce_sum(x_border)).numpy()

## Evaluation Print

In [14]:
def evaluate(seg,gt):
    print('Jaccard:',jaccard(seg,gt))
    print('Local Jaccard (-ve,low,+ve):',local_jaccard(seg,gt))
    print('Hausdorff:', hausdorff(seg,gt))
    print('BDE,SEG->GT:',bde(seg,gt))
    print('BDE,GT->SEG:',bde(gt,seg))

# Phansalkar Seed Selection

In [249]:
def phansalkar(x,phansalkar_size=21,bg_d_r=20,bg_e_r=5,fg_d_r=4,fg_e_r=10,p=2):

  # set up variables to find the mean and standard deviation
  phansalkar_r = phansalkar_size//2
  phansalkar_size = phansalkar_r*2 + 1

  # set up variables
  y = tf.pad(tf.expand_dims(tf.expand_dims(x,0),-1),[[0,0],[phansalkar_r,phansalkar_r],[phansalkar_r,phansalkar_r],[phansalkar_r,phansalkar_r],[0,0]],mode='REFLECT')
  filter = tf.ones((phansalkar_size,phansalkar_size,phansalkar_size,1,1))
  divide = tf.math.reduce_sum(filter)
  
  # calculate the local mean of each pixel (using a convolution)
  mean = tf.nn.conv3d(y, filter, [1, 1, 1, 1, 1], padding='VALID')/divide
  
  # calculate the local standard deviation of each pixel
  sd = tf.sqrt(tf.square(tf.nn.conv3d(y-tf.pad(mean,[[0,0],[phansalkar_r,phansalkar_r],[phansalkar_r,phansalkar_r],[phansalkar_r,phansalkar_r],[0,0]],mode='REFLECT'), filter, [1, 1, 1, 1, 1], padding='VALID'))/divide)
  
  del y
  del phansalkar_r

  # calculate the mean of means, sd of means and the sd max to find parameters q,k,r
  mean_means = tf.math.reduce_sum(mean)/tf.cast(tf.size(mean),tf.float32)
  sd_means = tf.sqrt(tf.math.reduce_sum(tf.square(mean-mean_means))/tf.cast(tf.size(mean),tf.float32))
  sd_max = tf.math.reduce_max(sd)

  # find q,r,k
  q = tf.math.log(2.)/(2*sd_means)
  r = sd_max/(1-tf.exp(-q*sd_means))
  k = p*(tf.exp(-q*(mean_means+sd_means))+(tf.exp(-q*(mean_means+(2*sd_means)))/(1-(sd_max/r))))/2

  del mean_means
  del sd_means
  del sd_max

  # return the threshold where
  x = (tf.expand_dims(tf.expand_dims(x,0),-1) - ( mean * (1 + p * tf.exp(-q * mean) + k * ((sd / r) - 1))))[0,:,:,:,0]

  del mean
  del sd

  # return the segmentation from phansalkar
  seed = np.zeros_like(x)
  seed[x < 0] = 0
  seed[x >= 0] = 1

  del x

  plt.imshow(seed[:,:,70])
  plt.show()

    # perform eroision and dilation
    # for background seeds
    # erosion
  bg_seed = tf.where(distance_transform_edt(seed)<bg_e_r,tf.zeros_like(seed),tf.ones_like(seed))
    # dilation
  bg_seed = 1. - tf.where(distance_transform_edt(1. - bg_seed)<bg_d_r,tf.zeros_like(seed),tf.ones_like(seed))
  
  
    # for foreground seeds
    # dilate
  fg_seed = 1. - tf.where(distance_transform_edt(1. - seed)<fg_d_r,tf.zeros_like(seed),tf.ones_like(seed))
    # erosion
  fg_seed = tf.where(distance_transform_edt(fg_seed)<fg_e_r,tf.zeros_like(seed),tf.ones_like(seed))

  seed = np.zeros_like(seed) + 0.5
  seed[bg_seed == 0] = 0
  seed[fg_seed == 1] = 1

  plt.imshow(seed[:,:,70])
  plt.show()

  return seed




# Example of Synthetic Data Creation

In [31]:
def make_cell(scale_factor):
  # make a protrusion by resizing a sphere
  a = np.pad(ball((scale_factor*50)-1)[0::4],((scale_factor*37,scale_factor*38),(1,0),(1,0)))

  protrusion = np.zeros((100*scale_factor,100*scale_factor,100*scale_factor))

  for i in range(25*scale_factor):
    
    protrusion[:,:,37*scale_factor+i] = a[:,:,4*i]

  del a

  # make the gt image with a sphere and a protrusion
  gt = (np.logical_or(np.pad(ball(30*scale_factor), ((20*scale_factor, 20*scale_factor-1), (20*scale_factor, 20*scale_factor-1), (20*scale_factor, 20*scale_factor-1))), protrusion)).astype(np.int8)

  del protrusion

  plt.subplot(1,3,1)
  plt.imshow(gt[50*scale_factor,:,:])

  plt.subplot(1,3,2)
  plt.imshow(gt[:,50*scale_factor,:])

  plt.subplot(1,3,3)
  plt.imshow(gt[:,:,50*scale_factor])
  plt.show()

  # create the synthetic image
  # blur the image
  image = nd.gaussian_filter(gt,1)

  # rescale
  image = image*500

  # add poisson noise
  image += np.random.poisson(1,size=image.shape)

  # normalise
  image = image/np.max(image)

  plt.subplot(1,3,1)
  plt.imshow(image[50*scale_factor,:,:])

  plt.subplot(1,3,2)
  plt.imshow(image[:,50*scale_factor,:])

  plt.subplot(1,3,3)
  plt.imshow(image[:,:,50*scale_factor])
  plt.show()

  return gt, image