In [None]:
import pandas as pd
import keras
from keras.layers import LSTM, Input, Bidirectional, TimeDistributed, Dropout, Dense, Activation
from keras.models import Model
from keras.utils.np_utils import to_categorical
import tensorflow as tf
import numpy as np
import os
import fnmatch
from matplotlib import pyplot as plt
from IPython.display import clear_output

## Model

In [None]:
def get_model(input_shape, output_shape, lr = 0.00001, path = None):
    inputs = Input(shape=input_shape)
    X = Dropout(0.5)(inputs)
    X = Dense(4096, activation='relu')(X)
    X = Dropout(0.5)(X)
    X = Dense(2048, activation='relu')(X)
    X = Dense(2048, activation='relu')(X)
    X = Bidirectional(LSTM(1024, return_sequences=True))(X)
    outputs = TimeDistributed(Dense(output_shape, activation='softmax'))(X)
    
    model = Model(inputs=inputs, outputs=outputs)
    
    model.compile(
        optimizer=tf.train.AdamOptimizer(lr),
        loss='categorical_crossentropy',
        metrics=['categorical_accuracy'],
    )
    
    if path != None:
        model.load_weights(path)
    
    return model

## Callbacks

In [None]:
class PlotLosses(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.i = 0
        self.x = []
        self.losses = []
        self.acc = []
        self.val_losses = []
        self.val_acc = []
        
        self.logs = []

    def on_epoch_end(self, epoch, logs={}):
        f, (ax1, ax2) = plt.subplots(2, sharex=True, sharey=False)
        self.ax1 = ax1
        self.ax2 = ax2
    
        self.logs.append(logs)
        self.x.append(self.i)
        self.losses.append(logs.get('loss'))
        self.acc.append(logs.get('categorical_accuracy'))
        self.val_losses.append(logs.get('val_loss'))
        self.val_acc.append(logs.get('val_categorical_accuracy'))
        self.i += 1
        
        clear_output(wait=True)
        self.ax1.plot(self.x, self.losses, label="loss")
        self.ax1.plot(self.x, self.val_losses, label="val loss")
        self.ax1.legend()
        self.ax2.plot(self.x, self.acc, label="accuracy")
        self.ax2.plot(self.x, self.val_acc, label="val accuracy")
        self.ax2.legend()
        plt.show()
        
        for i in range(self.i):
            print('Epoch ' + str(i+1))
            print('-----------------------')
            print('- Loss:', self.losses[i])
            print('- Accuracy:', self.acc[i])
            print('- Validation loss:', self.val_losses[i])
            print('- Validation accuracy:', self.val_acc[i])
            print(' ')
            
checkpoint = keras.callbacks.ModelCheckpoint(
    './convnet_weights/lstm/weights.{epoch:02d}-{val_loss:.2f}.hdf5', 
    monitor='val_loss',
    verbose=0, 
    save_best_only=False, 
    save_weights_only=True, 
    mode='max', 
    period=1,
)

## Data

In [None]:
def generator(path, mode, num_batches):
    counter = 0
    indices = np.random.permutation(list(range(num_batches)))
    
    while True:
        if counter >= num_batches:
            counter = 0
            indices = np.random.permutation(list(range(num_batches)))

        features = np.load(path + 'features_' + mode + '_' + str(indices[counter]) + '.npy')
        classes = np.load(path + 'classes_' + mode + '_' + str(indices[counter]) + '.npy')
        
        counter += 1
        yield features, classes
        
def get_metadata(path):
    features = np.load(path + 'features_train_0.npy')
    input_shape = features.shape[1:]
    classes = np.load(path + 'classes_train_0.npy')
    num_classes = classes.shape[2]
    steps_per_epoch = len(fnmatch.filter(os.listdir(path),'*features_train_*'))
    validation_steps = len(fnmatch.filter(os.listdir(path),'*features_dev_*'))
    return input_shape, num_classes, steps_per_epoch, validation_steps

## Training

In [None]:
input_shape, num_classes, steps_per_epoch, validation_steps = get_metadata('./features/lstm/2_steps/')

In [None]:
print('Input shape:', input_shape)
print('Number of classes:', num_classes)
print('Steps per epoch:', steps_per_epoch) 
print('Validation steps:', validation_steps)

In [None]:
model = get_model(
    input_shape, 
    num_classes, 
    lr=0.00001,
    path='./convnet_weights/lstm/weights.23-1.23.hdf5',
)
plot_losses = PlotLosses()

In [None]:
model.summary()

In [None]:
model.fit_generator(
    generator('./features/lstm/2_steps/', 'train', steps_per_epoch),
    steps_per_epoch=steps_per_epoch,
    epochs=200,
    validation_data=generator('./features/lstm/2_steps/', 'dev', validation_steps),
    validation_steps=validation_steps,
    callbacks=[plot_losses, checkpoint],
)