# 6. Transformer models for text processing

In [1]:
import keras
from keras import layers
from keras import ops
import numpy as np
import tensorflow as tf

In [2]:
batch_size = 32

train_ds = keras.utils.text_dataset_from_directory( 
    '../../aclImdb/train/', 
    validation_split=0.2, 
    subset="training", 
    seed=123,
    batch_size=batch_size)

val_ds = keras.utils.text_dataset_from_directory( 
    '../../aclImdb/train/', 
    validation_split=0.2, 
    subset="validation", 
    seed=123, # same seed as above!
    batch_size=batch_size)

test_ds = keras.utils.text_dataset_from_directory( 
    '../../aclImdb/test/', 
    batch_size=batch_size)

Found 25000 files belonging to 2 classes.
Using 20000 files for training.
Found 25000 files belonging to 2 classes.
Using 5000 files for validation.
Found 25000 files belonging to 2 classes.


In [3]:
vocab_size = 10000 # maximum number of tokens in vocabulary 
sequence_length = 250 # maximum length of sequences

vectorization_layer = layers.TextVectorization( 
    max_tokens=vocab_size, 
    output_mode='int',
    output_sequence_length=sequence_length,
)

train_texts = train_ds.map(lambda x, y: x) 
vectorization_layer.adapt(train_texts)

train_ds_int = train_ds.map(lambda x, y: (vectorization_layer(x), y)) 
val_ds_int = val_ds.map(lambda x, y: (vectorization_layer(x), y))
test_ds_int = test_ds.map(lambda x, y: (vectorization_layer(x), y))

In [4]:
for (x, y) in train_ds_int.take(1):
    print(x.shape, y.shape)
    break

(32, 250) (32,)


In [5]:
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, mask=None):
        if mask is not None:
            padding_mask = ops.cast(mask[:, None, :], dtype="int32")
        else:
            padding_mask = None

        attention_output = self.attention(
            query=inputs, value=inputs, key=inputs, attention_mask=padding_mask
        )
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "dense_dim": self.dense_dim,
                "num_heads": self.num_heads,
            }
        )
        return config


class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def call(self, inputs):
        length = ops.shape(inputs)[-1]
        positions = ops.arange(0, length, 1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        if mask is None:
            return None
        else:
            return ops.not_equal(inputs, 0)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "sequence_length": self.sequence_length,
                "vocab_size": self.vocab_size,
                "embed_dim": self.embed_dim,
            }
        )
        return config

In [9]:
embed_dim = 32
dense_dim = 32
num_heads = 2

inputs = keras.Input(shape=(None,), dtype="int64")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(inputs)
x = TransformerEncoder(embed_dim, dense_dim, num_heads)(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation='sigmoid')(x)

model = keras.Model(inputs=inputs, outputs=outputs)

model.compile(optimizer='rmsprop', 
              loss='binary_crossentropy', 
              metrics=['accuracy'])

model.summary()

In [10]:
history = model.fit(
    train_ds_int, 
    batch_size=batch_size, 
    epochs=2, 
    validation_data=(val_ds_int)
)


Epoch 1/2
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 26ms/step - accuracy: 0.6188 - loss: 0.6566 - val_accuracy: 0.8178 - val_loss: 0.4402
Epoch 2/2
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 26ms/step - accuracy: 0.8255 - loss: 0.4028 - val_accuracy: 0.8604 - val_loss: 0.3395


In [11]:
model.evaluate(test_ds_int)

[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 13ms/step - accuracy: 0.8464 - loss: 0.3573


[0.3544078469276428, 0.8489999771118164]

In [12]:
def reweight_distribution(original_distribution, temperature=0.5):
    distribution = np.log(original_distribution) / temperature
    distribution = np.exp(distribution)
    return distribution / np.sum(distribution)

In [15]:
dataset = keras.utils.text_dataset_from_directory(
    directory="../../aclImdb/train/", label_mode=None, batch_size=256)
dataset = dataset.map(lambda x: tf.strings.regex_replace(x, "<br />", " "))

Found 25000 files.


In [40]:
sequence_length = 100
vocab_size = 10000
text_vectorization = layers.TextVectorization(
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=sequence_length,
)
text_vectorization.adapt(dataset)

In [41]:
def prepare_lm_dataset(text_batch):
    vectorized_sequences = text_vectorization(text_batch)
    x = vectorized_sequences[:, :-1]
    y = vectorized_sequences[:, 1:]
    return x, y

lm_dataset = dataset.map(prepare_lm_dataset)

In [18]:
class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(latent_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, encoder_outputs, mask=None):
        causal_mask = self.get_causal_attention_mask(inputs)
        if mask is not None:
            padding_mask = ops.cast(mask[:, None, :], dtype="int32")
            padding_mask = ops.minimum(padding_mask, causal_mask)
        else:
            padding_mask = None

        attention_output_1 = self.attention_1(
            query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
        )
        out_1 = self.layernorm_1(inputs + attention_output_1)

        attention_output_2 = self.attention_2(
            query=out_1,
            value=encoder_outputs,
            key=encoder_outputs,
            attention_mask=padding_mask,
        )
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        proj_output = self.dense_proj(out_2)
        return self.layernorm_3(out_2 + proj_output)

    def get_causal_attention_mask(self, inputs):
        input_shape = ops.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = ops.arange(sequence_length)[:, None]
        j = ops.arange(sequence_length)
        mask = ops.cast(i >= j, dtype="int32")
        mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = ops.concatenate(
            [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
            axis=0,
        )
        return ops.tile(mask, mult)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "latent_dim": self.latent_dim,
                "num_heads": self.num_heads,
            }
        )
        return config

In [43]:
embed_dim = 32
latent_dim = 64
num_heads = 2

inputs = keras.Input(shape=(None,), dtype="int64")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(inputs)
x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, x)
outputs = layers.Dense(vocab_size, activation="softmax")(x)

model = keras.Model(inputs, outputs)

model.compile(loss="sparse_categorical_crossentropy", optimizer="rmsprop")

In [44]:
model.summary()

In [49]:
model.fit(lm_dataset, epochs=10)

Epoch 1/10
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m67s[0m 679ms/step - loss: 5.1510
Epoch 2/10
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m86s[0m 874ms/step - loss: 5.1249
Epoch 3/10
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m105s[0m 1s/step - loss: 5.0941
Epoch 4/10
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m89s[0m 904ms/step - loss: 5.0771
Epoch 5/10
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m88s[0m 892ms/step - loss: 5.0589
Epoch 6/10
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m85s[0m 863ms/step - loss: 5.0407
Epoch 7/10
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m86s[0m 875ms/step - loss: 5.0257
Epoch 8/10
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 844ms/step - loss: 5.0102
Epoch 9/10
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 842ms/step - loss: 5.0005
Epoch 10/10
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m84s[0m 850ms/s

<keras.src.callbacks.history.History at 0x199e3eef2f0>

In [46]:
tokens_index = dict(enumerate(text_vectorization.get_vocabulary()))

def sample_next(predictions, temperature=1.0):
    predictions = np.asarray(predictions).astype("float64")
    predictions = np.log(predictions) / temperature
    exp_preds = np.exp(predictions)
    predictions = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, predictions, 1)
    return np.argmax(probas)

In [65]:
temperature = 1.5

sentence = "in my opinion"
generate_length = 50
for i in range(generate_length):
    tokenized_sentence = text_vectorization([sentence])
    predictions = model(tokenized_sentence)
    next_token = sample_next(predictions[0, i, :])
    sampled_token = tokens_index[next_token]
    sentence += " " + sampled_token
print(sentence)

in my opinion the kind the top of characters time game never  the find video anything  about [UNK] veteran im glad ive seen there are fears [UNK]                        


In [66]:
sentence = ""
for (x, y) in lm_dataset.take(1):
    for i in x[0]:
        sampled_token = tokens_index[int(i)]
        sentence += " " + sampled_token
    break
    
print(sentence)

 de palmas technique had hit its high maturity by the time of this film which is a wonderful showcase of his classic techniques though unfortunately as with many of the films written by de palma himself the story serves the [UNK] more than the interests of putting forth an emotionally compelling tale the story opens with a crazy scene in which angie dickinson [UNK] in a shower while she looks at her husband she is then grabbed and raped while he husband stands [UNK] [UNK] the whole thing is revealed to be [UNK] fantasy as he husband is [UNK]
