In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
dataset, info = tfds.load(
    'imdb_reviews/subwords8k', with_info=True, as_supervised=True)
train_dataset, test_dataset = dataset['train'], dataset['test']

In [3]:
tokenizer = info.features['text'].encoder

In [4]:
BUFFER = 10000
BATCH = 64

train_dataset = train_dataset.shuffle(BUFFER)
train_dataset = train_dataset.padded_batch(
    BATCH, train_dataset.output_shapes)
test_dataset = test_dataset.padded_batch(
    BATCH, test_dataset.output_shapes)

In [5]:
mod = tf.keras.Sequential([
    tf.keras.layers.Embedding(tokenizer.vocab_size, 64),
    tf.keras.layers.Conv1D(128, 5, activation='relu'),
    tf.keras.layers.GlobalAveragePooling1D(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')])

In [6]:
mod.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (None, None, 64)          523840    
_________________________________________________________________
conv1d (Conv1D)              (None, None, 128)         41088     
_________________________________________________________________
global_average_pooling1d (Gl (None, 128)               0         
_________________________________________________________________
dense (Dense)                (None, 64)                8256      
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 65        
Total params: 573,249
Trainable params: 573,249
Non-trainable params: 0
_________________________________________________________________


In [8]:
mod.compile(
    loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

In [9]:
EPX = 10
history = mod.fit(train_dataset, epochs=EPX, validation_data=test_dataset)

Epoch 1/10
    301/Unknown - 109s 361ms/step - loss: 0.4787 - accuracy: 0.7623

KeyboardInterrupt: 

In [None]:
def plot_graphs(history, string):
    plt.plot(history.history[string])
    plt.plot(history.history['val_' + string])
    plt.xlabel('Epochs')
    plt.ylabel(string)
    plt.legend([string, 'val_' + string])
    plt.show()

In [None]:
plot_graphs(history, 'accuracy')

In [None]:
plot_graphs(history, 'loss')