In [1]:
import warnings
import numpy as np
import gc
import json
import h5py
import threading
import time

import tensorflow as tf
from keras.utils import Sequence
from keras.utils.data_utils import get_file
from keras.preprocessing import sequence
from keras.preprocessing.sequence import _remove_long_seq
from keras.models import Sequential
from keras.layers import Embedding, Input, Dense, Dropout, CuDNNLSTM, Bidirectional
from keras import metrics
from keras import backend as K
from keras.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split
from sklearn.utils.multiclass import unique_labels
import matplotlib.pyplot as plt


gpu_options = tf.GPUOptions(
    per_process_gpu_memory_fraction=0.95, allow_growth=False)
config = tf.ConfigProto(intra_op_parallelism_threads=16,
                        inter_op_parallelism_threads=16, allow_soft_placement=True, gpu_options=gpu_options)
session = tf.Session(config=config)
K.set_session(session)


Using TensorFlow backend.


In [2]:
def get_word_index():
    f = open('data/ams_word_index.json')
    data = json.load(f)
    f.close()
    return data


def load_vocab():
    with open('data/ams_word_index.json') as json_data:
        return json.load(json_data)

    
def load_glove():
    glove = {}
    with open('data/glove.model.txt') as glove_data:
        for line in glove_data:
            items = line.split()
            key = items[0]
            glove[key] = np.asarray(items[1:], dtype='float32')
    return glove


def build_embedding_layer(with_input=False, maxlen=480, vocab_dim=300, mask_zero=True):
    print("-- loading word embeddings, this may take a little while...")
    index_dict = load_vocab()
    word_vectors = load_glove()
    # adding 1 to account for 0th index (for masking)
    n_symbols = len(index_dict) + 1
    print("-- known dictionary items: ", n_symbols)
    embedding_weights = np.zeros((n_symbols, vocab_dim))
    for word, index in index_dict.items():
        embedding_weights[index, :] = word_vectors[word]
    print("-- embeddings ")
    if not with_input:
        embedding_layer = Embedding(
            mask_zero=mask_zero,
            output_dim=vocab_dim, input_dim=n_symbols, input_length=maxlen, trainable=False, weights=[embedding_weights])
        return embedding_layer
    else:
        # define inputs here
        input_1 = Input(shape=(maxlen,), dtype='int32')
        embedding_layer = Embedding(
            weights=[embedding_weights],
            mask_zero=mask_zero,
            output_dim=vocab_dim, input_dim=n_symbols, input_length=maxlen, trainable=False)(input_1)
        return (embedding_layer, input_1)


In [3]:
data_lock = threading.Lock()

# Input data is obtained via batch loading from HDF5
class DataGenerator(Sequence):
    def __init__(self, batch_size=32, mode="train", dim=480, n_channels=1,
             n_classes=49, x_hf = None, y_hf = None, shuffle=True):
        'Initialization'
        self.x_hf = x_hf
        self.y_hf = y_hf
        self.total_len = 10_000# self.y_hf.shape[0]
        self.validation_len = int(np.floor(0.1 * self.total_len))
        self.training_len = self.total_len - self.validation_len
        if mode == "validation":
            self.data_len = self.validation_len
            self.list_IDs = np.arange(self.data_len, self.total_len-1)
        else:
            self.data_len = self.training_len
            self.list_IDs = np.arange(0,self.data_len)
            
        self.dim = dim
        self.batch_size = batch_size
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()
        self.shuffle = shuffle
        self.on_epoch_end()        

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(self.data_len)
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(self.data_len / self.batch_size))
    
    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, self.dim))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            X[i,] = self.x_hf[ID]
            # Store class
            y[i] = self.y_hf[ID]
        
        return X, y
    
    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

batch_size = 128
# Parameters
data_hf = h5py.File("data/full_ams.hdf5", 'r')
generator_params = {
    'batch_size': batch_size,
    'n_classes': 49,
    'n_channels': 1,
    'shuffle': True,
    'x_hf': data_hf['x_train'],
    'y_hf': data_hf['y_train']
}
# Generators
training_generator = DataGenerator(**generator_params)
validation_generator = DataGenerator(mode="validate", **generator_params)




In [None]:
# preparing word embeddings
# Takes 10 minutes to recompute, memoizing here for the v4 49 class variant:
# class_weights = compute_class_weight('balanced', np.unique(training_generator.hf["y_train"]), training_generator.hf["y_train"])
class_weights = [2.0847070797220815e-01,1.3245738287918261e+00,5.9278789584799433e+03,
 5.3721403061224491e+03,7.2654786271044491e+00,4.5239076262083781e+03,
 6.5991742723960982e+01,2.3946355262772623e+00,6.6118649921507063e+02,
 7.5508411281303633e-01,5.4401420821493154e+01,4.7866706519997315e+00,
 2.8556227540850227e+02,9.8797982641332396e+01,4.9199366305655895e-01,
 9.0956872907893319e+02,3.1291818621247236e-01,9.3251147163503312e+00,
 1.8421398392190138e+00,7.2805560645399947e-01,5.3058175862937765e+02,
 4.2977122448979593e+04,1.5628044526901669e+04,1.3863587886767609e+03,
 1.3223729984301413e+04,1.2114763199148582e+01,2.1488561224489797e+04,
 3.1209331800342460e-01,5.2093481756338897e+03,1.3730710047597313e+02,
 1.6271293120197022e-01,4.2159233322522649e+00,1.2937122952733171e+01,
 4.8153638598296460e+01,4.2977122448979593e+04,1.1444543625319112e+01,
 1.9050142929512230e+01,9.0956872907893319e+02,7.0761706510215845e+00,
 1.0108708218868796e-01,2.5918991666214608e-01,7.8886054421768703e+00,
 3.3626447938485293e-01,8.9561793958612079e-01,2.7727175773535220e+02,
 1.3122785480604455e+03,3.1097773117930242e+01,1.8288137212331742e+03,
 1.6688265185272180e-01]

# 08.2018 (subformula lexemes)
# Analyzing the arxiv dataset seems to indicate 
#   a maxlen of 960 is needed to fit 99.2% of the data
#   a maxlen of 480 fits 96.03%, and a maxlen of 300 covers 90.0% of paragraphs
maxlen = 480
embedding_layer = build_embedding_layer(maxlen=maxlen, mask_zero=False)

-- loading word embeddings, this may take a little while...


In [None]:
# BiLSTM Model Setup
n_classes = 49
layer_size = 128  # ~maxlen // 4

print("-- setting up model layout...")
use_dropout = True

model = Sequential()
model.add(embedding_layer)
if use_dropout:
    model.add(Dropout(0.2))

model.add(Bidirectional(CuDNNLSTM(layer_size, return_sequences=True)))
if use_dropout:
    model.add(Dropout(0.2))

model.add(Bidirectional(CuDNNLSTM(layer_size // 2, return_sequences=True)))
if use_dropout:
    model.add(Dropout(0.1))

model.add(CuDNNLSTM(layer_size // 2))
if use_dropout:
    model.add(Dropout(0.1))

model.add(Dense(n_classes, activation='softmax'))

model.compile(loss='sparse_categorical_crossentropy',
              optimizer="adam",
              weighted_metrics=[metrics.sparse_categorical_accuracy])

# Print model summary
print(model.summary())

In [None]:
model_file = "confusion_bilstm%d_batch%d_cat%d_gpu" % (
    layer_size, batch_size, n_classes)

# Checkpoints: 1) save best model at epoch end, 2) stop early when metric stops improving
checkpoint = ModelCheckpoint(model_file+"-checkpoint.h5",
                             monitor='val_weighted_sparse_categorical_accuracy',
                             verbose=1,
                             save_best_only=True,
                             mode='max')

earlystop = EarlyStopping(monitor='val_weighted_sparse_categorical_accuracy',
                          min_delta=0.001,
                          patience=3,
                          verbose=0, mode='auto')


In [None]:
# Perform training
print('-- training model...')
# TODO: How can we make this work with more workers? HDF5 is not thread-safe for reads...
# maybe use the unpacked .txt files and map them through the dictionary each time? Unsure... 
# On a single CPU worker: 2.5 hours per epoch for the full data, with the main BiLSTM model.
model.fit_generator(
    generator=training_generator,
    validation_data=validation_generator,
    workers = 1, 
    use_multiprocessing=True,
    class_weight=class_weights,
    epochs=5,
    verbose=1,
    callbacks=[checkpoint, earlystop])


print("-- saving model to disk : %s " % model_file)
model.save(model_file+'_notebook.h5')

In [None]:
print("Per-class test measures:")
y_pred = model.predict_classes(data_hf['x_test'], verbose=1, batch_size=batch_size)
print(classification_report(data_hf['y_test'], y_pred))

## Per-class test measures:
Using 10,000 paragraphs, training for 50 epochs yielded `0.8613` val_weighted_sparse_categorical_accuracy.

Testing on the full 2.1 million testing set yields:
```
precision    recall  f1-score   support

           0       0.80      0.85      0.82    206154
           1       0.96      0.94      0.95     32445
           2       0.00      0.00      0.00         7
           3       0.00      0.00      0.00         8
           4       0.00      0.00      0.00      5915
           5       0.00      0.00      0.00         9
           6       0.00      0.00      0.00       651
           7       0.03      0.01      0.02     17947
           8       0.00      0.00      0.00        65
           9       0.39      0.35      0.37     56916
          10       0.00      0.00      0.00       789
          11       0.00      0.00      0.00      8978
          12       0.00      0.00      0.00       150
          13       0.00      0.00      0.00       435
          14       0.00      0.00      0.00     87353
          15       0.00      0.00      0.00        47
          16       0.58      0.59      0.58    137342
          17       0.00      0.00      0.00      4608
          18       0.00      0.00      0.00     23330
          19       0.17      0.21      0.19     59029
          20       0.00      0.00      0.00        80
          21       0.00      0.00      0.00         1
          22       0.00      0.00      0.00         2
          23       0.00      0.00      0.00        30
          24       0.00      0.00      0.00         3
          25       0.00      0.00      0.00      3547
          26       0.00      0.00      0.00         1
          27       0.61      0.75      0.67    137706
          28       0.00      0.00      0.00         8
          29       0.00      0.00      0.00       313
          30       0.37      0.43      0.40    264128
          31       0.00      0.00      0.00     10193
          32       0.00      0.00      0.00      3322
          33       0.00      0.00      0.00       892
          35       0.00      0.00      0.00      3755
          36       0.00      0.00      0.00      2255
          37       0.00      0.00      0.00        47
          38       0.02      0.00      0.00      6073
          39       0.80      0.83      0.81    425149
          40       0.19      0.29      0.23    165813
          41       0.00      0.00      0.00      5448
          42       0.45      0.47      0.46    127807
          43       0.56      0.48      0.52     47985
          44       0.00      0.00      0.00       154
          45       0.00      0.00      0.00        32
          46       0.00      0.00      0.00      1382
          47       0.00      0.00      0.00        23
          48       0.35      0.33      0.34    257528

   micro avg       0.52      0.52      0.52   2105855
   macro avg       0.13      0.14      0.13   2105855
weighted avg       0.49      0.52      0.51   2105855

```

In [None]:
class_names=sorted([
    "abstract", "acknowledgement", "affirmation", "answer", "assumption",
    "bound", "case", "claim", "comment", "conclusion",
    "condition", "conjecture", "constraint", "convention", "corollary",
    "criterion", "definition", "demonstration", "discussion", "example",
    "exercise", "expansion", "expectation", "experiment", "explanation",
    "fact", "hint", "introduction", "issue", "keywords",
    "lemma", "method", "notation", "note", "notice",
    "observation", "overview", "principle", "problem", "proof",
    "proposition", "question", "remark", "result", "rule",
    "solution", "step", "summary", "theorem"])

def plot_confusion_matrix(y_true, y_pred, classes,
                          normalize=False,
                          title=None,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    # Only use the labels that appear in the data
    #classes = classes[unique_labels(y_true, y_pred)]
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    fig, ax = plt.subplots(figsize=(80,60))
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
#     fmt = '.2f' if normalize else 'd'
#     thresh = cm.max() / 2.
#     for i in range(cm.shape[0]):
#         for j in range(cm.shape[1]):
#             ax.text(j, i, format(cm[i, j], fmt),
#                     ha="center", va="center",
#                     color="white" if cm[i, j] > thresh else "black")
#     fig.tight_layout()
    return ax


np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plot_confusion_matrix(data_hf['y_test'], y_pred, classes=class_names,
                      title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plot_confusion_matrix(data_hf['y_test'], y_pred, classes=class_names, normalize=True,
                      title='Normalized confusion matrix')

plt.show()