## Translating NanoGPT (GPT2) to TensorFlow

#### Based on https://github.com/karpathy/nanoGPT/blob/master/model.py

In [None]:
import tensorflow as tf
import numpy as np
from  dataclasses import dataclass
from tensorflow.experimental import numpy as tnp
import tensorflow_probability as tfp
import os

## Distribute Strategy

In [None]:
strategy = tf.distribute.MirroredStrategy()

## The Model

In [3]:
class MyLayerNorm(tf.keras.layers.Layer):
    def __init__(self, bias=True, eps=1e-6, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.bias = bias
       
    def build(self, input_shape):  
        self.weight = self.add_weight(name='weight',
                                      shape=input_shape[-1:], # [-1:] gives last elem but keeps dims
                                      initializer=tf.keras.initializers.Ones(),
                                      trainable=True)

        self.bias = self.add_weight(name='bias',
                                      shape=input_shape[-1:], # [-1:] gives last elem but keeps dims
                                      initializer=tf.keras.initializers.Zeros(),
                                      trainable=True) if self.bias else None

        super(MyLayerNorm, self).build(input_shape)
    @tf.function
    def call(self, x):
        # Can also use tf.nn.moments(inputs, axes=-1, keepdims=True), 
        # but then additionally one needs to take the sqrt to get \sigma
        mean = tf.keras.backend.mean(x, axis=-1, keepdims=True)
        std = tf.keras.backend.std(x, axis=-1, keepdims=True)
        
        return self.weight * (x - mean) / (std + self.eps) + self.bias

In [4]:
class GPTConfig():
    def __init__(self, 
                 block_size:int=8, 
                 vocab_size:int=39,
                 n_layer:int=2,
                 n_head:int=2,
                 n_embd:int=10,
                 dropout:float=0.0,
                 bias:bool=False,
                 seed:int=1337):

        self.block_size=block_size
        self.vocab_size=vocab_size
        self.n_layer=n_layer
        self.n_head=n_head
        self.n_embd=n_embd
        self.dropout=dropout
        self.bias=bias
        self.seed=seed

In [5]:
class CausalSelfAttention(tf.keras.layers.Layer):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0, "Embedding dimension must divide number of heads"
        # key, query, value computed at once and splitted later
        self.initializer_proj = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02 / tf.math.sqrt(2. * config.n_layer), seed=None)
        self.c_attn = tf.keras.layers.Dense(#config.n_embd,
                                            3 * config.n_embd,
                                            activation=None,
                                            use_bias=config.bias)
        # output projection
        self.c_proj = tf.keras.layers.Dense(#config.n_embd,
                                            config.n_embd,
                                            activation=None,
                                            kernel_initializer=self.initializer_proj,
                                            use_bias=config.bias)
        self.dropout = config.dropout
        self.attn_dropout = tf.keras.layers.Dropout(self.dropout)
        self.resid_dropout = tf.keras.layers.Dropout(self.dropout)

        self.mask = tf.experimental.numpy.tril(
            tf.ones([config.block_size, config.block_size]))[tf.newaxis, tf.newaxis, :, :]
    
    @tf.function
    def forward(self, x):
        
        B, T, C = x.size() # batch, sequence and channel, which is the embedding dim

        q, k, v = self.c_attn(x).split(self.n_embd, axis=2)
        k = tf.transpose(tf.reshape(k, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])
        q = tf.transpose(tf.reshape(q, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])
        v = tf.transpose(tf.reshape(v, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])

        att = (q @ tf.transpose(k, perm=[0, 1, 3, 2])) * (1.0 / tf.math.sqrt(k.shape[-1]))

        mask = tf.experimental.numpy.tril(tf.ones([T, T]))[tf.newaxis, tf.newaxis, :, :]
        att = tf.where(mask != 0, att, tf.constant(-np.inf))
        att = tf.nn.softmax(att, axis = 3)
        att = self.attn_dropout(att)
        y = att @ v

        y = tf.reshape(tf.transpose(y, perm=[0, 2, 1, 3]), [B, T, C])

        return self.resid_dropout(self.c_proj(y))

In [6]:
class MLP(tf.keras.layers.Layer):
    def __init__(self, config):
        super().__init__()
        self.initializer_proj = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02 / tf.math.sqrt(2. * config.n_layer), seed=None)
        # Streching and shrinking in channel/embedding dimension,
        # like for large resnets
        self.c_fc = tf.keras.layers.Dense(4 * config.n_embd, activation=None, use_bias=config.bias)
        self.c_proj = tf.keras.layers.Dense(config.n_embd, activation=None, kernel_initializer=self.initializer_proj, use_bias=config.bias)
        self.gelu = tf.keras.activations.gelu
        self.dropout = tf.keras.layers.Dropout(config.dropout)
    
    @tf.function
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return self.dropout(x)
        

In [7]:
class Block(tf.keras.layers.Layer):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = MyLayerNorm(bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = MyLayerNorm(bias=config.bias)
        self.mlp = MLP(config)
    
    @tf.function
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        return x + self.mlp(self.ln_2(x))

In [8]:
class GPT(tf.keras.models.Model):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.initializer_dense = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02, seed=self.config.seed)
        self.initializer_embed = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02, seed=self.config.seed)
        self.initializer_bias = tf.keras.initializers.Zeros()
        
        self.wte = tf.keras.layers.Embedding(self.config.vocab_size, 
                                             self.config.n_embd, 
                                             embeddings_initializer=self.initializer_embed, 
                                             name='wte')
        
        self.wpe = tf.keras.layers.Embedding(self.config.block_size, 
                                             self.config.n_embd, 
                                             embeddings_initializer=self.initializer_embed, 
                                             name='wpe')
        
        self.drop = tf.keras.layers.Dropout(self.config.dropout, name='drop')
        
        self.h = [Block(self.config) for _ in range(self.config.n_layer)]
        
        self.ln_f = MyLayerNorm(bias=self.config.bias, name='ln_f')

        self.lm_head = tf.keras.layers.Dense(self.config.vocab_size, use_bias=self.config.bias)

    def build(self, input_shape):
        self.wte.build(input_shape=[self.config.vocab_size])
        self.lm_head.build(input_shape=[self.config.n_embd])
        self.wte.trainable_weights[0].assign(tf.transpose(self.lm_head.trainable_weights[0]))
    
    @tf.function    
    def call(self, idx):
        b, t = idx.shape
        #assert t <= self.config.block_size, f'sequence too long for the defined context of {self.config.block_size}'
        pos = tf.range(0, t, dtype=tf.int64)

        tok_emb = self.wte(idx)
        pos_emb = self.wpe(pos)
        x = self.drop(tok_emb + pos_emb)
        for block in self.h:
            x = block(x)
        x = self.ln_f(x)
        return self.lm_head(x)

    def crop_block_size(self, block_size):
        assert block_size < self.config.block_size
        self.config.block_size = block_size
        self.wpe.weights[0] = tf.Variable(self.wpe.weights[0][:block_size], trainable=True)
        for block in self.h:
            #if hasattr(block.attn, 'bias'):
            if len(block.attn.weights) == 2:
                block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
                  
    def get_num_params(self, non_embedding=True):
        n_params = self.count_params()
        if non_embedding:
            n_params -= self.wpe.count_params()
        return n_params
        
    def estimate_mfu(self, fwdbwd_per_iter, dt):
        N = self.get_num_params()
        L, H, Q, T = self.config.n_layer, self.config.n_head, self.config.n_embd / self.config.n_head, self.config.block_size
        flops_per_token = 6 * N + 12 * L * H * Q * T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        flops_achieved = flops_per_iter * (1.0/dt)
        flops_promised = 312e12 # A100 at bfloat16
        mfu = flops_achieved / flops_promised
        return mfu
        
    @tf.function
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for i in range(max_new_tokens):
            idx_cond = idx if idx.shape[1] <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = tf.math.top_k(logits, min(top_k, logits.shape[-1]))
                # Returned top_k values are sorted, so [-1] is the smallest of top_k
                # and we cut all below that value
                logits[logits < v[:, -1]] = -float('Inf')
            probs = tf.keras.activations.softmax(logits, axis=-1)
            idx_dist = tfp.distributions.Multinomial(total_count=1, probs=probs)
            idx_next = idx_dist.sample(1)
            idx_next = tf.reshape(tf.math.argmax(idx_next, axis=-1), shape=(-1, 1))
            idx = tf.concat([idx, idx_next], axis=1)
        return idx

## Auxilliary routines

In [10]:
class MyLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self,
                 learning_rate: tf.float32 = 6e-4,
                 warmup_iters: int = 10,
                 min_lr: tf.float32 = 6e-5, 
                 lr_decay_iters: int= 100):
    
        self.learning_rate = learning_rate
        self.warmup_iters = warmup_iters
        self.min_lr = min_lr
        self.lr_decay_iters = lr_decay_iters

    def warmup(self, step):
        def res():
            return self.learning_rate * tf.cast(step, tf.float32) / self.warmup_iters
        return res
        
    def late(self):
        return self.min_lr
    
    def middle(self, step):
        def res():
            decay_ratio = (tf.cast(step, tf.float32) - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters)
            #assert 0 <= decay_ratio <= 1
            coeff = 0.5 * (1.0 + tf.math.cos(tnp.pi * decay_ratio))
            return self.min_lr + coeff * tf.cast((self.learning_rate - self.min_lr), dtype=tf.float32)
        return res
        
    @tf.function
    def __call__(self, step):
        lr = tf.case([(tf.less(step, self.warmup_iters), self.warmup(step)),
                   (tf.greater(step, self.lr_decay_iters), self.late)],
                   default=self.middle(step), exclusive=True)
        return lr

In [11]:
def estimate_loss():
    out = {'train': None, 'val': None}
    
    losses = tf.zeros(eval_iters)
    k = tf.Variable(0, trainable=False)
    for X, Y in train_set.take(eval_iters):
        logits, loss = model(X, Y, training=False)
        tf.tensor_scatter_nd_update(losses, [[k]], [loss])
        k.assign_add(1)
    out['train'] = tf.reduce_mean(loss)

    losses = tf.zeros(eval_iters)
    k = tf.Variable(0, trainable=False)
    for X, Y in valid_set.take(eval_iters):
        logits, loss = model(X, Y, training=False)
        tf.tensor_scatter_nd_update(losses, [[k]], [loss])
        k.assign_add(1)
    out['val'] = tf.reduce_mean(loss)

    return out

## Dataset

In [12]:
shakespear_url = "https://homl.info/shakespeare"
filepath = tf.keras.utils.get_file('shakespear.txt', shakespear_url)

In [13]:
with open(filepath, 'r', encoding='utf-8') as f:
    shakespear_txt = f.read()

In [14]:
print(shakespear_txt[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [15]:
text_vec_layer = tf.keras.layers.TextVectorization(split='character',
                                                  standardize='lower')

In [16]:
text_vec_layer.adapt([shakespear_txt])

In [17]:
text_vec_layer.get_vocabulary()[:10]

['', '[UNK]', ' ', 'e', 't', 'o', 'a', 'i', 'h', 's']

In [18]:
encoded = text_vec_layer([shakespear_txt])[0]

In [19]:
encoded

<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([21,  7, 10, ..., 22, 28, 12])>

In [20]:
# Removing code 0 and 1 reserved for padding and unknown characters 
# (codes start at 2 before that removal so now 0 and 1 will be some chars)
encoded -= 2

n_tokens = text_vec_layer.vocabulary_size() - 2
dataset_size = len(encoded)

In [21]:
n_tokens

39

In [22]:
ds = tf.data.Dataset.from_tensor_slices(encoded)

In [23]:
def to_dataset(sequence, length, shuffle=False, seed=None, batch_size=128):
    ds = tf.data.Dataset.from_tensor_slices(sequence)
    ds = ds.window(length + 1, shift=1, drop_remainder=True)
    ds = ds.flat_map(lambda window_ds: window_ds.batch(length+1))
    if shuffle:
        ds = ds.shuffle(buffer_size=100_000, seed=seed)
    ds = ds.batch(batch_size)
    return ds.map(lambda window: (window[:, :-1], window[:, 1:])).prefetch(1)

In [24]:
length = 32
tf.random.set_seed(1337)

In [25]:
train_set = to_dataset(encoded[:1_060_000], length=length, shuffle=True, seed=1337)
valid_set = to_dataset(encoded[1_060_000:], length=length)

In [26]:
#for sample in train_set.rebatch(1).take(1):
#    print(sample[0])
#    print(sample[1])

In [27]:
train_set_dist = strategy.experimental_distribute_dataset(train_set)
valid_set_dist = strategy.experimental_distribute_dataset(valid_set)

## Training

In [28]:
#buffer_size = 1024
batch_per_replica = 64
global_batch_size = batch_per_replica * strategy.num_replicas_in_sync

n_layer = 4
n_head = 4
n_embd = 32
block_size = length
bias = True
vocab_size = n_tokens
#dropout = tf.constant(0.0, dtype=tf.float32)
dropout = 0.0

iter_num = 0
best_val_loss = 1e9
max_iters = 150

weight_decay = tf.constant(1e-1, dtype=tf.float32)

warmup_iters = 10
learning_rate = tf.constant(6e-4, dtype=tf.float32)
lr_decay_iters = 200 # == max_iters
min_lr = tf.constant(6e-5, dtype=tf.float32)

eval_iters = 20
eval_interval = 10

grad_clip = tf.constant(1.0, dtype=tf.float32)

eval_only = False
eval_interval=1
step = 0

always_save_checkpoint = True
restore = False

seed = 1337

In [30]:
model_args = dict(n_layer=n_layer, 
                  n_head=n_head, 
                  n_embd=n_embd,
                  block_size=block_size,
                  bias=bias,
                  vocab_size=vocab_size,
                  dropout=dropout,
                  seed=seed)

In [31]:
model_args

{'n_layer': 4,
 'n_head': 4,
 'n_embd': 32,
 'block_size': 32,
 'bias': True,
 'vocab_size': 39,
 'dropout': 0.0,
 'seed': 1337}

In [32]:
checkpoint_directory = "./checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

In [33]:
def build_model(model_config):
    return GPT(model_config)

In [34]:
gptconf = GPTConfig(**model_args)

In [35]:
with strategy.scope():
    
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True,
        reduction=tf.keras.losses.Reduction.NONE)
    
    def compute_loss(Y, logits, model_losses):
        per_example_loss = loss_object(tf.reshape(Y, [-1]),
                      tf.reshape(logits, [-1, logits.shape[-1]]))
        loss = tf.nn.compute_average_loss(per_example_loss)
        if model_losses:
            loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))
        return loss

    val_loss = tf.keras.metrics.Mean(name='test_loss')

    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')
    val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='val_accuracy')

    model = build_model(gptconf)

    optimizer = tf.keras.optimizers.AdamW(learning_rate=MyLRSchedule(learning_rate, warmup_iters, min_lr, lr_decay_iters))
    #optimizer = tf.keras.optimizers.AdamW(learning_rate=0.0001)

    ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                           optimizer=optimizer, 
                           model=model,
                           model_args=model_args,
                           best_val_loss=tf.Variable(best_val_loss))
                           #config=GPTConfig)



In [36]:
#manager = tf.train.CheckpointManager(ckpt, checkpoint_directory, max_to_keep=3)

#if restore:
#    ckpt.restore(manager.latest_checkpoint)
#    step = int(ckpt.step)
#    best_val_loss = float(ckpt.best_val_loss)
#    model = ckpt.model
#    optimizer = ckpt.optimizer
#    model_args=ckpt.model_args

In [37]:
@tf.function
def train_step(sample):
    X, Y = sample

    with tf.GradientTape() as tape:
        logits = model(X, training=True)
        loss = compute_loss(Y, logits, model.losses)
    
    gradients = tape.gradient(loss, model.trainable_variables)
    gradients = [tf.clip_by_value(g,
                                  clip_value_min=-grad_clip,
                                  clip_value_max=grad_clip) for g in gradients]
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_accuracy.update_state(tf.reshape(Y, [-1]),
                      tf.reshape(logits, [-1, logits.shape[-1]]))
    return loss
    

In [38]:
@tf.function
def val_step(sample):
    X, Y = sample
    logits = model(X, training=False)
    v_loss = loss_object(tf.reshape(Y, [-1]),
                      tf.reshape(logits, [-1, logits.shape[-1]]))
    val_loss.update_state(v_loss)
    val_accuracy(tf.reshape(Y, [-1]),
                      tf.reshape(logits, [-1, logits.shape[-1]]))

In [39]:
#@tf.function
def distributed_train_epoch(dataset):
    total_loss = 0.0
    num_batches = 0
    for sample in dataset:
        per_replica_losses = strategy.run(train_step, args=(sample,))
        total_loss += strategy.reduce(
            tf.distribute.ReduceOp.SUM,
            per_replica_losses,
            axis=None)
        num_batches += 1
    return total_loss / tf.cast(num_batches, dtype=tf.float32)

In [40]:
while True:

    train_loss = distributed_train_epoch(train_set_dist)

    out_format = ("Epoch {}, Loss: {}, Accuracy: {}")
    print(out_format.format(step, train_loss, train_accuracy.result() * 100 ))
    
    step += 1
    if step > max_iters:
        break

2024-03-22 23:47:50.828593: I external/local_xla/xla/service/service.cc:168] XLA service 0x7b43b2a7e480 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-03-22 23:47:50.828616: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 3060 Laptop GPU, Compute Capability 8.6
2024-03-22 23:47:50.831907: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-03-22 23:47:50.851524: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8906
I0000 00:00:1711147670.894180  112435 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 0, Loss: 2.625222682952881, Accuracy: 23.015405654907227
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1, Loss: 2.4381868839263916, Accuracy: 24.725814819335938
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 2, Loss: 2.433342218399048,

KeyboardInterrupt: 

In [None]:
while True:
    if step % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # Checkpointing
    if losses['val'] < best_val_loss or always_save_checkpoint:
        best_val_loss = losses['val']
        if step > 0:
            print(f'Saving checkpoint to {checkpoint_prefix}')
            ckpt.step.assign_add(1)
            ckpt.best_val_loss.assign(best_val_loss)
            manager.save()
           
    if step == 0 and eval_only:
        break

    for X, Y in train_set:
        with tf.GradientTape() as tape:
            logits, main_loss = model(X, Y, training=True)
            loss = tf.add_n([main_loss] + model.losses)
        gradients = tape.gradient(loss, model.trainable_variables)
        gradients = [tf.clip_by_value(g, 
                                      clip_value_min=-grad_clip, 
                                      clip_value_max=grad_clip) for g in gradients]
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))    
    
    step += 1
    if step > max_iters:
        break
    

## Inference

In [31]:
def decode(gen_txt, text_vec_layer):
    out = []
    gen_txt = gen_txt + 2
    reverse = text_vec_layer.get_vocabulary()
    decoder = lambda x: str(reverse[x])
    gen_shape = gen.shape
    gen_flat = tf.reshape(gen_txt, -1)
    res = tf.map_fn(decoder, gen_flat, fn_output_signature='string')
    res = tf.reshape(res, gen_shape)
    res = res.numpy()
    for sentence in res:
        out.append(b"".join(sentence).decode())
    return out

In [32]:
shakespear_txt[:32]

'First Citizen:\nBefore we proceed'

In [33]:
txt = tf.reshape(text_vec_layer(shakespear_txt[:32]), shape=(-1, 32)) - 2

In [34]:
gen = model.generate(txt, 32)

In [35]:
decode(gen, text_vec_layer)

["first citizen:\nbefore we proceedi i'lanqu shatheipe tifadouteare"]