# Training the Models 

In [1]:
from numpy.random import seed
seed(12345)
import tensorflow as tf
from tensorflow.random import set_seed
set_seed(1234)
import os
import random
import numpy as np
import skimage
import matplotlib.pyplot as plt
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau, TensorBoard
from keras import backend as keras
from utils import DataGenerator
from unet3 import *

## Log

First, we need to define some necessary logging utilities.

In [2]:
class TrainValTensorBoard(TensorBoard):
    def __init__(self, log_dir='./log1', **kwargs):
        # Make the original `TensorBoard` log to a subdirectory 'training'
        training_log_dir = os.path.join(log_dir, 'training')
        super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)
        # Log the validation metrics to a separate subdirectory
        self.val_log_dir = os.path.join(log_dir, 'validation')

    def set_model(self, model):
        # Setup writer for validation metrics
        self.writer = tf.summary.create_file_writer(self.val_log_dir)
        super(TrainValTensorBoard, self).set_model(model)

    def on_epoch_end(self, epoch, logs=None):
        # Pop the validation logs and handle them separately with
        # `self.writer`. Also rename the keys so that they can
        # be plotted on the same figure with the training metrics
        logs = logs or {}
        val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
        with self.writer.as_default():
            for name, value in val_logs.items():
                tf.summary.scalar(name, value, step=epoch)
            self.writer.flush()
        # Pass the remaining logs to `TensorBoard.on_epoch_end`
        logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
        logs.update({'lr': keras.eval(self.model.optimizer.lr)})
        super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)

    def on_train_end(self, logs=None):
        super(TrainValTensorBoard, self).on_train_end(logs)
        self.val_writer.close()

Now we need to tell it where to find the data.

In [None]:
# input image dimensions
params = {'batch_size': 1,
          'dim': (128,128,128),
          'n_channels': 1,
          'shuffle': True}
seismPathT = './data/train/seis/'
faultPathT = './data/train/fault/'

seismPathV = './data/validation/seis/'
faultPathV = './data/validation/fault/'
train_ID = range(200)
valid_ID = range(20)
train_generator = DataGenerator(dpath=seismPathT,fpath=faultPathT,
                                  data_IDs=train_ID,**params)
valid_generator = DataGenerator(dpath=seismPathV,fpath=faultPathV,
                                  data_IDs=valid_ID,**params)
train_generator

Create the model (most of this code is in `unet3.py`).

In [None]:
model = unet(input_size=(None, None, None,1))
model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', 
                metrics=['accuracy'])
model.summary()

Finally we need to set up our logging (the TensorBoard).

In [None]:
# checkpoint
filepath = 'check1/fseg-{epoch:02d}.hdf5'
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', 
        verbose=1, save_best_only=False, mode='max')
logging = TrainValTensorBoard()
#reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, 
#                              patience=20, min_lr=1e-8)
callbacks_list = [checkpoint, logging]
print('Data prepared. Ready to train!")

Finally we are ready to train.

In [None]:
# Fit the model
history = model.fit(x=train_generator,
                    validation_data=valid_generator,
                    epochs=100,
                    callbacks=callbacks_list,
                    verbose=1)
model.save('check1/fseg.hdf5')
print('Model saved')

Let's see what we have.

In [None]:
# list all data in history
history.history.keys()

In [None]:
# summarize history for accuracy
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111)
ax.plot(history.history['acc'])
ax.plot(history.history['val_acc'])
ax.title('Model accuracy', fontsize=20)
ax.xlabel('Epoch', fontsize=20)
ax.ylabel('Accuracy', fontsize=20)
ax.legend(['train', 'test'], loc='center right', fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=18)
ax.tick_params(axis='both', which='minor', labelsize=18)
fig

In [3]:
# summarize history for loss
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111)
ax.plot(history.history['loss'])
ax.plot(history.history['val_loss'])
ax.title('Model loss',fontsize=20)
ax.ylabel('Loss',fontsize=20)
ax.xlabel('Epoch',fontsize=20)
ax.legend(['train', 'test'], loc='center right',fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=18)
ax.tick_params(axis='both', which='minor', labelsize=18)
fig

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
conv3d (Conv3D)                 (None, None, None, N 448         input_1[0][0]                    
__________________________________________________________________________________________________
conv3d_1 (Conv3D)               (None, None, None, N 6928        conv3d[0][0]                     
__________________________________________________________________________________________________
max_pooling3d (MaxPooling3D)    (None, None, None, N 0           conv3d_1[0][0]                   
______________________________________________________________________________________________

KeyboardInterrupt: 