# Train UNet model

In [None]:
import pickle

from keras.models import Input, Model
from keras.layers import Conv2D, Concatenate, MaxPooling2D, Conv2DTranspose
from keras.layers import UpSampling2D, Dropout, BatchNormalization

First, we define the UNet architecture using the Keras framework. This code snippet is based on the [unet-keras](https://github.com/pietz/unet-keras) repository.

In [None]:
def conv_block(m, dim, acti, bn, res, do=0):
    n = Conv2D(dim, 3, activation=acti, padding='same')(m)
    n = BatchNormalization()(n) if bn else n
    n = Dropout(do)(n) if do else n
    n = Conv2D(dim, 3, activation=acti, padding='same')(n)
    n = BatchNormalization()(n) if bn else n
    return Concatenate()([m, n]) if res else n

def level_block(m, dim, depth, inc, acti, do, bn, mp, up, res):
    if depth > 0:
        n = conv_block(m, dim, acti, bn, res)
        m = MaxPooling2D()(n) if mp else Conv2D(dim, 3, strides=2, padding='same')(n)
        m = level_block(m, int(inc*dim), depth-1, inc, acti, do, bn, mp, up, res)
        if up:
            m = UpSampling2D()(m)
            m = Conv2D(dim, 2, activation=acti, padding='same')(m)
        else:
            m = Conv2DTranspose(dim, 3, strides=2, activation=acti, padding='same')(m)
        n = Concatenate()([n, m])
        m = conv_block(n, dim, acti, bn, res)
    else:
        m = conv_block(m, dim, acti, bn, res, do)
    return m

def UNet(img_shape, out_ch=1, start_ch=64, depth=4, inc_rate=2., activation='relu', 
         dropout=0.5, batchnorm=False, maxpool=True, upconv=True, residual=False):
    i = Input(shape=img_shape)
    o = level_block(i, start_ch, depth, inc_rate, activation, dropout, batchnorm, maxpool, upconv, residual)
    o = Conv2D(out_ch, 1, activation='sigmoid')(o)
    return Model(inputs=i, outputs=o)

Next, we load our train set:

In [None]:
with open('train_images.pkl', 'rb') as fin:
    input_images = pickle.load(fin)
with open('train_labels.pkl', 'rb') as fin:
    output_tensors = pickle.load(fin)

Now we can train our network using this data:

In [None]:
model = UNet(input_images.shape[1:], out_ch=output_tensors.shape[-1], depth=4, start_ch=64, residual=True)

In [None]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
model.fit(input_images, output_tensors, batch_size=4, epochs=500)

In [None]:
model.save('unet_segmentation.h5')
model.save_weights('unet_segmentation_weights.h5')