# Training the Models

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from numpy.random import seed
seed(12345)
import tensorflow as tf
from tensorflow.random import set_seed
set_seed(1234)
import os
import random
import numpy as np
import skimage
import matplotlib.pyplot as plt
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau, TensorBoard
from keras import utils

## Logging

First, we need to define some necessary logging utilities.

In [None]:
class TrainValTensorBoard(TensorBoard):
    def __init__(self, log_dir='./log1', **kwargs):
        # Make the original `TensorBoard` log to a subdirectory 'training'
        training_log_dir = os.path.join(log_dir, 'training')
        super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)
        # Log the validation metrics to a separate subdirectory
        self.val_log_dir = os.path.join(log_dir, 'validation')

    def set_model(self, model):
        # Setup writer for validation metrics
        self.writer = tf.summary.create_file_writer(self.val_log_dir)
        super(TrainValTensorBoard, self).set_model(model)

    def on_epoch_end(self, epoch, logs=None):
        # Pop the validation logs and handle them separately with
        # `self.writer`. Also rename the keys so that they can
        # be plotted on the same figure with the training metrics
        logs = logs or {}
        val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
        with self.writer.as_default():
            for name, value in val_logs.items():
                tf.summary.scalar(name, value, step=epoch)
            self.writer.flush()
        # Pass the remaining logs to `TensorBoard.on_epoch_end`
        logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
        logs.update({'lr': keras.eval(self.model.optimizer.lr)})
        super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)

    def on_train_end(self, logs=None):
        super(TrainValTensorBoard, self).on_train_end(logs)
        self.val_writer.close()

# The Data

Now we need to tell it where to find the data.

In [None]:
# input image dimensions
params = {'batch_size': 1,
          'dim': (128,128,128),
          'n_channels': 1,
          'shuffle': True}
seismPathT = './data/train/seis/'
faultPathT = './data/train/fault/'

seismPathV = './data/validation/seis/'
faultPathV = './data/validation/fault/'
train_ID = range(200)
valid_ID = range(20)

### Data Normalization and Generation

In [None]:
class DataGenerator(utils.Sequence):
    'Generates data for keras'
    def __init__(self,dpath,fpath,data_IDs, batch_size=1, dim=(128,128,128), 
                 n_channels=1, shuffle=True):
        'Initialization'
        self.dim   = dim
        self.dpath = dpath
        self.fpath = fpath
        self.batch_size = batch_size
        self.data_IDs   = data_IDs
        self.n_channels = n_channels
        self.shuffle    = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.data_IDs)/self.batch_size))

    def __getitem__(self, index):
        'Generates one batch of data'
        # Generate indexes of the batch
        bsize = self.batch_size
        indexes = self.indexes[index*bsize:(index+1)*bsize]

        # Find list of IDs
        data_IDs_temp = [self.data_IDs[k] for k in indexes]

        # Generate data
        X, Y = self.__data_generation(data_IDs_temp)

        return X, Y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.data_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, data_IDs_temp):
        'Generates data containing batch_size samples'
        # Initialization
        gx  = np.fromfile(self.dpath+str(data_IDs_temp[0])+'.dat',dtype=np.single)
        fx = np.fromfile(self.fpath+str(data_IDs_temp[0])+'.dat',dtype=np.single)
        gx = np.reshape(gx,self.dim)
        fx = np.reshape(fx,self.dim)
        xm = np.mean(gx)
        xs = np.std(gx)
        gx = gx-xm
        gx = gx/xs
        gx = np.transpose(gx)
        fx = np.transpose(fx)
        #in seismic processing, the dimensions of a seismic array is often arranged as
        #a[n3][n2][n1] where n1 represnts the vertical dimenstion. This is why we need 
        #to transpose the array here in python 
        # Generate data
        X = np.zeros((2, *self.dim, self.n_channels),dtype=np.single)
        Y = np.zeros((2, *self.dim, self.n_channels),dtype=np.single)
        X[0,] = np.reshape(gx, (*self.dim,self.n_channels))
        Y[0,] = np.reshape(fx, (*self.dim,self.n_channels))
        X[1,] = np.reshape(np.flipud(gx), (*self.dim,self.n_channels))
        Y[1,] = np.reshape(np.flipud(fx), (*self.dim,self.n_channels))
        return X,Y

In [None]:
train_generator = DataGenerator(dpath=seismPathT,fpath=faultPathT,
                                  data_IDs=train_ID,**params)
valid_generator = DataGenerator(dpath=seismPathV,fpath=faultPathV,
                                  data_IDs=valid_ID,**params)
train_generator

## Model

Create the model (most of this code is in `unet3.py`).

In [None]:
def unet(pretrained_weights = None,input_size = (None,None,None,1)):
    inputs = Input(input_size)
    conv1 = Conv3D(16, (3,3,3), activation='relu', padding='same')(inputs)
    conv1 = Conv3D(16, (3,3,3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling3D(pool_size=(2,2,2))(conv1)

    conv2 = Conv3D(32, (3,3,3), activation='relu', padding='same')(pool1)
    conv2 = Conv3D(32, (3,3,3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling3D(pool_size=(2,2,2))(conv2)

    conv3 = Conv3D(64, (3,3,3), activation='relu', padding='same')(pool2)
    conv3 = Conv3D(64, (3,3,3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling3D(pool_size=(2,2,2))(conv3)

    conv4 = Conv3D(128, (3,3,3), activation='relu', padding='same')(pool3)
    conv4 = Conv3D(128, (3,3,3), activation='relu', padding='same')(conv4)

    up5 = concatenate([UpSampling3D(size=(2,2,2))(conv4), conv3], axis=-1)
    conv5 = Conv3D(64, (3,3,3), activation='relu', padding='same')(up5)
    conv5 = Conv3D(64, (3,3,3), activation='relu', padding='same')(conv5)

    up6 = concatenate([UpSampling3D(size=(2,2,2))(conv5), conv2], axis=-1)
    conv6 = Conv3D(32, (3,3,3), activation='relu', padding='same')(up6)
    conv6 = Conv3D(32, (3,3,3), activation='relu', padding='same')(conv6)

    up7 = concatenate([UpSampling3D(size=(2,2,2))(conv6), conv1], axis=-1)
    conv7 = Conv3D(16, (3,3,3), activation='relu', padding='same')(up7)
    conv7 = Conv3D(16, (3,3,3), activation='relu', padding='same')(conv7)

    conv8 = Conv3D(1, (1,1,1), activation='sigmoid')(conv7)

    model = Model(inputs=[inputs], outputs=[conv8])
    model.summary()
    #model.compile(optimizer = Adam(lr = 1e-4), 
    #    loss = cross_entropy_balanced, metrics = ['accuracy'])
    return model

def cross_entropy_balanced(y_true, y_pred):
    # Note: tf.nn.sigmoid_cross_entropy_with_logits expects y_pred is logits, 
    # Keras expects probabilities.
    # transform y_pred back to logits
    _epsilon = _to_tensor(K.epsilon(), y_pred.dtype.base_dtype)
    y_pred   = tf.clip_by_value(y_pred, _epsilon, 1 - _epsilon)
    y_pred   = tf.log(y_pred/ (1 - y_pred))

    y_true = tf.cast(y_true, tf.float32)

    count_neg = tf.reduce_sum(1. - y_true)
    count_pos = tf.reduce_sum(y_true)

    beta = count_neg / (count_neg + count_pos)

    pos_weight = beta / (1 - beta)

    cost = tf.nn.weighted_cross_entropy_with_logits(logits=y_pred, targets=y_true, pos_weight=pos_weight)

    cost = tf.reduce_mean(cost * (1 - beta))

    return tf.where(tf.equal(count_pos, 0.0), 0.0, cost)


def _to_tensor(x, dtype):
    """Convert the input `x` to a tensor of type `dtype`.
    # Arguments
    x: An object to be converted (numpy array, list, tensors).
    dtype: The destination type.
    # Returns
    A tensor.
    """
    x = tf.convert_to_tensor(x)
    if x.dtype != dtype:
        x = tf.cast(x, dtype)
    return x

model = unet(input_size=(None, None, None,1))
model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', 
                metrics=['accuracy'])
model.summary()

Finally we need to set up our logging (the TensorBoard).

In [None]:
# checkpoint
filepath = 'check1/fseg-{epoch:02d}.hdf5'
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', 
        verbose=1, save_best_only=False, mode='max')
logging = TrainValTensorBoard()
#reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, 
#                              patience=20, min_lr=1e-8)
callbacks_list = [checkpoint, logging]
print('Data prepared. Ready to train!')

## Fitting the Model

Finally we are ready to train.

In [None]:
model = Sequential()
model.add(Conv2D(3, (4,4), input_shape=(100,100,4)))
model.summary()

In [None]:
# Fit the model
history = model.fit(x=train_generator,
                    validation_data=valid_generator,
                    epochs=100,
                    callbacks=callbacks_list,
                    verbose=1)
model.save('check1/fseg.hdf5')
print('Model saved')

# Training Results

Let's see what we have.

In [None]:
# list all data in history
history.history.keys()

In [None]:
# summarize history for accuracy
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111)
ax.plot(history.history['acc'])
ax.plot(history.history['val_acc'])
ax.title('Model accuracy', fontsize=20)
ax.xlabel('Epoch', fontsize=20)
ax.ylabel('Accuracy', fontsize=20)
ax.legend(['train', 'test'], loc='center right', fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=18)
ax.tick_params(axis='both', which='minor', labelsize=18)
fig

In [None]:
# summarize history for loss
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111)
ax.plot(history.history['loss'])
ax.plot(history.history['val_loss'])
ax.title('Model loss',fontsize=20)
ax.ylabel('Loss',fontsize=20)
ax.xlabel('Epoch',fontsize=20)
ax.legend(['train', 'test'], loc='center right',fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=18)
ax.tick_params(axis='both', which='minor', labelsize=18)
fig