In [1]:
# @formatter:off
%load_ext autoreload
%autoreload 2
# @formatter:on

In [2]:
import os
import random
import string

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import TextVectorization

In [10]:
M_VOCAB_SZ = 20000
M_MAX_LEN = 80
M_ATT_HEADS = 2
M_DIM_EMB = 256
M_DIM_FFN = 256
M_WARMUP_STEPS = 2000
T_BATCH_SIZE = 256
T_EPOCHS = 5

data_files = ["data_test/rap.txt"]

In [11]:
def causal_attention_mask(batch_size, n_dest, n_src, dtype):
    i = tf.range(n_dest)[:, None]
    j = tf.range(n_src)
    m = i >= j - n_src + n_dest
    mask = tf.cast(m, dtype)
    mask = tf.reshape(mask, [1, n_dest, n_src])
    mult = tf.concat(
        [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
    )
    return tf.tile(mask, mult)


class Transformer(layers.Layer):
    def __init__(self, embedding_dim, num_att_heads, state_dims, dropout_rate=0.1):
        super(Transformer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_att_heads = num_att_heads
        self.state_dims = state_dims
        self.dropout_rate = dropout_rate
        self.attention = layers.MultiHeadAttention(num_att_heads, embedding_dim)
        self.feed_forward = keras.Sequential(
            [layers.Dense(state_dims, activation="relu"), layers.Dense(embedding_dim)]
        )
        self.norm1, self.norm2 = layers.LayerNormalization(
            epsilon=1e-6
        ), layers.LayerNormalization(epsilon=1e-6)
        self.drop1, self.drop2 = layers.Dropout(dropout_rate), layers.Dropout(
            dropout_rate
        )

    def call(self, inputs):
        inp_shape = tf.shape(inputs)
        batch_sz, seq_len = inp_shape[0], inp_shape[1]
        causal_mask = causal_attention_mask(batch_sz, seq_len, seq_len, tf.bool)
        attention_out = self.attention(inputs, inputs, attention_mask=causal_mask)
        attention_out = self.drop1(attention_out)
        out1 = self.norm1(inputs + attention_out)
        feed_forward_out = self.feed_forward(out1)
        feed_forward_out = self.drop2(feed_forward_out)
        return self.norm2(out1 + feed_forward_out)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embedding_dim": self.embedding_dim,
                "num_att_heads": self.num_att_heads,
                "state_dims": self.state_dims,
                "dropout_rate": self.dropout_rate,
            }
        )
        return config


class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, max_len, vocab_size, embed_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.max_len = max_len
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = layers.Embedding(input_dim=max_len, output_dim=embed_dim)

    def call(self, x):
        max_len = tf.shape(x)[-1]
        pos = tf.range(start=0, limit=max_len, delta=1)
        pos = self.pos_emb(pos)
        x = self.token_emb(x)
        return x + pos

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "max_len": self.max_len,
                "vocab_size": self.vocab_size,
                "embed_dim": self.embed_dim,
            }
        )
        return config

In [12]:
class WarmupScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, embedding_dim, warmup_steps):
        super(WarmupScheduler, self).__init__()
        self.emb_dim = tf.cast(embedding_dim, tf.float32)
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps**-1.5)
        lr = tf.math.rsqrt(self.emb_dim) * tf.math.minimum(arg1, arg2)
        # print(f"WarmupScheduler: {step} => {lr}")
        return lr

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "emb_dim": self.emb_dim,
                "warmup_steps": self.warmup_steps,
            }
        )
        return config

In [13]:
def create_model():
    # TODO: Remove this if we're using tokenizer
    embedding = TokenAndPositionEmbedding(M_MAX_LEN, M_VOCAB_SZ, M_DIM_EMB)
    transformer = Transformer(M_DIM_EMB, M_ATT_HEADS, M_DIM_FFN)

    l_input = layers.Input(shape=(M_MAX_LEN,), dtype=tf.int32)
    l_emb = embedding(l_input)
    l_trans = transformer(l_emb)
    l_output = layers.Dense(M_VOCAB_SZ)(l_trans)

    m = keras.Model(inputs=l_input, outputs=[l_output, l_trans])
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    learning_rate = WarmupScheduler(M_DIM_EMB, M_WARMUP_STEPS)
    optimizer = tf.keras.optimizers.Adam(
        learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9
    )

    m.compile(optimizer, loss=[loss_fn, None])
    return m

In [14]:
def create_dataset(file_pth, batch_sz, buf_sz=1000, shuffle=True):
    # Shuffle the data and create batches
    if shuffle:
        random.shuffle(file_pth)
    ds = tf.data.TextLineDataset(file_pth)
    ds = ds.shuffle(buffer_size=buf_sz)
    ds = ds.batch(batch_sz)
    return ds


def create_tokenizer(dataset, max_vocab_size, max_seq_len):
    def preprocess_txt(input_string):
        # Preprocessing for word-level model
        s1 = tf.strings.lower(input_string)
        return tf.strings.regex_replace(s1, f"([{string.punctuation}])", r" \1")

    # Vectorization of the data
    vectorize = TextVectorization(
        standardize=preprocess_txt,
        max_tokens=max_vocab_size - 1,
        output_mode="int",
        output_sequence_length=max_seq_len + 1,
    )
    vectorize.adapt(dataset)
    vocab = vectorize.get_vocabulary()
    return vectorize, vocab


# Read in the data and create the dataset
dataset = create_dataset(data_files, T_BATCH_SIZE)
# Create the tokenizer
tokenizer, vocab = create_tokenizer(dataset, M_VOCAB_SZ, M_MAX_LEN)


def create_sequences(txt):
    txt = tf.expand_dims(txt, -1)
    txt_tok = tokenizer(txt)
    return txt_tok[:, :-1], txt_tok[:, 1:]


dataset = dataset.map(create_sequences).prefetch(tf.data.AUTOTUNE)

In [15]:
for d in dataset.take(1):
    print(d)

(<tf.Tensor: shape=(256, 80), dtype=int64, numpy=
array([[ 114,   15,   46, ...,    0,    0,    0],
       [ 506, 5039, 5250, ...,    0,    0,    0],
       [  95,    4, 3806, ...,    0,    0,    0],
       ...,
       [  49,  610,   26, ...,    0,    0,    0],
       [ 166,   58,    4, ...,    0,    0,    0],
       [ 326,    6, 5727, ...,    0,    0,    0]])>, <tf.Tensor: shape=(256, 80), dtype=int64, numpy=
array([[  15,   46,  143, ...,    0,    0,    0],
       [5039, 5250,  491, ...,    0,    0,    0],
       [   4, 3806,  762, ...,    0,    0,    0],
       ...,
       [ 610,   26,  194, ...,    0,    0,    0],
       [  58,    4,    1, ...,    0,    0,    0],
       [   6, 5727, 1817, ...,    0,    0,    0]])>)


In [16]:
class TextGenerator(keras.callbacks.Callback):
    def __init__(
        self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1
    ):
        self.max_tokens = max_tokens
        self.start_tokens = start_tokens
        self.index_to_word = index_to_word
        self.print_every = print_every
        self.k = top_k

    def sample_top_k(self, logits):
        logits, indices = tf.math.top_k(logits, k=self.k, sorted=True)
        indices = np.asarray(indices).astype("int32")
        preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0]
        preds = np.asarray(preds).astype("float32")
        return np.random.choice(indices, p=preds)

    def detokenize(self, number):
        return self.index_to_word[number]

    def on_epoch_end(self, epoch, logs=None):
        start_tokens = list(self.start_tokens)
        if (epoch + 1) % self.print_every != 0:
            return
        num_tokens_generated = 0
        tokens_generated = []
        while num_tokens_generated <= self.max_tokens:
            pad_len = M_MAX_LEN - len(start_tokens)
            sample_index = len(start_tokens) - 1
            if pad_len < 0:
                x = start_tokens[:M_MAX_LEN]
                sample_index = M_MAX_LEN - 1
            elif pad_len > 0:
                x = start_tokens + [0] * pad_len
            else:
                x = start_tokens
            x = np.array([x])
            y, _ = self.model.predict(x)
            sample_token = self.sample_top_k(y[0][sample_index])
            tokens_generated.append(sample_token)
            start_tokens.append(sample_token)
            num_tokens_generated = len(tokens_generated)
        txt = " ".join(
            [self.detokenize(_) for _ in self.start_tokens + tokens_generated]
        )
        print(f"Generated:\n{txt}\n")

In [17]:
def create_generation_callback(start_prompt, vocabulary, gen_len=100):
    # Tokenize starting prompt
    word_to_index = {word: index for index, word in enumerate(vocabulary)}
    prompt_tokens = [word_to_index.get(_, 1) for _ in start_prompt.lower().split()]
    return TextGenerator(gen_len, prompt_tokens, vocabulary)


def create_callbacks(base_dir, model, defaults: list = None):
    dir_models = os.path.join(base_dir, model.name)
    os.makedirs(dir_models, exist_ok=True)
    print(f"# Model Dir: {dir_models}")

    path_csv = os.path.join(dir_models, "history.csv")
    print(" - History CSV:", path_csv)

    path_ckp = os.path.join(dir_models, "checkpoints.h5")
    print(" - Checkpoint:", path_ckp)

    path_tb = os.path.join(dir_models, "logs")
    tb_file_writer = tf.summary.create_file_writer(path_tb)
    callbacks = [] if defaults is None else defaults
    callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=path_tb, histogram_freq=1))
    callbacks.append(tf.keras.callbacks.CSVLogger(path_csv, separator=",", append=True))
    # callbacks.append(tf.keras.callbacks.ModelCheckpoint(path_ckp,
    #                                                     monitor='loss',
    #                                                     save_best_only=True,
    #                                                     mode='auto',
    #                                                     verbose=0))
    os.makedirs(dir_models, exist_ok=True)
    return callbacks, tb_file_writer

In [18]:
model = create_model()
callbacks, tb_file_writer = create_callbacks("logs", model)
gen_callback = create_generation_callback("a day in the life", vocab)
callbacks.append(gen_callback)

# Model Dir: logs/model_1
 - History CSV: logs/model_1/history.csv
 - Checkpoint: logs/model_1/checkpoints.h5


In [14]:
# wandb.tensorboard.patch(root_logdir="logs")
# wandb.init(project='transformer')
model.fit(dataset, verbose=2, epochs=T_EPOCHS, callbacks=callbacks)

Epoch 1/5


2022-05-16 23:18:28.608855: E tensorflow/stream_executor/cuda/cuda_blas.cc:232] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
2022-05-16 23:18:28.608889: E tensorflow/stream_executor/cuda/cuda_blas.cc:234] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your deep-learning framework may have preallocated more than its fair share), or may be because this binary was not built with support for the GPU in your machine.
2022-05-16 23:18:28.608916: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at matmul_op_impl.h:442 : INTERNAL: Attempting to perform BLAS operation using StreamExecutor without BLAS support
2022-05-16 23:18:28.611674: E tensorflow/stream_executor/cuda/cuda_blas.cc:232] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
2022-05-16 23:18:28.611690: E tensorflow/stream_executor/cuda/cuda_blas.cc:234] Failure to initialize cublas may be due to OOM (cublas needs some free memory

InternalError: Graph execution error:

Detected at node 'model/transformer/multi_head_attention/key/einsum/Einsum' defined at (most recent call last):
    File "/opt/conda/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/opt/conda/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/opt/conda/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/opt/conda/lib/python3.9/site-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 712, in start
      self.io_loop.start()
    File "/opt/conda/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 199, in start
      self.asyncio_loop.run_forever()
    File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 596, in run_forever
      self._run_once()
    File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 1890, in _run_once
      handle._run()
    File "/opt/conda/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 504, in dispatch_queue
      await self.process_one()
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 493, in process_one
      await dispatch(*args)
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 400, in dispatch_shell
      await result
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 724, in execute_request
      reply_content = await reply_content
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 390, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2863, in run_cell
      result = self._run_cell(
    File "/opt/conda/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2909, in _run_cell
      return runner(coro)
    File "/opt/conda/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/opt/conda/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3106, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/opt/conda/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3309, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/opt/conda/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3369, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_2023/405319210.py", line 3, in <cell line: 3>
      model.fit(dataset, verbose=2, epochs=T_EPOCHS, callbacks=callbacks)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/training.py", line 1384, in fit
      tmp_logs = self.train_function(iterator)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/training.py", line 1021, in train_function
      return step_function(self, iterator)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/training.py", line 1010, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/tmp/ipykernel_2023/3326115625.py", line 30, in call
      attention_out = self.attention(inputs, inputs, attention_mask=causal_mask)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/layers/multi_head_attention.py", line 505, in call
      key = self._key_dense(key)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/layers/einsum_dense.py", line 187, in call
      ret = tf.einsum(self.equation, inputs, self.kernel)
Node: 'model/transformer/multi_head_attention/key/einsum/Einsum'
Attempting to perform BLAS operation using StreamExecutor without BLAS support
	 [[{{node model/transformer/multi_head_attention/key/einsum/Einsum}}]] [Op:__inference_train_function_2840]

In [None]:
model.save("model_warmup.h5")

In [19]:
model

<keras.engine.functional.Functional at 0x7f9798641100>

In [20]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from official import nlp
from official.nlp.modeling.ops import beam_search, sampling_module

In [21]:
params = {}
params["num_heads"] = M_ATT_HEADS
params["num_layers"] = 1
params["batch_size"] = T_BATCH_SIZE
params["n_dims"] = M_DIM_EMB
params["max_decode_length"] = 100

In [22]:
cache = {
    "layer_%d"
    % layer: {
        "k": tf.zeros(
            [
                params["batch_size"],
                params["max_decode_length"],
                params["num_heads"],
                int(params["n_dims"] / params["num_heads"]),
            ],
            dtype=tf.float32,
        ),
        "v": tf.zeros(
            [
                params["batch_size"],
                params["max_decode_length"],
                params["num_heads"],
                int(params["n_dims"] / params["num_heads"]),
            ],
            dtype=tf.float32,
        ),
    }
    for layer in range(params["num_layers"])
}
print("cache key shape for layer 0 :", cache["layer_0"]["k"].shape)

cache key shape for layer 0 : (256, 100, 2, 128)


In [23]:
def length_norm(length, dtype):
    """Return length normalization factor."""
    return tf.pow(((5.0 + tf.cast(length, dtype)) / 6.0), 0.0)

In [24]:
def model_fn(ids):
    y, _ = model.predict(ids)
    return y

In [35]:
def _symbols_to_logits_fn():
    """Calculates logits of the next tokens."""

    def symbols_to_logits_fn(ids, i, temp_cache):
        # Get the tensor into a numpy array
        ids = tf.reshape(ids, [1, -1]).numpy()
        # Pad the array to the max sequence lenght
        l = len(ids[0])
        ids = np.pad(
            np.array(ids),
            [(0, 0), (0, M_MAX_LEN - l)],
            mode="constant",
            constant_values=0,
        )
        print(ids)
        # Feed through the model
        # logits = tf.cast(tf.math.log(model_fn(ids)), tf.float32)
        logits = model_fn(ids)

        # Return the generated logits
        print("-" * 80)
        print(logits[0])
        return logits, temp_cache

    return symbols_to_logits_fn

In [36]:
def tokenize_prompt(prompt, vocab, max_len):
    word_to_index = {word: index for index, word in enumerate(vocab)}
    prompt_tokens = [word_to_index.get(_, 1) for _ in prompt.lower().split()]
    return prompt_tokens

In [37]:
def detokenize(vocab, number):
    return vocab[number]

In [38]:
prompt_enc = tokenize_prompt("i will always", vocab, M_MAX_LEN)
prompt_enc

[3, 150, 187]

In [39]:
top_k_obj = sampling_module.SamplingModule(
    length_normalization_fn=length_norm,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=M_VOCAB_SZ,
    max_decode_length=params["max_decode_length"],
    eos_id=0,
    sample_temperature=tf.constant(1.0),
    top_k=tf.constant(3),
    padded_decode=False,
    enable_greedy=False,
)
ids, _ = top_k_obj.generate(initial_ids=prompt_enc, initial_cache={})

[[  3 150 187   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0]]


2022-05-17 00:25:30.502476: E tensorflow/stream_executor/cuda/cuda_blas.cc:232] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
2022-05-17 00:25:30.502513: E tensorflow/stream_executor/cuda/cuda_blas.cc:234] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your deep-learning framework may have preallocated more than its fair share), or may be because this binary was not built with support for the GPU in your machine.
2022-05-17 00:25:30.502538: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at matmul_op_impl.h:438 : INTERNAL: Attempting to perform BLAS operation using StreamExecutor without BLAS support
2022-05-17 00:25:30.539862: E tensorflow/stream_executor/cuda/cuda_blas.cc:232] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
2022-05-17 00:25:30.539925: E tensorflow/stream_executor/cuda/cuda_blas.cc:234] Failure to initialize cublas may be due to OOM (cublas needs some free memory

InternalError: Graph execution error:

Detected at node 'model_1/transformer_1/multi_head_attention_1/key/einsum/Einsum' defined at (most recent call last):
    File "/opt/conda/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/opt/conda/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/opt/conda/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/opt/conda/lib/python3.9/site-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 712, in start
      self.io_loop.start()
    File "/opt/conda/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 199, in start
      self.asyncio_loop.run_forever()
    File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 596, in run_forever
      self._run_once()
    File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 1890, in _run_once
      handle._run()
    File "/opt/conda/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 504, in dispatch_queue
      await self.process_one()
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 493, in process_one
      await dispatch(*args)
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 400, in dispatch_shell
      await result
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 724, in execute_request
      reply_content = await reply_content
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 390, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/opt/conda/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2863, in run_cell
      result = self._run_cell(
    File "/opt/conda/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2909, in _run_cell
      return runner(coro)
    File "/opt/conda/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/opt/conda/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3106, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/opt/conda/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3309, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/opt/conda/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3369, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_12206/3575384082.py", line 13, in <cell line: 13>
      ids, _ = top_k_obj.generate(initial_ids=prompt_enc, initial_cache={})
    File "/opt/conda/lib/python3.9/site-packages/official/nlp/modeling/ops/decoding_module.py", line 167, in generate
      tf.while_loop(
    File "/opt/conda/lib/python3.9/site-packages/official/nlp/modeling/ops/decoding_module.py", line 146, in _generate_step
      topk_seq, topk_log_probs, topk_ids, new_cache = self._grow_alive_seq(
    File "/opt/conda/lib/python3.9/site-packages/official/nlp/modeling/ops/sampling_module.py", line 209, in _grow_alive_seq
      new_logits, new_cache = self.symbols_to_logits_fn(ids, i, alive_cache)
    File "/tmp/ipykernel_12206/376583282.py", line 18, in symbols_to_logits_fn
      logits = model_fn(ids)
    File "/tmp/ipykernel_12206/2936490992.py", line 2, in model_fn
      y, _ = model.predict(ids)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/training.py", line 2033, in predict
      tmp_batch_outputs = self.predict_function(iterator)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/training.py", line 1845, in predict_function
      return step_function(self, iterator)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/training.py", line 1834, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/training.py", line 1823, in run_step
      outputs = model.predict_step(data)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/training.py", line 1791, in predict_step
      return self(x, training=False)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/training.py", line 490, in __call__
      return super().__call__(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/functional.py", line 458, in call
      return self._run_internal_graph(
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/functional.py", line 596, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/tmp/ipykernel_12206/2577669967.py", line 35, in call
      attention_out = self.attention(inputs, inputs, attention_mask=causal_mask)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/layers/attention/multi_head_attention.py", line 504, in call
      key = self._key_dense(key)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/keras/layers/einsum_dense.py", line 187, in call
      ret = tf.einsum(self.equation, inputs, self.kernel)
Node: 'model_1/transformer_1/multi_head_attention_1/key/einsum/Einsum'
Attempting to perform BLAS operation using StreamExecutor without BLAS support
	 [[{{node model_1/transformer_1/multi_head_attention_1/key/einsum/Einsum}}]] [Op:__inference_predict_function_2991]

In [44]:
tf.constant([1, 1, 23]).numpy()

array([ 1,  1, 23], dtype=int32)

In [17]:
l = len([1, 1, 23])
np.pad(
    np.array([1, 1, 23]), (0, M_VOCAB_SZ - l), mode="constant", constant_values=0
).shape

(80,)

In [9]:
M_VOCAB_SZ = 80