## Translating NanoGPT (GPT2) to TensorFlow

#### Based on https://github.com/karpathy/nanoGPT/blob/master/model.py

In [None]:
import tensorflow as tf
import numpy as np
from  dataclasses import dataclass

In [2]:
class MyLayerNorm(tf.keras.layers.Layer):
    
    def __init__(self, bias=True, eps=1e-6, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.bias = bias
        
    def build(self, input_shape):  
        self.weight = self.add_weight(name='weight',
                                      shape=input_shape[-1:], # [-1:] gives last elem but keeps dims
                                      initializer=tf.keras.initializers.Ones(),
                                      trainable=True)

        self.bias = self.add_weight(name='bias',
                                      shape=input_shape[-1:], # [-1:] gives last elem but keeps dims
                                      initializer=tf.keras.initializers.Zeros(),
                                      trainable=True) if self.bias else None

        super(MyLayerNorm, self).build(input_shape)
    
    def call(self, x):
        # Can also use tf.nn.moments(inputs, axes=-1, keepdims=True), 
        # but then additionally one needs to take the sqrt to get \sigma
        mean = tf.keras.backend.mean(x, axis=-1, keepdims=True)
        std = tf.keras.backend.std(x, axis=-1, keepdims=True)
        
        return self.weight * (x - mean) / (std + self.eps) + self.bias

In [3]:
@dataclass
class GPTConfig:
    block_size: int = 8 # 1024 for GPT2
    vocab_size: int = 10 # 50304 for GPT2
    n_layer: int = 2 # 12
    n_head: int = 2 # 12
    n_embd: int = 10 # 768
    dropout: float = 0.0
    bias: bool = True
    seed: int = 1337

In [4]:
class CausalSelfAttention(tf.keras.layers.Layer):
    
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0, "Embedding dimension must divide number of heads"
        # key, query, value computed at once and splitted later
        self.initializer_proj = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02 / tf.math.sqrt(2. * config.n_layer), seed=None)
        self.c_attn = tf.keras.layers.Dense(#config.n_embd,
                                            3 * config.n_embd,
                                            activation=None,
                                            use_bias=config.bias)
        # output projection
        self.c_proj = tf.keras.layers.Dense(#config.n_embd,
                                            config.n_embd,
                                            activation=None,
                                            kernel_initializer=self.initializer_proj,
                                            use_bias=config.bias)
        self.dropout = config.dropout
        self.attn_dropout = tf.keras.layers.Dropout(self.dropout)
        self.resid_dropout = tf.keras.layers.Dropout(self.dropout)

        self.mask = tf.experimental.numpy.tril(
            tf.ones([config.block_size, config.block_size]))[tf.newaxis, tf.newaxis, :, :]

    def forward(self, x):
        
        B, T, C = x.size() # batch, sequence and channel, which is the embedding dim

        q, k, v = self.c_attn(x).split(self.n_embd, axis=2)
        k = tf.transpose(tf.reshape(k, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])
        q = tf.transpose(tf.reshape(q, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])
        v = tf.transpose(tf.reshape(v, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])

        att = (q @ tf.transpose(k, perm=[0, 1, 3, 2])) * (1.0 / tf.math.sqrt(k.shape[-1]))

        mask = tf.experimental.numpy.tril(tf.ones([T, T]))[tf.newaxis, tf.newaxis, :, :]
        att = tf.where(mask != 0, att, tf.constant(-np.inf))
        att = tf.nn.softmax(att, axis = 3)
        att = self.attn_dropout(att)
        y = att @ v

        y = tf.reshape(tf.transpose(y, perm=[0, 2, 1, 3]), [B, T, C])

        return self.resid_dropout(self.c_proj(y))

In [5]:
class MLP(tf.keras.layers.Layer):

    def __init__(self, config):
        super().__init__()
        self.initializer_proj = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02 / tf.math.sqrt(2. * config.n_layer), seed=None)
        # Streching and shrinking in channel/embedding dimension,
        # like for large resnets
        self.c_fc = tf.keras.layers.Dense(4 * config.n_embd, activation=None, use_bias=config.bias)
        self.c_proj = tf.keras.layers.Dense(config.n_embd, activation=None, kernel_initializer=self.initializer_proj, use_bias=config.bias)
        self.gelu = tf.keras.activations.gelu
        self.dropout = tf.keras.layers.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return self.dropout(x)
        

In [6]:
class Block(tf.keras.layers.Layer):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = MyLayerNorm(bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = MyLayerNorm(bias=config.bias)
        self.mlp = MLP(config)
    
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        return x + self.mlp(self.ln_2(x))

In [7]:
class GPT(tf.keras.models.Model):

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.initializer_dense = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02, seed=self.config.seed)
        self.initializer_embed = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02, seed=self.config.seed)
        self.initializer_bias = tf.keras.initializers.Zeros()
        
        self.wte = tf.keras.layers.Embedding(self.config.vocab_size, self.config.n_embd, 
                                             embeddings_initializer=self.initializer_embed, name='wte')
        self.wpe = tf.keras.layers.Embedding(self.config.vocab_size, self.config.n_embd, 
                                             embeddings_initializer=self.initializer_embed, name='wpe')
        self.drop = tf.keras.layers.Dropout(self.config.dropout, name='drop')
        self.h = [Block(self.config) for _ in range(self.config.n_layer)]
        self.ln_f = MyLayerNorm(bias=self.config.bias, name='ln_f')

        self.lm_head = tf.keras.layers.Dense(self.config.vocab_size, use_bias=self.config.bias)

    def build(self, input_shape):
        self.wte.build(input_shape=[self.config.vocab_size, self.config.n_embd])
        self.lm_head.build(input_shape=[self.config.n_embd, self.config.vocab_size])
        self.wte.trainable_weights[0].assign(self.lm_head.trainable_weights[0])
        
    def call(self, idx, targets=None):
        b, t = idx.shape
        assert t <= self.config.block_size, f'sequence too long for the defined context of {self.config.block_size}'
        pos = tf.range(0, t, dtype=tf.int64)

        tok_emb = self.wte(idx)
        pos_emb = self.wpe(pos)
        x = self.drop(tok_emb + pos_emb)
        for block in self.h:
            x = block(x)
        x = self.ln_f(x)

        if targets is not None:
            ce = tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True)
            logits = self.lm_head(x)
            loss = ce(tf.reshape(targets, [-1]),
                      tf.reshape(logits, [-1, logits.shape[-1]]))
        else:
            logits = self.lm_head(x[:, -1, :])[:, tf.newaxis, :]
            loss = None

        return logits, loss

In [8]:
cfg = GPTConfig()

In [9]:
cfg

GPTConfig(block_size=8, vocab_size=10, n_layer=2, n_head=2, n_embd=10, dropout=0.0, bias=True, seed=1337)

In [None]:
txt = tf.constant(np.random.randint(0, 9, size=[2, 8]), dtype=tf.int64)

In [11]:
txt

<tf.Tensor: shape=(2, 8), dtype=int64, numpy=
array([[7, 1, 4, 6, 7, 6, 2, 7],
       [4, 4, 8, 8, 6, 2, 6, 7]])>

In [12]:
gpt = GPT(cfg)

In [13]:
gpt(txt)

(<tf.Tensor: shape=(2, 1, 10), dtype=float32, numpy=
 array([[[-0.46243486, -0.16105999,  0.5470927 ,  1.121534  ,
          -0.43017292,  0.14129391, -1.1424067 ,  1.7714827 ,
           1.0557643 , -0.9435491 ]],
 
        [[-0.46243486, -0.16105999,  0.5470927 ,  1.121534  ,
          -0.43017292,  0.14129391, -1.1424067 ,  1.7714827 ,
           1.0557643 , -0.9435491 ]]], dtype=float32)>,
 None)