# Transformer Encoder
Data: https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

### Import Data

In [4]:
batch_size = 32
train_ds = keras.utils.text_dataset_from_directory('data/aclImdb/train/', batch_size=batch_size)
test_ds = keras.utils.text_dataset_from_directory('data/aclImdb/test/', batch_size=batch_size)

Found 65564 files belonging to 3 classes.
Found 25000 files belonging to 2 classes.


In [6]:
for inputs, targets in train_ds:
    print(inputs.shape, inputs.dtype)
    print(targets.shape, targets.dtype)
    break

(32,) <dtype: 'string'>
(32,) <dtype: 'int32'>


In [7]:
# tokenize data
max_length = 600  # max words to use
max_tokens = 20000  # tokens to use for embedding
text_vectorisation = layers.TextVectorization(
    max_tokens=max_tokens,
    output_mode='int',
    output_sequence_length=max_length,
)

# get only raw text without labels
text_only = train_ds.map(lambda x, y: x)
text_vectorisation.adapt(text_only)

int_train_ds = train_ds.map(lambda x, y: (text_vectorisation(x), y), num_parallel_calls=4)
int_test_ds = test_ds.map(lambda x, y: (text_vectorisation(x), y), num_parallel_calls=4)

### Define Transformer

In [2]:
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim)
        self.dense_proj = keras.Sequential(
            [layers.Dense(dense_dim, activation='relu'),
             layers.Dense(embed_dim)])
        self.norm_1 = layers.LayerNormalization()
        self.norm_2 = layers.LayerNormalization()
    
    def call(self, inputs, mask=None):
        if mask is not None:
            mask = mask[:, tf.newaxis, :]
        attention_out = self.attention(
            inputs, inputs, attention_mask=mask)
        # add residual connection to attention output
        proj_in = self.norm_1(inputs + attention_out)
        proj_out = self.dense_proj(proj_in)
        # add residual connection to projection layer
        return self.norm_2(proj_in + proj_out)
    
    def get_config(self):
        # enable saving custom layer
        config = super().get_config()
        config.update({
            'embed_dim': self.embed_dim,
            'num_heads:': self.num_heads,
            'dense_dim': self.dense_dim,
        })
        return config

In [9]:
vocab_size = 20000
embed_dim = 256
num_heads = 2
dense_dim = 32

inputs = keras.Input(shape=(None,), dtype='int64')
x = layers.Embedding(vocab_size, embed_dim)(inputs)
x = TransformerEncoder(embed_dim, dense_dim, num_heads)(x)
x = layers.GlobalMaxPooling1D()(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs, outputs)

model.compile(
    optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'],
)
model.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, None)]            0         
                                                                 
 embedding_1 (Embedding)     (None, None, 256)         5120000   
                                                                 
 transformer_encoder_1 (Tran  (None, None, 256)        543776    
 sformerEncoder)                                                 
                                                                 
 global_max_pooling1d_1 (Glo  (None, 256)              0         
 balMaxPooling1D)                                                
                                                                 
 dropout_1 (Dropout)         (None, 256)               0         
                                                                 
 dense_5 (Dense)             (None, 1)                 257 

In [None]:
callbacks = [keras.callbacks.ModelCheckpoint('transformer_encoder.keras', save_best_only=True)]
model.fit(int_train_ds, validation_data=int_test_ds, epochs=10, callbacks=callbacks)

In [None]:
model = keras.models.load_model('transformer_encoder.keras', custom_objects={'TransformerEncoder': TransformerEncoder})
eval = model.evaluate(int_test_ds)
print('Accuracy:', eval[1])