In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd 
import os
import time
from tqdm import tqdm, trange

In [2]:
scripts = pd.read_csv('seinfeld_data/scripts.csv')
scripts['line'] = scripts.Character + ': ' + scripts.Dialogue
corpus = scripts.line.astype(str).tolist()
corpus=''.join(corpus)

In [3]:
# The unique characters in the file
vocab = sorted(set(corpus))
print(f'{len(vocab)} unique characters')

94 unique characters


In [4]:
device = tf.device('device:GPU:0') 

chars = tf.strings.unicode_split(vocab, input_encoding='UTF-8')


Metal device set to: Apple M1


2022-12-21 12:39:37.330709: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-12-21 12:39:37.331047: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [5]:
ids_from_chars = tf.keras.layers.StringLookup(
    vocabulary=list(vocab), mask_token=None)

In [6]:
ids = ids_from_chars(chars)

chars_from_ids = tf.keras.layers.StringLookup(
    vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)
chars = chars_from_ids(ids)

In [7]:
def text_from_ids(ids):
  return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

In [8]:
all_ids = ids_from_chars(tf.strings.unicode_split(corpus, 'UTF-8'))
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)

In [9]:
seq_length = 100
sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)

def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text

dataset = sequences.map(split_input_target)

for input_example, target_example in dataset.take(1):
    print("Input :", text_from_ids(input_example).numpy())
    print("Target:", text_from_ids(target_example).numpy())

Input : b'JERRY: Do you know what this is all about? Do you know, why were here? To be out, this is out...and '
Target: b'ERRY: Do you know what this is all about? Do you know, why were here? To be out, this is out...and o'


2022-12-21 12:39:38.131317: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [10]:
# Batch size
BATCH_SIZE = 64

# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE))


In [11]:
# Length of the vocabulary in StringLookup Layer
vocab_size = len(ids_from_chars.get_vocabulary())

# The embedding dimension
embedding_dim = 256

# Number of RNN units
rnn_units = 1024

In [53]:
class MyModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units, timesteps, features):
      super().__init__(self)
      self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
      self.gru = tf.keras.layers.GRU(rnn_units,
                                 return_sequences=True,
                                 return_state=True)
      self.dense = tf.keras.layers.Dense(vocab_size)

      # Set the values of the timesteps and features attributes
      self.timesteps = timesteps
      self.features = features


  def call(self, inputs, states=None, return_state=False, training=False):
  # Reshape the input tensor to have the correct shape for the GRU layer
      x = tf.reshape(inputs, ( self.timesteps, self.features))

      x = self.embedding(x, training=training)
      if states is None:
        states = self.gru.get_initial_state(x)
      x, states = self.gru(x, initial_state=states, training=training)
      x = self.dense(x, training=training)

      if return_state:
        return x, states
      else:
        return x


In [54]:
model = MyModel(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    rnn_units=rnn_units, 
    timesteps=64, 
    features=100)

In [14]:
dataset.take(1)

<TakeDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int64, name=None), TensorSpec(shape=(64, 100), dtype=tf.int64, name=None))>

In [44]:
for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

AttributeError: Exception encountered when calling layer 'my_model_1' (type MyModel).

'MyModel' object has no attribute 'timesteps'

Call arguments received by layer 'my_model_1' (type MyModel):
  • inputs=tf.Tensor(shape=(64, 100), dtype=int64)
  • states=None
  • return_state=False
  • training=False

In [16]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

In [17]:
print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())

Input:
 b") I found it!ELAINE: He's got it.KRAMER: Oh...no.JERRY: All right, that's it. From now on no more ca"

Next Char Predictions:
 b'bwRwD\\PWQZ mn>-.wwC]\\s4\\[UNK]vm`%" 5~hFj:"8_=3%$ouLT1\\kImH?eCSq&~rcPZ%1OXYu;aK\\*[PZK~\'\xc2\x92wa[UNK]H\'m(II>R{opB.-'


In [18]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
example_batch_mean_loss = loss(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Mean loss:        ", example_batch_mean_loss)

Prediction shape:  (64, 100, 95)  # (batch_size, sequence_length, vocab_size)
Mean loss:         tf.Tensor(4.5540285, shape=(), dtype=float32)


In [19]:
tf.exp(example_batch_mean_loss).numpy()

95.01439

In [20]:
model.compile(optimizer=tf.keras.optimizers.legacy.Adam(), loss=loss)
# Directory where the checkpoints will be saved
checkpoint_dir = './seinfeld_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

In [21]:
EPOCHS = 40
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

Epoch 1/40


2022-12-21 12:39:39.638683: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2022-12-21 12:39:40.111000: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2022-12-21 12:39:40.168091: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40


In [88]:
class OneStep(tf.keras.Model):
  def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0, optimizer=tf.keras.optimizers.legacy.Adam()):
    super().__init__()
    self.temperature = temperature
    self.model = model
    self.chars_from_ids = chars_from_ids
    self.ids_from_chars = ids_from_chars

    # Create a mask to prevent "[UNK]" from being generated.
    skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
    sparse_mask = tf.SparseTensor(
        # Put a -inf at each bad index.
        values=[-float('inf')]*len(skip_ids),
        indices=skip_ids,
        # Match the shape to the vocabulary
        dense_shape=[len(ids_from_chars.get_vocabulary())])
    self.prediction_mask = tf.sparse.to_dense(sparse_mask)

  @tf.function
  def generate_one_step(self, inputs, states=None):
    # Convert strings to token IDs.
    input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
    input_ids = self.ids_from_chars(input_chars).to_tensor()
    
    # Run the model.
    # predicted_logits.shape is [batch, char, next_char_logits]
    predicted_logits, states = self.model(inputs=input_ids, states=states,
                                          return_state=True)
    print(predicted_logits)
    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    #predicted_logits = predicted_logits/self.temperature
    # Apply the prediction mask: prevent "[UNK]" from being generated.
    predicted_logits = predicted_logits + self.prediction_mask
    print(predicted_logits)

    # Sample the output logits to generate token IDs.
    predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
    predicted_ids = tf.squeeze(predicted_ids, axis=-1)
    print(predicted_ids)

    # Convert from token ids to characters
    predicted_chars = self.chars_from_ids(predicted_ids)
    print(predicted_chars)

    # Return the characters and model state.
    return predicted_chars#, states

In [89]:
one_step_model = OneStep(model, chars_from_ids, ids_from_chars, optimizer=tf.keras.optimizers.legacy.Adam())

In [90]:
one_step_model.generate_one_step(next_char)

Tensor("my_model_3/dense_3/BiasAdd:0", shape=(64, 100, 95), dtype=float32)
Tensor("add:0", shape=(64, 95), dtype=float32)
Tensor("Squeeze:0", shape=(64,), dtype=int64)
Tensor("string_lookup_1/None_Lookup/LookupTableFindV2:0", shape=(64,), dtype=string)


2022-12-21 15:05:36.991304: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


InvalidArgumentError: Graph execution error:

Detected at node 'my_model_3/Reshape' defined at (most recent call last):
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/traitlets/config/application.py", line 1041, in launch_instance
      app.start()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 711, in start
      self.io_loop.start()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 411, in do_execute
      res = shell.run_cell(
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 530, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2940, in run_cell
      result = self._run_cell(
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2995, in _run_cell
      return runner(coro)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3194, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3373, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/var/folders/3p/xpgnq5_d0cz13g99y3v4g03r0000gn/T/ipykernel_49523/1653303957.py", line 1, in <module>
      one_step_model.generate_one_step(next_char)
    File "/var/folders/3p/xpgnq5_d0cz13g99y3v4g03r0000gn/T/ipykernel_49523/1998464073.py", line 27, in generate_one_step
      predicted_logits, states = self.model(inputs=input_ids, states=states,
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/engine/training.py", line 561, in __call__
      return super().__call__(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/var/folders/3p/xpgnq5_d0cz13g99y3v4g03r0000gn/T/ipykernel_49523/1989881299.py", line 17, in call
      x = tf.reshape(inputs, ( self.timesteps, self.features))
Node: 'my_model_3/Reshape'
Detected at node 'my_model_3/Reshape' defined at (most recent call last):
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/traitlets/config/application.py", line 1041, in launch_instance
      app.start()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 711, in start
      self.io_loop.start()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 411, in do_execute
      res = shell.run_cell(
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 530, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2940, in run_cell
      result = self._run_cell(
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2995, in _run_cell
      return runner(coro)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3194, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3373, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/var/folders/3p/xpgnq5_d0cz13g99y3v4g03r0000gn/T/ipykernel_49523/1653303957.py", line 1, in <module>
      one_step_model.generate_one_step(next_char)
    File "/var/folders/3p/xpgnq5_d0cz13g99y3v4g03r0000gn/T/ipykernel_49523/1998464073.py", line 27, in generate_one_step
      predicted_logits, states = self.model(inputs=input_ids, states=states,
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/engine/training.py", line 561, in __call__
      return super().__call__(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/var/folders/3p/xpgnq5_d0cz13g99y3v4g03r0000gn/T/ipykernel_49523/1989881299.py", line 17, in call
      x = tf.reshape(inputs, ( self.timesteps, self.features))
Node: 'my_model_3/Reshape'
2 root error(s) found.
  (0) INVALID_ARGUMENT:  Input to reshape is a tensor with 6 values, but the requested shape has 6400
	 [[{{node my_model_3/Reshape}}]]
	 [[add/_32]]
  (1) INVALID_ARGUMENT:  Input to reshape is a tensor with 6 values, but the requested shape has 6400
	 [[{{node my_model_3/Reshape}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_generate_one_step_74527]

In [83]:
start = time.time()
states = None
next_char = tf.constant(['JERRY:'])
result = [next_char]

for n in range(1000):
  next_char, states = one_step_model.generate_one_step(next_char, states=states)
  result.append(next_char)

result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)
print('\nRun time:', end - start)

InvalidArgumentError: Graph execution error:

Detected at node 'my_model_3/Reshape' defined at (most recent call last):
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/traitlets/config/application.py", line 1041, in launch_instance
      app.start()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 711, in start
      self.io_loop.start()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 411, in do_execute
      res = shell.run_cell(
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 530, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2940, in run_cell
      result = self._run_cell(
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2995, in _run_cell
      return runner(coro)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3194, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3373, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/var/folders/3p/xpgnq5_d0cz13g99y3v4g03r0000gn/T/ipykernel_49523/1653303957.py", line 1, in <module>
      one_step_model.generate_one_step(next_char)
    File "/var/folders/3p/xpgnq5_d0cz13g99y3v4g03r0000gn/T/ipykernel_49523/11284627.py", line 27, in generate_one_step
      predicted_logits, states = self.model(inputs=input_ids, states=states,
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/engine/training.py", line 561, in __call__
      return super().__call__(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/var/folders/3p/xpgnq5_d0cz13g99y3v4g03r0000gn/T/ipykernel_49523/1989881299.py", line 17, in call
      x = tf.reshape(inputs, ( self.timesteps, self.features))
Node: 'my_model_3/Reshape'
Detected at node 'my_model_3/Reshape' defined at (most recent call last):
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/traitlets/config/application.py", line 1041, in launch_instance
      app.start()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 711, in start
      self.io_loop.start()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 411, in do_execute
      res = shell.run_cell(
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 530, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2940, in run_cell
      result = self._run_cell(
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2995, in _run_cell
      return runner(coro)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3194, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3373, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/var/folders/3p/xpgnq5_d0cz13g99y3v4g03r0000gn/T/ipykernel_49523/1653303957.py", line 1, in <module>
      one_step_model.generate_one_step(next_char)
    File "/var/folders/3p/xpgnq5_d0cz13g99y3v4g03r0000gn/T/ipykernel_49523/11284627.py", line 27, in generate_one_step
      predicted_logits, states = self.model(inputs=input_ids, states=states,
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/engine/training.py", line 561, in __call__
      return super().__call__(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/anaconda3/envs/tensorflow/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/var/folders/3p/xpgnq5_d0cz13g99y3v4g03r0000gn/T/ipykernel_49523/1989881299.py", line 17, in call
      x = tf.reshape(inputs, ( self.timesteps, self.features))
Node: 'my_model_3/Reshape'
2 root error(s) found.
  (0) INVALID_ARGUMENT:  Input to reshape is a tensor with 6 values, but the requested shape has 6400
	 [[{{node my_model_3/Reshape}}]]
	 [[add/_32]]
  (1) INVALID_ARGUMENT:  Input to reshape is a tensor with 6 values, but the requested shape has 6400
	 [[{{node my_model_3/Reshape}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_generate_one_step_73563]