U^2 Net - https://github.com/NathanUA/U-2-Net/blob/master/model/u2net.py

Based off of implementation https://github.com/NathanUA/U-2-Net/blob/master/model/u2net.py

In [None]:
import tensorflow as tf

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    BatchNormalization,
    Conv2D,
    Conv2DTranspose,
    MaxPooling2D,
    Dropout,
    SpatialDropout2D,
    UpSampling2D,
    Input,
    concatenate,
    multiply,
    add,
    Activation,
    GlobalAveragePooling2D,
    Dense,
    Multiply,
    Input,
)
from tensorflow.keras import backend as K

In [None]:
def rebnconv_block(inputs, filters = 3, dirate = 1, afunc = 'relu', kernel_initializer="he_normal", padding="same"):
    conv = Conv2D(
        filters,
        3,
        kernel_initializer=kernel_initializer,
        padding=padding,
        dilation_rate=(dirate, dirate),
    )(inputs)
    bn = BatchNormalization()(conv)
    activation = Activation(afunc)(bn)
    return activation

In [None]:
def rsu4f(inputs, in_ch=3, mid_ch=12, out_ch=3):
    rebnconvin = rebnconv_block(inputs, out_ch)
    
    rebnconv1 = rebnconv_block(rebnconvin, mid_ch, dirate=1)
    rebnconv2 = rebnconv_block(rebnconv1, mid_ch, dirate=2)
    rebnconv3 = rebnconv_block(rebnconv2, mid_ch, dirate=4)
    
    rebnconv4 = rebnconv_block(rebnconv3, mid_ch, dirate=8)
    
    rebnconv3d = rebnconv_block(concatenate([rebnconv4, rebnconv3]), mid_ch, dirate=4)
    rebnconv2d = rebnconv_block(concatenate([rebnconv3d, rebnconv2]), mid_ch, dirate=2)
    rebnconv1d = rebnconv_block(concatenate([rebnconv2d, rebnconv1]), mid_ch, dirate=1)
    
    return rebnconv1d

In [None]:
def rsu4(inputs, in_ch=3, mid_ch=12, out_ch=3):
    rebnconvin = rebnconv_block(inputs, out_ch)
    
    rebnconv1 = rebnconv_block(rebnconvin, mid_ch, dirate=1)
    pool1 = MaxPooling2D((2, 2))(rebnconv1)
    rebnconv2 = rebnconv_block(pool1, mid_ch, dirate=1)
    pool2 = MaxPooling2D((2, 2))(rebnconv2)
    rebnconv3 = rebnconv_block(pool2, mid_ch, dirate=1)
    
    rebnconv4 = rebnconv_block(rebnconv3, mid_ch, dirate=2)
    
    rebnconv3d = rebnconv_block(concatenate([rebnconv4, rebnconv3]), mid_ch, dirate=1)
    rebnconv3dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv3d)
    rebnconv2d = rebnconv_block(concatenate([rebnconv3dup, rebnconv2]), mid_ch, dirate=1)
    rebnconv2dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv2d)
    rebnconv1d = rebnconv_block(concatenate([rebnconv2dup, rebnconv1]), mid_ch, dirate=1)
    
    return rebnconv1d

In [None]:
def rsu5(inputs, in_ch=3, mid_ch=12, out_ch=3):
    rebnconvin = rebnconv_block(inputs, out_ch)
    
    rebnconv1 = rebnconv_block(rebnconvin, mid_ch, dirate=1)
    pool1 = MaxPooling2D((2, 2))(rebnconv1)
    rebnconv2 = rebnconv_block(pool1, mid_ch, dirate=1)
    pool2 = MaxPooling2D((2, 2))(rebnconv2)
    rebnconv3 = rebnconv_block(pool2, mid_ch, dirate=1)
    pool3 = MaxPooling2D((2, 2))(rebnconv3)
    rebnconv4 = rebnconv_block(pool3, mid_ch, dirate=1)
    
    rebnconv5 = rebnconv_block(rebnconv4, mid_ch, dirate=2)
    
    rebnconv4d = rebnconv_block(concatenate([rebnconv5, rebnconv4]), mid_ch, dirate=1)
    rebnconv4dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv4d)
    rebnconv3d = rebnconv_block(concatenate([rebnconv4dup, rebnconv3]), mid_ch, dirate=1)
    rebnconv3dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv3d)
    rebnconv2d = rebnconv_block(concatenate([rebnconv3dup, rebnconv2]), mid_ch, dirate=1)
    rebnconv2dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv2d)
    rebnconv1d = rebnconv_block(concatenate([rebnconv2dup, rebnconv1]), mid_ch, dirate=1)
    
    return rebnconv1d

In [None]:
def rsu6(inputs, in_ch=3, mid_ch=12, out_ch=3):
    rebnconvin = rebnconv_block(inputs, out_ch)
    
    rebnconv1 = rebnconv_block(rebnconvin, mid_ch, dirate=1)
    pool1 = MaxPooling2D((2, 2))(rebnconv1)
    rebnconv2 = rebnconv_block(pool1, mid_ch, dirate=1)
    pool2 = MaxPooling2D((2, 2))(rebnconv2)
    rebnconv3 = rebnconv_block(pool2, mid_ch, dirate=1)
    pool3 = MaxPooling2D((2, 2))(rebnconv3)
    rebnconv4 = rebnconv_block(pool3, mid_ch, dirate=1)
    pool4 = MaxPooling2D((2, 2))(rebnconv4)
    rebnconv5 = rebnconv_block(pool4, mid_ch, dirate=1)
    
    rebnconv6 = rebnconv_block(rebnconv5, mid_ch, dirate=2)
    
    rebnconv5d = rebnconv_block(concatenate([rebnconv6, rebnconv5]), mid_ch, dirate=1)
    rebnconv5dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv5d)
    rebnconv4d = rebnconv_block(concatenate([rebnconv5dup, rebnconv4]), mid_ch, dirate=1)
    rebnconv4dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv4d)
    rebnconv3d = rebnconv_block(concatenate([rebnconv4dup, rebnconv3]), mid_ch, dirate=1)
    rebnconv3dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv3d)
    rebnconv2d = rebnconv_block(concatenate([rebnconv3dup, rebnconv2]), mid_ch, dirate=1)
    rebnconv2dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv2d)
    rebnconv1d = rebnconv_block(concatenate([rebnconv2dup, rebnconv1]), mid_ch, dirate=1)
    
    return rebnconv1d

In [None]:
def rsu7(inputs, in_ch=3, mid_ch=12, out_ch=3):
    rebnconvin = rebnconv_block(inputs, out_ch)
    
    rebnconv1 = rebnconv_block(rebnconvin, mid_ch, dirate=1)
    pool1 = MaxPooling2D((2, 2))(rebnconv1)
    rebnconv2 = rebnconv_block(pool1, mid_ch, dirate=1)
    pool2 = MaxPooling2D((2, 2))(rebnconv2)
    rebnconv3 = rebnconv_block(pool2, mid_ch, dirate=1)
    pool3 = MaxPooling2D((2, 2))(rebnconv3)
    rebnconv4 = rebnconv_block(pool3, mid_ch, dirate=1)
    pool4 = MaxPooling2D((2, 2))(rebnconv4)
    rebnconv5 = rebnconv_block(pool4, mid_ch, dirate=1)
    pool5 = MaxPooling2D((2, 2))(rebnconv5)
    rebnconv6 = rebnconv_block(pool5, mid_ch, dirate=1)
    
    rebnconv7 = rebnconv_block(rebnconv6, mid_ch, dirate=2)
    
    rebnconv6d = rebnconv_block(concatenate([rebnconv7, rebnconv6]), mid_ch, dirate=1)
    rebnconv6dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv6d)
    rebnconv5d = rebnconv_block(concatenate([rebnconv6dup, rebnconv5]), mid_ch, dirate=1)
    rebnconv5dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv5d)
    rebnconv4d = rebnconv_block(concatenate([rebnconv5dup, rebnconv4]), mid_ch, dirate=1)
    rebnconv4dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv4d)
    rebnconv3d = rebnconv_block(concatenate([rebnconv4dup, rebnconv3]), mid_ch, dirate=1)
    rebnconv3dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv3d)
    rebnconv2d = rebnconv_block(concatenate([rebnconv3dup, rebnconv2]), mid_ch, dirate=1)
    rebnconv2dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv2d)
    rebnconv1d = rebnconv_block(concatenate([rebnconv2dup, rebnconv1]), mid_ch, dirate=1)
    
    return rebnconv1d

In [None]:
min = Input((128, 128, 1))
block = rsu7(min)

model = Model(inputs=[min], outputs=[block])

model.summary()