## Translating NanoGPT (GPT2) to TensorFlow

#### Based on https://github.com/karpathy/nanoGPT/blob/master/model.py

In [17]:
import tensorflow as tf
import numpy as np
from  dataclasses import dataclass
from tensorflow.experimental import numpy as tnp
import tensorflow_probability as tfp

## The Model

In [2]:
class MyLayerNorm(tf.keras.layers.Layer):
    
    def __init__(self, bias=True, eps=1e-6, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.bias = bias
        
    def build(self, input_shape):  
        self.weight = self.add_weight(name='weight',
                                      shape=input_shape[-1:], # [-1:] gives last elem but keeps dims
                                      initializer=tf.keras.initializers.Ones(),
                                      trainable=True)

        self.bias = self.add_weight(name='bias',
                                      shape=input_shape[-1:], # [-1:] gives last elem but keeps dims
                                      initializer=tf.keras.initializers.Zeros(),
                                      trainable=True) if self.bias else None

        super(MyLayerNorm, self).build(input_shape)
    
    def call(self, x):
        # Can also use tf.nn.moments(inputs, axes=-1, keepdims=True), 
        # but then additionally one needs to take the sqrt to get \sigma
        mean = tf.keras.backend.mean(x, axis=-1, keepdims=True)
        std = tf.keras.backend.std(x, axis=-1, keepdims=True)
        
        return self.weight * (x - mean) / (std + self.eps) + self.bias

In [3]:
@dataclass
class GPTConfig:
    block_size: int = 8 # 1024 for GPT2
    vocab_size: int = 20 # 50304 for GPT2
    n_layer: int = 2 # 12
    n_head: int = 2 # 12
    n_embd: int = 10 # 768
    dropout: float = 0.0
    bias: bool = True
    seed: int = 1337

In [4]:
class CausalSelfAttention(tf.keras.layers.Layer):
    
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0, "Embedding dimension must divide number of heads"
        # key, query, value computed at once and splitted later
        self.initializer_proj = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02 / tf.math.sqrt(2. * config.n_layer), seed=None)
        self.c_attn = tf.keras.layers.Dense(#config.n_embd,
                                            3 * config.n_embd,
                                            activation=None,
                                            use_bias=config.bias)
        # output projection
        self.c_proj = tf.keras.layers.Dense(#config.n_embd,
                                            config.n_embd,
                                            activation=None,
                                            kernel_initializer=self.initializer_proj,
                                            use_bias=config.bias)
        self.dropout = config.dropout
        self.attn_dropout = tf.keras.layers.Dropout(self.dropout)
        self.resid_dropout = tf.keras.layers.Dropout(self.dropout)

        self.mask = tf.experimental.numpy.tril(
            tf.ones([config.block_size, config.block_size]))[tf.newaxis, tf.newaxis, :, :]

    def forward(self, x):
        
        B, T, C = x.size() # batch, sequence and channel, which is the embedding dim

        q, k, v = self.c_attn(x).split(self.n_embd, axis=2)
        k = tf.transpose(tf.reshape(k, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])
        q = tf.transpose(tf.reshape(q, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])
        v = tf.transpose(tf.reshape(v, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])

        att = (q @ tf.transpose(k, perm=[0, 1, 3, 2])) * (1.0 / tf.math.sqrt(k.shape[-1]))

        mask = tf.experimental.numpy.tril(tf.ones([T, T]))[tf.newaxis, tf.newaxis, :, :]
        att = tf.where(mask != 0, att, tf.constant(-np.inf))
        att = tf.nn.softmax(att, axis = 3)
        att = self.attn_dropout(att)
        y = att @ v

        y = tf.reshape(tf.transpose(y, perm=[0, 2, 1, 3]), [B, T, C])

        return self.resid_dropout(self.c_proj(y))

In [5]:
class MLP(tf.keras.layers.Layer):

    def __init__(self, config):
        super().__init__()
        self.initializer_proj = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02 / tf.math.sqrt(2. * config.n_layer), seed=None)
        # Streching and shrinking in channel/embedding dimension,
        # like for large resnets
        self.c_fc = tf.keras.layers.Dense(4 * config.n_embd, activation=None, use_bias=config.bias)
        self.c_proj = tf.keras.layers.Dense(config.n_embd, activation=None, kernel_initializer=self.initializer_proj, use_bias=config.bias)
        self.gelu = tf.keras.activations.gelu
        self.dropout = tf.keras.layers.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return self.dropout(x)
        

In [6]:
class Block(tf.keras.layers.Layer):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = MyLayerNorm(bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = MyLayerNorm(bias=config.bias)
        self.mlp = MLP(config)
    
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        return x + self.mlp(self.ln_2(x))

In [83]:
class GPT(tf.keras.models.Model):

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.initializer_dense = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02, seed=self.config.seed)
        self.initializer_embed = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02, seed=self.config.seed)
        self.initializer_bias = tf.keras.initializers.Zeros()
        
        self.wte = tf.keras.layers.Embedding(self.config.vocab_size, self.config.n_embd, 
                                             embeddings_initializer=self.initializer_embed, name='wte')
        self.wpe = tf.keras.layers.Embedding(self.config.vocab_size, self.config.n_embd, 
                                             embeddings_initializer=self.initializer_embed, name='wpe')
        self.drop = tf.keras.layers.Dropout(self.config.dropout, name='drop')
        self.h = [Block(self.config) for _ in range(self.config.n_layer)]
        self.ln_f = MyLayerNorm(bias=self.config.bias, name='ln_f')

        self.lm_head = tf.keras.layers.Dense(self.config.vocab_size, use_bias=self.config.bias)

    def build(self, input_shape):
        self.wte.build(input_shape=[self.config.vocab_size])
        self.lm_head.build(input_shape=[self.config.n_embd])
        self.wte.trainable_weights[0].assign(tf.transpose(self.lm_head.trainable_weights[0]))
        
    def call(self, idx, targets=None):
        b, t = idx.shape
        assert t <= self.config.block_size, f'sequence too long for the defined context of {self.config.block_size}'
        pos = tf.range(0, t, dtype=tf.int64)

        tok_emb = self.wte(idx)
        pos_emb = self.wpe(pos)
        x = self.drop(tok_emb + pos_emb)
        for block in self.h:
            x = block(x)
        x = self.ln_f(x)

        if targets is not None:
            ce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
            logits = self.lm_head(x)
            
            #print(tf.reshape(targets, [-1]).shape)
            #print(tf.reshape(logits, [-1, logits.shape[-1]]).shape)
            
            loss = ce(tf.reshape(targets, [-1]),
                      tf.reshape(logits, [-1, logits.shape[-1]]))
        else:
            logits = self.lm_head(x[:, -1, :])[:, tf.newaxis, :]
            loss = None

        return logits, loss

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        N = self.count_params()
        L, H, Q, T = self.config.n_layer, self.config.n_head, self.config.n_embd / self.config.n_head, self.config.block_size
        flops_per_token = 6 * N + 12 * L * H * Q * T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        flops_achieved = flops_per_iter * (1.0/dt)
        flops_promised = 312e12 # A100 at bfloat16
        mfu = flops_achieved / flops_promised
        return mfu
    
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for i in range(max_new_tokens):
            idx_cond = idx if idx.shape[1] <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = tf.math.top_k(logits, min(top_k, logits.shape[-1]))
                # Returned top_k values are sorted, so [-1] is the smallest of top_k
                # and we cut all below that value
                logits[logits < v[:, -1]] = -float('Inf')
            probs = tf.keras.activations.softmax(logits, axis=-1)
            idx_dist = tfp.distributions.Multinomial(total_count=1, probs=probs)
            idx_next = idx_dist.sample(1)
            idx_next = tf.reshape(tf.math.argmax(idx_next, axis=-1), shape=(-1, 1))
            print(idx_next)
            idx = tf.concat([idx, idx_next], axis=1)
        return idx

In [24]:
cfg = GPTConfig()

In [25]:
cfg

GPTConfig(block_size=8, vocab_size=20, n_layer=2, n_head=2, n_embd=10, dropout=0.0, bias=True, seed=1337)

In [None]:
txt = tf.constant(np.random.randint(0, 9, size=[2, 8]), dtype=tf.int64)

In [12]:
txt

<tf.Tensor: shape=(2, 8), dtype=int64, numpy=
array([[2, 3, 4, 6, 5, 8, 7, 6],
       [0, 0, 2, 7, 8, 1, 2, 6]])>

In [84]:
gpt = GPT(cfg)

In [86]:
gpt(txt, txt)

(<tf.Tensor: shape=(2, 8, 20), dtype=float32, numpy=
 array([[[-1.1011575 ,  0.01307986,  2.3074875 , -0.5072701 ,
          -0.23140283, -0.7444831 , -1.1220703 ,  0.88971376,
           0.4679233 ,  1.2061884 ,  0.81449175, -0.58697015,
           0.4837832 ,  0.5646569 ,  0.61891294, -0.6671307 ,
          -2.0079443 , -1.8840489 ,  1.2537247 , -1.7491484 ],
         [ 0.14379741, -1.1388445 , -0.33680928,  2.6758158 ,
          -0.10954601,  1.1145775 , -0.88609207, -0.6991405 ,
          -1.0007828 , -0.6171318 , -1.852392  , -0.83712107,
           0.9389086 , -1.7580435 ,  0.7752962 ,  0.8872563 ,
           0.6011041 ,  0.49998727, -1.1544275 ,  0.46716666],
         [ 0.3890845 , -1.0326089 , -0.2573085 , -0.17484933,
           2.423233  , -1.1124059 , -0.889344  , -0.42216268,
           0.28954422,  0.63077796,  0.92211014,  0.7719299 ,
           0.37363315, -0.7288001 , -0.67543805,  0.36637235,
           0.66849226, -0.3997035 , -0.22817367,  0.75288165],
         [ 0.2

In [74]:
gpt.generate(txt, 3)

tf.Tensor(
[[6]
 [6]], shape=(2, 1), dtype=int64)
tf.Tensor(
[[ 6]
 [10]], shape=(2, 1), dtype=int64)
tf.Tensor(
[[6]
 [0]], shape=(2, 1), dtype=int64)


<tf.Tensor: shape=(2, 11), dtype=int64, numpy=
array([[ 2,  3,  4,  6,  5,  8,  7,  6,  6,  6,  6],
       [ 0,  0,  2,  7,  8,  1,  2,  6,  6, 10,  0]])>

In [78]:
gpt.summary()

Model: "gpt_13"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 wte (Embedding)             multiple                  200       
                                                                 
 wpe (Embedding)             multiple                  200       
                                                                 
 drop (Dropout)              multiple                  0         
                                                                 
 block_26 (Block)            multiple                  0         
                                                                 
 block_27 (Block)            multiple                  0         
                                                                 
 ln_f (MyLayerNorm)          multiple                  20        
                                                                 
 dense_125 (Dense)           multiple                  220  

In [79]:
gpt.count_params()

640

In [87]:
gpt.estimate_mfu(2,1)

2.9538461538461536e-10

## Training

In [14]:
shakespear_url = "https://homl.info/shakespeare"
filepath = tf.keras.utils.get_file('shakespear.txt', shakespear_url)

In [15]:
with open(filepath, 'r', encoding='utf-8') as f:
    shakespear_txt = f.read()

In [16]:
print(shakespear_txt[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [17]:
text_vec_layer = tf.keras.layers.TextVectorization(split='character',
                                                  standardize='lower')

In [18]:
text_vec_layer.adapt([shakespear_txt])

In [19]:
text_vec_layer.get_vocabulary()[:10]

['', '[UNK]', ' ', 'e', 't', 'o', 'a', 'i', 'h', 's']

In [20]:
encoded = text_vec_layer([shakespear_txt])[0]

In [21]:
encoded

<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([21,  7, 10, ..., 22, 28, 12])>

In [22]:
# Removing code 0 and 1 reserved for padding and unknown characters 
# (codes start at 2 before that removal so now 0 and 1 will be some chars)
encoded -= 2

n_tokens = text_vec_layer.vocabulary_size() - 2
dataset_size = len(encoded)

In [23]:
n_tokens

39

In [24]:
ds = tf.data.Dataset.from_tensor_slices(encoded)

In [25]:
def to_dataset(sequence, length, shuffle=False, seed=None, batch_size=32):
    ds = tf.data.Dataset.from_tensor_slices(sequence)
    ds = ds.window(length + 1, shift=1, drop_remainder=True)
    ds = ds.flat_map(lambda window_ds: window_ds.batch(length+1))
    if shuffle:
        ds = ds.shuffle(buffer_size=100_000, seed=seed)
    ds = ds.batch(batch_size)
    return ds.map(lambda window: (window[:, :-1], window[:, 1:])).prefetch(1)

In [26]:
length = 8
tf.random.set_seed(42)

In [27]:
train_set = to_dataset(encoded[:1_060_000], length=length, shuffle=True, seed=1337)
valid_set = to_dataset(encoded[1_060_000:], length=length)

In [28]:
for sample in train_set.rebatch(1).take(1):
    print(sample[0])
    print(sample[1])

tf.Tensor([[ 3 14  1 26 10 10 25  5]], shape=(1, 8), dtype=int64)
tf.Tensor([[14  1 26 10 10 25  5  8]], shape=(1, 8), dtype=int64)


In [29]:
n_layer = 2
n_head = 2 
n_embd = 32
block_size = length
bias = True
vocab_size = n_tokens
dropout = tf.constant(0.0, dtype=tf.float32)

iter_num = 0
best_val_loss = 1e9
max_iters = 100

init_from = 'scratch'

weight_decay = tf.constant(1e-1, dtype=tf.float32)

warmup_iters = 10
learning_rate = tf.constant(6e-4, dtype=tf.float32)
lr_decay_iters = 100 # == max_iters
min_lr = tf.constant(6e-5, dtype=tf.float32)

eval_iters = 20
eval_interval = 10

grad_clip = tf.constant(1.0, dtype=tf.float32)

In [30]:
model_args = dict(n_layer=n_layer, 
                  n_head=n_head, 
                  n_embd=n_embd,
                  block_size=block_size,
                  bias=bias,
                  vocab_size=vocab_size,
                  dropout=dropout)

In [32]:
model_args

{'n_layer': 2,
 'n_head': 2,
 'n_embd': 32,
 'block_size': 8,
 'bias': True,
 'vocab_size': 39,
 'dropout': <tf.Tensor: shape=(), dtype=float32, numpy=0.0>}

In [33]:
if init_from == 'scratch':
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)

In [34]:
for X, _ in train_set.take(1):
    out = model(X)

In [35]:
out

(<tf.Tensor: shape=(32, 1, 39), dtype=float32, numpy=
 array([[[ 5.2112837 ,  2.0908573 , -1.7094367 , ..., -1.2486523 ,
           0.69483757, -0.00625577]],
 
        [[-0.26086673, -0.7746036 , -0.5939345 , ...,  0.542379  ,
           0.046754  , -0.31683245]],
 
        [[ 5.2112837 ,  2.0908573 , -1.7094367 , ..., -1.2486523 ,
           0.69483757, -0.00625577]],
 
        ...,
 
        [[-0.26086673, -0.7746036 , -0.5939345 , ...,  0.542379  ,
           0.046754  , -0.31683245]],
 
        [[-0.26086673, -0.7746036 , -0.5939345 , ...,  0.542379  ,
           0.046754  , -0.31683245]],
 
        [[ 5.2112837 ,  2.0908573 , -1.7094367 , ..., -1.2486523 ,
           0.69483757, -0.00625577]]], dtype=float32)>,
 None)

In [36]:
@dataclass
class MyLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    learning_rate: tf.float32 = learning_rate
    warmup_iters: int = warmup_iters
    min_lr: tf.float32 = min_lr 
    lr_decay_iters: int= lr_decay_iters
    
    def __call__(self, step):
      if step < self.warmup_iters:
          return self.learning_rate * float(step) / self.warmup_iters
      if step > self.lr_decay_iters:
          return self.min_lr
      decay_ratio = (float(step) - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters)
      assert 0 <= decay_ratio <= 1
      coeff = 0.5 * (1.0 + tf.math.cos(tnp.pi * decay_ratio))
      return self.min_lr + coeff * tf.cast((self.learning_rate - self.min_lr), dtype=tf.float32)

In [37]:
optimizer = tf.keras.optimizers.AdamW(learning_rate=MyLRSchedule(learning_rate, warmup_iters, min_lr, lr_decay_iters))

In [38]:
def estimate_loss():
    out = {'train': None, 'val': None}
    
    losses = tf.zeros(eval_iters)
    k = tf.Variable(0)
    for X, Y in train_set.take(eval_iters):
        logits, loss = model(X, Y, training=False)
        tf.tensor_scatter_nd_update(losses, [[k]], [loss])
        k.assign_add(1)
    out['train'] = tf.reduce_mean(loss)

    losses = tf.zeros(eval_iters)
    k = tf.Variable(0)
    for X, Y in valid_set.take(eval_iters):
        logits, loss = model(X, Y, training=False)
        tf.tensor_scatter_nd_update(losses, [[k]], [loss])
        k.assign_add(1)
    out['val'] = tf.reduce_mean(loss)

    return out

In [39]:
estimate_loss()

{'train': <tf.Tensor: shape=(), dtype=float32, numpy=5.441515>,
 'val': <tf.Tensor: shape=(), dtype=float32, numpy=5.13965>}

In [40]:
eval_only = False
eval_interval=1
iter_num = 0

In [42]:
while True:
    if iter_num % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    if iter_num == 0 and eval_only:
        break

    for X, Y in train_set.take(1):
        with tf.GradientTape() as tape:
            logits, main_loss = model(X, Y, training=True)
            loss = tf.add_n([main_loss] + model.losses)
        gradients = tape.gradient(loss, model.trainable_variables)
        gradients = [tf.clip_by_value(g, 
                                      clip_value_min=-grad_clip, 
                                      clip_value_max=grad_clip) for g in gradients]
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))    
    
    iter_num += 1
    if iter_num > 20:
        break


step 5: train loss 5.1932, val loss 5.1211
step 6: train loss 5.2641, val loss 5.1120
step 7: train loss 5.3812, val loss 5.1011
step 8: train loss 5.3169, val loss 5.0886
step 9: train loss 5.3550, val loss 5.0741
step 10: train loss 5.2777, val loss 5.0579
step 11: train loss 5.2650, val loss 5.0400
step 12: train loss 5.2888, val loss 5.0226
step 13: train loss 5.2873, val loss 5.0054
step 14: train loss 5.1705, val loss 4.9881
step 15: train loss 5.1544, val loss 4.9709
step 16: train loss 5.1350, val loss 4.9537
step 17: train loss 5.1135, val loss 4.9369
step 18: train loss 5.0683, val loss 4.9201
step 19: train loss 5.0741, val loss 4.9034
step 20: train loss 5.0603, val loss 4.8868
