# Setup Imports

In [None]:
!pip install requests tqdm regex tensorflow_datasets

In [1]:
import tensorflow as tf
import os
import json
import sys
import requests
from tqdm import tqdm
import regex as re
from functools import lru_cache
import numpy as np
import tensorflow_datasets as tfds

tf.__version__

'2.0.0-beta0'

# Download Weights and Encoding Files
Copied from https://github.com/openai/gpt-2

In [None]:
subdir = os.path.join('models', "345M")
if not os.path.exists(subdir):
    os.makedirs(subdir)

for filename in ['checkpoint','encoder.json', 'model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']:

    r = requests.get("https://storage.googleapis.com/gpt-2/" + subdir + "/" + filename, stream=True)

    with open(os.path.join(subdir, filename), 'wb') as f:
        file_size = int(r.headers["content-length"])
        chunk_size = 1000
        with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
            # 1k for chunk_size, since Ethernet packet size is around 1500 bytes
            for chunk in r.iter_content(chunk_size=chunk_size):
                f.write(chunk)
                pbar.update(chunk_size)

# Setup Training Data

In [None]:
!wget -O ./data/training.txt https://raw.githubusercontent.com/micheletufano/NeuralCodeTranslator/master/dataset/bug-fixes/medium/train/fixed.txt # http://groups.inf.ed.ac.uk/cup/javaGithub/java_projects.tar.gz
!ls ./data/

In [2]:
# Generate a Dataset from textfile where each example is on a separate line 
trn_exmpls = tf.data.TextLineDataset(os.path.join("data", "training.txt"))

In [4]:
# Examine the first 5 method examples
for method in trn_exmpls.take(5):
    print(method)
# trn_exmpls = trn_exmpls.batch(2)
# trn_exmpls.

tf.Tensor(b'public static TYPE_1 init ( java.lang.String name , java.util.Date date ) { TYPE_1 VAR_1 = new TYPE_1 ( ) ; VAR_1 . METHOD_1 ( name ) ; java.util.Calendar VAR_2 = null ; if ( date != null ) { VAR_2 = java.util.Calendar.getInstance ( ) ; VAR_2 . METHOD_2 ( date ) ; } VAR_1 . METHOD_3 ( VAR_2 ) ; return VAR_1 ; }', shape=(), dtype=string)
tf.Tensor(b'public TYPE_1 METHOD_1 ( java.lang.String name ) { if ( name . equals ( STRING_3 ) ) return new TYPE_3 ( STRING_4 , true ) ; if ( name . equals ( STRING_5 ) ) return new TYPE_4 ( ) ; return super . METHOD_1 ( name ) ; }', shape=(), dtype=string)
tf.Tensor(b'private boolean METHOD_1 ( TYPE_1 VAR_1 ) { boolean VAR_2 = ( VAR_3 . compareTo ( VAR_1 . METHOD_2 ( ) ) ) < 0 ; VAR_2 = VAR_2 || ( ! ( VAR_1 . METHOD_3 ( ) . METHOD_4 ( ) . equals ( VAR_4 ) ) ) ; return VAR_2 ; }', shape=(), dtype=string)
tf.Tensor(b'public void METHOD_1 ( TYPE_1 VAR_1 , boolean VAR_2 ) { if ( VAR_2 ) { VAR_3 . METHOD_2 ( 1 , CHAR_1 ) ; VAR_4 . METHOD_3 ( VAR

In [5]:
# Perform Byte Pair Encoding tokenization on the dataset
tokenizer_method = tfds.features.text.SubwordTextEncoder.build_from_corpus(
    (method.numpy() for method in trn_exmpls), target_vocab_size=2**13)

In [6]:
# Examine the tokenization
sample_string = 'This is out of vocabz~ if _ java.'

tokenized_string = tokenizer_method.encode(sample_string)
print ('Tokenized string is {}'.format(tokenized_string))

original_string = tokenizer_method.decode(tokenized_string)
print ('The original string: {}'.format(original_string))

# Prints out the words as tokens. If word does not exist, it will break it down into individual tokens
for ts in tokenized_string:
  print ('{} ----> {}'.format(ts, tokenizer_method.decode([ts])))

Tokenized string is [3014, 3034, 3035, 3045, 2962, 3035, 3045, 2962, 1846, 3041, 1669, 3048, 3041, 3029, 3027, 3028, 3052, 3056, 2962, 15, 2962, 3025, 2962, 7, 2976]
The original string: This is out of vocabz~ if _ java.
3014 ----> T
3034 ----> h
3035 ----> i
3045 ----> s
2962 ---->  
3035 ----> i
3045 ----> s
2962 ---->  
1846 ----> out 
3041 ----> o
1669 ----> f 
3048 ----> v
3041 ----> o
3029 ----> c
3027 ----> a
3028 ----> b
3052 ----> z
3056 ----> ~
2962 ---->  
15 ----> if
2962 ---->  
3025 ----> _
2962 ---->  
7 ----> java
2976 ----> .


In [7]:
# Takes a method and generates the input into the model, except the last token,
# and the target output of the model (input shifted to the left by one)
def split_into_inpt_trgt(method):
    input_method = tokenizer_method.encode(method.numpy())[:-1]
    target_method = tokenizer_method.encode(method.numpy())[1:]
    return input_method, target_method

In [8]:
def tf_split_into_inpt_trgt(method):
    return tf.py_function(split_into_inpt_trgt, [method], [tf.int64, tf.int64])

### Following: https://www.tensorflow.org/beta/tutorials/text/text_generation#top_of_page for how to get data into format for language models

In [31]:
# generate, shuffle, cache, and prefetch the training examples
trn_ds = trn_exmpls.map(tf_split_into_inpt_trgt)#.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
trn_ds = trn_ds.cache()
trn_ds = trn_ds.prefetch(tf.data.experimental.AUTOTUNE)
trn_ds

<PrefetchDataset shapes: (<unknown>, <unknown>), types: (tf.int64, tf.int64)>

In [32]:
# Examine the dataset
for inpt_exmpl, outpt_exmpl in trn_ds.take(1):
    print ('Input data: ', repr(''.join(tokenizer_method.decode(inpt_exmpl.numpy()))))
    print ('Target data:', repr(''.join(tokenizer_method.decode(outpt_exmpl.numpy()))))

Input data:  'public static TYPE\\&undsc1 init ( java.lang.String name , java.util.Date date ) { TYPE\\&undsc1 VAR\\&undsc1 = new TYPE\\&undsc1 ( ) ; VAR\\&undsc1 . METHOD\\&undsc1 ( name ) ; java.util.Calendar VAR\\&undsc2 = null ; if ( date != null ) { VAR\\&undsc2 = java.util.Calendar.getInstance ( ) ; VAR\\&undsc2 . METHOD\\&undsc2 ( date ) ; } VAR\\&undsc1 . METHOD\\&undsc3 ( VAR\\&undsc2 ) ; return VAR\\&undsc1'
Target data: 'static TYPE\\&undsc1 init ( java.lang.String name , java.util.Date date ) { TYPE\\&undsc1 VAR\\&undsc1 = new TYPE\\&undsc1 ( ) ; VAR\\&undsc1 . METHOD\\&undsc1 ( name ) ; java.util.Calendar VAR\\&undsc2 = null ; if ( date != null ) { VAR\\&undsc2 = java.util.Calendar.getInstance ( ) ; VAR\\&undsc2 . METHOD\\&undsc2 ( date ) ; } VAR\\&undsc1 . METHOD\\&undsc3 ( VAR\\&undsc2 ) ; return VAR\\&undsc1 ; }'


In [33]:
# Examine the dataset
for i, (inpt_tkn, trgt_tkn) in enumerate(zip(inpt_exmpl[:5], outpt_exmpl[:5])):
    print("Step {:4d}".format(i))
    print("  input: {} ({:s})".format(inpt_tkn, tokenizer_method.decode([inpt_tkn.numpy()])))
    print("  expected output: {} ({:s})".format(trgt_tkn, tokenizer_method.decode([trgt_tkn.numpy()])))

Step    0
  input: 17 (public )
  expected output: 77 (static )
Step    1
  input: 77 (static )
  expected output: 18 (TYPE\&undsc1 )
Step    2
  input: 18 (TYPE\&undsc1 )
  expected output: 282 (init)
Step    3
  input: 282 (init)
  expected output: 1 ( ( )
Step    4
  input: 1 ( ( )
  expected output: 7 (java)


In [34]:
BUFFER_SIZE = 20000
BATCH_SIZE = 64

trn_ds = trn_ds.shuffle(BUFFER_SIZE).padded_batch(BATCH_SIZE, padded_shapes=([-1], [-1]), drop_remainder=True)
inpt_batch, trgt_batch = next(iter(trn_ds))
inpt_batch, trgt_batch
# trn_ds.output_shapes

(<tf.Tensor: id=793582, shape=(64, 91), dtype=int64, numpy=
 array([[ 17,  26,  11, ...,   0,   0,   0],
        [ 17,  26,  11, ...,   0,   0,   0],
        [ 17,  26,  11, ...,   0,   0,   0],
        ...,
        [ 50,  26, 282, ...,   0,   0,   0],
        [ 17,  26,  11, ...,   0,   0,   0],
        [ 17,  66,  11, ...,   0,   0,   0]])>,
 <tf.Tensor: id=793583, shape=(64, 91), dtype=int64, numpy=
 array([[ 26,  11,  42, ...,   0,   0,   0],
        [ 26,  11,  42, ...,   0,   0,   0],
        [ 26,  11,   1, ...,   0,   0,   0],
        ...,
        [ 26, 282,   1, ...,   0,   0,   0],
        [ 26,  11,  42, ...,   0,   0,   0],
        [ 66,  11,  42, ...,   0,   0,   0]])>)

In [35]:
# Length of the vocabulary in chars
vocab_size = tokenizer_method.vocab_size

# The embedding dimension
embedding_dim = 256

# Number of RNN units
rnn_units = 1024

In [36]:
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
  model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, embedding_dim,
                              batch_input_shape=[batch_size, None]),
    tf.keras.layers.LSTM(rnn_units,
                        return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform'),
    tf.keras.layers.Dense(vocab_size)
  ])
  return model

In [37]:
model = build_model(
  vocab_size = vocab_size,
  embedding_dim=embedding_dim,
  rnn_units=rnn_units,
  batch_size=BATCH_SIZE)

In [38]:
input_example_batch, target_example_batch = next(iter(trn_ds))
example_batch_predictions = model(input_example_batch)
print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

(64, 91, 3186) # (batch_size, sequence_length, vocab_size)


In [39]:
model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_1 (Embedding)      (64, None, 256)           815616    
_________________________________________________________________
lstm_1 (LSTM)                (64, None, 1024)          5246976   
_________________________________________________________________
dense_1 (Dense)              (64, None, 3186)          3265650   
Total params: 9,328,242
Trainable params: 9,328,242
Non-trainable params: 0
_________________________________________________________________


In [40]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()

In [41]:
print("Input: \n", repr("".join(tokenizer_method.decode(input_example_batch[0].numpy()))))
print()
print("Next Char Predictions: \n", repr("".join(tokenizer_method.decode(sampled_indices))))

Input: 
 'public void METHOD\\&undsc1 ( ) { if ( ( TYPE\\&undsc2 . METHOD\\&undsc2 ( ) ) == null ) return ; TYPE\\&undsc1 VAR\\&undsc1 = ( ( TYPE\\&undsc1 ) ( TYPE\\&undsc2 . METHOD\\&undsc2 ( ) . METHOD\\&undsc3 ( VAR\\&undsc2 ) ) ) ; if ( VAR\\&undsc1 == null ) return ; VAR\\&undsc1 . METHOD\\&undsc4 ( ) ; TYPE\\&undsc2 . METHOD\\&undsc2 ( ) . METHOD\\&undsc5 ( VAR\\&undsc2'

Next Char Predictions: 
 ' ) ) ) !=  ( ) ) & - boolean  ) >= ( ( ( fail ( ) ) /  ) , ( ) - >  ) ) ) && ( ( ! (  ( ) ) & (  ) ] ) ==  ) ) + ( ( ( - testng] ] ; } -- (  ) } ) ,  [ ] ) {  ( ) ) ; ) {  ) ) ; } } } ) ) < (  ) ) ) ) * (  ) ] ) {  ) ) > (  > > ( ) ;  ) < < (  ) ) ) ) , ( (  ? !  ++ ) ] ) ;  ) ++ ) ] = ( (  ) ++ ; } { } } } \x1c ) ) ] ; } } ) ) ; } } } } }  : } } } > ) (  ; } } } } ) ) ) ) >= (  ] ) != ( ( (  } ) ; } } ) ;  ( ) ) ; } } } ) -- ; } } " ; } } params ( ( ( !  ) )  ( ( ( ( "\\ ) ] = -  = -  ] ) < < (  } ; 1 == ( ( ( throw INT\\&undsc15 ( ) ) != ( ( ( � ) ) & ( ! (  ( ) ] =  > ( ) ) ; } } } ?

In [42]:
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

example_batch_loss  = loss(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("scalar_loss:      ", example_batch_loss.numpy().mean())

Prediction shape:  (64, 91, 3186)  # (batch_size, sequence_length, vocab_size)
scalar_loss:       8.064807


In [43]:
model.compile(optimizer='adam', loss=loss)

In [51]:
# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

In [58]:
EPOCHS=1
history = model.fit(trn_ds, epochs=EPOCHS, callbacks=[checkpoint_callback])

RuntimeError: You must compile your model before training/testing. Use `model.compile(optimizer, loss)`.

In [53]:
model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

model.build(tf.TensorShape([1, None]))

In [54]:
model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_2 (Embedding)      (1, None, 256)            815616    
_________________________________________________________________
lstm_2 (LSTM)                (1, None, 1024)           5246976   
_________________________________________________________________
dense_2 (Dense)              (1, None, 3186)           3265650   
Total params: 9,328,242
Trainable params: 9,328,242
Non-trainable params: 0
_________________________________________________________________


In [55]:
def generate_text(model, start_string):
  # Evaluation step (generating text using the learned model)

  # Number of characters to generate
  num_generate = 1000

  # Converting our start string to numbers (vectorizing)
  input_eval = tokenizer_method.encode(start_string)
  input_eval = tf.expand_dims(input_eval, 0)

  # Empty string to store our results
  text_generated = []

  # Low temperatures results in more predictable text.
  # Higher temperatures results in more surprising text.
  # Experiment to find the best setting.
  temperature = 1.0

  # Here batch size == 1
  model.reset_states()
  for i in range(num_generate):
      predictions = model(input_eval)
      # remove the batch dimension
      predictions = tf.squeeze(predictions, 0)

      # using a categorical distribution to predict the word returned by the model
      predictions = predictions / temperature
      predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()

      # We pass the predicted word as the next input to the model
      # along with the previous hidden state
      input_eval = tf.expand_dims([predicted_id], 0)

      text_generated.append(tokenizer_method.decode([predicted_id]))

  return (start_string + ''.join(text_generated))

In [57]:
print(generate_text(model, start_string=u"public "))

public FLOAT\&undsc2 ) ) ) ) ) > ( y < INT\&undsc1 , 0 , start , length , INT\&undsc2 } ; } }METHOD\&undsc2 ( ) ) > ( number ) ) . equals ( event . length ) ; ( VAR\&undsc1 ) ++ ; return end ) ; }b . METHOD\&undsc4 ( ( ( VAR\&undsc2 ) - count ) | size ) ) ; } } }n ) ) ; }VAR\&undsc3 ; }this . result = METHOD\&undsc6 ( ) ) ) ; }else ; }this . METHOD\&undsc5 ( VAR\&undsc3 , ( x ] = "b ) ++ ; } } return false ; }VAR\&undsc10 ; }this . METHOD\&undsc6 ( y ) ; this . METHOD\&undsc5 ( this . METHOD\&undsc6 ( ) ) ; }this . position ) ++ ; this . METHOD\&undsc1 ( ) ; }VAR\&undsc1 . setEnabled ( this ) ; } catch ( java.lang.Exception VAR\&undsc2 : this . VAR\&undsc1 ) { this . METHOD\&undsc1 ( ) . METHOD\&undsc1 ( ) ; return ; } } } }this . VAR\&undsc1 ) - 1 ; } }this . VAR\&undsc2 += "\else { this . delete ( ) ; } finally { this . METHOD\&undsc2 ( ) ; this . METHOD\&undsc7 ( ) . METHOD\&undsc8 ( true ) ; } this . METHOD\&undsc6 ( ( ( VAR\&undsc2 . METHOD\&undsc1 ( ) ) + 1 ) ) ) ) ) ; } } }this 

# Encoder
Direct copy and paste from  https://github.com/openai/gpt-2 with a few extra comments for clarity

In [None]:
@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 bytes and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

In [None]:
def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

In [None]:
class Encoder:
    def __init__(self, encoder, bpe_merges, errors='replace'):
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}
        self.errors = errors # how to handle errors in decoding
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}

        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)

        if not pairs:
            return token

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
        return text

In [None]:
def get_encoder(model_name, models_dir):
    with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
        encoder = json.load(f)
    with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
        bpe_data = f.read()
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
    return Encoder(
        encoder=encoder,
        bpe_merges=bpe_merges,
    )

# Model
Updated from https://github.com/openai/gpt-2 for tensorflow 2.0 and be more readable using https://github.com/graykode/gpt-2-Pytorch and the following tutorial https://www.tensorflow.org/alpha/tutorials/text/transformer

In [None]:
# Hyperparameters for the 345M model
num_vocab = 50257
num_ctx = 1024
num_embd = 1024
num_heads = 16
num_layers = 24
num_state = num_embd # used for the MLP's dense layers

In [None]:
# Gaussian Error Linear Unit (GELU) from https://arxiv.org/abs/1606.08415
def gelu(x):
    return 0.5 * x * (1 + tf.tanh(np.sqrt(2 / np.pi) * (x+0.044715 * tf.pow(x, 3))))

In [None]:
def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.
  
  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.
    
  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
  
  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_v, depth_v)

  return output, attention_weights

In [None]:
class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, num_state, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = num_state
    
    assert num_state % self.num_heads == 0
    
    self.depth = num_state // self.num_heads
    
    self.wq = tf.keras.layers.Dense(num_state)
    self.wk = tf.keras.layers.Dense(num_state)
    self.wv = tf.keras.layers.Dense(num_state)
    
    self.dense = tf.keras.layers.Dense(num_state)
        
  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])
    
  def call(self, v, k, q, mask):
    batch_size = tf.shape(q)[0]
    
    
    q = self.wq(q)  # (batch_size, seq_len, num_state)
    k = self.wk(k)  # (batch_size, seq_len, num_state)
    v = self.wv(v)  # (batch_size, seq_len, num_state)
    
    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
    
    present = tf.stack([k, v], axis=1)
    # scaled_attention.shape == (batch_size, num_heads, seq_len_v, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)
    
    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_v, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_v, d_model)

    output = self.dense(concat_attention)  # (batch_size, seq_len_v, d_model)
        
    return output, present
  
# Testing data is flowing correctly
temp_mha = MultiHeadAttention(num_state, num_heads)
y = tf.random.uniform((1, 60, num_state))  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, k=y, q=y, mask=None)
out.shape, attn.shape

In [None]:
class MLP(tf.keras.layers.Layer):
  def __init__(self, num_state, num_embd):
    super(MLP, self).__init__()
    self.dense_1 = tf.keras.layers.Dense(num_embd, activation=gelu)
    self.dense_2 = tf.keras.layers.Dense(num_state)
    
  def call(self, x):
    x = self.dense_1(x)
    x = self.dense_2(x)
    
    return x

# Testing data is flowing correctly
sample_mlp_layer = MLP(4 * num_state, num_embd)
sample_mlp_layer_out = sample_mlp_layer(tf.random.uniform((64, 40, num_state)))
sample_mlp_layer_out.shape  # (batch_size, input_seq_len, d_model)

In [None]:
class DecoderBlock(tf.keras.layers.Layer):
  def __init__(self, num_state, num_heads, num_embd):
    super(DecoderBlock, self).__init__()
    
    # Block layout is as follows
    self.layer_norm_1 = tf.keras.layers.experimental.LayerNormalization(epsilon=1e-5)
    self.attn_layer = MultiHeadAttention(num_state, num_heads)
    self.layer_norm_2 = tf.keras.layers.experimental.LayerNormalization(epsilon=1e-5)
    self.mlp = MLP(4 * num_state, num_embd)
    
  def call(self, x, past=None):
    x = self.layer_norm_1(x) # Layer Norm
    attn, present = self.attn_layer(x, x, x, mask=None) # Attend
    x = x + attn # Residual Connection
    x = self.layer_norm_2(x) # Layer Norm
    m = self.mlp(x) # Feedforward
    x = x + m # Residual Connection
    
    return x, present
  
# Testing data is flowing correctly
sample_decoder_block = DecoderBlock(num_state, num_heads, num_embd)
sample_decoder_block_output, _ = sample_decoder_block(tf.random.uniform((64, 50, num_state)))
sample_decoder_block_output.shape  # (batch_size, target_seq_len, d_model)

In [None]:
class GPT(tf.keras.layers.Model):
  def __init__(self):
    self.wrd_tkn_embd = tf.keras.layers.Embedding(input_vocab_size, dim_model)
    self.wrd_pos_enc = tf.keras.layers.Embedding(input_vocab_size, dim_model)
    
    self.dec_blks = [DecoderBlock(d_model, num_heads, dff) for _ in range(num_layers)]
    self.layer_norm = tf.keras.layers.experimental.LayerNormalization(epsilon=1e-5)
    self.flatn = tf.keras.layers.Flatten()
    pass
  
  def call(self, x, pos_x, pasts=None): # Past = enc_output since this is just a decoder, not encoder -> decoder architecture
    if pasts is None:
      pasts_length = 0
      pasts = [None] * n_layer
    else:
      pasts_length = pasts.shape[-2]
      pasts = tf.unstack(pasts, axis=1)
      
    x = self.wrd_tk_embd(x)
    x = x + self.wrd_pos_enc(pos_x)
    
    presents = []
    for block, past in zip(self.dec_blks, pasts):
      x, present = block(x, past)
      presents.append(present)
    
    results['present'] = tf.stack(presents, axis=1)
    x = self.layer_norm(x)
    x = self.flatn(x)
    
    return x

# Reloading Checkpoint

In [None]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')