In [None]:
%matplotlib widget
import time
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
(train_data, train_lbl), (test_data, test_lbl) = keras.datasets.imdb.load_data(num_words=10000)
index = {v: k for k, v in keras.datasets.imdb.get_word_index().items()}
#print(" ".join(index.get(i-3, "?") for i in train_data[0]))

In [None]:
mdl = keras.Sequential([
    layers.Dense(16, activation="relu"),
    layers.Dense(16, activation="relu"),
    layers.Dense(1, activation="sigmoid", dtype="float32"),
])
mdl.compile(optimizer="rmsprop", loss="binary_crossentropy", metrics=["accuracy"])

In [None]:
def vec_seq(seqs, dims=10000):
    res = np.zeros((len(seqs), dims))
    for i, seq in enumerate(seqs):
        for j in seq:
            res[i, j] = 1.
    return res

x_train = vec_seq(train_data)
x_test = vec_seq(test_data)
y_train = np.array(train_lbl).astype("float32")
y_test = np.array(test_lbl).astype("float32")

x_val = x_train[:10000]
part_x_train = x_train[10000:]
y_val = y_train[:10000]
part_y_train = y_train[10000:]

In [None]:
start = time.time()
history = mdl.fit(
    part_x_train, part_y_train, epochs=20, batch_size=512, verbose=0,
    validation_data=(x_val, y_val))
print("Done in {:.1f}s".format(time.time() - start))
epochs = range(1, len(history.history["loss"]) + 1)

In [None]:
res = mdl.evaluate(x_test, y_test)
print("Loss: {:.3f}  Acc: {:.3f}".format(*res))

In [None]:
plt.clf()
plot1 = plt.figure(1)
plt.plot(epochs, history.history["loss"], "bo", label="Training loss")
plt.plot(epochs, history.history["val_loss"], "b", label="Validation loss")
plt.title("Training and validation loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

plot2 = plt.figure(2)
plt.plot(epochs, history.history["accuracy"], "bo", label="Training acc")
plt.plot(epochs, history.history["val_accuracy"], "b", label="Validation acc")
plt.title("Training and validation accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()