# Setup

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from matplotlib import pyplot as plt

# Data Preprocessing

In [None]:
# download tiny_shakespeare dataset
dataset_dict = tfds.load(name='tiny_shakespeare')

In [None]:
dataset_dict

In [4]:
# get train/validation/test data
train_data = dataset_dict['train']
validation_data = dataset_dict['validation']
test_data = dataset_dict['test']

In [5]:
'''
each dataset contains 1 example of string type
split the string into a sequence of Unicode code points
'''
train_dataset = train_data.map(lambda x: tf.strings.unicode_split(x['text'], 'UTF-8'))
validation_dataset = validation_data.map(lambda x: tf.strings.unicode_split(x['text'], 'UTF-8'))
test_dataset = test_data.map(lambda x: tf.strings.unicode_split(x['text'], 'UTF-8'))

In [6]:
vocabulary = sorted(set(next(iter(train_dataset)).numpy()))

In [7]:
ids_to_tokens = {id:token for id, token in enumerate(vocabulary)}
tokens_to_ids = {token:id for id, token in enumerate(vocabulary)}

keys_tensor = tf.constant(list(tokens_to_ids.keys()))
vals_tensor = tf.constant(list(tokens_to_ids.values()))
tokens_to_ids_loopup = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), default_value=-1
)

def decode(line):
  return ''.join(ids_to_tokens[id].decode('UTF-8') for id in line)

def tokenize(line):
  return tokens_to_ids_loopup.lookup(line)

In [8]:
# configuration to store model parameters

MODEL_CONFIG = {
    "vocab_size": len(vocabulary),
    'batch_size': 32,
    'context_window': 16,
    'd_model': 128,
    'epochs': 10,
    'n_heads': 8,
    'n_layers': 4,
}

In [9]:
def process_dataset(dataset, model_config, is_train=False):
  dataset = dataset.map(lambda x: tokenize(x))
  # shift the sequence by 1
  dataset = dataset.map(lambda x: (x[:-1], x[1:]))
  dataset = dataset.unbatch()
  dataset = dataset.batch(model_config['context_window'], drop_remainder=True)
  dataset = dataset.batch(model_config['batch_size'], drop_remainder=True)
  return dataset

processed_train_dataset = process_dataset(train_dataset, MODEL_CONFIG, True)
processed_validation_dataset = process_dataset(validation_dataset, MODEL_CONFIG)
processed_test_dataset = process_dataset(test_dataset, MODEL_CONFIG)

In [None]:
# [OPTIONAL] store training results in google drive
from google.colab import drive
drive.mount('/content/drive')
model_dir = '/content/drive/MyDrive/Colab Notebooks/llama'

In [None]:
'''
# helper function to check for lingering tensorboards
from tensorboard import notebook
notebook.list()
'''

# BaseModel

In [None]:
%load_ext tensorboard
import datetime
log_dir = f"{model_dir}/logs_base/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback_base = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq="batch")

%tensorboard --logdir "$log_dir"

In [11]:
# model return logits rather than normalized probabilities
def custom_loss(y_true, y_pred):
    return tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)

In [None]:
class BaseModel(tf.keras.Model):
  def __init__(self, config):
    super().__init__()
    print(config)
    self.config = config
    self.embedding = tf.keras.layers.Embedding(input_dim=config['vocab_size'], output_dim=config['d_model'])
    self.dense = tf.keras.models.Sequential(
        [tf.keras.layers.Dense(config['d_model']),
         tf.keras.layers.ReLU(),
         tf.keras.layers.Dense(config['vocab_size'])]
    )

  def call(self, inputs):
    x = self.embedding(inputs)
    logits = self.dense(x)
    return logits

# store checkpoint
checkpoint_callback_base = tf.keras.callbacks.ModelCheckpoint(
    filepath=f'{model_dir}/tmp/checkpoint_base',
    save_best_only=True,
    monitor='val_loss',
    mode='min',
    save_weights_only=False,
    save_freq='epoch'  # save the model after each epoch
)

model = BaseModel(config=MODEL_CONFIG)

model.compile(optimizer='adam',
              loss=custom_loss,
              metrics=['accuracy'])
model.fit(processed_train_dataset, validation_data=processed_validation_dataset, epochs=MODEL_CONFIG['epochs'], callbacks=[checkpoint_callback_base, tensorboard_callback_base])

In [12]:
def generate_text(model, config, sentence_count=5, max_new_tokens=50):
    idx = tf.zeros([sentence_count, 1], dtype=tf.int64)
    for _ in range(max_new_tokens):
        logits  = model(idx[:, -config['context_window']:])
        # get the distribution of the last token
        p = logits[:, -1,:]
        # use the distribution p to sample the next token
        idx_next = tf.random.categorical(p, num_samples=1, dtype=tf.int64)
        idx = tf.concat([idx, idx_next], axis=-1)
    return [decode(line) for line in idx.numpy()]

In [None]:
texts = generate_text(model, MODEL_CONFIG)
texts

# RMSNorm

In [13]:
class RMSNormLayer(tf.keras.layers.Layer):
    def __init__(self, layer_shape, eps=1e-8, bias=False):
        super().__init__()
        self.scale = self.add_weight("scale", shape=layer_shape, initializer="ones", trainable=True)
        self.eps = eps
        if bias:
            self.bias = self.add_weight("bias", shape=layer_shape, initializer="zeros", trainable=True)
        else:
            self.bias = None

    def call(self, x):
        # Frobenius norm
        fro_norm = tf.norm(x, ord='fro', axis=[1, 2]) * tf.math.pow(tf.cast(tf.reduce_prod(tf.shape(x[0])), tf.float32), -0.5)
        normalized = x / (tf.expand_dims(tf.expand_dims(fro_norm, -1), -1) + self.eps)
        scaled = tf.expand_dims(self.scale[:x.shape[1], :], 0) * normalized

        if self.bias is not None:
            return scaled + self.bias
        else:
            return scaled

In [21]:
# test RMSNormLayer

# use different configuration values to make sure
# the layer is able to handle all
config = {
    'batch_size': 5,
    'context_window': 11,
    'd_model': 13,
}

def test_output_shape():
  batch = tf.random.normal((config['batch_size'], config['context_window'], config['d_model']))
  rms_layer = RMSNormLayer((config['context_window'], config['d_model']))
  output = rms_layer(batch)
  tf.debugging.assert_shapes([(output, (config['batch_size'], config['context_window'], config['d_model']))])
test_output_shape()

def test_layer_initialization():
  rms_layer = RMSNormLayer((config['context_window'], config['d_model']))
  tf.debugging.assert_near(tf.reduce_mean(rms_layer.scale), 1.0, atol=1e-5)
test_layer_initialization()

def test_output_variance():
    batch = tf.random.normal((config['batch_size'], config['context_window'], config['d_model']))
    rms_layer = RMSNormLayer((config['context_window'], config['d_model']))
    output = rms_layer(batch)
    variance = tf.math.reduce_variance(output)
    tf.debugging.assert_near(variance, 1.0, atol=1e-1)
test_output_variance()

def test_frobenius_norm_calculation():
    batch = tf.random.normal((config['batch_size'], config['context_window'], config['d_model']))
    rms = tf.norm(batch, axis=(1, 2)) * (tf.cast(tf.size(batch[0]), tf.float32) ** -0.5)
    assert rms.shape == (config['batch_size'],), "RMS shape is incorrect"
test_frobenius_norm_calculation()

def test_tf_norm_equivalence():
    tf.debugging.assert_near(
        tf.norm(tf.range(5, dtype=tf.float32)),
        tf.math.sqrt(tf.reduce_sum(tf.square(tf.range(5, dtype=tf.float32))))
    )
test_tf_norm_equivalence()

def test_normalized_tensor_norm():
    rms_single = tf.norm(tf.range(5, dtype=tf.float32)) * (tf.cast(tf.size(tf.range(5, dtype=tf.float32)), tf.float32) ** -0.5)
    tf.debugging.assert_near(
        tf.norm(tf.range(5, dtype=tf.float32) / rms_single),
        tf.math.sqrt(tf.constant(5, dtype=tf.float32))
    )
test_normalized_tensor_norm()

def test_ff_rms_calculation():
    batch = tf.random.normal((config['batch_size'], config['context_window'], config['d_model']))
    ff_rms = tf.norm(batch, axis=(1, 2)) * (tf.cast(tf.math.reduce_prod(tf.shape(batch)[1:]), tf.float32) ** -0.5)
    assert ff_rms.shape == (config['batch_size'],), "FF RMS shape is incorrect"
test_ff_rms_calculation()

def test_per_batch_item_normalization():
    batch = tf.random.normal((config['batch_size'], config['context_window'], config['d_model']))
    ff_rms = tf.norm(batch, axis=(1, 2)) * (tf.cast(tf.math.reduce_prod(tf.shape(batch)[1:]), tf.float32) ** -0.5)

    ffx = tf.zeros_like(batch)
    for i in range(tf.shape(batch)[0]):
        ffx = tf.tensor_scatter_nd_update(ffx, [[i]], [batch[i] / ff_rms[i]])

    tf.debugging.assert_near(
        tf.square(tf.norm(ffx, axis=(1, 2))),
        tf.constant(143, dtype=tf.float32)
    )
test_per_batch_item_normalization()

def test_rmsnorm_layer_output():
    batch = tf.random.normal((config['batch_size'], config['context_window'], config['d_model']))
    rms_layer = RMSNormLayer((config['context_window'], config['d_model']))
    g = rms_layer(batch)

    ff_rms = tf.norm(batch, axis=(1, 2)) * (tf.cast(tf.math.reduce_prod(tf.shape(batch)[1:]), tf.float32) ** -0.5)
    ffx = tf.zeros_like(batch)
    for i in range(tf.shape(batch)[0]):
        ffx = tf.tensor_scatter_nd_update(ffx, [[i]], [batch[i] / ff_rms[i]])

    tf.debugging.assert_near(ffx, g)
test_rmsnorm_layer_output()

In [None]:
%load_ext tensorboard
import datetime
log_dir = f"{model_dir}/rms_logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback_rms = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq="batch")

%tensorboard --logdir "$log_dir"

In [None]:
class RMSModel(tf.keras.Model):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.embedding = tf.keras.layers.Embedding(input_dim=config['vocab_size'], output_dim=config['d_model'])
    self.rms = RMSNormLayer((config['context_window'], config['d_model'])) # new layer
    self.dense = tf.keras.models.Sequential(
        [tf.keras.layers.Dense(units=config['d_model']),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Dense(units=config['vocab_size'])]
    )

  def call(self, inputs):
    x = self.embedding(inputs)
    x = self.rms(x) # new layer
    logits = self.dense(x)
    return logits

# Initialize the ModelCheckpoint callback
checkpoint_callback_rms = tf.keras.callbacks.ModelCheckpoint(
    filepath=f'{model_dir}/tmp/checkpoint_rms',
    save_best_only=True,
    monitor='val_loss',
    mode='min',
    save_weights_only=False,
    save_freq='epoch'  # save the model after each epoch
)

model = RMSModel(config=MODEL_CONFIG)
model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
model.fit(processed_train_dataset,  validation_data=processed_validation_dataset, epochs=MODEL_CONFIG['epochs'], callbacks=[tensorboard_callback_rms, checkpoint_callback_rms])

In [None]:
texts = generate_text(model, MODEL_CONFIG)
texts

# RoPE

In [26]:
def get_rotary_matrix(context_window, embedding_dim):
  matrix = tf.zeros((context_window, embedding_dim, embedding_dim))

  for position in range(context_window):
    for i in range(embedding_dim // 2):
      theta = 10000. ** (-2. * (i - 1) / embedding_dim)
      m_theta = position * theta

      matrix = tf.tensor_scatter_nd_update(matrix, [[position, 2 * i, 2 * i]], [tf.math.cos(m_theta)])
      matrix = tf.tensor_scatter_nd_update(matrix, [[position, 2 * i, 2 * i + 1]], [-tf.math.sin(m_theta)])
      matrix = tf.tensor_scatter_nd_update(matrix, [[position, 2 * i + 1, 2 * i]], [tf.math.sin(m_theta)])
      matrix = tf.tensor_scatter_nd_update(matrix, [[position, 2 * i + 1, 2 * i + 1]], [tf.math.cos(m_theta)])

  return matrix

In [28]:
config = {
  'd_model': 128,
  'context_window': 16,
}

def test_rotary_matrix_shape():
  matrix = get_rotary_matrix(config['context_window'], config['d_model'])
  tf.debugging.assert_shapes([(matrix, (config['context_window'], config['d_model'], config['d_model']))])
test_rotary_matrix_shape()

def test_rotary_matrix_values():
  matrix = get_rotary_matrix(config['context_window'], config['d_model'])
  for position in range(config['context_window']):
    for i in range(config['d_model'] // 2):
      theta = 10000. ** (-2. * (i - 1) / config['d_model'])
      m_theta = position * theta
      assert matrix[position, 2 * i, 2 * i] == tf.math.cos(m_theta)
      assert matrix[position, 2 * i, 2 * i + 1] == -tf.math.sin(m_theta)
      assert matrix[position, 2 * i + 1, 2 * i] == tf.math.sin(m_theta)
      assert matrix[position, 2 * i + 1, 2 * i + 1] == tf.math.cos(m_theta)
test_rotary_matrix_values()

def test_rotary_matrix_multiplication_properties():
  rotary_matrix = get_rotary_matrix(config['context_window'], config['d_model'])

  # Random tensors for x and y
  x = tf.random.normal((config['d_model'],))
  y = tf.random.normal((config['d_model'],))

  m = 3
  n = 13

  # Matrix-vector multiplication
  x_m = tf.linalg.matvec(rotary_matrix[m, :, :], x)
  x_n = tf.linalg.matvec(rotary_matrix[n, :, :], y)

  tf.debugging.assert_near(
    tf.tensordot(x_m, x_n, axes=1),
    tf.tensordot(x, tf.linalg.matvec(rotary_matrix[n-m, :, :], y), axes=1)
  )
test_rotary_matrix_multiplication_properties()

In [123]:
def generate_square_subsequent_mask(size):
  mask = tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask

In [None]:
generate_square_subsequent_mask(5)

In [133]:
class RoPEAttentionLayer(tf.keras.layers.Layer):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.w_q = tf.keras.layers.Dense(config['d_model'], use_bias=False)
    self.w_k = tf.keras.layers.Dense(config['d_model'], use_bias=False)
    self.w_v = tf.keras.layers.Dense(config['d_model'], use_bias=False)

    self.multihead = tf.keras.layers.MultiHeadAttention(config['n_heads'], config['d_model'], dropout=0.1)

    self.rotary_matrix = self.get_rotary_matrix(config['context_window'], config['d_model'])


  def get_rotary_matrix(self, context_window, embedding_dim):
    matrix = tf.zeros((context_window, embedding_dim, embedding_dim))

    for position in range(context_window):
      for i in range(embedding_dim // 2):
        theta = 10000. ** (-2. * (i - 1) / embedding_dim)
        m_theta = position * theta

        matrix = tf.tensor_scatter_nd_update(matrix, [[position, 2 * i, 2 * i]], [tf.math.cos(m_theta)])
        matrix = tf.tensor_scatter_nd_update(matrix, [[position, 2 * i, 2 * i + 1]], [-tf.math.sin(m_theta)])
        matrix = tf.tensor_scatter_nd_update(matrix, [[position, 2 * i + 1, 2 * i]], [tf.math.sin(m_theta)])
        matrix = tf.tensor_scatter_nd_update(matrix, [[position, 2 * i + 1, 2 * i + 1]], [tf.math.cos(m_theta)])

    return matrix

  def call(self, x):
    _, context_window, _ = x.shape
    q = self.w_q(x)
    k = self.w_k(x)
    v = self.w_v(x)
    q_out = tf.linalg.matmul(tf.transpose(q, perm=[1, 0, 2]), self.rotary_matrix[:context_window, ...])
    q_out = tf.transpose(q_out, perm=[1, 0, 2])
    k_out = tf.linalg.matmul(tf.transpose(k, perm=[1, 0, 2]), self.rotary_matrix[:context_window, ...])
    k_out = tf.transpose(k_out, perm=[1, 0, 2])
    v_out = tf.linalg.matmul(tf.transpose(v, perm=[1, 0, 2]), self.rotary_matrix[:context_window, ...])
    v_out = tf.transpose(v_out, perm=[1, 0, 2])

    activations = self.multihead(
        q_out, k_out, v_out,
        attention_mask=generate_square_subsequent_mask(context_window),
        return_attention_scores=False,
        use_causal_mask=True
    )
    return activations


In [None]:
#%load_ext tensorboard
import datetime
log_dir = f"{model_dir}/logs_rope/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback_rope = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq="batch")

%tensorboard --logdir "$log_dir"

In [None]:
class RoPEModel(tf.keras.Model):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.embedding = tf.keras.layers.Embedding(input_dim=config['vocab_size'], output_dim=config['d_model'])
    self.rms = RMSNormLayer((config['context_window'], config['d_model']))
    self.rope_attention = RoPEAttentionLayer(config)
    self.dense = tf.keras.models.Sequential(
        [tf.keras.layers.Dense(units=config['d_model']),
         tf.keras.layers.ReLU(),
        ],

    )
    self.dense_last = tf.keras.layers.Dense(units=config['vocab_size'])

  def call(self, x):
    x = self.embedding(x)

    x = self.rms(x)
    x = x + self.rope_attention(x)

    x = self.rms(x)
    x = x + self.dense(x)
    logits = self.dense_last(x)
    return logits

# Initialize the ModelCheckpoint callback
checkpoint_callback_rope = tf.keras.callbacks.ModelCheckpoint(
    filepath=f'{model_dir}/tmp/checkpoint_rope',
    save_best_only=True,
    monitor='val_loss',
    mode='min',
    save_weights_only=False,
    save_freq='epoch'  # save the model after each epoch
)

model = RoPEModel(config=MODEL_CONFIG)
model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
model.fit(processed_train_dataset, validation_data=processed_validation_dataset, epochs=MODEL_CONFIG['epochs'], callbacks=[tensorboard_callback_rope, checkpoint_callback_rope])

In [None]:
texts = generate_text(model, MODEL_CONFIG)
texts

# SwiGLU

In [137]:
class SwiGLU(tf.keras.layers.Layer):
  def __init__(self, size):
    super().__init__()
    self.linear_gate = tf.keras.layers.Dense(units=size)
    self.linear = tf.keras.layers.Dense(units=size)
    self.beta = self.add_weight(name='beta', shape=(1,), initializer='random_normal', trainable=True)

  def call(self, x):
    swish_gate = self.linear_gate(x) * tf.sigmoid(self.beta * self.linear_gate(x))
    swi_glu = swish_gate * self.linear(x)
    return swi_glu

In [None]:
#%load_ext tensorboard
import datetime
log_dir = f"{model_dir}/logs_swiglu/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback_swiglu = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq="batch")

%tensorboard --logdir "$log_dir"

In [None]:
class SwiGLUModel(tf.keras.Model):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.embedding = tf.keras.layers.Embedding(input_dim=config['vocab_size'], output_dim=config['d_model'])
    self.rms = RMSNormLayer((config['context_window'], config['d_model']))
    self.rope_attention = RoPEAttentionLayer(config)
    self.dense = tf.keras.models.Sequential(
        [tf.keras.layers.Dense(units=config['d_model']),
         SwiGLU(config['d_model']),
        ],

    )
    self.dense_last = tf.keras.layers.Dense(units=config['vocab_size'])

  def call(self, x):
    x = self.embedding(x)

    x = self.rms(x)
    x = x + self.rope_attention(x)

    x = self.rms(x)
    x = x + self.dense(x)
    logits = self.dense_last(x)
    return logits

# Initialize the ModelCheckpoint callback
checkpoint_callback_swiglu = tf.keras.callbacks.ModelCheckpoint(
    filepath=f'{model_dir}/tmp/checkpoint_swiglu',
    save_best_only=True,
    monitor='val_loss',
    mode='min',
    save_weights_only=False,
    save_freq='epoch'  # save the model after each epoch
)

model = SwiGLUModel(config=MODEL_CONFIG)
model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
model.fit(processed_train_dataset, validation_data=processed_validation_dataset, epochs=MODEL_CONFIG['epochs'], callbacks=[tensorboard_callback_swiglu, checkpoint_callback_swiglu])

In [None]:
%load_ext tensorboard
import datetime
log_dir = f"{model_dir}/logs_baby/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback_baby = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq="batch")

%tensorboard --logdir "$log_dir"

# Putting Evenrything Together: BabyLLaMA

In [None]:
class AttentionBlock(tf.keras.layers.Layer):
  def __init__(self, config):
    super().__init__()
    self.rms = RMSNormLayer((config['context_window'], config['d_model']))
    self.rope_attention = RoPEAttentionLayer(config)
    self.dense = tf.keras.models.Sequential(
      [tf.keras.layers.Dense(units=config['d_model']),
      SwiGLU(config['d_model']),
      ]
    )

  def call(self, x):
    x = self.rms(x)
    x = x + self.rope_attention(x)

    x = self.rms(x)
    x = x + self.dense(x)

    return x

class BabyLLaMA(tf.keras.Model):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.embedding = tf.keras.layers.Embedding(input_dim=config['vocab_size'], output_dim=config['d_model'])
    self.attention_blocks = tf.keras.models.Sequential(
        [
          AttentionBlock(config) for _ in range(config['n_layers'])
        ]
    )
    self.dense = tf.keras.models.Sequential(
      [tf.keras.layers.Dense(units=config['d_model']),
       SwiGLU(config['d_model']),
       tf.keras.layers.Dense(units=config['vocab_size']),
      ]
    )

  def call(self, x):
    x = self.embedding(x)
    x = self.attention_blocks(x)
    logits = self.dense(x)
    return logits

# Initialize the ModelCheckpoint callback
checkpoint_callback_baby = tf.keras.callbacks.ModelCheckpoint(
    filepath=f'{model_dir}/tmp/checkpoint_baby',
    save_best_only=True,
    monitor='val_loss',
    mode='min',
    save_weights_only=False,
    save_freq='epoch'  # save the model after each epoch
)

model = BabyLLaMA(config=MODEL_CONFIG)
model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
model.fit(processed_train_dataset, validation_data=processed_validation_dataset, epochs=MODEL_CONFIG['epochs'], callbacks=[tensorboard_callback_baby, checkpoint_callback_baby])

In [None]:
texts = generate_text(model, MODEL_CONFIG, 5, 500)
texts

In [None]:
%load_ext tensorboard
import datetime
log_dir = f"{model_dir}/logs_final/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback_final = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq="batch")

%tensorboard --logdir "$log_dir"

In [None]:
initial_learning_rate = 1e-3
decay_steps = 300
end_learning_rate = 1e-5

lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate, decay_steps, alpha=end_learning_rate)

custom_optimizer = tf.keras.optimizers.Adam(
    learning_rate=lr_schedule,
    beta_1=0.9,
    beta_2=0.95,
    epsilon=1e-9,
    weight_decay=0.1
)

# Initialize the ModelCheckpoint callback
checkpoint_callback_final = tf.keras.callbacks.ModelCheckpoint(
    filepath=f'{model_dir}/tmp/checkpoint_final',
    save_best_only=True,
    monitor='val_loss',
    mode='min',
    save_weights_only=False,
    save_freq='epoch'  # save the model after each epoch
)

model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
model.fit(processed_train_dataset, validation_data=processed_validation_dataset, epochs=MODEL_CONFIG['epochs'], callbacks=[tensorboard_callback_final, checkpoint_callback_final])


In [None]:
model.load_weights(f'{model_dir}/tmp/checkpoint_final')

In [None]:
texts = generate_text(model, MODEL_CONFIG, 5, 500)
texts