# Kymonet notebook
This notebook implements the training of a UNET on 29 labelled kymograph images. The input images do not have the same size and are therfeore padded to match the network input size. 

## Batch generator for padding the images

In [1]:
# import librairies
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import img_to_array,load_img
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2,preprocess_input
import matplotlib.pyplot as plt
import numpy as np
import numpy as np
from cv2 import resize
from os import path, listdir

In [2]:
#data generator class; yields batches of data for training/testing
class ImageGenerator():

    def __init__(self, raw_directory, binary_directory, batch_size=16, shuffle=False, max_dimension=None):        
        
        self.directories = raw_directory
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.max_dimension = max_dimension
        
        self.image_paths = []
        self.class_labels = []
        
        
        #create list of image file paths and class target labels
        for id, imgpath in enumerate(listdir(raw_directory)):
            
            self.image_paths.append(path.join(raw_directory,imgpath))
            
        self.image_paths = np.array(self.image_paths)

        # todo get the masks and store them in class labels
        #create list of target labels
        for id, imgpath in enumerate(listdir(binary_directory)):
           
            self.class_labels.append(path.join(binary_directory,imgpath))
            
        self.class_labels= np.array(self.class_labels)
        
        #index array for shuffling data
        self.idx = np.arange(len(self.image_paths))
        
    
    def __len__(self):
        
        #number of batches in an epoch
        return int(np.ceil(len(self.image_paths)/float(self.batch_size)))
    
    
    def _load_image(self,img_path):
        
        #load image from path and convert to array
        img = load_img(img_path, color_mode='rgb', interpolation='nearest')
      
        img = img_to_array(img)
        
        #downsample image if above allowed size if specified
        max_dim = max(img.shape) 
        if self.max_dimension:
            if max_dim > self.max_dimension:
                new_dim = tuple(d*self.max_dimension//max_dim for d in img.shape[1::-1])
                img = resize(img, new_dim)
            
        #scale image values
        img = preprocess_input(img)

        return img
    
    
    def _pad_images(self,img,shape):
        #pad images to match largest image in batch
        img = np.pad(img,(*[((shape[i]-img.shape[i])//2,
                    ((shape[i]-img.shape[i])//2) + ((shape[i]-img.shape[i])%2)) for i in range(2)],
                          (0,0)),mode='constant',constant_values=0.)
        
        return img


    def __call__(self):
        #shuffle index
        if self.shuffle:
            np.random.shuffle(self.idx)
        
        #generate batches
        for batch in range(len(self)):

            batch_image_paths = self.image_paths[self.idx[batch*self.batch_size:(batch+1)*self.batch_size]]
            batch_class_labels = self.class_labels[self.idx[batch*self.batch_size:(batch+1)*self.batch_size]]

            batch_images = [self._load_image(image_path) for image_path in batch_image_paths]
            
            max_resolution = tuple(max([img.shape[i] for img in batch_images]) for i in range(2))
            batch_images = np.array([self._pad_images(image,max_resolution) for image in batch_images])
            print(batch_class_labels)
            yield [tf.convert_to_tensor(batch_images), tf.convert_to_tensor(batch_class_labels)]

## Build of the UNET network

In [3]:
def double_conv_block(x, n_filters):

    # Conv2D then ReLU activation
    x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
    # Conv2D then ReLU activation
    x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)

    return x

def downsample_block(x, n_filters):
    f = double_conv_block(x, n_filters)
    p = layers.MaxPool2D(2)(f)
    p = layers.Dropout(0.3)(p)

    return f, p

def upsample_block(x, conv_features, n_filters):
    # upsample
    x = layers.Conv2DTranspose(n_filters, 3, strides=2, padding="same")(x)
    
    # Crop conv_features to match the shape of x
    target_height, target_width = x.shape[1], x.shape[2]
    conv_features = layers.Cropping2D(cropping=((0, conv_features.shape[1] - target_height),
                                                 (0, conv_features.shape[2] - target_width)))(conv_features)
    
    # concatenate
    x = layers.concatenate([x, conv_features])
    
    # dropout
    x = layers.Dropout(0.3)(x)
    
    # Conv2D twice with ReLU activation
    x = double_conv_block(x, n_filters)
    return x



In [4]:
def build_unet_model():

    # inputs
    inputs = layers.Input(shape=(700,700,1))

    # encoder: contracting path - downsample
    # 1 - downsample
    f1, p1 = downsample_block(inputs, 64)
    # 2 - downsample
    f2, p2 = downsample_block(p1, 128)
    # 3 - downsample
    f3, p3 = downsample_block(p2, 256)
    # 4 - downsample
    f4, p4 = downsample_block(p3, 512)

    # 5 - bottleneck
    bottleneck = double_conv_block(p4, 1024)

    # decoder: expanding path - upsample
    # 6 - upsample
    u6 = upsample_block(bottleneck, f4, 512)
    # 7 - upsample
    u7 = upsample_block(u6, f3, 256)
    # 8 - upsample
    u8 = upsample_block(u7, f2, 128)
    # 9 - upsample
    u9 = upsample_block(u8, f1, 64)

    # outputs
    outputs = layers.Conv2D(3, 1, padding="same", activation = "softmax")(u9)

    # unet model with Keras Functional API
    unet_model = tf.keras.Model(inputs, outputs, name="U-Net")

    return unet_model

In [5]:
unet_model = build_unet_model()

2023-11-22 13:29:33.757979: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2023-11-22 13:29:33.758005: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2023-11-22 13:29:33.758009: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2023-11-22 13:29:33.758254: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-11-22 13:29:33.758277: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [6]:
unet_model.summary()

Model: "U-Net"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 700, 700, 1)]        0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 700, 700, 64)         640       ['input_1[0][0]']             
                                                                                                  
 conv2d_1 (Conv2D)           (None, 700, 700, 64)         36928     ['conv2d[0][0]']              
                                                                                                  
 max_pooling2d (MaxPooling2  (None, 350, 350, 64)         0         ['conv2d_1[0][0]']            
 D)                                                                                           

In [7]:
unet_model.compile(optimizer=tf.keras.optimizers.legacy.Adam(),
                   loss="sparse_categorical_crossentropy",
                   metrics="accuracy")

# Training of the model

In [8]:
def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ["Input Image", "True Mask", "Predicted Mask"]

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis("off")
  plt.show()

In [9]:
#initialize our generators; specifying data directories, batch size, and dimension threshold

train_image_directory = 'Raw kymos/kymographs2/raw images'
test_image_directory = 'Raw kymos/kymographs2/normal/binary_masks'

n_classes = 10
batch_size = 16
max_dimension = 700

#create generators for training and generating

train_generator = ImageGenerator(train_image_directory,test_image_directory, batch_size=batch_size, shuffle=False, max_dimension=max_dimension)
#test_generator = ImageGenerator(test_image_directory, batch_size=batch_size, max_dimension=max_dimension)

#convert generators into tf.data.Dataset objects for optimization with keras model fit method

train_dataset = tf.data.Dataset.from_generator(train_generator,(tf.float32, tf.float32),(tf.TensorShape([None, 700, 700, 1]), tf.TensorShape([None, 700, 700, 1])))
print(train_dataset)
#test_dataset = tf.data.Dataset.from_generator(test_generator,(tf.float32, tf.int32),(tf.TensorShape([None, 700, 700, 1]), tf.TensorShape([None])))


<_FlatMapDataset element_spec=(TensorSpec(shape=(None, 700, 700, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 700, 700, 1), dtype=tf.float32, name=None))>


In [10]:
sample_batch = next(iter(train_dataset))
random_index = np.random.choice(sample_batch[0].shape[0])
sample_image, sample_mask = sample_batch[0][random_index], sample_batch[1][random_index]
display([sample_image, sample_mask])

['Raw kymos/kymographs2/normal/binary_masks/WT_C_15_denoised_cropped_kymo.tiff'
 'Raw kymos/kymographs2/normal/binary_masks/WT_C_08_raw_cropped_kymo.tiff'
 'Raw kymos/kymographs2/normal/binary_masks/WT_C_03_denoised_cropped_kymo.tiff'
 'Raw kymos/kymographs2/normal/binary_masks/.DS_Store'
 'Raw kymos/kymographs2/normal/binary_masks/WT_C_14_raw_cropped_kymo.tiff'
 'Raw kymos/kymographs2/normal/binary_masks/WT_C_08_denoised_cropped_kymo.tiff'
 'Raw kymos/kymographs2/normal/binary_masks/WT_C_02_raw_cropped_kymo.tiff'
 'Raw kymos/kymographs2/normal/binary_masks/WT_C_04_denoised_cropped_kymo.tiff'
 'Raw kymos/kymographs2/normal/binary_masks/WT_C_12_denoised_cropped_kymo.tiff'
 'Raw kymos/kymographs2/normal/binary_masks/WT_C_12_raw_cropped_kymo.tiff'
 'Raw kymos/kymographs2/normal/binary_masks/WT_C_04_raw_cropped_kymo.tiff'
 'Raw kymos/kymographs2/normal/binary_masks/WT_C_09_denoised_cropped_kymo.tiff'
 'Raw kymos/kymographs2/normal/binary_masks/WT_C_05_denoised_cropped_kymo.tiff'
 'Raw kymo

2023-11-22 13:29:34.331126: W tensorflow/core/framework/op_kernel.cc:1827] INVALID_ARGUMENT: TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.float32), but the yielded element was [<tf.Tensor: shape=(16, 299, 700, 3), dtype=float32, numpy=
array([[[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
     

InvalidArgumentError: {{function_node __wrapped__IteratorGetNext_output_types_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.float32), but the yielded element was [<tf.Tensor: shape=(16, 299, 700, 3), dtype=float32, numpy=
array([[[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       [[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       [[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       ...,


       [[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       [[[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       [[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]]], dtype=float32)>, <tf.Tensor: shape=(16,), dtype=string, numpy=
array([b'Raw kymos/kymographs2/normal/binary_masks/WT_C_15_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_08_raw_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_03_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/.DS_Store',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_14_raw_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_08_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_02_raw_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_04_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_12_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_12_raw_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_04_raw_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_09_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_05_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_13_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_10_raw_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_06_raw_cropped_kymo.tiff'],
      dtype=object)>].
Traceback (most recent call last):

  File "/Users/quillan/anaconda3/envs/lab/lib/python3.11/site-packages/tensorflow/python/data/ops/from_generator_op.py", line 204, in generator_py_func
    flattened_values = nest.flatten_up_to(output_types, values)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/Users/quillan/anaconda3/envs/lab/lib/python3.11/site-packages/tensorflow/python/data/util/nest.py", line 237, in flatten_up_to
    return nest_util.flatten_up_to(
           ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/Users/quillan/anaconda3/envs/lab/lib/python3.11/site-packages/tensorflow/python/util/nest_util.py", line 1644, in flatten_up_to
    return _tf_data_flatten_up_to(shallow_tree, input_tree)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/Users/quillan/anaconda3/envs/lab/lib/python3.11/site-packages/tensorflow/python/util/nest_util.py", line 1673, in _tf_data_flatten_up_to
    _tf_data_assert_shallow_structure(shallow_tree, input_tree)

  File "/Users/quillan/anaconda3/envs/lab/lib/python3.11/site-packages/tensorflow/python/util/nest_util.py", line 1517, in _tf_data_assert_shallow_structure
    raise TypeError(

TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: 'list'.


The above exception was the direct cause of the following exception:


Traceback (most recent call last):

  File "/Users/quillan/anaconda3/envs/lab/lib/python3.11/site-packages/tensorflow/python/ops/script_ops.py", line 270, in __call__
    ret = func(*args)
          ^^^^^^^^^^^

  File "/Users/quillan/anaconda3/envs/lab/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "/Users/quillan/anaconda3/envs/lab/lib/python3.11/site-packages/tensorflow/python/data/ops/from_generator_op.py", line 206, in generator_py_func
    raise TypeError(

TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.float32), but the yielded element was [<tf.Tensor: shape=(16, 299, 700, 3), dtype=float32, numpy=
array([[[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       [[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       [[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       ...,


       [[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       [[[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       [[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        ...,

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]]], dtype=float32)>, <tf.Tensor: shape=(16,), dtype=string, numpy=
array([b'Raw kymos/kymographs2/normal/binary_masks/WT_C_15_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_08_raw_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_03_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/.DS_Store',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_14_raw_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_08_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_02_raw_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_04_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_12_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_12_raw_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_04_raw_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_09_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_05_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_13_denoised_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_10_raw_cropped_kymo.tiff',
       b'Raw kymos/kymographs2/normal/binary_masks/WT_C_06_raw_cropped_kymo.tiff'],
      dtype=object)>].


	 [[{{node PyFunc}}]] [Op:IteratorGetNext] name: 

In [None]:
#train and evaluate model
model_history = unet_model.fit(train_dataset,epochs=10,verbose=1,workers=2,max_queue_size=20)

# References
- https://colab.research.google.com/github/margaretmz/image-segmentation/blob/main/unet_pet_segmentation.ipynb#scrollTo=_L_TF4djF8FY
- https://medium.com/mindboard/image-classification-with-variable-input-resolution-in-keras-cbfbe576126f