### Import Packages

In [None]:
import re
import os
import numpy as np
import pickle

import matplotlib.pyplot as plt
from skimage.io import imread
from skimage.util.montage import montage2d
from osgeo import gdal

import keras_metrics
from keras import models
from keras import backend as K
from keras.optimizers import Adam, RMSprop
from keras.layers import Input, GaussianNoise, Conv2D, Activation, concatenate, BatchNormalization, SpatialDropout2D, Cropping2D, ZeroPadding2D
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau

import tensorflow as tf

from functions5 import batch_img_gen, dice_p_bce, dice_coef, jaccard

HOME = os.path.expanduser("~")

In [None]:
K.backend()
#K.tensorflow_backend._get_available_gpus()

### Load Data

In [None]:
filepath = HOME + '/new_project/data/pickles/mask_train.pkl'
with open(filepath, 'rb') as pkl:
    mask_train = pickle.load(pkl)

In [None]:
filepath = HOME + '/new_project/data/pickles/tif_train.pkl'
with open(filepath, 'rb') as pkl:
    tif_train = pickle.load(pkl)

### View Data

In [None]:
valid_gen = batch_img_gen(4, tif_train, mask_train)          

In [None]:
t_x, t_y = next(valid_gen)
print('x', t_x.shape, t_x.dtype, t_x.min(), t_x.max())
print('y', t_y.shape, t_y.dtype, t_y.min(), t_y.max())
fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (20, 10))
montage_rgb = lambda x: np.stack([montage2d(x[:, :, :, i]) for i in range(x.shape[3])], -1)
ax1.imshow(montage_rgb(t_x))
ax2.imshow(montage2d(t_y[:, :, :, 0]), cmap = 'Greys')

### Build Network

In [None]:
CROP = 16
DROPOUT = 0.25
NOISE = 0.1
BATCH_SIZE = 4

In [None]:
in_layer = Input((650, 650, 3), name = 'RGB_Input')

layer1 = GaussianNoise(NOISE)(in_layer)
layer2 = BatchNormalization()(layer1)

layer3 = Conv2D(8, [10,10], activation = 'linear', padding = 'same', dilation_rate=(1,1), use_bias=False)(layer2)
layer4 = BatchNormalization()(layer3)
layer5 = Activation('elu')(layer4)

layer6 = Conv2D(8, [10,10], activation = 'linear', padding = 'same', dilation_rate=(1,1), use_bias=False)(layer5)
layer7 = BatchNormalization()(layer6)
layer8 = Activation('elu')(layer7)

layer9 = Conv2D(16, [10,10], activation = 'linear', padding = 'same', dilation_rate=(1,1), use_bias=False)(layer8)
layer10 = BatchNormalization()(layer9)
layer11 = Activation('elu')(layer10)

layer12 = Conv2D(16, [10,10], activation = 'linear', padding = 'same', dilation_rate=(1,1), use_bias=False)(layer11)
layer13 = Conv2D(16, [10,10], activation = 'linear', padding = 'same', dilation_rate=(2,2), use_bias=False)(layer11)
layer14 = Conv2D(16, [10,10], activation = 'linear', padding = 'same', dilation_rate=(4,4), use_bias=False)(layer11)
layer15 = Conv2D(16, [10,10], activation = 'linear', padding = 'same', dilation_rate=(8,8), use_bias=False)(layer11)
layer16 = Conv2D(16, [10,10], activation = 'linear', padding = 'same', dilation_rate=(16,16), use_bias=False)(layer11)
layer17 = Conv2D(16, [10,10], activation = 'linear', padding = 'same', dilation_rate=(32,32), use_bias=False)(layer11)
layer18 = Conv2D(16, [10,10], activation = 'linear', padding = 'same', dilation_rate=(64,64), use_bias=False)(layer11)

layer19 = concatenate([layer2, layer12, layer13, layer14, layer15, layer16, layer17, layer18])

layer20 = SpatialDropout2D(DROPOUT)(layer19)
layer21 = BatchNormalization()(layer20)
layer22 = Activation('elu')(layer21)

layer23 = Conv2D(32, [10,10], activation = 'linear', padding = 'same', dilation_rate=(1,1), use_bias=False)(layer22)
layer24 = BatchNormalization()(layer23)
layer25 = Activation('elu')(layer24)

layer26 = Conv2D(1, (1, 1), activation='sigmoid', padding='same')(layer25)
layer27 = Cropping2D((CROP, CROP))(layer26)
layer28 = ZeroPadding2D((CROP, CROP))(layer27)
model = models.Model(inputs = [in_layer],
                outputs = [layer28])

model.summary()

In [None]:
model.compile(optimizer=RMSprop(lr=1e-6, rho=0.9, epsilon=None, decay=0.0), loss=dice_p_bce, metrics=[dice_coef, 'binary_accuracy', keras_metrics.precision(), keras_metrics.recall()])

In [None]:
weight_path= HOME + "/new_project/Models/{}_weights.best.hdf5".format('seg_model')

checkpoint = ModelCheckpoint(weight_path, monitor='val_dice_coef', verbose=1, 
                             save_best_only=True, mode='max', save_weights_only = True)

reduceLR = ReduceLROnPlateau(monitor='val_dice_coef', factor=0.5, 
                                   patience=3, 
                                   verbose=1, mode='max', epsilon=0.0001, cooldown=2, min_lr=1e-6)

early = EarlyStopping(monitor="val_dice_coef", 
                      mode="max", 
                      patience=15)

callbacks_list = [checkpoint, early, reduceLR]

In [None]:
total_items = len(tif_train)

### Train Network

In [None]:
valid_gen = batch_img_gen(BATCH_SIZE)
loss_history = [model.fit_generator(batch_img_gen(BATCH_SIZE), 
                             steps_per_epoch=min(total_items//BATCH_SIZE, 100),
                             epochs=120, 
                             validation_data = valid_gen,
                             validation_steps = min(total_items//BATCH_SIZE, 50),
                             callbacks=callbacks_list,
                             workers=1, use_multiprocessing=True)]

### Save Weights

In [None]:
# seg_model.load_weights(weight_path)
# seg_model.save('model12_full.h5')