In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

%matplotlib inline

In [2]:
dataset, info = tfds.load(
    'imdb_reviews/subwords8k', with_info=True, as_supervised=True)
train, test = dataset['train'], dataset['test']

In [3]:
tokenizer = info.features['text'].encoder

In [4]:
BUFFER = 10000
BATCH = 64

In [5]:
train = train.shuffle(BUFFER)
train = train.padded_batch(BATCH, train.output_shapes)
test = test.padded_batch(BATCH, test.output_shapes)

In [6]:
mod = tf.keras.Sequential([
    tf.keras.layers.Embedding(tokenizer.vocab_size, 64),
    tf.keras.layers.Bidirectional(
        tf.keras.layers.LSTM(64, return_sequences=True)),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')])
mod.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (None, None, 64)          523840    
_________________________________________________________________
bidirectional (Bidirectional (None, None, 128)         66048     
_________________________________________________________________
bidirectional_1 (Bidirection (None, 64)                41216     
_________________________________________________________________
dense (Dense)                (None, 64)                4160      
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 65        
Total params: 635,329
Trainable params: 635,329
Non-trainable params: 0
_________________________________________________________________


In [8]:
mod.compile(
    loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

In [9]:
EPX = 10

In [None]:
history = mod.fit(train, epochs=EPX, validation_data=test)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
 50/391 [==>...........................] - ETA: 37:08 - loss: 0.2586 - accuracy: 0.9053

In [None]:
def plot_graphs(history, string):
    plt.plot(history.history[string])
    plt.plot(history.history['val_' + string])
    plt.xlabel('Epochs')
    plt.ylabel(string)
    plt.legend([string, 'val_' + string])

In [None]:
plot_graphs(history, 'accuracy')

In [None]:
plot_graphs(history, 'loss')