In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization, Dropout, Layer
from tensorflow.keras.layers import Embedding, Input, GlobalAveragePooling1D, Dense
from tensorflow.keras.datasets import imdb
from tensorflow.keras.models import Sequential, Model
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

In [None]:
class TransformerBlock(Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = Sequential(
            [Dense(ff_dim, activation="relu"), 
             Dense(embed_dim),]
        )
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)

    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

In [None]:
class TokenAndPositionEmbedding(Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

In [None]:
vocab_size = 20000  # Only consider the top 20k words
maxlen = 200  # Only consider the first 200 words of each movie review

(x_train, y_train), (x_val, y_val) = imdb.load_data(num_words=vocab_size)
print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
   16384/17464789 [..............................] - ETA: 5s

  720896/17464789 [>.............................] - ETA: 1s

 1916928/17464789 [==>...........................] - ETA: 0s

 2924544/17464789 [====>.........................] - ETA: 0s

 3596288/17464789 [=====>........................] - ETA: 0s

























25000 Training sequences
25000 Validation sequences


In [None]:
x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_val = tf.keras.preprocessing.sequence.pad_sequences(x_val, maxlen=maxlen)

In [None]:
embed_dim = 32  # Embedding size for each token
num_heads = 2  # Number of attention heads
ff_dim = 32  # Hidden layer size in feed forward network inside transformer

inputs = Input(shape=(maxlen,))
embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
x = embedding_layer(inputs)
transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
x = transformer_block(x)
x = GlobalAveragePooling1D()(x)
x = Dropout(0.1)(x)
x = Dense(20, activation="relu")(x)
x = Dropout(0.1)(x)
outputs = Dense(2, activation="softmax")(x)

model = Model(inputs=inputs, outputs=outputs)

2022-02-15 02:11:00.468018: W tensorflow/stream_executor/platform/default/dso_loader.cc:65] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/compat/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-02-15 02:11:00.468077: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2022-02-15 02:11:00.468097: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (nqiwrmsge5): /proc/driver/nvidia/version does not exist


In [None]:
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

history = model.fit(x_train, y_train, 
                    batch_size=64, epochs=2, 
                    validation_data=(x_val, y_val)
                   )
model.save_weights("predict_class.h5")

2022-02-15 02:11:22.459668: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


Epoch 1/2


  1/391 [..............................] - ETA: 15:19 - loss: 0.7387 - accuracy: 0.4844

  2/391 [..............................] - ETA: 2:35 - loss: 0.7188 - accuracy: 0.5312 

  3/391 [..............................] - ETA: 2:34 - loss: 0.7156 - accuracy: 0.5104

  4/391 [..............................] - ETA: 2:43 - loss: 0.7150 - accuracy: 0.5000

  5/391 [..............................] - ETA: 2:33 - loss: 0.7110 - accuracy: 0.4844

  6/391 [..............................] - ETA: 2:34 - loss: 0.7105 - accuracy: 0.4818

  7/391 [..............................] - ETA: 2:37 - loss: 0.7075 - accuracy: 0.4911

  8/391 [..............................] - ETA: 2:36 - loss: 0.7050 - accuracy: 0.4941

  9/391 [..............................] - ETA: 2:36 - loss: 0.7037 - accuracy: 0.4983

 10/391 [..............................] - ETA: 2:35 - loss: 0.7030 - accuracy: 0.4969

 11/391 [..............................] - ETA: 2:35 - loss: 0.7003 - accuracy: 0.5057

 12/391 [..............................] - ETA: 2:34 - loss: 0.6997 - accuracy: 0.5039

 13/391 [..............................] - ETA: 2:34 - loss: 0.6993 - accuracy: 0.5036

 14/391 [>.............................] - ETA: 2:35 - loss: 0.6985 - accuracy: 0.5056

 15/391 [>.............................] - ETA: 2:35 - loss: 0.6974 - accuracy: 0.5094

 16/391 [>.............................] - ETA: 2:34 - loss: 0.6977 - accuracy: 0.5078

 17/391 [>.............................] - ETA: 2:34 - loss: 0.6976 - accuracy: 0.5083

 18/391 [>.............................] - ETA: 2:33 - loss: 0.6966 - accuracy: 0.5113

 19/391 [>.............................] - ETA: 2:32 - loss: 0.6962 - accuracy: 0.5156

 20/391 [>.............................] - ETA: 2:34 - loss: 0.6961 - accuracy: 0.5156

 21/391 [>.............................] - ETA: 2:33 - loss: 0.6970 - accuracy: 0.5126

 22/391 [>.............................] - ETA: 2:32 - loss: 0.6975 - accuracy: 0.5099

 23/391 [>.............................] - ETA: 2:33 - loss: 0.6969 - accuracy: 0.5095

 24/391 [>.............................] - ETA: 2:32 - loss: 0.6971 - accuracy: 0.5085

 25/391 [>.............................] - ETA: 2:32 - loss: 0.6966 - accuracy: 0.5094

 26/391 [>.............................] - ETA: 2:31 - loss: 0.6965 - accuracy: 0.5096

 27/391 [=>............................] - ETA: 2:30 - loss: 0.6961 - accuracy: 0.5116

 28/391 [=>............................] - ETA: 2:30 - loss: 0.6957 - accuracy: 0.5134

 29/391 [=>............................] - ETA: 2:29 - loss: 0.6954 - accuracy: 0.5151

 30/391 [=>............................] - ETA: 2:29 - loss: 0.6953 - accuracy: 0.5167

 31/391 [=>............................] - ETA: 2:28 - loss: 0.6950 - accuracy: 0.5181

 32/391 [=>............................] - ETA: 2:29 - loss: 0.6947 - accuracy: 0.5205

 33/391 [=>............................] - ETA: 2:28 - loss: 0.6949 - accuracy: 0.5199

 34/391 [=>............................] - ETA: 2:27 - loss: 0.6945 - accuracy: 0.5239

 35/391 [=>............................] - ETA: 2:27 - loss: 0.6940 - accuracy: 0.5286

 36/391 [=>............................] - ETA: 2:27 - loss: 0.6937 - accuracy: 0.5295

 37/391 [=>............................] - ETA: 2:27 - loss: 0.6929 - accuracy: 0.5351

 38/391 [=>............................] - ETA: 2:26 - loss: 0.6930 - accuracy: 0.5337

 39/391 [=>............................] - ETA: 2:26 - loss: 0.6923 - accuracy: 0.5381

 40/391 [==>...........................] - ETA: 2:26 - loss: 0.6919 - accuracy: 0.5395

 41/391 [==>...........................] - ETA: 2:25 - loss: 0.6915 - accuracy: 0.5396

 42/391 [==>...........................] - ETA: 2:25 - loss: 0.6913 - accuracy: 0.5391

 43/391 [==>...........................] - ETA: 2:25 - loss: 0.6908 - accuracy: 0.5407

 44/391 [==>...........................] - ETA: 2:24 - loss: 0.6906 - accuracy: 0.5426

 45/391 [==>...........................] - ETA: 2:23 - loss: 0.6900 - accuracy: 0.5451

 46/391 [==>...........................] - ETA: 2:23 - loss: 0.6896 - accuracy: 0.5462

 47/391 [==>...........................] - ETA: 2:22 - loss: 0.6894 - accuracy: 0.5469

 48/391 [==>...........................] - ETA: 2:22 - loss: 0.6893 - accuracy: 0.5472

In [None]:
results = model.evaluate(x_val, y_val, verbose=2)

for name, value in zip(model.metrics_names, results):
    print("%s: %.3f" % (name, value))