# Text classification with an RNN

## Setup

In [13]:
import os
import absl.logging

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
tf.get_logger().setLevel('ERROR')
absl.logging.set_verbosity(absl.logging.ERROR)
absl.logging.set_stderrthreshold(absl.logging.ERROR)

In [14]:
import warnings
warnings.filterwarnings('ignore')

import tensorflow as tf

In [15]:
import numpy as np

import tensorflow_datasets as tfds

from tensorflow.keras.layers import LSTM, Dense, Embedding, Input, Layer, AdditiveAttention, Attention

import matplotlib.pyplot as plt
tfds.disable_progress_bar()

In [16]:
dataset, info = tfds.load('imdb_reviews', with_info=True,
                          as_supervised=True)

train_dataset, test_dataset = dataset['train'], dataset['test']

In [17]:
BUFFER_SIZE = 10000
BATCH_SIZE = 64

In [18]:
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

## Create the text encoder

The raw text loaded by `tfds` needs to be processed before it can be used in a model. The simplest way to process text for training is using the `TextVectorization` layer. This layer has many capabilities, but this tutorial sticks to the default behavior.

Create the layer, and pass the dataset's text to the layer's `.adapt` method:

In [19]:
VOCAB_SIZE = 1000
encoder = tf.keras.layers.TextVectorization(
    max_tokens=VOCAB_SIZE)
encoder.adapt(train_dataset.map(lambda text, label: text))

The `.adapt` method sets the layer's vocabulary. Here are the first 20 tokens. After the padding and unknown tokens they're sorted by frequency:

# Dot Product Attention LSTM

Introduced by Vaswanie et al. 

\begin{align}
    Attention(Q, K, V) = softmax \left( \frac{QK^T}{\sqrt{d_k}} \right) V  
\end{align}

In [20]:
def create_attention_lstm():
    inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string)
    x = encoder(inputs)  
    x = tf.keras.layers.Embedding(input_dim=VOCAB_SIZE, output_dim=64, mask_zero=True)(x)
    
    x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32, return_sequences=True))(x)

    c, alphas = tf.keras.layers.Attention(name='attention')([x, x], return_attention_scores=True) 
    
    c = tf.keras.layers.GlobalAveragePooling1D()(c)

    out = tf.keras.layers.Dense(1, activation="sigmoid")(c)  
    print(out.shape)

    model = tf.keras.Model(inputs=inputs, outputs=[out])
    model.compile(loss=tf.keras.losses.BinaryCrossentropy(),
                  optimizer=tf.keras.optimizers.Adam(1e-4),
                  metrics=['accuracy'])

    return model
    

In [21]:
lstm_attention = create_attention_lstm()
lstm_attention.summary()

(None, 1)
Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_3 (InputLayer)        [(None, 1)]                  0         []                            
                                                                                                  
 text_vectorization_1 (Text  (None, None)                 0         ['input_3[0][0]']             
 Vectorization)                                                                                   
                                                                                                  
 embedding_2 (Embedding)     (None, None, 64)             64000     ['text_vectorization_1[0][0]']
                                                                                                  
 bidirectional_2 (Bidirecti  (None, None, 64)             24832     ['embedding_2[

In [22]:
h = lstm_attention.fit(train_dataset, batch_size=64, verbose=1, epochs=3)

Epoch 1/3


W0000 00:00:1714982686.261091   76463 op_level_cost_estimator.cc:699] Error in PredictCost() for the op: op: "Softmax" attr { key: "T" value { type: DT_FLOAT } } inputs { dtype: DT_FLOAT shape { unknown_rank: true } } device { type: "CPU" vendor: "GenuineIntel" model: "110" frequency: 2793 num_cores: 38 environment { key: "cpu_instruction_set" value: "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2" } environment { key: "eigen" value: "3.4.90" } l1_cache_size: 32768 l2_cache_size: 4194304 l3_cache_size: 16777216 memory_size: 268435456 } outputs { dtype: DT_FLOAT shape { unknown_rank: true } }


Epoch 2/3
Epoch 3/3


In [23]:
test_loss, test_acc = lstm_attention.evaluate(test_dataset)

print('Test Loss:', test_loss)
print('Test Accuracy:', test_acc)

W0000 00:00:1714985941.383123   76463 op_level_cost_estimator.cc:699] Error in PredictCost() for the op: op: "Softmax" attr { key: "T" value { type: DT_FLOAT } } inputs { dtype: DT_FLOAT shape { unknown_rank: true } } device { type: "CPU" vendor: "GenuineIntel" model: "110" frequency: 2793 num_cores: 38 environment { key: "cpu_instruction_set" value: "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2" } environment { key: "eigen" value: "3.4.90" } l1_cache_size: 32768 l2_cache_size: 4194304 l3_cache_size: 16777216 memory_size: 268435456 } outputs { dtype: DT_FLOAT shape { unknown_rank: true } }


Test Loss: 0.4088391065597534
Test Accuracy: 0.8254799842834473


### Extract attention weights

In [24]:
attention = tf.keras.Model(inputs=lstm_attention.inputs, outputs=[lstm_attention.outputs[0], lstm_attention.get_layer('attention').output])
attention.summary()

Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_3 (InputLayer)        [(None, 1)]                  0         []                            
                                                                                                  
 text_vectorization_1 (Text  (None, None)                 0         ['input_3[0][0]']             
 Vectorization)                                                                                   
                                                                                                  
 embedding_2 (Embedding)     (None, None, 64)             64000     ['text_vectorization_1[0][0]']
                                                                                                  
 bidirectional_2 (Bidirecti  (None, None, 64)             24832     ['embedding_2[0][0]']   

In [25]:
from IPython.display import HTML, display

In [26]:
def colorize_words(text, weights, alpha=0.5):
    words = text.split()
    weights = (weights - np.min(weights)) / (np.max(weights) - np.min(weights))
    weights = np.mean(weights, axis=1)
    weights = np.squeeze(weights)

    colors = np.asarray(plt.cm.Reds(weights))

    colors_with_alpha = np.insert(colors, 3, 0.5, axis=1) 
    template = '<span style="background-color: rgba({}, {}, {}, {}); color: black;">{} </span>'

    colored_string = "".join([
        template.format(int(color[0]*255), int(color[1]*255), int(color[2]*255), color[3], word)
        for word, color in zip(words, colors_with_alpha)])

    return colored_string

def display_pred(sample_text, label):
    prediction, (context, weights) = attention.predict(np.array([sample_text]))
    print(prediction[0][0])
    print(f'actual label {label}')

    return colorize_words(sample_text, weights)

In [46]:
for text, label in train_dataset.take(10):
    sample = text[:1].numpy()[0].decode('utf-8')
    label =  label[:1].numpy()[0]
    
    display(HTML(display_pred(sample, label)))

    print('\n')
    print('-'*80)

0.45837817
actual label 1




--------------------------------------------------------------------------------
0.27770874
actual label 0




--------------------------------------------------------------------------------
0.047667842
actual label 0




--------------------------------------------------------------------------------
0.6430563
actual label 1




--------------------------------------------------------------------------------
0.44752303
actual label 0




--------------------------------------------------------------------------------
0.49723545
actual label 0




--------------------------------------------------------------------------------
0.05513845
actual label 0




--------------------------------------------------------------------------------
0.94696605
actual label 1




--------------------------------------------------------------------------------
0.9304744
actual label 1




--------------------------------------------------------------------------------
0.9575539
actual label 0




--------------------------------------------------------------------------------
