In [1]:
import numpy as np
import os
import PIL
import PIL.Image as Image
import pathlib
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
import sys

# In Weizmann dataset, there are totally 100 images. 
# Image width is fixed at 300 but image height varies from 125 to 465.
# Here we will resize all images to a fixed size 
img_width = 64
img_height = 64
batch_size = 8

data_dir = pathlib.Path('/home/cly/python_projects/image_segmentation/1obj')
filenames = list(data_dir.glob('*/src_bw/*'))
image_count = len(filenames)

# Each tuple consists of the file path of an image and its three segmentation references. 
def file_paths(f):
  segs = list(f.parents[1].glob('human_seg/*.png'))
  return (str(f), str(segs[0]), str(segs[1]), str(segs[2]))

filename_tuples = [file_paths(f) for f in filenames]

list_ds = tf.data.Dataset.from_tensor_slices(filename_tuples)

for f in list_ds.take(2):
  print(f.numpy())

# split test, validation, and training data
val_size = int(image_count * 0.3)
test_size = int(val_size * 0.5)

train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)
val_ds = val_ds.skip(test_size)
test_ds = val_ds.take(val_size - test_size)

train_ds = train_ds.shuffle(image_count, reshuffle_each_iteration=False)

print(tf.data.experimental.cardinality(train_ds).numpy())
print(tf.data.experimental.cardinality(val_ds).numpy())

def initializeState(img_height, img_width):
  center1 = int(img_height/2)
  center2 = int(img_width/2)
  r = min(img_height, img_width)/3
  r2 = r**2
  a = np.zeros((img_height, img_width, 1))
  for i in range(img_height):
    for j in range(img_width):
      a[i, j, 0] = float((i-center1)**2 + (j-center2)**2<r2)
  return a  
    
def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor of shape [height, width, channels]
  img = tf.image.decode_png(img, channels=1)
  # resize the image to the desired size
  return tf.image.resize(img, [img_height, img_width], method='nearest')

def decode_seg(seg):
  # segmentation is an RGB image with object marked with red color (255, 0, 0)
  seg = tf.image.decode_png(seg, channels=3)
  # convert the image to a mask
  mask = tf.cast(seg[:,:,0:1]==255, dtype=tf.float32)*tf.cast(seg[:,:,1:2]==0, dtype=tf.float32)*tf.cast(seg[:,:,2:]==0, dtype=tf.float32)
  return tf.image.resize(mask, [img_height, img_width], method='nearest')

def process_path(paths):
  img_path = paths[0]
  seg_path = paths[1]
  if tf.random.uniform(())>0.66:
    seg_path = paths[2]
  elif tf.random.uniform(())>0.5:
    seg_path = paths[3]
  # load the raw data from the file as a string
  img_string = tf.io.read_file(img_path)
  seg_string = tf.io.read_file(seg_path)
  img = decode_img(img_string)
  mask = decode_seg(seg_string)
  img = tf.image.convert_image_dtype(img, tf.float32)
  return img, mask

def augment(paths):
  img, mask = process_path(paths)
  # Cast and normalize the image to [0,1]
  img = tf.image.random_brightness(img, max_delta=0.5) # Random brightness
  img = tf.image.random_contrast(img, 0.2, 0.5) # Random contrast
  
  # Filp the image and the mask together
  # It seems there is no random rotation in tensorflow?
  if tf.random.uniform(()) > 0.5:
    img = tf.image.flip_up_down(img)
    mask = tf.image.flip_up_down(mask)
  if tf.random.uniform(()) > 0.5:
    img = tf.image.flip_left_right(img)
    mask = tf.image.flip_left_right(mask)
  return img, mask

AUTOTUNE = tf.data.experimental.AUTOTUNE

train_ds = train_ds.map(augment, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)

for image, mask in train_ds.take(1):
  print("Image shape: ", image.numpy().shape)
  print("Mask shape: ", mask.numpy().shape)
  print("Image", np.max(image.numpy()))
  print(np.min(image.numpy()))

def configure_for_performance(ds):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=1000)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(buffer_size=AUTOTUNE)
  return ds


train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)

# Visualize data
image_batch, mask_batch = next(iter(train_ds))

plt.figure(figsize=(10, 10))
for i in range(4):
  ax = plt.subplot(4, 2, 2*i + 1)
  plt.imshow(image_batch[i][:,:,0].numpy(), cmap='gray')
  plt.axis("off")
  mask = mask_batch[i][:,:,0].numpy().astype("uint8")
  ax = plt.subplot(4, 2, 2*i+2)
  plt.imshow(mask, cmap='gray')
  plt.axis("off")


[b'/home/cly/python_projects/image_segmentation/1obj/114591144943/src_bw/114591144943.png'
 b'/home/cly/python_projects/image_segmentation/1obj/114591144943/human_seg/114591144943_11.png'
 b'/home/cly/python_projects/image_segmentation/1obj/114591144943/human_seg/114591144943_10.png'
 b'/home/cly/python_projects/image_segmentation/1obj/114591144943/human_seg/114591144943_12.png']
[b'/home/cly/python_projects/image_segmentation/1obj/bbmf_lancaster_july_06/src_bw/bbmf_lancaster_july_06.png'
 b'/home/cly/python_projects/image_segmentation/1obj/bbmf_lancaster_july_06/human_seg/bbmf_lancaster_july_06_14.png'
 b'/home/cly/python_projects/image_segmentation/1obj/bbmf_lancaster_july_06/human_seg/bbmf_lancaster_july_06_13.png'
 b'/home/cly/python_projects/image_segmentation/1obj/bbmf_lancaster_july_06/human_seg/bbmf_lancaster_july_06_15.png']
70
15
Image shape:  (64, 64, 1)
Mask shape:  (64, 64, 1)
Image 0.7048426
0.490635


In [6]:
class fidelityTerm(keras.layers.Layer):
  def __init__(self, img_size, threshold=0.5):
    super(fidelityTerm, self).__init__()
    self.Ug = self.add_weight(shape=(img_size, img_size), initializer=tf.random_normal_initializer(mean=1), trainable=True)
    self.Wg = self.add_weight(shape=(img_size, img_size), initializer=tf.random_normal_initializer(mean=1), trainable=True)
    self.flatten = keras.layers.Flatten()
    self.threshold = threshold
    
  def call(self, inputs):
    img = inputs[0]
    state = inputs[1]
    # c1 = tf.stop_gradient(tf.divide(tf.reduce_sum(img*state, axis=1), tf.reduce_sum(state, axis=1)))
    # c2 = tf.stop_gradient(tf.divide(tf.reduce_sum(img*(1-state), axis=1), tf.reduce_sum(1-state, axis=1)))
    c1 = tf.stop_gradient(tf.divide(tf.reduce_sum(tf.cast(state>self.threshold, tf.float32)*img, axis=1), tf.reduce_sum(tf.cast(state>0.5, tf.float32), axis=1)+sys.float_info.epsilon))
    c2 = tf.stop_gradient(tf.divide(tf.reduce_sum(tf.cast(state<self.threshold, tf.float32)*img, axis=1), tf.reduce_sum(tf.cast(state<0.5, tf.float32), axis=1)+sys.float_info.epsilon))
    c1 = self.flatten(c1)
    c2 = self.flatten(c2)
    return tf.matmul(tf.math.squared_difference(img, c1), self.Ug) - tf.matmul(tf.math.squared_difference(img, c2), self.Wg)

class myModelBlock(keras.layers.Layer):
  def __init__(self, img_height, img_width):
    super(myModelBlock, self).__init__()
    self.fidelity = fidelityTerm(img_height*img_width)
    self.reshape2D = keras.layers.Reshape((img_height, img_width, 1))
    self.conv2D = keras.layers.Conv2D(1, 3, padding='same')
    self.flatten = keras.layers.Flatten()
    self.dense = keras.layers.Dense(img_height*img_width)
    self.activation = keras.layers.Activation(activation=tf.keras.activations.tanh)
  
  def call(self, inputs):
    state = inputs[1]
    x = self.fidelity(inputs)
    y = self.reshape2D(state)
    y = self.conv2D(1-2*y)
    y = self.flatten(y)
    y = self.dense(-x-y)
    return [inputs[0], 0.5+0.5*self.activation(y)]

class myModel(keras.Model):
  def __init__(self, img_height, img_width, state0):
    super(myModel, self).__init__()
    self.state0 = state0
    self.flatten = keras.layers.Flatten()
    self.block1 = myModelBlock(img_height, img_width)
    self.block2 = myModelBlock(img_height, img_width)
    self.block3 = myModelBlock(img_height, img_width)
    self.reshape = keras.layers.Reshape((img_height, img_width, 1))
    
  def call(self, inputs):
    x = self.flatten(inputs)
    x = self.block1([x, tf.reshape(state0, [1, -1])])
    x = self.block2(x)
    x = self.block3(x)[1]
    return self.reshape(x)
    
    

In [None]:
class fidelityTerm_Heaviside(keras.layers.Layer):
  def __init__(self):
    
class levelsetBlock(keras.layers.Layer):
  def __init__(self):
    super(levelsetBlock, self).__init__()
  def call(self, inputs):
    

In [7]:
state0 = initializeState(img_height, img_width)
model = myModel(img_height, img_width, state0)
optimizer = keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer, loss=keras.losses.BinaryCrossentropy())
model.fit(train_ds, validation_data=val_ds, epochs=5)

Epoch 1/5


To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fc2286a22e8>