In [None]:
import tensorflow as tf
from tensorflow import keras

In [None]:
# From hackathon4
class ResBlock(tf.Module):

    def __init__(self, filter_num, stride=1):
        super().__init__()
        self.stride = stride

        # Both self.conv1 and self.down_conv layers downsample the input when stride != 1
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.conv1 = tf.keras.layers.Conv2D(filters=filter_num,
                                            kernel_size=(3, 3),
                                            strides=stride,
                                            padding="same")
        self.bn2 = tf.keras.layers.BatchNormalization()
        self.conv2 = tf.keras.layers.Conv2D(filters=filter_num,
                                            kernel_size=(3, 3),
                                            padding="same")

        if self.stride != 1:
            self.down_conv = tf.keras.layers.Conv2D(filters=filter_num,
                                                    kernel_size=(1, 1),
                                                    strides=stride,
                                                    padding="same")
            self.down_bn = tf.keras.layers.BatchNormalization()    
    def __call__(self, x, is_training=False):
        identity = x
        if self.stride != 1:
            identity = self.down_conv(identity)
            identity = self.down_bn(identity, training=is_training)

        x = self.bn1(x, training=is_training)
        x = tf.nn.relu(x)
        x = self.conv1(x)
        
        
        x = self.bn2(x, training=is_training)
        x = tf.nn.relu(x)
        x = self.conv2(x)

        return x + identity

In [None]:
def build_stem(stem_type='B', input_height=750, input_width=1280, input_channels=3):
    VALID_TYPES = {'A','B','C'}
    if stem_type not in VALID_TYPES:
        raise ValueError("Must define type A, B, or C")
    inputs = keras.Input(shape=(input_height, input_width, input_channels))
    if stem_type == 'B':
        x = keras.layers.Conv2D(filters=32, kernel_size=3, strides=2, padding='same')(inputs)
        x = tf.nn.relu(x)
        x = keras.layers.BatchNormalization()(x)
        
        x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(x)
        x = tf.nn.relu(x)
        x = keras.layers.BatchNormalization()(x)

    stem = keras.Model(inputs, x, name="Stem "+stem_type)
    return stem
    

In [None]:
def build_backbone(stem, stage_blocks=[4,5,6,4], stage_stride=[2,2,2,2]):
    inputs = keras.Input()
    stages = []
    n_stages = len(stage_blocks)
    for stage in range(n_stages):
        if not stages:
            stage_in = inputs
        else:
            stage_in = stages[-1]
        for i in range(stage_blocks[stage]):
            if i == 0:
                block = ResBlock(2,stage_stride[stage])(stage_in)
            else:
                block = ResBlock(2,stage_stride[stage])(block)
        stages.append(block)
    backbone = keras.Model(inputs, stages, name="backbone")
    return backbone

In [None]:
def build_decoder(decoder_type='B', num_stages=4):
    VALID_TYPES = {'A','B','C'}
    if decoder_type not in VALID_TYPES:
        raise ValueError("Must define type A, B, or C")
    inputs = []
    for i in range(num_stages):
        inputs.append(keras.Input(shape=(None,None,None)))
    if decoder_type == 'B':
        filters1 = [64,128,128,256]
        filters2 = [32,64,96,96]
        layer_out = []
        for i, tensor in enumerate(inputs):
            x = keras.layers.Conv2D(filters=filters1[i], kernel_size=1, strides=1, padding='same')(tensor)
            x = tf.nn.relu(x)
            x = keras.layers.BatchNormalization()(x)
            layer_out.append(x)
        for i in range(1,len(layer_out)):
            x = keras.layers.Conv2D(filters=filters1[i-1], kernel_size=1, strides=1, padding='same')(layer_out[i])
            x = tf.nn.relu(x)
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.UpSampling2D()(x)
            layer_out[i-1] = keras.layers.Add([layer_out[i-1], x])
        for i, tensor in enumerate(layer_out):
            x = keras.layers.Conv2D(filters=filters2[i], kernel_size=3, strides=1, padding='same')(tensor)
            x = tf.nn.relu(x)
            x = keras.layers.BatchNormalization()(x)
            layer_out[i] = x
        for i in range(len(layer_out)-1,0,-1):
            x = keras.layers.UpSampling2D()(layer_out[i])
            output = keras.layers.Concatenate(layer_out[i-1], x)
        model = keras.Model(inputs, output)
        return model
            

In [None]:
def build_head():

In [None]:
#Using B-B-B sections of FFNet
def build_model(input_height=750, input_width=1280, input_channels=3, stage_blocks=[4,5,6,4], stage_stride=[2,2,2,2]):
    inputs = keras.Input(shape=(input_height, input_width, input_channels))
    # Input Stem
    # 3x3 Convolution, S=2
    x = keras.layers.Conv2D(filters=32, kernel_size=3, strides=2, padding='same')(inputs)
    x = tf.nn.relu(x)
    x = keras.layers.BatchNormalization()(x)
    # 3x3 Convolution, S=2
    x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(x)
    x = tf.nn.relu(x)
    x = keras.layers.BatchNormalization()(x)
    
    # Backbone
    stages = []
    n_stages = len(stage_blocks)
    for stage in range(n_stages):
        if not stages:
            stage_in = x
        else:
            stage_in = stages[-1]
        for i in range(stage_blocks[stage]):
            if i == 0:
                block = ResBlock(2,stage_stride[stage])(stage_in)
            else:
                block = ResBlock(2,stage_stride[stage])(block)
        stages.append(block)
    
    # Decoder
    filters1 = [64,128,128,256]
    filters2 = [32,64,96,96]
    layer_out = []
    for i, tensor in enumerate(stages):
        x = keras.layers.Conv2D(filters=filters1[i], kernel_size=1, strides=1, padding='same')(tensor)
        x = tf.nn.relu(x)
        x = keras.layers.BatchNormalization()(x)
        layer_out.append(x)
    for i in range(1,len(layer_out)):
        x = keras.layers.Conv2D(filters=filters1[i-1], kernel_size=1, strides=1, padding='same')(layer_out[i])
        x = tf.nn.relu(x)
        x = keras.layers.BatchNormalization()(x)
        x = keras.layers.UpSampling2D()(x)
        layer_out[i-1] = keras.layers.Add([layer_out[i-1], x])
    for i, tensor in enumerate(layer_out):
        x = keras.layers.Conv2D(filters=filters2[i], kernel_size=3, strides=1, padding='same')(tensor)
        x = tf.nn.relu(x)
        x = keras.layers.BatchNormalization()(x)
        layer_out[i] = x
    for i in range(len(layer_out)-1,0,-1):
        x = keras.layers.UpSampling2D()(layer_out[i])
        decoder = keras.layers.Concatenate(layer_out[i-1], x)
    
    # Segmentation Head
    x = keras.layers.Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(decoder)
    x = tf.nn.relu(x)
    x = keras.layers.BatchNormalization()(x)
    
    output = keras.layers.Conv2D(filters=2, kernel_size=3, strides=1, padding='same')(x)
    
    model = keras.Model(inputs, output)
    return model

In [None]:
model = build_model()
model.summary()