# RNN Training
This notebook can be used to train an LSTM for text classification and generate predictions for the kaggle competition found [here](https://www.kaggle.com/c/quora-insincere-questions-classification). 

The notebook utilizes Keras and GloVe for preprocessing using word embeddings. Then, Keras with Tensorflow backend is used for training a deep LSTM. 

Ensure that the train.csv and test.csv are in the data/ directory of this project. 

In [3]:
import pandas as pd
import numpy as np
import os
from tqdm import tqdm_notebook as tqdm
from time import sleep

from sklearn.model_selection import train_test_split

from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.layers import Dense, Embedding, LSTM
from keras.models import Sequential

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
# Read in training and testing data
train_df = pd.read_csv('data/train.csv')
train_df.head()

Unnamed: 0,qid,question_text,target
0,00002165364db923c7e6,How did Quebec nationalists see their province...,0
1,000032939017120e6e44,"Do you have an adopted dog, how would you enco...",0
2,0000412ca6e4628ce2cf,Why does velocity affect time? Does velocity a...,0
3,000042bf85aa498cd78e,How did Otto von Guericke used the Magdeburg h...,0
4,0000455dfa3e01eae3af,Can I convert montra helicon D to a mountain b...,0


In [9]:
# Eliminate any potential null values
train_df[train_df.isnull().any(axis=1)].shape

(0, 3)

In [10]:
# Extract the training data and corresponding labels
text = train_df['question_text'].values
labels = train_df['target'].values

# Split into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(text, labels,\
                                                  test_size=0.2)

In [11]:
embed_size = 300 # how big is each word vector
max_words = 30000 # how many unique words to use (i.e num rows in embedding vector)
maxlen = 100 # max number of words in a question to use

In [12]:
## Tokenize the sentences
tokenizer = Tokenizer()
tokenizer.fit_on_texts(list(X_train))
X_train = tokenizer.texts_to_sequences(X_train)
X_val = tokenizer.texts_to_sequences(X_val)

## Pad the sentences 
X_train = pad_sequences(X_train, maxlen=maxlen)
X_val = pad_sequences(X_val, maxlen=maxlen)

In [13]:
word_index = tokenizer.word_index
print('The word index consists of {} unique tokens.'.format(len(word_index)))

The word index consists of 196065 unique tokens.


In [4]:
# Create the embedding dictionary from the word embedding file
embedding_dict = {}
filename = os.path.join('./embeddings/', 'glove.840B.300d/glove.840B.300d.txt')
with tqdm(os.path.getsize(filename), total=os.path.getsize(filename)) as pbar:
    with open(filename) as f:
        for line in f:
            pbar.update(len(line))
            line = line.split()
            sleep(0.001)
            token = line[0]
            try:
                coefs = np.asarray(line[1:], dtype='float32')
                embedding_dict[token] = coefs
            except:
                pass
print('The embedding dictionary has {} items'.format(len(embedding_dict)))

HBox(children=(IntProgress(value=0, max=5646236541), HTML(value='')))




KeyboardInterrupt: 

In [47]:
embed_mat = np.zeros(shape=[len(word_index)+1, embed_size])
for word, idx in word_index.items():
    vector = embedding_dict.get(word)
    if vector is not None:
        embed_mat[idx] = vector

In [48]:
print(embed_mat.shape)


(196467, 300)
[   0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    3   17  109
    5  384  187  168  131 5450  302  455    5   25 3236   13  253  380
  109 1077]


In [49]:
embed_layer = Embedding(len(word_index) + 1,
                        embed_size,
                        weights=[embed_mat],
                        input_length=maxlen,
                        trainable=False)

In [50]:
def create_rnn():
    model_rnn = Sequential()
    model_rnn.add(embed_layer)
    model_rnn.add(GRU(500, dropout=0.2, recurrent_dropout=0.2))
    model_rnn.add(Dense(1, activation='sigmoid'))
    model_rnn.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model_rnn

In [51]:
lstm = create_rnn()
lstm.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=1, batch_size=128)

Train on 1044897 samples, validate on 261225 samples
Epoch 1/1


    128/1044897 [..............................] - ETA: 16:36:22 - loss: 0.7717 - acc: 0.2266

    256/1044897 [..............................] - ETA: 11:10:07 - loss: 0.6227 - acc: 0.5742

    384/1044897 [..............................] - ETA: 9:19:57 - loss: 0.5176 - acc: 0.7005 

    512/1044897 [..............................] - ETA: 8:27:10 - loss: 0.4572 - acc: 0.7598

    640/1044897 [..............................] - ETA: 7:53:27 - loss: 0.4362 - acc: 0.7922

    768/1044897 [..............................] - ETA: 7:29:54 - loss: 0.4097 - acc: 0.8164

    896/1044897 [..............................] - ETA: 7:13:41 - loss: 0.3887 - acc: 0.8348

   1024/1044897 [..............................] - ETA: 7:01:06 - loss: 0.3619 - acc: 0.8506

   1152/1044897 [..............................] - ETA: 6:54:30 - loss: 0.3697 - acc: 0.8550

   1280/1044897 [..............................] - ETA: 6:47:14 - loss: 0.3416 - acc: 0.8680

   1408/1044897 [..............................] - ETA: 6:40:39 - loss: 0.3353 - acc: 0.8736

   1536/1044897 [..............................] - ETA: 6:34:54 - loss: 0.3263 - acc: 0.8796

   1664/1044897 [..............................] - ETA: 6:30:21 - loss: 0.3214 - acc: 0.8822

   1792/1044897 [..............................] - ETA: 6:26:02 - loss: 0.3099 - acc: 0.8890

   1920/1044897 [..............................] - ETA: 6:23:13 - loss: 0.3045 - acc: 0.8927

   2048/1044897 [..............................] - ETA: 6:20:45 - loss: 0.2943 - acc: 0.8979

   2176/1044897 [..............................] - ETA: 6:17:49 - loss: 0.2879 - acc: 0.9012

   2304/1044897 [..............................] - ETA: 6:15:08 - loss: 0.2858 - acc: 0.9028

   2432/1044897 [..............................] - ETA: 6:12:46 - loss: 0.2827 - acc: 0.9042

   2560/1044897 [..............................] - ETA: 6:10:43 - loss: 0.2809 - acc: 0.9051

   2688/1044897 [..............................] - ETA: 6:08:55 - loss: 0.2774 - acc: 0.9074

   2816/1044897 [..............................] - ETA: 6:06:58 - loss: 0.2759 - acc: 0.9084

   2944/1044897 [..............................] - ETA: 6:05:15 - loss: 0.2787 - acc: 0.9086

   3072/1044897 [..............................] - ETA: 6:03:46 - loss: 0.2770 - acc: 0.9095

   3200/1044897 [..............................] - ETA: 6:02:19 - loss: 0.2748 - acc: 0.9106

   3328/1044897 [..............................] - ETA: 6:01:00 - loss: 0.2745 - acc: 0.9102

   3456/1044897 [..............................] - ETA: 5:59:44 - loss: 0.2714 - acc: 0.9112

   3584/1044897 [..............................] - ETA: 5:58:41 - loss: 0.2683 - acc: 0.9121

   3712/1044897 [..............................] - ETA: 5:57:39 - loss: 0.2643 - acc: 0.9138

   3840/1044897 [..............................] - ETA: 5:56:38 - loss: 0.2586 - acc: 0.9164

   3968/1044897 [..............................] - ETA: 5:55:41 - loss: 0.2549 - acc: 0.9178

   4096/1044897 [..............................] - ETA: 5:54:47 - loss: 0.2540 - acc: 0.9180

   4224/1044897 [..............................] - ETA: 5:55:37 - loss: 0.2545 - acc: 0.9181

   4352/1044897 [..............................] - ETA: 5:55:05 - loss: 0.2513 - acc: 0.9193

   4480/1044897 [..............................] - ETA: 5:54:18 - loss: 0.2495 - acc: 0.9199

   4608/1044897 [..............................] - ETA: 5:53:31 - loss: 0.2491 - acc: 0.9201

   4736/1044897 [..............................] - ETA: 5:53:28 - loss: 0.2479 - acc: 0.9202

   4864/1044897 [..............................] - ETA: 5:52:50 - loss: 0.2472 - acc: 0.9206

   4992/1044897 [..............................] - ETA: 5:52:15 - loss: 0.2461 - acc: 0.9209

   5120/1044897 [..............................] - ETA: 5:51:51 - loss: 0.2455 - acc: 0.9207

   5248/1044897 [..............................] - ETA: 5:51:18 - loss: 0.2439 - acc: 0.9207

   5376/1044897 [..............................] - ETA: 5:50:43 - loss: 0.2426 - acc: 0.9209

   5504/1044897 [..............................] - ETA: 5:50:09 - loss: 0.2407 - acc: 0.9215

   5632/1044897 [..............................] - ETA: 5:49:36 - loss: 0.2386 - acc: 0.9222

   5760/1044897 [..............................] - ETA: 5:49:03 - loss: 0.2384 - acc: 0.9220

   5888/1044897 [..............................] - ETA: 5:48:34 - loss: 0.2368 - acc: 0.9229

   6016/1044897 [..............................] - ETA: 5:48:06 - loss: 0.2345 - acc: 0.9237

   6144/1044897 [..............................] - ETA: 5:47:38 - loss: 0.2347 - acc: 0.9237

   6272/1044897 [..............................] - ETA: 5:47:13 - loss: 0.2336 - acc: 0.9241

   6400/1044897 [..............................] - ETA: 5:46:47 - loss: 0.2339 - acc: 0.9237

   6528/1044897 [..............................] - ETA: 5:46:24 - loss: 0.2329 - acc: 0.9240

   6656/1044897 [..............................] - ETA: 5:46:01 - loss: 0.2315 - acc: 0.9241

   6784/1044897 [..............................] - ETA: 5:45:46 - loss: 0.2300 - acc: 0.9242

   6912/1044897 [..............................] - ETA: 5:45:29 - loss: 0.2279 - acc: 0.9251

   7040/1044897 [..............................] - ETA: 5:45:05 - loss: 0.2287 - acc: 0.9246

   7168/1044897 [..............................] - ETA: 5:44:45 - loss: 0.2267 - acc: 0.9252

   7296/1044897 [..............................] - ETA: 5:44:27 - loss: 0.2269 - acc: 0.9253

   7424/1044897 [..............................] - ETA: 5:44:06 - loss: 0.2273 - acc: 0.9247

   7552/1044897 [..............................] - ETA: 5:43:46 - loss: 0.2261 - acc: 0.9252

   7680/1044897 [..............................] - ETA: 5:43:28 - loss: 0.2252 - acc: 0.9255

   7808/1044897 [..............................] - ETA: 5:43:11 - loss: 0.2243 - acc: 0.9260

   7936/1044897 [..............................] - ETA: 5:42:57 - loss: 0.2243 - acc: 0.9254

   8064/1044897 [..............................] - ETA: 5:42:42 - loss: 0.2229 - acc: 0.9258

   8192/1044897 [..............................] - ETA: 5:42:27 - loss: 0.2210 - acc: 0.9263

   8320/1044897 [..............................] - ETA: 5:42:11 - loss: 0.2209 - acc: 0.9261

   8448/1044897 [..............................] - ETA: 5:42:00 - loss: 0.2205 - acc: 0.9264

   8576/1044897 [..............................] - ETA: 5:41:46 - loss: 0.2186 - acc: 0.9269

   8704/1044897 [..............................] - ETA: 5:41:32 - loss: 0.2176 - acc: 0.9270

   8832/1044897 [..............................] - ETA: 5:41:18 - loss: 0.2168 - acc: 0.9271

   8960/1044897 [..............................] - ETA: 5:41:04 - loss: 0.2152 - acc: 0.9277

   9088/1044897 [..............................] - ETA: 5:40:49 - loss: 0.2143 - acc: 0.9276

   9216/1044897 [..............................] - ETA: 5:40:36 - loss: 0.2150 - acc: 0.9272

   9344/1044897 [..............................] - ETA: 5:40:21 - loss: 0.2156 - acc: 0.9268

   9472/1044897 [..............................] - ETA: 5:40:09 - loss: 0.2146 - acc: 0.9272

   9600/1044897 [..............................] - ETA: 5:39:55 - loss: 0.2143 - acc: 0.9272

   9728/1044897 [..............................] - ETA: 5:39:41 - loss: 0.2132 - acc: 0.9274

   9856/1044897 [..............................] - ETA: 5:39:31 - loss: 0.2121 - acc: 0.9279

   9984/1044897 [..............................] - ETA: 5:39:20 - loss: 0.2111 - acc: 0.9283

  10112/1044897 [..............................] - ETA: 5:39:10 - loss: 0.2101 - acc: 0.9284

  10240/1044897 [..............................] - ETA: 5:39:00 - loss: 0.2092 - acc: 0.9286

  10368/1044897 [..............................] - ETA: 5:38:48 - loss: 0.2084 - acc: 0.9288

  10496/1044897 [..............................] - ETA: 5:39:16 - loss: 0.2094 - acc: 0.9286

  10624/1044897 [..............................] - ETA: 5:39:19 - loss: 0.2080 - acc: 0.9290

  10752/1044897 [..............................] - ETA: 5:39:08 - loss: 0.2065 - acc: 0.9296

  10880/1044897 [..............................] - ETA: 5:39:05 - loss: 0.2061 - acc: 0.9297

  11008/1044897 [..............................] - ETA: 5:39:07 - loss: 0.2052 - acc: 0.9299

  11136/1044897 [..............................] - ETA: 5:39:26 - loss: 0.2045 - acc: 0.9301

KeyboardInterrupt: 

In [29]:
print(sum(y for y in y_val if y==1))
print(934*2)

113
1868


# Predictions
The remainder of this notebok will generate predictions from the test set and write them to a submission csv file. 

In [None]:
test_df = pd.read_csv('data/test.csv')
X_test = train_df['question_text'].values

X_test = tokenizer.texts_to_sequences(X_test)
X_test = pad_sequences(X_test, maxlen=maxlen)

test.to_csv('data/bmmidei_NB_Submission_1', index=False