<a href="https://colab.research.google.com/github/lisatwyw/unet_variants/blob/main/tf_U_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [296]:
import tensorflow as tf
import numpy as np

from tensorflow.keras.layers import concatenate,add,Dense,BatchNormalization,Concatenate,Input,Dropout,Maximum,Activation,Dense,Flatten,Add,Multiply,Lambda
from tensorflow.keras.layers import MaxPooling3D,Conv3D,UpSampling3D,Conv3DTranspose
from tensorflow.keras.layers import MaxPooling2D,Conv2D,UpSampling2D,Conv2DTranspose
from tensorflow.keras.models import Model

# trivial test


o = Conv2D( 8, (1,1) )( Input( (224, 224, 3) ) )

# Shared components

In [297]:
def stack_bn_act( x, NF, KS, strides=(1,1) ):
  for i, k in enumerate( KS ):
    x = Conv2D( NF, k, padding='same', strides=strides )(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
  return x

def conv_block( x, NF, KS, strides=(1,1) ):

  o = Conv2D( NF, kernel_size= KS[0], strides=strides, padding='same' )(x)
  o = BatchNormalization()(o)
  o = Activation( 'relu' )(o)

  o = stack_bn_act(o, NF, KS[1:])
  return o

def deconv_block( x, NF, ks=2, strides=(1,1) ):
  o = Conv2DTranspose( NF, ks, strides=strides, padding='same' )(x)
  o = BatchNormalization()(o)
  o = Activation( 'relu' )(o)
  return o

# Admin variables and hyperparameters

In [298]:
AC = 'sigmoid'
BS = 8
NF = 16
NDIM=2
ks =(3)

ks_s=[ks]*len(nfilters)

vol = np.zeros( (65,224*3,224*3) )

IR=3
NZ,NX,NY = vol[:,::IR,::IR].shape

n_slices = 3
inp = Input( (NX, NY, n_slices) )

# Basic U-Net

In [299]:


o1 = conv_block( inp, NF,    ks_s ); p1 = MaxPooling2D()( o1 )
o2 = conv_block( p1,  NF*2,  ks_s ); p2 = MaxPooling2D()( o2 )
o3 = conv_block( p2,  NF*4,  ks_s ); p3 = MaxPooling2D()( o3 )
o4 = conv_block( p3,  NF*8,  ks_s ); p4 = MaxPooling2D()( o4 )
o5 = conv_block( p4,  NF*16, ks_s ); p5 = MaxPooling2D()( o5 )

o6 = conv_block( Concatenate()( [deconv_block( o5, NF*8, strides=(2,2) ), o4]), NF*8, ks_s )
o7 = conv_block( Concatenate()( [deconv_block( o6, NF*4, strides=(2,2) ), o3]), NF*4, ks_s )
o8 = conv_block( Concatenate()( [deconv_block( o7, NF*2, strides=(2,2) ), o2]), NF*2, ks_s )
o9 = conv_block( Concatenate()( [deconv_block( o8, NF*1, strides=(2,2) ), o1]), NF*1, ks_s )

out = Activation( AC )( Conv2D( 2, 1 )( o9 ) )
model = Model(inp, outputs=out )

In [300]:
model.summary()

Model: "model_35"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_114 (InputLayer)         [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_3393 (Conv2D)           (None, 224, 224, 16  448         ['input_114[0][0]']              
                                )                                                                 
                                                                                                  
 batch_normalization_3533 (Batc  (None, 224, 224, 16  64         ['conv2d_3393[0][0]']            
 hNormalization)                )                                                          

# Attention U-Net

In [301]:
def attention_block( input_block, gate, ks=(1,1) ):
  x = Conv2D( NF, ks )(input_block)
  x = BatchNormalization()(x)

  g = Conv2D( NF, ks )(gate)
  g = BatchNormalization()(g)

  att_map = Add()( [g, x] )
  att_map = Activation( 'relu' )(att_map)

  att_map = Conv2D( 1, ks )(att_map)
  att_map = Activation( 'sigmoid')(att_map)
  x = Multiply()( [input_block, att_map ] )

  return x

o1 = conv_block( inp, NF,    ks_s ); p1 = MaxPooling2D()( o1 )
o2 = conv_block( p1,  NF*2,  ks_s ); p2 = MaxPooling2D()( o2 )
o3 = conv_block( p2,  NF*4,  ks_s ); p3 = MaxPooling2D()( o3 )
o4 = conv_block( p3,  NF*8,  ks_s ); p4 = MaxPooling2D()( o4 )
o5 = conv_block( p4,  NF*16, ks_s ); p5 = MaxPooling2D()( o5 )

c6=attention_block( deconv_block( o5, NF*8, strides=(2,2) ), o4)
c7=attention_block( deconv_block( o6, NF*4, strides=(2,2) ), o3)
c8=attention_block( deconv_block( o7, NF*2, strides=(2,2) ), o2)
c9=attention_block( deconv_block( o8, NF*1, strides=(2,2) ), o1)

o6 = conv_block( Concatenate()( [c6, o4] ), NF*8, ks_s )
o7 = conv_block( Concatenate()( [c7, o3] ), NF*4, ks_s )
o8 = conv_block( Concatenate()( [c8, o2] ), NF*2, ks_s )
o9 = conv_block( Concatenate()( [c9, o1] ), NF*1, ks_s )

out = Activation( AC )( Conv2D( 3, 1 )( o9 ) )
att_model = Model(inp, outputs=out )


In [302]:
att_model.summary()

Model: "model_36"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_114 (InputLayer)         [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_3393 (Conv2D)           (None, 224, 224, 16  448         ['input_114[0][0]']              
                                )                                                                 
                                                                                                  
 batch_normalization_3533 (Batc  (None, 224, 224, 16  64         ['conv2d_3393[0][0]']            
 hNormalization)                )                                                          

# [Not finished] W-Net

In [303]:
def wnet_encoder( inp, NB=5, DO=.2 ):

  for i in range(NB):
    if i==0:
      x = Conv2D( NF, ks )(inp)
    else:
      x = Conv2D( NF, ks )(x)
    x = BatchNormalization()(x)
    x = Conv2D( NF, ks )(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D()(x)

    if i== 2:
      x = Dropout(DO)(x)
  x = Conv2D( NF, ks )(x)

  return x


o= wnet_encoder( inp, )
out = Activation( AC )( Conv2D( 3, ks )( o9 ) )
wnet_model = Model( inp, out)

wnet_model.summary()

Model: "model_37"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_114 (InputLayer)         [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_3393 (Conv2D)           (None, 224, 224, 16  448         ['input_114[0][0]']              
                                )                                                                 
                                                                                                  
 batch_normalization_3533 (Batc  (None, 224, 224, 16  64         ['conv2d_3393[0][0]']            
 hNormalization)                )                                                          

 # Residual U-Net

In [304]:
def residual_block( x, NF, KS, strides=(1,1) ):

  x = Conv2D( NF, KS[0], padding='same', strides=strides )(x)
  x = BatchNormalization()(x)
  x = Activation( 'relu' )(x)

  o = stack_bn_act( x, NF, KS[1:], strides=(1,1) )    # one less
  x = Add()( [o, x] )
  return x

NC=2
ks_s = [3]*NC
o1 = residual_block( inp, NF*1,  ks_s ); p1= MaxPooling2D( (2,2) )( o1 )
o2 = residual_block(  p1, NF*2,  ks_s ); p2= MaxPooling2D( (2,2) )( o2 )
o3 = residual_block(  p2, NF*4,  ks_s ); p3= MaxPooling2D( (2,2) )( o3 )
o4 = residual_block(  p3, NF*8,  ks_s ); p4= MaxPooling2D( (2,2) )( o4 )
o5 = residual_block(  p4, NF*16, ks_s );

o6 = concatenate( [deconv_block( o5, NF*8, 2, strides=(2,2) ), o4]); o6 = residual_block( o6, NF*8, ks_s );
o7 = concatenate( [deconv_block( o6, NF*4, 2, strides=(2,2) ), o3]); o7 = residual_block( o7, NF*4, ks_s );
o8 = concatenate( [deconv_block( o7, NF*2, 2, strides=(2,2) ), o2]); o8 = residual_block( o8, NF*2, ks_s );
o9 = concatenate( [deconv_block( o8, NF*1, 2, strides=(2,2) ), o1]); o9 = residual_block( o9, NF*1, ks_s );


out = Activation( AC )( Conv2D( n_slices, 1 )( o9 ) )
model = Model(inp, outputs=out )


In [305]:
model.summary()

Model: "model_38"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_114 (InputLayer)         [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_3509 (Conv2D)           (None, 224, 224, 16  448         ['input_114[0][0]']              
                                )                                                                 
                                                                                                  
 batch_normalization_3649 (Batc  (None, 224, 224, 16  64         ['conv2d_3509[0][0]']            
 hNormalization)                )                                                          

# [to be cont'd] SE U-Net

In [307]:

def SE_block( input_block ):
    x = GlobalAveragePoolingND()(input_block) # squeeze-step
    x = Dense( n_outputs//2, activation='relu')(x)
    x = Dense( n_outputs, activation='sigmoid')(x)
    x = Multiply()( [input_block, x ] )   # excite-step
    return x
