In [None]:
import numpy as np
import pandas as pd
import keras
from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers import Dense, Embedding, SpatialDropout1D
from keras.layers import LSTM
from keras.datasets import imdb, reuters
from keras.utils import to_categorical
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
max_words = 20_000
line_length = 200
batch_size = 32

In [None]:
(x_train, y_train), (x_test, y_test) = reuters.load_data(num_words=max_words)

In [None]:
# number of articles
x_train.shape, y_train.shape

In [None]:
# reuters categories
np.unique(y_train)

In [None]:
# mapping dictionaries
word_to_id = reuters.get_word_index()
id_to_word = {v:k for k,v in word_to_id.items()}

In [None]:
id_to_word[0] = '<START>'

In [None]:
def article(index):
    return ' '.join([id_to_word[max(0, idx-3)] for idx in x_train[index]])

In [None]:
for i in range(10):
    print(article(i), '\n\n\n')

In [None]:
# make all rows the same length
x_train = sequence.pad_sequences(x_train, maxlen=line_length)
x_test = sequence.pad_sequences(x_test, maxlen=line_length)

In [None]:
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

In [None]:
x_train.shape, y_train.shape

In [None]:
model = Sequential()
model.add(Embedding(max_words, 128))
model.add(LSTM(128, dropout=0.5, recurrent_dropout=0.5))
model.add(Dense(46, activation='softmax'))
model.summary()

In [None]:
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# ************************************************
# CHANGE THE EPOCHS, BELOW, TO GET HIGHER ACCURACY
# ************************************************

history = model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=2,
          validation_data=(x_test, y_test))

In [None]:
fig, ax1 = plt.subplots(1,1,figsize=(12,7))
ax1.plot(history.epoch, history.history['loss'], marker='^', color='purple')
ax1.set_xlabel('epochs')
ax1.set_ylabel('loss', color='purple')
ax1.tick_params('y', colors='purple')

ax2 = ax1.twinx()
plt.plot(history.epoch, history.history['acc'], marker='+', color='green', label='train')
ax2.set_ylim(0,1)

ax3 = ax1.twinx()
plt.plot(history.epoch, history.history['val_acc'], marker='*', color='red', label='validation')
ax3.set_ylim(0,1)

fig.suptitle('classifying reuters articles');
fig.legend();