In [1]:
import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras import models
from keras import optimizers
from keras import losses
from keras import metrics
from keras import regularizers
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Transformer block
class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super().__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)
    

In [3]:
# Encoder block
class EncoderBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super().__init__()
        self.transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim, rate)
        self.dropout = layers.Dropout(rate)

    def call(self, inputs, training):
        attn_output = self.transformer_block(inputs, training=training)
        return self.dropout(attn_output, training=training)

In [4]:
# Load imdb dataset
vocab_size = 20000
maxlen = 200
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=maxlen)

In [5]:
# Create model
embed_dim = 32  # Embedding size for each token
num_heads = 2  # Number of attention heads
ff_dim = 32  # Hidden layer size in feed forward network inside transformer

inputs = layers.Input(shape=(maxlen,))
embedding_layer = layers.Embedding(vocab_size, embed_dim)(inputs)
x = embedding_layer
for _ in range(2):
    x = EncoderBlock(embed_dim, num_heads, ff_dim)(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dropout(0.1)(x)
x = layers.Dense(20, activation="relu")(x)
x = layers.Dropout(0.1)(x)
outputs = layers.Dense(2, activation="softmax")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

In [6]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 200)]             0         
                                                                 
 embedding (Embedding)       (None, 200, 32)           640000    
                                                                 
 encoder_block (EncoderBlock  (None, 200, 32)          10656     
 )                                                               
                                                                 
 encoder_block_1 (EncoderBlo  (None, 200, 32)          10656     
 ck)                                                             
                                                                 
 global_average_pooling1d (G  (None, 32)               0         
 lobalAveragePooling1D)                                          
                                                             

In [7]:
# Train model
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
history = model.fit(
    x_train, y_train, batch_size=32, epochs=2, validation_data=(x_val, y_val)
)

Epoch 1/2
Epoch 2/2


In [9]:
# model evaluation
test_scores = model.evaluate(x_val, y_val, verbose=2)
print("Test loss:", test_scores[0])
print("Test accuracy:", test_scores[1])

782/782 - 35s - loss: 0.3204 - accuracy: 0.8672 - 35s/epoch - 44ms/step
Test loss: 0.3203639090061188
Test accuracy: 0.8672000169754028
