In [65]:
from keras import layers
from keras.models import Model
from keras.optimizers import Adam
from keras.metrics import top_k_categorical_accuracy
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

In [66]:
import os
import datetime
from glob import glob
from sklearn.metrics import confusion_matrix, classification_report

import nbimporter
from DataParser import generateDf, getXYfromDf

In [1]:
class WaveNetParams():
    """ WaveNet parameter class, eaiser to be imported """
    def __init__(self):
        """ n_filters: filter numbers for 
        """
        self.input_shape = (196, 3)
        
        self.output_shape = (340, )
        
        self.n_filters = 64 + 128
        
        self.kernel_size = 2
        
        self.dilation_depth = 8
        
        self.pool_size_1 = 4
        
        self.pool_size_2 = 8
        
        self.batch_size = 512
        
        self.activation = 'softmax'

In [3]:
def residual_block(x, i, wavenet_params):
    tanh_out = layers.Conv1D(wavenet_params.n_filters, 
                      wavenet_params.kernel_size, 
                      dilation_rate = wavenet_params.kernel_size**i, 
                      padding='causal', 
                      name='dilated_conv_%d_tanh' % (wavenet_params.kernel_size ** i), 
                      activation='tanh'
                      )(x)
    sigm_out = layers.Conv1D(wavenet_params.n_filters, 
                      wavenet_params.kernel_size, 
                      dilation_rate = wavenet_params.kernel_size**i, 
                      padding='causal', 
                      name='dilated_conv_%d_sigm' % (wavenet_params.kernel_size ** i), 
                      activation='sigmoid'
                      )(x)
    z = layers.Multiply(name='gated_activation_%d' % (i))([tanh_out, sigm_out])
    skip = layers.Conv1D(wavenet_params.n_filters, 1, name='skip_%d'%(i))(z)
    res = layers.Add(name='residual_block_%d' % (i))([skip, x])
    return res, skip

In [4]:
def WaveNet(wavenet_params):
    stroke_input = layers.Input(shape=wavenet_params.input_shape, name='featureInput')
    
    x = layers.Conv1D(wavenet_params.n_filters, wavenet_params.kernel_size, dilation_rate=1, 
                      padding='causal', name='dilated_conv_1')(stroke_input)
    
    skip_connections = []
    
    for i in range(1, wavenet_params.dilation_depth + 1):
        x, skip = residual_block(x, i, wavenet_params)
        skip_connections.append(skip)
        
    x = layers.Add(name='skip_connections')(skip_connections)
#     x = layers.Activation('relu')(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    
    x = layers.Conv1D(wavenet_params.n_filters, wavenet_params.pool_size_1, strides=1, 
                      padding='same', name='conv_5ms', activation='relu')(x)
    
    x = layers.Conv1D(wavenet_params.output_shape[0], wavenet_params.pool_size_2, padding='same', 
                      activation='relu', name='conv_500ms')(x)
    
    x = layers.Conv1D(wavenet_params.output_shape[0], wavenet_params.pool_size_2, padding='same', 
                      activation='relu', name='conv_500ms_target_shape')(x)
    
    x = layers.AveragePooling1D(wavenet_params.pool_size_2, padding='same',name = 'downsample_to_2Hz')(x)
    
    x = layers.Conv1D(wavenet_params.output_shape[0], 
                      (int)(wavenet_params.input_shape[0] / (wavenet_params.pool_size_1*wavenet_params.pool_size_2)), 
                      padding='same', name='final_conv')(x)
    
    x = layers.GlobalAveragePooling1D(name='final_pooling')(x)
    
    x = layers.Activation(wavenet_params.activation, name='final_activation')(x)
    
    model = Model(input=stroke_input, output=x)
    print(model.summary())
    return model

In [5]:
def top_3_accuracy(x,y): 
    return top_k_categorical_accuracy(x,y, 3)

In [6]:
def train(model, x_train, y_train, x_valid, y_valid):
    date = datetime.datetime.today().strftime('%H_%M_%m_%d')
    weight_save_path = './model/stroke_wn_%s' % date + '.h5'
    
    checkpoint = ModelCheckpoint(weight_save_path, monitor='val_loss',
                                verbose=1, save_best_only=True, period=1)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.6, 
                                  patience=5, min_lr=1e-6, mode='auto')
    early_stop = EarlyStopping(monitor='val_loss', mode='min', patience=5)
    callback = [checkpoint, early_stop, reduce_lr]
    optimizer = Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
    
    model.compile(optimizer=optimizer,
                  loss='categorical_crossentropy', 
                  metrics=['accuracy', top_3_accuracy])
    model.fit(x_train, y_train, 
              validation_data=(x_valid, y_valid),
              batch_size=wavenet_params.batch_size,
              epochs=50,
              callbacks=callback)

In [7]:
def run(weight_path=None):
    """ basic parameters """
    base_dir = '/Volumes/JS/QuickDraw'
    test_path = os.path.join(base_dir, '/test_simplified.csv')
    all_train_paths = glob(os.path.join(base_dir, 'train_simplified', '*.csv'))
    cols = ['countrycode', 'drawing', 'key_id', 'recognized', 'timestamp', 'word']
    
    wavenet_params = WaveNetParams()
    
    train_df, valid_df, test_df, word_encoder = generateDf(n_train=75, n_valid=7, n_test=5, 
                                                       n_strokes=196, path=all_train_paths)
    x_train, y_train = getXYfromDf(train_df, word_encoder)
    x_valid, y_valid = getXYfromDf(valid_df, word_encoder)
    x_test, y_test = getXYfromDf(test_df, word_encoder)
    
    input_shape = x_train.shape[1:]
    output_shape = y_train.shape[1:]
    model = WaveNet(wavenet_params)
    if weight_path is not None:
        model.load_weights(weight_path)
    train(model, x_train, y_train, x_valid, y_valid)

In [None]:
run()



__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
featureInput (InputLayer)       (None, 196, 3)       0                                            
__________________________________________________________________________________________________
dilated_conv_1 (Conv1D)         (None, 196, 64)      448         featureInput[0][0]               
__________________________________________________________________________________________________
dilated_conv_2_tanh (Conv1D)    (None, 196, 64)      8256        dilated_conv_1[0][0]             
__________________________________________________________________________________________________
dilated_conv_2_sigm (Conv1D)    (None, 196, 64)      8256        dilated_conv_1[0][0]             
__________________________________________________________________________________________________
gated_acti

Train on 25500 samples, validate on 2380 samples
Epoch 1/50
