In [1]:
# We set the backend to TensorFlow. The code works with
# both `tensorflow` and `torch`. It does not work with JAX
# due to the behavior of `jax.numpy.tile` in a jit scope
# (used in `causal_attention_mask()`: `tile` in JAX does
# not support a dynamic `reps` argument.
# You can make the code work in JAX by wrapping the
# inside of the `causal_attention_mask` function in
# a decorator to prevent jit compilation:
# `with jax.ensure_compile_time_eval():`.
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
from keras import layers
from keras import ops
from keras.layers import TextVectorization
import numpy as np
import os
import string
import random
import tensorflow
import tensorflow.data as tf_data
import tensorflow.strings as tf_strings

2025-11-13 09:32:21.884374: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763022742.015589   10277 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763022742.049984   10277 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1763022742.355707   10277 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763022742.355717   10277 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763022742.355719   10277 computation_placer.cc:177] computation placer alr

In [2]:
def causal_attention_mask(batch_size, n_dest, n_src, dtype):
    """
    Mask the upper half of the dot product matrix in self attention.
    This prevents flow of information from future tokens to current token.
    1's in the lower triangle, counting from the lower right corner.
    """
    i = ops.arange(n_dest)[:, None]
    j = ops.arange(n_src)
    m = i >= j - n_src + n_dest
    mask = ops.cast(m, dtype)
    mask = ops.reshape(mask, [1, n_dest, n_src])
    mult = ops.concatenate(
        [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0
    )
    return ops.tile(mask, mult)


class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super().__init__()
        self.att = layers.MultiHeadAttention(num_heads, embed_dim)
        self.ffn = keras.Sequential(
            [
                layers.Dense(ff_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs):
        input_shape = ops.shape(inputs)
        batch_size = input_shape[0]
        seq_len = input_shape[1]
        causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool")
        attention_output = self.att(inputs, inputs, attention_mask=causal_mask)
        attention_output = self.dropout1(attention_output)
        out1 = self.layernorm1(inputs + attention_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)

In [3]:
class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = ops.shape(x)[-1]
        positions = ops.arange(0, maxlen, 1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

In [4]:
vocab_size = 20000  # Only consider the top 20k words
maxlen = 80  # Max sequence size
embed_dim = 256  # Embedding size for each token
num_heads = 2  # Number of attention heads
feed_forward_dim = 256  # Hidden layer size in feed forward network inside transformer


def create_model():
    inputs = layers.Input(shape=(maxlen,), dtype="int32")
    embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
    x = embedding_layer(inputs)
    transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)
    x = transformer_block(x)
    outputs = layers.Dense(vocab_size)(x)
    model = keras.Model(inputs=inputs, outputs=[outputs, x])
    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    model.compile(
        "adam",
        loss=[loss_fn, None],
    )  # No loss and optimization based on word embeddings from transformer block
    return model

In [5]:
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 80.2M  100 80.2M    0     0  6582k      0  0:00:12  0:00:12 --:--:-- 8836k


In [6]:
batch_size = 128

# The dataset contains each review in a separate text file
# The text files are present in four different folders
# Create a list all files
filenames = []
directories = [
    "aclImdb/train/pos",
    "aclImdb/train/neg",
    "aclImdb/test/pos",
    "aclImdb/test/neg",
]
for dir in directories:
    for f in os.listdir(dir):
        filenames.append(os.path.join(dir, f))

print(f"{len(filenames)} files")

# Create a dataset from text files
random.shuffle(filenames)
text_ds = tf_data.TextLineDataset(filenames)
text_ds = text_ds.shuffle(buffer_size=256)
text_ds = text_ds.batch(batch_size)


def custom_standardization(input_string):
    """Remove html line-break tags and handle punctuation"""
    lowercased = tf_strings.lower(input_string)
    stripped_html = tf_strings.regex_replace(lowercased, "<br />", " ")
    return tf_strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")


# Create a vectorization layer and adapt it to the text
vectorize_layer = TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size - 1,
    output_mode="int",
    output_sequence_length=maxlen + 1,
)
vectorize_layer.adapt(text_ds)
vocab = vectorize_layer.get_vocabulary()  # To get words back from token indices


def prepare_lm_inputs_labels(text):
    """
    Shift word sequences by 1 position so that the target for position (i) is
    word at position (i+1). The model will use all words up till position (i)
    to predict the next word.
    """
    text = tensorflow.expand_dims(text, -1)
    tokenized_sentences = vectorize_layer(text)
    x = tokenized_sentences[:, :-1]
    y = tokenized_sentences[:, 1:]
    return x, y


text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE)
text_ds = text_ds.prefetch(tf_data.AUTOTUNE)

50000 files


I0000 00:00:1763022824.425552   10277 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 8982 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2080 Ti, pci bus id: 0000:08:00.0, compute capability: 7.5
2025-11-13 09:33:48.269728: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [7]:
class TextGenerator(keras.callbacks.Callback):
    """A callback to generate text from a trained model.
    1. Feed some starting prompt to the model
    2. Predict probabilities for the next token
    3. Sample the next token and add it to the next input

    Arguments:
        max_tokens: Integer, the number of tokens to be generated after prompt.
        start_tokens: List of integers, the token indices for the starting prompt.
        index_to_word: List of strings, obtained from the TextVectorization layer.
        top_k: Integer, sample from the `top_k` token predictions.
        print_every: Integer, print after this many epochs.
    """

    def __init__(
        self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1
    ):
        self.max_tokens = max_tokens
        self.start_tokens = start_tokens
        self.index_to_word = index_to_word
        self.print_every = print_every
        self.k = top_k

    def sample_from(self, logits):
        logits, indices = ops.top_k(logits, k=self.k, sorted=True)
        indices = np.asarray(indices).astype("int32")
        preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]
        preds = np.asarray(preds).astype("float32")
        return np.random.choice(indices, p=preds)

    def detokenize(self, number):
        return self.index_to_word[number]

    def on_epoch_end(self, epoch, logs=None):
        start_tokens = [_ for _ in self.start_tokens]
        if (epoch + 1) % self.print_every != 0:
            return
        num_tokens_generated = 0
        tokens_generated = []
        while num_tokens_generated <= self.max_tokens:
            pad_len = maxlen - len(start_tokens)
            sample_index = len(start_tokens) - 1
            if pad_len < 0:
                x = start_tokens[:maxlen]
                sample_index = maxlen - 1
            elif pad_len > 0:
                x = start_tokens + [0] * pad_len
            else:
                x = start_tokens
            x = np.array([x])
            y, _ = self.model.predict(x, verbose=0)
            sample_token = self.sample_from(y[0][sample_index])
            tokens_generated.append(sample_token)
            start_tokens.append(sample_token)
            num_tokens_generated = len(tokens_generated)
        txt = " ".join(
            [self.detokenize(_) for _ in self.start_tokens + tokens_generated]
        )
        print(f"generated text:\n{txt}\n")


# Tokenize starting prompt
word_to_index = {}
for index, word in enumerate(vocab):
    word_to_index[word] = index

start_prompt = "this movie is"
start_tokens = [word_to_index.get(_, 1) for _ in start_prompt.split()]
num_tokens_generated = 40
text_gen_callback = TextGenerator(num_tokens_generated, start_tokens, vocab)

In [8]:
model = create_model()

model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback])

Epoch 1/25


I0000 00:00:1763022858.167795   10574 service.cc:152] XLA service 0x7bd39005aec0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1763022858.167809   10574 service.cc:160]   StreamExecutor device (0): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5
2025-11-13 09:34:18.241687: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-11-13 09:34:18.378374: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:39] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert
I0000 00:00:1763022858.535221   10574 cuda_dnn.cc:529] Loaded cuDNN version 91500
I0000 00:00:1763022862.549076   10574 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
2025-11-13 09:34:47.599776: W tensorflow/compiler/tf2xla/

generated text:
this movie is great . a very funny movie , a bit good and very well done . it is very bad . . . the acting is very well directed and well as the acting is good , and the story is very

391/391 - 37s - 95ms/step - loss: 5.4462
Epoch 2/25


2025-11-13 09:35:22.993898: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2025-11-13 09:35:22.993919: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:35:22.993929: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is so awful in the end , so i think about how the story is that it is . it 's very funny ! the acting is great , with the script , and [UNK] . .     

391/391 - 31s - 79ms/step - loss: 4.6935
Epoch 3/25


2025-11-13 09:35:55.737593: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:35:55.737612: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is so funny . i can 't imagine how many others i can do with that , and the acting is awful , and a good movie that 's not worth seeing . the movie . the only thing i thought i

391/391 - 33s - 84ms/step - loss: 4.4507
Epoch 4/25


2025-11-13 09:36:28.831622: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2025-11-13 09:36:28.831642: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:36:28.831650: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is the only reason for a movie that it is not a movie . a lot , it is an extremely well acted and it is a comedy . the script is good . it is not a [UNK] but a comedy

391/391 - 33s - 85ms/step - loss: 4.2956
Epoch 5/25


2025-11-13 09:37:02.730414: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:37:02.730435: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is one of the worst movies i 've ever seen . i am very surprised that i 'm writing this movie . a complete mess of the storyline is very boring . the plot is weak . but the characters are so

391/391 - 34s - 87ms/step - loss: 4.1773
Epoch 6/25
generated text:
this movie is a great movie . if you enjoy watching a lot of good action movies , but i 'm sure that i 've never heard of the movie , it 's just plain stupid and stupid . i 've always enjoyed it

391/391 - 34s - 87ms/step - loss: 4.0801
Epoch 7/25
generated text:
this movie is a good story and is very well done . . the acting is not the best of the story and the acting , is a great script . the story is great , the characters are very good but the acting

391/391 - 34s - 87ms/step - loss: 3.9968
Epoch 8/25


2025-11-13 09:38:44.697736: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


generated text:
this movie is not a very good movie . it 's not a good story line . the movie doesn 't make it very good . the plot is simple but it doesn 't have much potential . not to be very good .

391/391 - 34s - 87ms/step - loss: 3.9242
Epoch 9/25


2025-11-13 09:39:19.048901: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:39:19.048918: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is so bad it 's just a bad movie . it has a good premise that a good movie is a good story . the movie is about a guy trying to kill a terrorist group of people in his family 's

391/391 - 34s - 88ms/step - loss: 3.8590
Epoch 10/25


2025-11-13 09:39:53.505571: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:39:53.505589: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is a very realistic film . it shows how great actors can a great job , great acting , directing , great music , great music and music , the song , dance is lush and the music is wonderful , and

391/391 - 34s - 88ms/step - loss: 3.8012
Epoch 11/25
generated text:
this movie is a good example of the great movies of all time and for that reason for the first 30 seconds of film and then i have watched it again with me and it was a very entertaining movie . if not a

391/391 - 35s - 88ms/step - loss: 3.7483
Epoch 12/25


2025-11-13 09:41:02.616790: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:41:02.616808: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is the perfect illustration is just about this one . the worst thing in the movie i have seen in the movie for years . the acting is awful . the worst script ever , the acting is terrible . . terrible

391/391 - 35s - 88ms/step - loss: 3.7003
Epoch 13/25


2025-11-13 09:41:37.099965: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:41:37.099987: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is a very good example of what a great movie , not even if it is a bad movie , and the plot . the script is awful . . the acting was so bad , and so bad that it is

391/391 - 34s - 88ms/step - loss: 3.6558
Epoch 14/25
generated text:
this movie is based on the real events that it 's a real event . it 's also very interesting to note to note that the history of [UNK] ' was not that great historians saw it as well as [UNK] ' , but

391/391 - 34s - 88ms/step - loss: 3.6154
Epoch 15/25


2025-11-13 09:42:46.291730: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:42:46.291749: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is so bad it , it 's a terrible movie . there are so many holes you could see , but the movie just didn 't make the movie for it . the movie wasn 't it too long . it was

391/391 - 35s - 89ms/step - loss: 3.5774
Epoch 16/25


2025-11-13 09:43:21.141062: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2025-11-13 09:43:21.141084: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:43:21.141093: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is a good example of the movie , that was just awful . the story is ridiculous , the plot is ridiculous , and the acting is the only one of the worst i 've ever seen ! canÂ´t do i believe

391/391 - 35s - 89ms/step - loss: 3.5431
Epoch 17/25


2025-11-13 09:43:56.360759: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:43:56.360778: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is a bit more than a light hearted attempt by [UNK] a film which has been done in a very long time in the movie , but it doesn 't have it . the plot is so thin it 's the same

391/391 - 35s - 90ms/step - loss: 3.5109
Epoch 18/25


2025-11-13 09:44:31.622267: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739


generated text:
this movie is a very well structured and it takes place very little . this may be a movie , but it is also very hard to swallow . this is it not the fault . it shows , i can see what it

391/391 - 35s - 90ms/step - loss: 3.4805
Epoch 19/25
generated text:
this movie is just a bad movie . a good story of a good story and the execution is nothing original . the story is flat out the characters , the movie is simply boring . the acting is the best part of the

391/391 - 35s - 89ms/step - loss: 3.4528
Epoch 20/25
generated text:
this movie is the first movie i 've ever seen . you have a life on me . i think i 'll watch it if it is an interesting movie , but this was a little bit too slow , and it was very

391/391 - 35s - 90ms/step - loss: 3.4269
Epoch 21/25


2025-11-13 09:46:16.330373: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:46:16.330390: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is a very funny movie . i am not sure how bad it was . the movie was made , it is , the story itself , the movie is not worth it . it 's a comedy that is funny .

391/391 - 35s - 89ms/step - loss: 3.4019
Epoch 22/25
generated text:
this movie is so funny . i don 't remember much about it and i love it when i saw it on a few times . it 's like the other movie , i 'm still not going to see the show , it

391/391 - 35s - 91ms/step - loss: 3.3797
Epoch 23/25
generated text:
this movie is the best movie i have ever seen . it is a complete waste of time and time . there is a story that makes for the movie more of a 1 . if it were a bit more entertaining than the

391/391 - 35s - 90ms/step - loss: 3.3573
Epoch 24/25


2025-11-13 09:48:02.193409: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739
2025-11-13 09:48:02.193428: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 15631052121725713559


generated text:
this movie is about a [UNK] who wants to know , but she is in [UNK] by a friend of mine ) that is one of the worst movies i 've seen . the plot is ridiculous , and the direction , and cinematography

391/391 - 35s - 90ms/step - loss: 3.3369
Epoch 25/25


2025-11-13 09:48:37.261044: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 6102953936996661739


generated text:
this movie is a complete waste of money in the movie . there is no plot , no action , the plot is ridiculous ! there is no action , and no . the acting in the movie is very bad . the acting

391/391 - 35s - 90ms/step - loss: 3.3176


<keras.src.callbacks.history.History at 0x7bd54db0de70>