In [12]:
import tensorflow as tf
import numpy as np
from collections.abc import Iterable

In [22]:
def resampling_layer(resampling_type,
            filter_channels, 
            kernel_size, 
            resampling_factor = 2,
            name=None,
            initializer_mean = 0.0,
            initializer_std = 0.02,
            apply_batchnorm = True,
            apply_dropout = False,
            dropout_rate = 0.5):
    
    """ Spatial resampling 2D convolutional layer
    
    # Input parameters:
    resampling_type:    'downsample' (convolution) or 'upsample' (transpose convolution)
    filter_channels:    Number of filters / "depth" of output
                        For images, this corresponds to number of color / wavelength channels
    kernel_size:        Spatial size of convolutional kernel
                        For images, if kernel_size = 3, each filter processes a 3x3 pixel neighborhood
                        
    # Notes
    - Based on TF example pix2pix: https://www.tensorflow.org/tutorials/generative/pix2pix
    """
    # Validate resampling layer type
    if resampling_type not in ['downsample','upsample']:
        raise ValueError(f"{resampling_type} is not a valid resampling type.")
    
    # Create kernel initializer for normally distributed random numbers
    initializer = tf.random_normal_initializer(
        mean=initializer_mean, stddev=initializer_std)                    
    
    # Initialize as sequential (stack of layers)
    resamp_layer = tf.keras.Sequential(name=name)
    
    # Add 2D convolutional layer
    if resampling_type == 'downsample':
        resamp_layer.add(
            tf.keras.layers.Conv2D(
                filter_channels, 
                kernel_size, 
                strides=resampling_factor, 
                padding='same',                          
                kernel_initializer=initializer, 
                use_bias=not(apply_batchnorm)))   
    else:
        resamp_layer.add(
            tf.keras.layers.Conv2DTranspose(
                filter_channels, 
                kernel_size, 
                strides=resampling_factor, 
                padding='same',
                kernel_initializer=initializer,
                use_bias=not(apply_batchnorm)))

    # Add (optional) batch normalization layer
    if apply_batchnorm:
        resamp_layer.add(tf.keras.layers.BatchNormalization())                

    # Add (optional) dropout layer
    if apply_dropout:
        resamp_layer.add(tf.keras.layers.Dropout(dropout_rate))
    
    # Add activation layer
    if resampling_type == 'downsample':
        resamp_layer.add(tf.keras.layers.LeakyReLU()) 
    else:
        resamp_layer.add(tf.keras.layers.ReLU()) 

    return resamp_layer

In [45]:
def unet(input_channels, output_channels, first_layer_channels, depth, 
         model_name=None, flip_aug=True, trans_aug=False, 
         apply_batchnorm = True, apply_dropout = False):
    """ Simple encoder-decoder U-Net architecture
    
    # Arguments:
    input_channels:         Number of channels in input image
    output_channels:        Number of classes (including background) to segment between
    first_layer_channels:   Number of channels in first downsampling layer
                            Each consecutive downsampling layer doubles the number of channels
                            In upsampling, each layer halves the number of channels
    
    # Keyword arguments:
    model_name:               Name of model
    flip_aug:           If true, a RandomFlip augmentation layer is included
                        before the first downsampling layer
    trans_aug:          If true, a RandomTranslation augmentation layer with 
                        height and width factor of 20% is included
                        before the first downsampling layer
    apply_batchnorm:    If (boolean) scalar, indicate whether to use batch normalization
                        in all downsampling / upsampling layers
                        If tuple of booleans (length equal to total number of 
                        downsampling / upsampling layers), indicate use of batch noarmalization
                        for each layer
    apply_dropout:      If (boolean) scalar, indicate whether to use dropout (rate 0.5)
                        in all downsampling / upsampling layers.
                        If tuple of booleans (length equal to total number of 
                        downsampling / upsampling layers), indicate use of dropout
                        for each layer
                        
    # Outputs:
    model:              Keras U-Net model
    
    # Notes:
    - Based on TF tutorial: https://www.tensorflow.org/tutorials/images/segmentation

    """
    resamp_kernel_size = 4
    
    # Create vectors for batchnorm / dropout booleans if scalar
    if not isinstance(apply_batchnorm,Iterable):
        apply_batchnorm = [apply_batchnorm for _ in range(depth*2)]

    if not isinstance(apply_dropout,Iterable):
        apply_dropout = [apply_dropout for _ in range(depth*2)]
        
    
    # Define input
    inputs = tf.keras.layers.Input(shape=[None, None, input_channels],name='input_image')   # Using None to signal variable image width and height (Ny,Nx,3)
    x = inputs    # x used as temparary variable for data flowing between layers
        
    # Add augmentation layer(s)
    if flip_aug or trans_aug:
        aug_layer = tf.keras.Sequential(name='augmentation')
        if flip_aug:
            aug_layer.add(tf.keras.layers.RandomFlip())
        if trans_aug:
            aug_layer.add(tf.keras.layers.RandomTranslation(height_factor=0.2,width_factor=0.2))
        x = aug_layer(x)

    # Add initial convolution layer with same resolution as input image
    x = tf.keras.layers.Conv2D(first_layer_channels,kernel_size=3,padding='same', name = 'initial_convolution')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    
    # Define downsampling layers
    down_stack = []
    nchannels_downsamp = [first_layer_channels*(2**(i+1)) for i in range(depth)]
    names_downsamp = [f'downsamp_factor_{(i+1)**2}' for i in range(depth)]   
    for channels, name, batchnorm, dropout in zip(nchannels_downsamp,names_downsamp,apply_batchnorm[0:depth],apply_dropout[0:depth]):
        down_stack.append(resampling_layer('downsample',
                                           channels,
                                           resamp_kernel_size,
                                           name = name,
                                           apply_batchnorm=batchnorm,
                                           apply_dropout=dropout))

    # Define upsampling layers
    up_stack = []
    nchannels_upsamp = [first_layer_channels*(2**i) for i in range(depth,-1)]
    names_upsamp = [f'upsamp_factor_{i**2}' for i in range(depth-1,-1)]   
    for channels, name, batchnorm, dropout in zip(nchannels_upsamp,names_upsamp,apply_batchnorm[depth:], apply_dropout[depth:]):
        up_stack.append(resampling_layer('upsample',
                                         channels,
                                         resamp_kernel_size,
                                         name = name,
                                         apply_batchnorm=batchnorm,
                                         apply_dropout=dropout))    

    # Downsampling through the model
    skips = [x]                   # Add output from first layer (before downsampling) to skips list
    for down in down_stack:
        x = down(x)               # Run input x through layer, then set x equal to output
        skips.append(x)           # Add layer output to skips list

    skips = reversed(skips[:-1])  # Reverse list, and don't include skip for last layer ("bottom of U") 

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)                                     # Run input x through layer, then set x to output
        x = tf.keras.layers.Concatenate()([x, skip])  # Stack layer output together with skip connection (downsampling layer output with same resolution)
    
    # Final layer
    last = tf.keras.layers.Conv2D(output_channels, 
                                  filters = 3,
                                  padding='same',
                                  activation='softmax',
                                  name='classification')    
    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x,name=model_name)

In [73]:
input_channels = 10
output_channels = 4
first_layer_channels = 64
depth = 2
model_name='my_unet' 
flip_aug=True 
trans_aug=False 
apply_batchnorm = True 
apply_dropout = False

resamp_kernel_size = 4

# Create vectors for batchnorm / dropout booleans if scalar
if not isinstance(apply_batchnorm,Iterable):
    apply_batchnorm = [apply_batchnorm for _ in range(depth*2)]

if not isinstance(apply_dropout,Iterable):
    apply_dropout = [apply_dropout for _ in range(depth*2)]


# Define input
inputs = tf.keras.layers.Input(shape=[None, None, input_channels],name='input_image')   # Using None to signal variable image width and height (Ny,Nx,3)
x = inputs    # x used as temparary variable for data flowing between layers

# Add augmentation layer(s)
if flip_aug or trans_aug:
    aug_layer = tf.keras.Sequential(name='augmentation')
    if flip_aug:
        aug_layer.add(tf.keras.layers.RandomFlip())
    if trans_aug:
        aug_layer.add(tf.keras.layers.RandomTranslation(height_factor=0.2,width_factor=0.2))
    x = aug_layer(x)

# Add initial convolution layer with same resolution as input image
x = tf.keras.layers.Conv2D(first_layer_channels,kernel_size=3,padding='same', name = 'initial_convolution',activation='relu')(x)


# Define downsampling layers
down_stack = []
nchannels_downsamp = [first_layer_channels*(2**(i+1)) for i in range(depth)]
names_downsamp = [f'downsamp_factor_{(2**(i+1))}' for i in range(depth)]  
for channels, name, batchnorm, dropout in zip(nchannels_downsamp,names_downsamp,apply_batchnorm[0:depth],apply_dropout[0:depth]):
    down_stack.append(resampling_layer('downsample',
                                       channels,
                                       resamp_kernel_size,
                                       name = name,
                                       apply_batchnorm=batchnorm,
                                       apply_dropout=dropout))

# Define upsampling layers
up_stack = []
nchannels_upsamp = [first_layer_channels*(2**i) for i in range(depth-1,-1,-1)]
names_upsamp = [f'upsamp_factor_{2**i}' for i in range(depth-1,-1,-1)]   
for channels, name, batchnorm, dropout in zip(nchannels_upsamp,names_upsamp,apply_batchnorm[depth:], apply_dropout[depth:]):
    up_stack.append(resampling_layer('upsample',
                                     channels,
                                     resamp_kernel_size,
                                     name = name,
                                     apply_batchnorm=batchnorm,
                                     apply_dropout=dropout))    

# Downsampling through the model
skips = [x]                   # Add output from first layer (before downsampling) to skips list
for down in down_stack:
    x = down(x)               # Run input x through layer, then set x equal to output
    skips.append(x)           # Add layer output to skips list

skips = reversed(skips[:-1])  # Reverse list, and don't include skip for last layer ("bottom of U") 

# Upsampling and establishing the skip connections
for up, skip in zip(up_stack, skips):
    x = up(x)                                     # Run input x through layer, then set x to output
    x = tf.keras.layers.Concatenate()([x, skip])  # Stack layer output together with skip connection (downsampling layer output with same resolution)

# Final layer
last = tf.keras.layers.Conv2D(output_channels, 
                              kernel_size = 3,
                              padding='same',
                              activation='softmax',
                              name='classification')    
x = last(x)

model =  tf.keras.Model(inputs=inputs, outputs=x,name=model_name)

In [68]:
nchannels_upsamp

[128, 64]

In [69]:
[f'downsamp_factor_{(2**(i+2))}' for i in range(depth)]   

['downsamp_factor_4', 'downsamp_factor_8']

In [74]:
model.summary()

Model: "my_unet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, None, None,  0           []                               
                                 10)]                                                             
                                                                                                  
 augmentation (Sequential)      (None, None, None,   0           ['input_image[0][0]']            
                                10)                                                               
                                                                                                  
 initial_convolution (Conv2D)   (None, None, None,   5824        ['augmentation[0][0]']           
                                64)                                                         