Goal: Build language model.

References:
- http://karpathy.github.io/2015/05/21/rnn-effectiveness/
  - https://github.com/karpathy/char-rnn
  - https://cs.stanford.edu/people/karpathy/char-rnn/
  - https://gist.github.com/karpathy/587454dc0146a6ae21fc
- https://www.tensorflow.org/text/tutorials/text_generation

In [None]:
import tensorflow as tf
import numpy as np
ks = tf.keras
print("TensorFlow version:", tf.__version__)

import urllib
import math

In [None]:
# Install the tensor2tensor library which contains useful functions for the attention mechanism.
!pip3 install tensor2tensor
from tensor2tensor.layers.common_attention import dot_product_attention

# Get the data

Using a Shakespeare dataset

In [None]:
# Karpathy's datasets used in his blog post,
# http://karpathy.github.io/2015/05/21/rnn-effectiveness/,
# and listed here: https://cs.stanford.edu/people/karpathy/char-rnn/.

TEXT_URL = {
    'shakespeare': 'https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt',
    'linux': 'https://cs.stanford.edu/people/karpathy/char-rnn/linux_input.txt',
    'tolstoy': 'https://cs.stanford.edu/people/karpathy/char-rnn/warpeace_input.txt',
}['shakespeare']  # Select a dataset

with urllib.request.urlopen(TEXT_URL) as f:
  text = f.read()

print(f'Length of text: {len(text)} characters')

In [None]:
# Note that the text is stored as a byte string
print(type(text))

In [None]:
# look at sample of the data
print(text[2100:2600].decode("utf-8"))

In [None]:
# This will be character level model. Serious language models use word pieces (https://paperswithcode.com/method/wordpiece).

# Make a numpy array of ASCII chars
raw_seq = np.frombuffer(text, dtype=np.uint8)

# Token ID to ascii code conversion
token_to_ascii = np.array(sorted(set(raw_seq)))
VOCAB_SIZE = len(token_to_ascii)

# Ascii code to token ID conversion
ascii_to_token = np.full(256, -1, np.int_)
for token, ascii in enumerate(token_to_ascii):
  ascii_to_token[ascii] = token

# Convert ascii array to token ID array
token_seq = ascii_to_token[raw_seq]

print('vocab size:',VOCAB_SIZE)
print('seq:', token_seq[:20])
print(token_seq.shape)
print(token_seq.dtype)
print('\ntoken_to_char:', token_to_ascii)
print('any invalid?', np.any(token_seq == -1))
print('min:', np.min(token_seq),'  max:', np.max(token_seq))

In [None]:
BATCH_SIZE = 32  # 16
CONTEXT_SIZE = 100  #  500;  truncated sequence length
PAD_CHAR = token_to_ascii[0]
PAD_LEN = math.ceil(token_seq.size / (BATCH_SIZE*CONTEXT_SIZE)) * BATCH_SIZE*CONTEXT_SIZE - token_seq.size

parallel_seq = np.append(token_seq, [PAD_CHAR]*PAD_LEN).reshape(BATCH_SIZE, -1)

# pad with beginning of sequences from next row
full_batches = 2  # How many full batches end of each row should bleed into start of next row
parallel_seq = np.concatenate((parallel_seq, np.roll(parallel_seq[:,:CONTEXT_SIZE*full_batches+1],-1,0)),1)
print('shape:', parallel_seq.shape)

NUM_BATCHES = (parallel_seq.shape[1]-1) // CONTEXT_SIZE
print('num batches:', NUM_BATCHES)
print('assert',parallel_seq.size - NUM_BATCHES*BATCH_SIZE*CONTEXT_SIZE,'==',BATCH_SIZE)

In [None]:
def get_batch(batch_i, offset=0):
  # When offset==0 we have a training batch, and when offset==1 we have the training targets
  return parallel_seq[:, batch_i*CONTEXT_SIZE+offset: (batch_i+1)*CONTEXT_SIZE+offset]

# get an example batch
print(get_batch(0))
print('')
print(get_batch(NUM_BATCHES-1))

In [None]:
# Human readable render of first training batch
[row.tobytes().decode('utf8') for row in token_to_ascii[get_batch(0)]]

In [None]:
# Show training targets for the above batch
[row.tobytes().decode('utf8') for row in token_to_ascii[get_batch(0, offset=1)]]

In [None]:
# Second to last training batch. Each line is now the next line down in the first batch
[row.tobytes().decode('utf8') for row in token_to_ascii[get_batch(NUM_BATCHES-2)]]

# Define the model

In [None]:
# TODO: implement `dot_product_attention` myself

In [None]:
# Define our model

CELL_CLS = {
    'rnn': ks.layers.SimpleRNNCell,
    'lstm': ks.layers.LSTMCell,
    'gru': ks.layers.GRUCell,
}['lstm']

class Model(ks.Model):

  def __init__(self, use_attn=False, use_cnn=False):
    super(Model, self).__init__()
    self.embedding_size = 20
    self.input_embed = ks.layers.Dense(self.embedding_size)
    self.cells = [CELL_CLS(100)]  # , CELL_CLS(50)]
    self.output_stack = [ks.layers.Dense(VOCAB_SIZE)]
    self.conv1d = ks.layers.Conv1D(filters=self.embedding_size, kernel_size=4, padding='causal')  # 'causal' convolutions only depend on inputs to the left (and center) of the current position
    self.num_attn_heads = 10
    self.query_embed = ks.layers.Dense(self.num_attn_heads * self.embedding_size//2)  # self.embedding_size//2 is the size of the query and key vectors
    self.use_attn = use_attn
    self.use_cnn = use_cnn

  def call(self, x, s=None, more_context=None):
    # `x` is the input tensor and `s` is the recurrent state
    # `more_context` is an optional tensor with shape (batch_size, extra_context_size).
    #     It is used to give the attention mechanism additional timesteps to look at.
    #     The context window for the attention is then the time-axis concatenation of
    #     `more_context`  and `x`, i.e. `tf.concat(more_context, x, axis=1)`.

    # Expecting x.shape == (batch_size, context_size), where batch_size and context_size can be variable from run to run
    bs, cs = tf.unstack(tf.shape(x))
    x = tf.one_hot(x, VOCAB_SIZE)  # shape == (batch_size, context_size, VOCAB_SIZE), where VOCAB_SIZE is a global constant

    if s is None:
      s = [cell.get_initial_state(batch_size=bs, dtype=tf.float32) for cell in self.cells]
    else:
      s = list(s)  # Make a copy of the input list since we will modify it in place

    # Embed one-hot tokens
    e = self.input_embed(x)  # shape == (batch_size, context_size, embedding_size)
    # Note: ks.layers.Embedding does the same thing but more efficiently for large vocabularies

    if self.use_cnn:
      # 1D convolution across time puts neighbor information into each embedding in the sequence 
      e = self.conv1d(e)

    if self.use_attn:
      if more_context is None:
        extra_cs = 0
        full_context = e
      else:
        more_context = tf.one_hot(more_context, VOCAB_SIZE)
        extra_cs = tf.shape(more_context)[1]  # size of time dim on more_context
        x_ = tf.concat((more_context, x), axis=1)
        full_context = self.input_embed(x_)
        if self.use_cnn:
          full_context = self.conv1d(full_context)
      # Split embedding dimension into two sectors: key and value.
      # That gives us a key and value pair for each timestep.
      num_features = tf.shape(full_context)[-1]
      k = full_context[:, :, :num_features//2]
      v = full_context[:, :, num_features//2:]
      k_size = tf.shape(k)[-1]  # should equal self.embedding_size//2

    # Recurrent cell stack
    outputs = []
    for t, h in enumerate(tf.unstack(e, axis=1)):
      if self.use_attn:
        # Query is computed from the current input and recurrent states.
        query_context = tf.concat(tf.nest.flatten([h, s]), axis=1)
        q = tf.reshape(self.query_embed(query_context), (bs, self.num_attn_heads, k_size))
        # https://github.com/tensorflow/tensor2tensor/blob/c8fe559e0b357389d8754474e1306b6ca9afc4f3/tensor2tensor/layers/common_attention.py#L1602
        # We slice `k` and `v` so that the future is not included
        attn_result = dot_product_attention(q, k[:, :extra_cs+t+1], v[:, :extra_cs+t+1], bias=None, make_image_summary=False)
        # attn_result shape is (batch_size, self.num_attn_heads, k_size)
        attn_result = tf.reshape(attn_result, (bs, -1))  # flatten last two dims
        h = tf.concat((h, attn_result), axis=-1)  # Concat current step input embedding with attention result
      for l, cell in enumerate(self.cells):
        h, s[l] = cell(h, s[l])
      outputs.append(h)

    # Feed forward stack
    h = tf.stack(outputs, axis=1)  # stack along the time axis
    for layer in self.output_stack:
      h = layer(h)
    return h, s

# Training loop

In [None]:
learning_rate = 1e-3

# model = Model()
model = Model(use_attn=True, use_cnn=True)

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True,  # predictions will be given as logits (log unnormalized probabilities) rather than probabilities
)

optimizer = tf.keras.optimizers.Adam()

# Use GPU if available.
# https://www.tensorflow.org/guide/gpu
GPUs = tf.config.list_physical_devices('GPU')
device = '/GPU:0' if GPUs else '/CPU:0'
print('device =', device)

In [None]:
@tf.function
def train_step(batch, labels, state=None):
  with tf.GradientTape() as tape:
    # training=True is only needed if there are layers with different
    # behavior during training versus inference (e.g. Dropout).
    logits, state_out = model(batch, state, training=True)
    loss = loss_object(labels, logits)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  return loss, logits, state_out


@tf.function
def accuracy(logits, target, normalize=True):
  argmaxs = tf.math.argmax(logits, axis=-1)
  corrects = tf.math.equal(argmaxs, target)
  if normalize:
    return tf.reduce_mean(tf.cast(corrects, tf.float32))
  else:
    return tf.math.count_nonzero(corrects)

In [None]:
# Start tensorboard (optional)
# This will embed a tensorboard front-end in the output of this cell, which will display training graphs in realtime.
# See https://colab.research.google.com/github/tensorflow/tensorboard/blob/master/docs/tensorboard_in_notebooks.ipynb
%load_ext tensorboard
%tensorboard --logdir logs

In [None]:
tb_writer = tf.summary.create_file_writer('logs')  # Tensorboard writer
global_step = 0

In [None]:
num_epochs = 100
for epoch in range(num_epochs):
  # Performing truncated backprop through time (TBPTT).
  # States are carried over between batches, but gradients are not propagated beyond a batch.
  # At the end of each epoch the state is reset to its default (typically all zeros).
  state = None  # None tells the model to use the default state
  for batch_i in range(NUM_BATCHES):  
    # Move tensors to the configured device
    batch = get_batch(batch_i)
    labels = get_batch(batch_i, offset=1)
    with tf.device(device):
      loss_, logits_, state = train_step(batch, labels, state)
    
    global_step += 1

    if batch_i % 10 == 0:
      loss_ = loss_.numpy()
      acc_ = accuracy(logits_, labels).numpy()
      print('  Step: %d out of %d | Train Loss: %.4f | Train Accuracy: %.2f' % (batch_i, NUM_BATCHES, loss_, acc_))
      with tb_writer.as_default():
        tf.summary.scalar('train_loss', loss_, step=global_step)
        tf.summary.scalar('train_accuracy', acc_, step=global_step)

  # Save model checkpoint
  # model.save(f'./training_checkpoints/ckpt_{epoch}')

  print('')
  print('Finished epoch')
  print('')

In [None]:
# Manual save model
model.save(f'./training_checkpoints/ckpt_{epoch}')

In [None]:
# Inspect predictions
l, _ = model(get_batch(1000, offset=0))
p = tf.nn.softmax(l[0], axis=-1).numpy()
print(np.argmax(p, axis=1))
print('')
print(get_batch(1000, offset=1)[0])

## Load checkpoint

Reference: https://www.tensorflow.org/guide/keras/save_and_serialize

In [None]:
%ls training_checkpoints

In [None]:
model_copy = ks.models.load_model('./training_checkpoints/ckpt_0')
model_copy.compile()

# Generate text

In [None]:
MODE = 'sample'  # 'argmax'
GENERATE_LENGTH = 1000
PROMPT = """
ROMEO:"""

# Sampling temperature.
# Lower temperature means peakier distribution.
# As the temp goes to 0, the distribution approaches one-hot (equivalent to taking the argmax)
temp = 1.0

prompt = np.frombuffer(bytes(PROMPT, 'utf-8'), dtype=np.uint8)
prompt = ascii_to_token[prompt]
_, state = model(prompt[None,:-1])  # Process prompt and get resulting recurrent state
generated = prompt[None,:]  # Last token in prompt is the input to the first generating step
state = None
for n in range(GENERATE_LENGTH):
  logits, state = model(generated[:, -1:], state, more_context=generated[:, :-1])
  if MODE == 'sample':
    next = tf.random.categorical(logits[:, 0]/temp, num_samples=1)
  else:  # MODE == 'argmax'
    next = tf.math.argmax(logits, axis=-1)
  generated = np.concatenate((generated, next.numpy()), axis=1)

In [None]:
print(token_to_ascii[generated].tobytes().decode('utf8'))

In [None]:
# TODO: implement beam search