In [16]:
import nbimporter
from DataParser import generateDf, getXYfromDf
from WaveNetClassifier import WaveNet, WaveNetParams

In [17]:
import os
from glob import glob
from datetime import datetime
from pytz import timezone

from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from keras.optimizers import Adam
from keras.metrics import top_k_categorical_accuracy

In [18]:
""" basic parameters """
base_dir = '/Volumes/JS/QuickDraw'
test_path = os.path.join(base_dir, '/test_simplified.csv')
all_train_paths = glob(os.path.join(base_dir, 'train_simplified', '*.csv'))
cols = ['countrycode', 'drawing', 'key_id', 'recognized', 'timestamp', 'word']

# WaveNet parameters
wavenet_params = WaveNetParams()

In [19]:
def train(net_type, weight_path=None):
    def top_3_accuracy(x,y): 
        return top_k_categorical_accuracy(x,y, 3)
    
    train_df, valid_df, test_df, word_encoder = generateDf(n_train=75, n_valid=7, n_test=5, 
                                                       n_strokes=196, path=all_train_paths)
    x_train, y_train = getXYfromDf(train_df, word_encoder)
    x_valid, y_valid = getXYfromDf(valid_df, word_encoder)
    x_test, y_test = getXYfromDf(test_df, word_encoder)
    
    input_shape = x_train.shape[1:]
    output_shape = y_train.shape[1:]
    
    if net_type == 'wavenet':
        model = WaveNet(input_shape, output_shape, wavenet_params)
    if weight_path is not None:
        model.load_weights(weight_path)
        
    date = datetime.now(timezone('US/Eastern')).strftime('%H_%M_%m_%d')
    weight_save_path = './model/stroke_wn_%s' % date + '.h5'
    
    checkpoint = ModelCheckpoint(weight_save_path, monitor='val_loss',
                                verbose=1, save_best_only=True, period=1)
    
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.6, 
                                  patience=1, min_lr=1e-6, mode='auto')
    
    early_stop = EarlyStopping(monitor='val_loss', mode='min', patience=5)
    callback = [checkpoint, early_stop, reduce_lr]
    optimizer = Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
    
    model.compile(optimizer=optimizer,
                  loss='categorical_crossentropy', 
                  metrics=['accuracy', top_3_accuracy])
    
    model.fit(x_train, y_train, 
              validation_data=(x_valid, y_valid),
              batch_size=wavenet_params.batch_size,
              epochs=50,
              callbacks=callback)

In [20]:
train('wavenet')

  "outputs": [],


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
featureInput (InputLayer)       (None, 196, 3)       0                                            
__________________________________________________________________________________________________
dilated_conv_1 (Conv1D)         (None, 196, 64)      448         featureInput[0][0]               
__________________________________________________________________________________________________
dilated_conv_2_tanh (Conv1D)    (None, 196, 64)      8256        dilated_conv_1[0][0]             
__________________________________________________________________________________________________
dilated_conv_2_sigm (Conv1D)    (None, 196, 64)      8256        dilated_conv_1[0][0]             
__________________________________________________________________________________________________
gated_acti

Train on 25500 samples, validate on 2380 samples
Epoch 1/50


KeyboardInterrupt: 