In [254]:
import pandas as pd
import numpy as np
from helpers import get_data, TrainValTensorBoard

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [184]:
from keras.layers import Input, Conv2D, Activation, BatchNormalization, add, \
                         MaxPool2D, GlobalAveragePooling2D, Dense
from keras import Model

In [282]:
salt_df = pd.read_csv('../salt-identifier.csv')
salt_df.head()

Unnamed: 0,id,salt
0,575d24d81d,0
1,a266a2a9df,1
2,75efad62c1,1
3,34e51dba6a,1
4,4875705fb0,1


In [283]:
xtrain, xval, ytrain, yval, \
dtrain, dval, idtrain, idval = get_data('../tfmodel')

In [284]:
idtrain = idtrain.astype(str)
idval = idval.astype(str)

ytrain = np.array([salt_df[salt_df.id == i]['salt'].values[0] for i in idtrain])
yval = np.array([salt_df[salt_df.id == i]['salt'].values[0] for i in idval])

ytrain.shape, xtrain.shape, yval.shape, xval.shape

((3446,), (3446, 224, 224, 1), (383,), (383, 224, 224, 1))

In [286]:
def init_block(input_img):
    x = Conv2D(64, (7, 7), strides=(2, 2), padding='same')(input_img)
    x = MaxPool2D((3, 3), strides=(2, 2), padding='same')(x)
    return x

In [287]:
def encoder(input_img):
    e1 = encode_block(input_img, [64, 64], )
    e2 = encode_block(e1, [64, 128])
    e3 = encode_block(e2, [128, 256])
    e4 = encode_block(e3, [256, 512])
    return e4

In [288]:
def encode_block(input_tensor, filters, ksize=(3, 3)):
    f_in, f_out = filters
    
    x = Conv2D(f_out, (1, 1), strides=(2, 2), padding='same')(input_tensor)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)
    
    x = Conv2D(f_out, ksize, strides=(1, 1), padding='same')(x)
    x = BatchNormalization(axis=3)(x)
    
    shortcut = Conv2D(f_out, (1, 1), strides=(2, 2), padding='same')(input_tensor)
    shortcut = BatchNormalization(axis=3)(shortcut)
    
    x = add([x, shortcut])
    ec1 = Activation('relu')(x)
    
    x = Conv2D(f_out, ksize, strides=(1, 1), padding='same')(ec1)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)
    
    x = Conv2D(f_out, ksize, strides=(1, 1), padding='same')(x)
    x = BatchNormalization(axis=3)(x)
    
    x = add([x, ec1])
    x = Activation('relu')(x)
    
    return x

In [289]:
H, W, C = 224, 224, 1

# salt / no salt
classes = 1

input_img = Input((H, W, C))

x = init_block(input_img)
x = encoder(x)

x = GlobalAveragePooling2D()(x)
x = Dense(classes, activation='softmax')(x)

In [290]:
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, \
                            EarlyStopping
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.losses import binary_crossentropy

In [291]:
lr = 1e-2
BATCH_SIZE = 32
EPOCHS = 1

tf_model = Model(input_img, x)
tf_model.compile(loss=binary_crossentropy, optimizer=Adam(lr), metrics=['accuracy'])

In [292]:
datagen = ImageDataGenerator(zoom_range=0.01, 
                             vertical_flip=True,
                             horizontal_flip=True)
datagen.fit(xtrain)

In [293]:
# define callbacks
lr_plat = ReduceLROnPlateau(monitor='val_dice_coef',
                               factor=0.2,
                               patience=5,
                               verbose=1,
                               min_delta=1e-4,
                               mode='max')
early_stop = EarlyStopping(monitor='val_dice_coef',
                           patience=10,
                           verbose=1,
                           min_delta=1e-4,
                           mode='max')
m_checkpoint = ModelCheckpoint(monitor='val_dice_coef',
                             filepath='tfmodel_weights.hdf5',
                             save_best_only=True,
                             mode='max')
tb = TrainValTensorBoard(write_graph=False)
callbacks = [lr_plat, early_stop, m_checkpoint, tb]

In [294]:
tf_model.fit_generator(generator=datagen.flow(xtrain, ytrain, batch_size=BATCH_SIZE),
                    steps_per_epoch=np.ceil(float(len(xtrain)) / float(BATCH_SIZE)),
                    epochs=EPOCHS,
                    verbose=1,
                    callbacks=callbacks,
                    validation_data=(xval, yval), 
                    validation_steps=np.ceil(float(len(xval)) / float(BATCH_SIZE)))

Epoch 1/1
  7/108 [>.............................] - ETA: 6:09 - loss: 6.5478 - acc: 0.5893

KeyboardInterrupt: 