## Translating NanoGPT (GPT2) to TensorFlow

#### Based on https://github.com/karpathy/nanoGPT/blob/master/model.py

In [None]:
import tensorflow as tf
tf.keras.mixed_precision.set_global_policy('mixed_float16')
import numpy as np
from  dataclasses import dataclass
from tensorflow.experimental import numpy as tnp
import tensorflow_probability as tfp
import os

## Distribute Strategy

In [None]:
strategy = tf.distribute.MirroredStrategy()

## The Model

In [3]:
class MyLayerNorm(tf.keras.layers.Layer):
    def __init__(self, bias=True, eps=1e-6, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.bias = bias
       
    def build(self, input_shape):  
        self.weight = self.add_weight(name='weight',
                                      shape=input_shape[-1:], # [-1:] gives last elem but keeps dims
                                      initializer=tf.keras.initializers.Ones(),
                                      trainable=True)

        self.bias = self.add_weight(name='bias',
                                      shape=input_shape[-1:], # [-1:] gives last elem but keeps dims
                                      initializer=tf.keras.initializers.Zeros(),
                                      trainable=True) if self.bias else None

        super(MyLayerNorm, self).build(input_shape)
    @tf.function(jit_compile=True)
    def call(self, x):
        # Can also use tf.nn.moments(inputs, axes=-1, keepdims=True), 
        # but then additionally one needs to take the sqrt to get \sigma
        mean = tf.keras.backend.mean(x, axis=-1, keepdims=True)
        std = tf.keras.backend.std(x, axis=-1, keepdims=True)
        
        return self.weight * (x - mean) / (std + self.eps) + self.bias

In [4]:
class GPTConfig():
    def __init__(self, 
                 block_size:int=8, 
                 vocab_size:int=39,
                 n_layer:int=2,
                 n_head:int=2,
                 n_embd:int=10,
                 dropout:float=0.0,
                 bias:bool=False,
                 seed:int=1337):

        self.block_size=block_size
        self.vocab_size=vocab_size
        self.n_layer=n_layer
        self.n_head=n_head
        self.n_embd=n_embd
        self.dropout=dropout
        self.bias=bias
        self.seed=seed

In [5]:
class CausalSelfAttention(tf.keras.layers.Layer):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0, "Embedding dimension must divide number of heads"
        # key, query, value computed at once and splitted later
        self.initializer_proj = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02 / tf.math.sqrt(2. * config.n_layer), seed=None)
        self.c_attn = tf.keras.layers.Dense(#config.n_embd,
                                            3 * config.n_embd,
                                            activation=None,
                                            use_bias=config.bias)
        # output projection
        self.c_proj = tf.keras.layers.Dense(#config.n_embd,
                                            config.n_embd,
                                            activation=None,
                                            kernel_initializer=self.initializer_proj,
                                            use_bias=config.bias)
        self.dropout = config.dropout
        self.attn_dropout = tf.keras.layers.Dropout(self.dropout)
        self.resid_dropout = tf.keras.layers.Dropout(self.dropout)

        self.mask = tf.experimental.numpy.tril(
            tf.ones([config.block_size, config.block_size]))[tf.newaxis, tf.newaxis, :, :]
    
    @tf.function(jit_compile=True)
    def forward(self, x):
        
        B, T, C = x.size() # batch, sequence and channel, which is the embedding dim

        q, k, v = self.c_attn(x).split(self.n_embd, axis=2)
        k = tf.transpose(tf.reshape(k, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])
        q = tf.transpose(tf.reshape(q, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])
        v = tf.transpose(tf.reshape(v, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])

        att = (q @ tf.transpose(k, perm=[0, 1, 3, 2])) * (1.0 / tf.math.sqrt(k.shape[-1]))

        mask = tf.experimental.numpy.tril(tf.ones([T, T]))[tf.newaxis, tf.newaxis, :, :]
        att = tf.where(mask != 0, att, tf.constant(-np.inf))
        att = tf.nn.softmax(att, axis = 3)
        att = self.attn_dropout(att)
        y = att @ v

        y = tf.reshape(tf.transpose(y, perm=[0, 2, 1, 3]), [B, T, C])

        return self.resid_dropout(self.c_proj(y))

In [6]:
class MLP(tf.keras.layers.Layer):
    def __init__(self, config):
        super().__init__()
        self.initializer_proj = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02 / tf.math.sqrt(2. * config.n_layer), seed=None)
        # Streching and shrinking in channel/embedding dimension,
        # like for large resnets
        self.c_fc = tf.keras.layers.Dense(4 * config.n_embd, activation=None, use_bias=config.bias)
        self.c_proj = tf.keras.layers.Dense(config.n_embd, activation=None, kernel_initializer=self.initializer_proj, use_bias=config.bias)
        self.gelu = tf.keras.activations.gelu
        self.dropout = tf.keras.layers.Dropout(config.dropout)
    
    @tf.function(jit_compile=True)
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return self.dropout(x)
        

In [7]:
class Block(tf.keras.layers.Layer):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = MyLayerNorm(bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = MyLayerNorm(bias=config.bias)
        self.mlp = MLP(config)
    
    @tf.function(jit_compile=True)
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        return x + self.mlp(self.ln_2(x))

In [8]:
class GPT(tf.keras.models.Model):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.initializer_dense = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02, seed=self.config.seed)
        self.initializer_embed = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02, seed=self.config.seed)
        self.initializer_bias = tf.keras.initializers.Zeros()
        
        self.wte = tf.keras.layers.Embedding(self.config.vocab_size, 
                                             self.config.n_embd, 
                                             embeddings_initializer=self.initializer_embed, 
                                             name='wte')
        
        self.wpe = tf.keras.layers.Embedding(self.config.block_size, 
                                             self.config.n_embd, 
                                             embeddings_initializer=self.initializer_embed, 
                                             name='wpe')
        
        self.drop = tf.keras.layers.Dropout(self.config.dropout, name='drop')
        
        self.h = [Block(self.config) for _ in range(self.config.n_layer)]
        
        self.ln_f = MyLayerNorm(bias=self.config.bias, name='ln_f')
        
        # Crucial for mixed precision: final model output should be of dtype='float32'
        self.lm_head = tf.keras.layers.Dense(self.config.vocab_size, use_bias=self.config.bias, dtype='float32')

    def build(self, input_shape):
        self.wte.build(input_shape=[self.config.vocab_size])
        self.lm_head.build(input_shape=[self.config.n_embd])
        self.wte.trainable_weights[0].assign(tf.transpose(self.lm_head.trainable_weights[0]))
    
    @tf.function(jit_compile=True)
    def call(self, idx):
        b, t = idx.shape
        #assert t <= self.config.block_size, f'sequence too long for the defined context of {self.config.block_size}'
        pos = tf.range(0, t, dtype=tf.int64)

        tok_emb = self.wte(idx)
        pos_emb = self.wpe(pos)
        x = self.drop(tok_emb + pos_emb)
        for block in self.h:
            x = block(x)
        x = self.ln_f(x)
        return self.lm_head(x)

    def crop_block_size(self, block_size):
        assert block_size < self.config.block_size
        self.config.block_size = block_size
        self.wpe.weights[0] = tf.Variable(self.wpe.weights[0][:block_size], trainable=True)
        for block in self.h:
            #if hasattr(block.attn, 'bias'):
            if len(block.attn.weights) == 2:
                block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
                  
    def get_num_params(self, non_embedding=True):
        n_params = self.count_params()
        if non_embedding:
            n_params -= self.wpe.count_params()
        return n_params
        
    def estimate_mfu(self, fwdbwd_per_iter, dt):
        N = self.get_num_params()
        L, H, Q, T = self.config.n_layer, self.config.n_head, self.config.n_embd / self.config.n_head, self.config.block_size
        flops_per_token = 6 * N + 12 * L * H * Q * T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        flops_achieved = flops_per_iter * (1.0/dt)
        flops_promised = 312e12 # A100 at bfloat16
        mfu = flops_achieved / flops_promised
        return mfu
        
    @tf.function(jit_compile=True)
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for i in range(max_new_tokens):
            idx_cond = idx if idx.shape[1] <= self.config.block_size else idx[:, -self.config.block_size:]
            logits = self(idx_cond, training=False)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = tf.math.top_k(logits, min(top_k, logits.shape[-1]))
                # Returned top_k values are sorted, so [-1] is the smallest of top_k
                # and we cut all below that value
                logits[logits < v[:, -1]] = -float('Inf')
            probs = tf.keras.activations.softmax(logits, axis=-1)
            idx_dist = tfp.distributions.Multinomial(total_count=1, probs=probs)
            idx_next = idx_dist.sample(1)
            idx_next = tf.reshape(tf.math.argmax(idx_next, axis=-1), shape=(-1, 1))
            idx = tf.concat([idx, idx_next], axis=1)
        return idx

## Auxilliary routines

In [9]:
class MyLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self,
                 learning_rate: float = 6e-4,
                 warmup_iters: int = 10,                
                 min_lr: float = 6e-5,
                 lr_decay_iters: int= 100):
    
        self.learning_rate = learning_rate
        self.warmup_iters = warmup_iters
        self.min_lr = min_lr
        self.lr_decay_iters = lr_decay_iters

    def warmup(self, step):
        def res():            
            return self.learning_rate * float(step) / self.warmup_iters
        return res
        
    def late(self):
        return self.min_lr
    
    def middle(self, step):
        def res():
            decay_ratio = (float(step) - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters)
            #assert 0 <= decay_ratio <= 1
            coeff = 0.5 * (1.0 + tf.math.cos(tnp.pi * decay_ratio))        
            return self.min_lr + coeff * (self.learning_rate - self.min_lr)
        return res
        
    @tf.function(jit_compile=True)
    def __call__(self, step):
        lr = tf.case([(tf.less(step, self.warmup_iters), self.warmup(step)),
                   (tf.greater(step, self.lr_decay_iters), self.late)],
                   default=self.middle(step), exclusive=True)
        return lr

## Dataset

In [10]:
shakespear_url = "https://homl.info/shakespeare"
filepath = tf.keras.utils.get_file('shakespear.txt', shakespear_url)

In [11]:
with open(filepath, 'r', encoding='utf-8') as f:
    shakespear_txt = f.read()

In [12]:
print(shakespear_txt[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [13]:
text_vec_layer = tf.keras.layers.TextVectorization(split='character',
                                                  standardize='lower')

In [14]:
text_vec_layer.adapt([shakespear_txt])

In [15]:
text_vec_layer.get_vocabulary()[:10]

['', '[UNK]', ' ', 'e', 't', 'o', 'a', 'i', 'h', 's']

In [16]:
encoded = text_vec_layer([shakespear_txt])[0]

In [17]:
encoded

<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([21,  7, 10, ..., 22, 28, 12])>

In [18]:
# Removing code 0 and 1 reserved for padding and unknown characters 
# (codes start at 2 before that removal so now 0 and 1 will be some chars)
encoded -= 2

n_tokens = text_vec_layer.vocabulary_size() - 2
dataset_size = len(encoded)

In [19]:
n_tokens

39

In [20]:
ds = tf.data.Dataset.from_tensor_slices(encoded)

In [21]:
def to_dataset(sequence, length, shuffle=False, seed=None, batch_size=512):
    ds = tf.data.Dataset.from_tensor_slices(sequence)
    ds = ds.window(length + 1, shift=1, drop_remainder=True)
    ds = ds.flat_map(lambda window_ds: window_ds.batch(length+1))
    if shuffle:
        ds = ds.shuffle(buffer_size=100_000, seed=seed)
    ds = ds.batch(batch_size)
    return ds.map(lambda window: (window[:, :-1], window[:, 1:])).prefetch(1)

In [22]:
length = 32
tf.random.set_seed(1337)

In [23]:
train_set = to_dataset(encoded[:1_060_000], length=length, shuffle=True, seed=1337)
valid_set = to_dataset(encoded[1_060_000:], length=length)

In [24]:
#for sample in train_set.rebatch(1).take(1):
#    print(sample[0])
#    print(sample[1])

## Training

In [25]:
#buffer_size = 1024
batch_per_replica = 512
global_batch_size = batch_per_replica * strategy.num_replicas_in_sync

n_layer = 8
n_head = 8
n_embd = 32
block_size = length
bias = True
vocab_size = n_tokens
dropout = 0.1

iter_num = 0
best_val_loss = 1e9
max_iters = 250

weight_decay = 1e-1

warmup_iters = 50
learning_rate = 6e-4
lr_decay_iters = 250 # == max_iters
min_lr = 6e-5

eval_iters = 20
eval_interval = 10

grad_clip = 1.0

eval_only = False
eval_interval=1
step = 0

always_save_checkpoint = True
restore = False

seed = 1337

In [26]:
model_args = dict(n_layer=n_layer, 
                  n_head=n_head, 
                  n_embd=n_embd,
                  block_size=block_size,
                  bias=bias,
                  vocab_size=vocab_size,
                  dropout=dropout,
                  seed=seed)

In [27]:
model_args

{'n_layer': 8,
 'n_head': 8,
 'n_embd': 32,
 'block_size': 32,
 'bias': True,
 'vocab_size': 39,
 'dropout': 0.1,
 'seed': 1337}

In [28]:
checkpoint_directory = "./checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

In [29]:
log_dir = "./tensorboard"

In [30]:
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

In [31]:
def build_model(model_config):
    return GPT(model_config)

In [32]:
gptconf = GPTConfig(**model_args)

In [33]:
train_set_dist = strategy.experimental_distribute_dataset(train_set)
valid_set_dist = strategy.experimental_distribute_dataset(valid_set.take(eval_iters))

In [None]:
with strategy.scope():
    
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True,
        reduction=tf.keras.losses.Reduction.NONE)
    
    def compute_loss(Y, logits, model_losses):
        per_example_loss = loss_object(tf.reshape(Y, [-1]),
                      tf.reshape(logits, [-1, logits.shape[-1]]))
        loss = tf.nn.compute_average_loss(per_example_loss)
        if model_losses:
            loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))
        return loss

    val_loss = tf.keras.metrics.Mean(name='test_loss')

    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')
    val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='val_accuracy')

    model = build_model(gptconf)

    tb_callback.set_model(model)

    callbacks = tf.keras.callbacks.CallbackList([
        tb_callback
    ])
    
    optimizer = tf.keras.optimizers.AdamW(learning_rate=MyLRSchedule(learning_rate, warmup_iters, min_lr, lr_decay_iters))
    optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
    
    ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                               optimizer=optimizer,
                               model=model,
                               model_args=model_args,
                               best_val_loss=tf.Variable(best_val_loss),
                               train_accuracy=train_accuracy,
                               val_accuracy=val_accuracy,
                               val_loss=val_loss)

    manager = tf.train.CheckpointManager(ckpt, checkpoint_directory, max_to_keep=3)

    if restore:
        ckpt.restore(manager.latest_checkpoint)
        #ckpt.restore(manager.checkpoints[-2])
        step = int(ckpt.step.value().numpy())
        best_val_loss = float(ckpt.best_val_loss.value().numpy())
        model = ckpt.model
        optimizer = ckpt.optimizer
        model_args=ckpt.model_args
        train_accuracy=ckpt.train_accuracy
        val_accuracy=ckpt.val_accuracy
        val_loss=ckpt.val_loss


In [35]:
@tf.function
def train_step(sample):
    X, Y = sample
    with tf.GradientTape() as tape:
        logits = model(X, training=True)
        loss = compute_loss(Y, logits, model.losses)
        scaled_loss = optimizer.get_scaled_loss(loss)
    
    scaled_gradients = tape.gradient(scaled_loss, model.trainable_variables)
    gradients = optimizer.get_unscaled_gradients(scaled_gradients)
    gradients = [tf.clip_by_value(gradient,
                                  clip_value_min=-grad_clip,
                                  clip_value_max=grad_clip) for gradient in gradients]
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_accuracy.update_state(tf.reshape(Y, [-1]),
                      tf.reshape(logits, [-1, logits.shape[-1]]))
    return loss

In [36]:
@tf.function
def val_step(sample):
    X, Y = sample
    logits = model(X, training=False)
    v_loss = loss_object(tf.reshape(Y, [-1]),
                      tf.reshape(logits, [-1, logits.shape[-1]]))
    val_loss.update_state(v_loss)
    val_accuracy(tf.reshape(Y, [-1]),
                      tf.reshape(logits, [-1, logits.shape[-1]]))

In [37]:
def distributed_train_epoch(dataset, step):
    total_loss = 0.0
    num_batches = 0
    for sample in dataset:
        callbacks.on_train_batch_begin(step)
        per_replica_losses = strategy.run(train_step, args=(sample,))
        callbacks.on_train_batch_end(step)
        total_loss += strategy.reduce(
            tf.distribute.ReduceOp.SUM,
            per_replica_losses,
            axis=None)
        num_batches += 1
    return total_loss / float(num_batches)

In [38]:
def distributed_val_epoch(dataset, step):
    for sample in dataset:
        callbacks.on_test_batch_begin(step)
        strategy.run(val_step, args=(sample,))
        callbacks.on_test_batch_end(step)

## Train Loop

In [39]:
callbacks.on_train_begin()

while True:
    callbacks.on_epoch_begin(step)
    # Train Epoch            
    train_loss = distributed_train_epoch(train_set_dist, step)
    
    # Val Epoch
    distributed_val_epoch(valid_set_dist, step)
    callbacks.on_epoch_end(step)
    out_format = ("Epoch {}\nLoss: {}, Accuracy: {}\nVal Loss: {}, Val Accuracy: {}")
    print(out_format.format(step, train_loss, train_accuracy.result() * 100, val_loss.result(), val_accuracy.result() * 100 ))

    # Checkpointing
    if step > 0:
        if val_loss.result() < best_val_loss or always_save_checkpoint:
            best_val_loss = val_loss.result()
            print(f'Saving checkpoint to {checkpoint_prefix}')
            ckpt.step.assign(step)
            ckpt.best_val_loss.assign(best_val_loss)
            manager.save()
    
    train_accuracy.reset_states()
    val_accuracy.reset_states()
    val_loss.reset_states()
    
    step += 1
    if step > max_iters:
        break
        
callbacks.on_train_end()




2024-03-24 22:06:10.067625: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator case/Assert/AssertGuard/Assert





INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

KeyboardInterrupt: 

## Model for inference

In [40]:
tf.config.run_functions_eagerly(True)
inference_model = GPT(gptconf)

In [41]:
restore_chkpt = tf.train.Checkpoint(model=inference_model)

In [42]:
restore_chkpt.restore(manager.checkpoints[-1])

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x774b185b7640>

## Inference

In [43]:
def decode(gen_txt, text_vec_layer):
    out = []
    gen_txt = gen_txt + 2
    reverse = text_vec_layer.get_vocabulary()
    decoder = lambda x: str(reverse[x])
    gen_shape = gen.shape
    gen_flat = tf.reshape(gen_txt, -1)
    res = tf.map_fn(decoder, gen_flat, fn_output_signature='string')
    res = tf.reshape(res, gen_shape)
    res = res.numpy()
    for sentence in res:
        out.append(b"".join(sentence).decode())
    return out

In [44]:
shakespear_txt[:32]

'First Citizen:\nBefore we proceed'

In [45]:
txt = tf.reshape(text_vec_layer(shakespear_txt[:32]), shape=(-1, 32)) - 2

In [46]:
gen = inference_model.generate(txt, 32)

In [47]:
decode(gen, text_vec_layer)

['first citizen:\nbefore we proceedddvvvvvvdddddddddddddddddddddddd']