## Translating NanoGPT (GPT2) to TensorFlow

#### Based on https://github.com/karpathy/nanoGPT/blob/master/model.py

In [1]:
import tensorflow as tf
import numpy as np
from  dataclasses import dataclass
from tensorflow.experimental import numpy as tnp
import tensorflow_probability as tfp
import os

2024-03-19 18:07:49.903543: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-19 18:07:49.928749: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-19 18:07:49.928768: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-19 18:07:49.929459: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-19 18:07:49.933685: I tensorflow/core/platform/cpu_feature_guar

## The Model

In [2]:
class MyLayerNorm(tf.keras.layers.Layer):
    
    def __init__(self, bias=True, eps=1e-6, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.bias = bias
        
    def build(self, input_shape):  
        self.weight = self.add_weight(name='weight',
                                      shape=input_shape[-1:], # [-1:] gives last elem but keeps dims
                                      initializer=tf.keras.initializers.Ones(),
                                      trainable=True)

        self.bias = self.add_weight(name='bias',
                                      shape=input_shape[-1:], # [-1:] gives last elem but keeps dims
                                      initializer=tf.keras.initializers.Zeros(),
                                      trainable=True) if self.bias else None

        super(MyLayerNorm, self).build(input_shape)
    
    def call(self, x):
        # Can also use tf.nn.moments(inputs, axes=-1, keepdims=True), 
        # but then additionally one needs to take the sqrt to get \sigma
        mean = tf.keras.backend.mean(x, axis=-1, keepdims=True)
        std = tf.keras.backend.std(x, axis=-1, keepdims=True)
        
        return self.weight * (x - mean) / (std + self.eps) + self.bias

In [3]:
@dataclass
class GPTConfig:
    block_size: int = 8 # 1024 for GPT2
    vocab_size: int = 20 # 50304 for GPT2
    n_layer: int = 2 # 12
    n_head: int = 2 # 12
    n_embd: int = 10 # 768
    dropout: float = 0.0
    bias: bool = True
    seed: int = 1337

In [4]:
class CausalSelfAttention(tf.keras.layers.Layer):
    
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0, "Embedding dimension must divide number of heads"
        # key, query, value computed at once and splitted later
        self.initializer_proj = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02 / tf.math.sqrt(2. * config.n_layer), seed=None)
        self.c_attn = tf.keras.layers.Dense(#config.n_embd,
                                            3 * config.n_embd,
                                            activation=None,
                                            use_bias=config.bias)
        # output projection
        self.c_proj = tf.keras.layers.Dense(#config.n_embd,
                                            config.n_embd,
                                            activation=None,
                                            kernel_initializer=self.initializer_proj,
                                            use_bias=config.bias)
        self.dropout = config.dropout
        self.attn_dropout = tf.keras.layers.Dropout(self.dropout)
        self.resid_dropout = tf.keras.layers.Dropout(self.dropout)

        self.mask = tf.experimental.numpy.tril(
            tf.ones([config.block_size, config.block_size]))[tf.newaxis, tf.newaxis, :, :]

    def forward(self, x):
        
        B, T, C = x.size() # batch, sequence and channel, which is the embedding dim

        q, k, v = self.c_attn(x).split(self.n_embd, axis=2)
        k = tf.transpose(tf.reshape(k, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])
        q = tf.transpose(tf.reshape(q, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])
        v = tf.transpose(tf.reshape(v, [B, T, self.n_head, C // self.n_head]),
                         perm=[0, 2, 1, 3])

        att = (q @ tf.transpose(k, perm=[0, 1, 3, 2])) * (1.0 / tf.math.sqrt(k.shape[-1]))

        mask = tf.experimental.numpy.tril(tf.ones([T, T]))[tf.newaxis, tf.newaxis, :, :]
        att = tf.where(mask != 0, att, tf.constant(-np.inf))
        att = tf.nn.softmax(att, axis = 3)
        att = self.attn_dropout(att)
        y = att @ v

        y = tf.reshape(tf.transpose(y, perm=[0, 2, 1, 3]), [B, T, C])

        return self.resid_dropout(self.c_proj(y))

In [5]:
class MLP(tf.keras.layers.Layer):

    def __init__(self, config):
        super().__init__()
        self.initializer_proj = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02 / tf.math.sqrt(2. * config.n_layer), seed=None)
        # Streching and shrinking in channel/embedding dimension,
        # like for large resnets
        self.c_fc = tf.keras.layers.Dense(4 * config.n_embd, activation=None, use_bias=config.bias)
        self.c_proj = tf.keras.layers.Dense(config.n_embd, activation=None, kernel_initializer=self.initializer_proj, use_bias=config.bias)
        self.gelu = tf.keras.activations.gelu
        self.dropout = tf.keras.layers.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return self.dropout(x)
        

In [6]:
class Block(tf.keras.layers.Layer):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = MyLayerNorm(bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = MyLayerNorm(bias=config.bias)
        self.mlp = MLP(config)
    
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        return x + self.mlp(self.ln_2(x))

In [7]:
class GPT(tf.keras.models.Model):

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.initializer_dense = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02, seed=self.config.seed)
        self.initializer_embed = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02, seed=self.config.seed)
        self.initializer_bias = tf.keras.initializers.Zeros()
        
        self.wte = tf.keras.layers.Embedding(self.config.vocab_size, self.config.n_embd, 
                                             embeddings_initializer=self.initializer_embed, name='wte')
        self.wpe = tf.keras.layers.Embedding(self.config.block_size, self.config.n_embd, 
                                             embeddings_initializer=self.initializer_embed, name='wpe')
        self.drop = tf.keras.layers.Dropout(self.config.dropout, name='drop')
        self.h = [Block(self.config) for _ in range(self.config.n_layer)]
        self.ln_f = MyLayerNorm(bias=self.config.bias, name='ln_f')

        self.lm_head = tf.keras.layers.Dense(self.config.vocab_size, use_bias=self.config.bias)

    def build(self, input_shape):
        self.wte.build(input_shape=[self.config.vocab_size])
        self.lm_head.build(input_shape=[self.config.n_embd])
        self.wte.trainable_weights[0].assign(tf.transpose(self.lm_head.trainable_weights[0]))
        
    def call(self, idx, targets=None):
        b, t = idx.shape
        assert t <= self.config.block_size, f'sequence too long for the defined context of {self.config.block_size}'
        pos = tf.range(0, t, dtype=tf.int64)

        tok_emb = self.wte(idx)
        pos_emb = self.wpe(pos)
        x = self.drop(tok_emb + pos_emb)
        for block in self.h:
            x = block(x)
        x = self.ln_f(x)

        if targets is not None:
            ce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
            logits = self.lm_head(x)
            
            #print(tf.reshape(targets, [-1]).shape)
            #print(tf.reshape(logits, [-1, logits.shape[-1]]).shape)
            
            loss = ce(tf.reshape(targets, [-1]),
                      tf.reshape(logits, [-1, logits.shape[-1]]))
        else:
            logits = self.lm_head(x[:, -1, :])[:, tf.newaxis, :]
            loss = None

        return logits, loss

    def crop_block_size(self, block_size):
        assert block_size < self.config.block_size
        self.config.block_size = block_size
        self.wpe.weights[0] = tf.Variable(self.wpe.weights[0][:block_size], trainable=True)
        for block in self.h:
            if hasattr(block.attn, 'bias'):
                block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
        
    def get_num_params(self, non_embedding=True):
        n_params = self.count_params()
        if non_embedding:
            n_params -= self.wpe.count_params()
        return n_params
    
    def estimate_mfu(self, fwdbwd_per_iter, dt):
        N = self.get_num_params()
        L, H, Q, T = self.config.n_layer, self.config.n_head, self.config.n_embd / self.config.n_head, self.config.block_size
        flops_per_token = 6 * N + 12 * L * H * Q * T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        flops_achieved = flops_per_iter * (1.0/dt)
        flops_promised = 312e12 # A100 at bfloat16
        mfu = flops_achieved / flops_promised
        return mfu
    
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for i in range(max_new_tokens):
            idx_cond = idx if idx.shape[1] <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = tf.math.top_k(logits, min(top_k, logits.shape[-1]))
                # Returned top_k values are sorted, so [-1] is the smallest of top_k
                # and we cut all below that value
                logits[logits < v[:, -1]] = -float('Inf')
            probs = tf.keras.activations.softmax(logits, axis=-1)
            idx_dist = tfp.distributions.Multinomial(total_count=1, probs=probs)
            idx_next = idx_dist.sample(1)
            idx_next = tf.reshape(tf.math.argmax(idx_next, axis=-1), shape=(-1, 1))
            idx = tf.concat([idx, idx_next], axis=1)
        return idx

In [163]:
cfg = GPTConfig()

In [164]:
cfg

GPTConfig(block_size=8, vocab_size=20, n_layer=2, n_head=2, n_embd=10, dropout=0.0, bias=True, seed=1337)

In [165]:
txt = tf.constant(np.random.randint(0, 9, size=[2, 8]), dtype=tf.int64)

In [166]:
txt

<tf.Tensor: shape=(2, 8), dtype=int64, numpy=
array([[3, 1, 8, 3, 0, 3, 6, 8],
       [2, 7, 0, 7, 6, 5, 8, 2]])>

In [167]:
gpt = GPT(cfg)

In [169]:
gpt(txt, txt)

(<tf.Tensor: shape=(2, 8, 20), dtype=float32, numpy=
 array([[[ 0.4274856 , -0.06014805,  1.8340145 ,  2.919419  ,
           0.44632468,  0.31492844,  0.6288445 , -1.4740782 ,
           0.18079159, -0.26295385, -0.5100614 , -0.05964876,
          -0.48101842, -1.2206202 , -0.51501125, -0.0133401 ,
          -0.11798257, -0.46437418,  0.9749199 , -0.73925036],
         [ 0.02297247,  2.6380482 ,  0.16860771, -0.00409502,
           1.4361117 ,  0.6396625 ,  0.1880045 ,  0.7027596 ,
          -0.12826179,  0.6724812 , -1.3839304 ,  0.05561219,
           1.4390004 , -0.08324537, -0.39451855,  0.41270292,
           1.794532  ,  0.42497164, -0.21578485,  1.0269561 ],
         [ 0.42997706, -0.1618574 ,  0.39848232,  0.39548135,
           0.1931145 ,  0.21241042, -0.05936687, -0.7839963 ,
           2.0015707 , -0.17732549, -0.28062558,  0.29340756,
          -0.6812506 , -0.94440085,  0.8995163 , -1.3410667 ,
          -0.16206568,  0.49769974, -0.7805122 , -0.2991427 ],
         [ 0.4

In [170]:
gpt.generate(txt, 3)

tf.Tensor(
[[ 8]
 [17]], shape=(2, 1), dtype=int64)
tf.Tensor(
[[6]
 [2]], shape=(2, 1), dtype=int64)
tf.Tensor(
[[15]
 [ 3]], shape=(2, 1), dtype=int64)


<tf.Tensor: shape=(2, 11), dtype=int64, numpy=
array([[ 3,  1,  8,  3,  0,  3,  6,  8,  8,  6, 15],
       [ 2,  7,  0,  7,  6,  5,  8,  2, 17,  2,  3]])>

In [155]:
gpt.summary()

Model: "gpt_19"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 wte (Embedding)             multiple                  200       
                                                                 
 wpe (Embedding)             multiple                  80        
                                                                 
 drop (Dropout)              multiple                  0         
                                                                 
 block_36 (Block)            multiple                  0         
                                                                 
 block_37 (Block)            multiple                  0         
                                                                 
 ln_f (MyLayerNorm)          multiple                  20        
                                                                 
 dense_170 (Dense)           multiple                  220  

In [156]:
gpt.wpe.count_params()

80

In [157]:
gpt.count_params()

520

In [158]:
gpt.estimate_mfu(2,1)

2.3384615384615383e-10

In [143]:
gpt.config.block_size

8

In [171]:
gpt.crop_block_size(4)

In [175]:
gpt.config

GPTConfig(block_size=4, vocab_size=20, n_layer=2, n_head=2, n_embd=10, dropout=0.0, bias=True, seed=1337)

In [173]:
gpt(txt[:, :4])

(<tf.Tensor: shape=(2, 1, 20), dtype=float32, numpy=
 array([[[ 0.44027796, -0.03563954,  1.8611115 ,  2.918157  ,
           0.4875484 ,  0.36393893,  0.67396784, -1.5018653 ,
           0.24394247, -0.22710907, -0.5447492 , -0.10886243,
          -0.49926877, -1.2556354 , -0.4496538 ,  0.01027814,
          -0.03563245, -0.3974866 ,  1.0165888 , -0.7735279 ]],
 
        [[-0.9206353 ,  0.72270334,  0.06139433, -1.6357415 ,
          -0.06042978,  0.8799873 , -0.6267076 ,  2.5394278 ,
          -0.57317775,  0.6112908 ,  0.8381661 , -0.09762476,
           1.1142485 ,  1.4172473 ,  0.16688019,  0.33162895,
           1.3914782 ,  1.4068255 ,  0.28691912,  0.83587766]]],
       dtype=float32)>,
 None)

## Auxilliary routines

In [61]:
@dataclass
class MyLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    learning_rate: tf.float32 = learning_rate
    warmup_iters: int = warmup_iters
    min_lr: tf.float32 = min_lr 
    lr_decay_iters: int= lr_decay_iters
    
    def __call__(self, step):
      if step < self.warmup_iters:
          return self.learning_rate * float(step) / self.warmup_iters
      if step > self.lr_decay_iters:
          return self.min_lr
      decay_ratio = (float(step) - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters)
      assert 0 <= decay_ratio <= 1
      coeff = 0.5 * (1.0 + tf.math.cos(tnp.pi * decay_ratio))
      return self.min_lr + coeff * tf.cast((self.learning_rate - self.min_lr), dtype=tf.float32)

In [None]:
def estimate_loss():
    out = {'train': None, 'val': None}
    
    losses = tf.zeros(eval_iters)
    k = tf.Variable(0, trainable=False)
    for X, Y in train_set.take(eval_iters):
        logits, loss = model(X, Y, training=False)
        tf.tensor_scatter_nd_update(losses, [[k]], [loss])
        k.assign_add(1)
    out['train'] = tf.reduce_mean(loss)

    losses = tf.zeros(eval_iters)
    k = tf.Variable(0, trainable=False)
    for X, Y in valid_set.take(eval_iters):
        logits, loss = model(X, Y, training=False)
        tf.tensor_scatter_nd_update(losses, [[k]], [loss])
        k.assign_add(1)
    out['val'] = tf.reduce_mean(loss)

    return out

## Training

In [8]:
shakespear_url = "https://homl.info/shakespeare"
filepath = tf.keras.utils.get_file('shakespear.txt', shakespear_url)

In [9]:
with open(filepath, 'r', encoding='utf-8') as f:
    shakespear_txt = f.read()

In [10]:
print(shakespear_txt[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [11]:
text_vec_layer = tf.keras.layers.TextVectorization(split='character',
                                                  standardize='lower')

2024-03-19 18:07:57.173868: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-19 18:07:57.202271: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-19 18:07:57.202425: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-

In [12]:
text_vec_layer.adapt([shakespear_txt])

In [13]:
text_vec_layer.get_vocabulary()[:10]

['', '[UNK]', ' ', 'e', 't', 'o', 'a', 'i', 'h', 's']

In [14]:
encoded = text_vec_layer([shakespear_txt])[0]

In [15]:
encoded

<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([21,  7, 10, ..., 22, 28, 12])>

In [16]:
# Removing code 0 and 1 reserved for padding and unknown characters 
# (codes start at 2 before that removal so now 0 and 1 will be some chars)
encoded -= 2

n_tokens = text_vec_layer.vocabulary_size() - 2
dataset_size = len(encoded)

In [17]:
n_tokens

39

In [18]:
ds = tf.data.Dataset.from_tensor_slices(encoded)

In [19]:
def to_dataset(sequence, length, shuffle=False, seed=None, batch_size=32):
    ds = tf.data.Dataset.from_tensor_slices(sequence)
    ds = ds.window(length + 1, shift=1, drop_remainder=True)
    ds = ds.flat_map(lambda window_ds: window_ds.batch(length+1))
    if shuffle:
        ds = ds.shuffle(buffer_size=100_000, seed=seed)
    ds = ds.batch(batch_size)
    return ds.map(lambda window: (window[:, :-1], window[:, 1:])).prefetch(1)

In [20]:
length = 8
tf.random.set_seed(42)

In [21]:
train_set = to_dataset(encoded[:1_060_000], length=length, shuffle=True, seed=1337)
valid_set = to_dataset(encoded[1_060_000:], length=length)

In [22]:
for sample in train_set.rebatch(1).take(1):
    print(sample[0])
    print(sample[1])

tf.Tensor([[ 3 14  1 26 10 10 25  5]], shape=(1, 8), dtype=int64)
tf.Tensor([[14  1 26 10 10 25  5  8]], shape=(1, 8), dtype=int64)


In [62]:
n_layer = 2
n_head = 2 
n_embd = 32
block_size = length
bias = True
vocab_size = n_tokens
dropout = tf.constant(0.0, dtype=tf.float32)

iter_num = 0
best_val_loss = 1e9
max_iters = 100

init_from = 'scratch'

weight_decay = tf.constant(1e-1, dtype=tf.float32)

warmup_iters = 10
learning_rate = tf.constant(6e-4, dtype=tf.float32)
lr_decay_iters = 100 # == max_iters
min_lr = tf.constant(6e-5, dtype=tf.float32)

eval_iters = 20
eval_interval = 10

grad_clip = tf.constant(1.0, dtype=tf.float32)

eval_only = False
eval_interval=1
step = 0

always_save_checkpoint = True
restore = False

In [35]:
model_args = dict(n_layer=n_layer, 
                  n_head=n_head, 
                  n_embd=n_embd,
                  block_size=block_size,
                  bias=bias,
                  vocab_size=vocab_size,
                  dropout=dropout)

In [36]:
model_args

{'n_layer': 2,
 'n_head': 2,
 'n_embd': 32,
 'block_size': 8,
 'bias': True,
 'vocab_size': 39,
 'dropout': <tf.Tensor: shape=(), dtype=float32, numpy=0.0>}

In [None]:
checkpoint_directory = "./checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

In [53]:
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
optimizer = tf.keras.optimizers.AdamW(learning_rate=MyLRSchedule(learning_rate, warmup_iters, min_lr, lr_decay_iters))

ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                           optimizer=optimizer, 
                           model=model,
                           model_args=model_args,
                           best_val_loss=tf.Variable(best_val_loss))
                           #config=GPTConfig)

manager = tf.train.CheckpointManager(ckpt, checkpoint_directory, max_to_keep=3)

if restore:
    ckpt.restore(manager.latest_checkpoint)
    step = int(ckpt.step)
    best_val_loss = float(ckpt.best_val_loss)
    model = ckpt.model
    optimizer = ckpt.optimizer
    model_args=ckpt.model_args

In [195]:
for X, _ in train_set.take(1):
    out = model(X)

In [196]:
out

(<tf.Tensor: shape=(32, 1, 39), dtype=float32, numpy=
 array([[[ 5.4931684e+00,  3.7205371e-01,  6.9388658e-01, ...,
           1.4949246e-01,  2.6834235e-03,  7.0552784e-01]],
 
        [[ 4.1647804e-01,  4.7048670e-01, -6.5505439e-01, ...,
          -8.7842530e-01, -6.4945616e-02,  4.7188056e-01]],
 
        [[ 5.4931684e+00,  3.7205371e-01,  6.9388658e-01, ...,
           1.4949246e-01,  2.6834235e-03,  7.0552784e-01]],
 
        ...,
 
        [[ 4.1647804e-01,  4.7048670e-01, -6.5505439e-01, ...,
          -8.7842530e-01, -6.4945616e-02,  4.7188056e-01]],
 
        [[ 4.1647804e-01,  4.7048670e-01, -6.5505439e-01, ...,
          -8.7842530e-01, -6.4945616e-02,  4.7188056e-01]],
 
        [[ 5.4931684e+00,  3.7205371e-01,  6.9388658e-01, ...,
           1.4949246e-01,  2.6834235e-03,  7.0552784e-01]]], dtype=float32)>,
 None)

In [59]:
while True:
    if step % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # Checkpointing
    if losses['val'] < best_val_loss or always_save_checkpoint:
        best_val_loss = losses['val']
        if step > 0:
            print(f'Saving checkpoint to {checkpoint_prefix}')
            ckpt.step.assign_add(1)
            ckpt.best_val_loss.assign(best_val_loss)
            manager.save()
           
    if step == 0 and eval_only:
        break

    for X, Y in train_set.take(4):
        with tf.GradientTape() as tape:
            logits, main_loss = model(X, Y, training=True)
            loss = tf.add_n([main_loss] + model.losses)
        gradients = tape.gradient(loss, model.trainable_variables)
        gradients = [tf.clip_by_value(g, 
                                      clip_value_min=-grad_clip, 
                                      clip_value_max=grad_clip) for g in gradients]
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))    
    
    step += 1
    if step > 20:
        break
    

step 7: train loss 4.6512, val loss 4.6152
Saving checkpoint to ./checkpoints/ckpt
step 8: train loss 4.7170, val loss 4.6048
Saving checkpoint to ./checkpoints/ckpt


KeyboardInterrupt: 

In [204]:
model.generate(txt, 3)

tf.Tensor(
[[20]
 [18]], shape=(2, 1), dtype=int64)
tf.Tensor(
[[20]
 [18]], shape=(2, 1), dtype=int64)
tf.Tensor(
[[34]
 [18]], shape=(2, 1), dtype=int64)


<tf.Tensor: shape=(2, 11), dtype=int64, numpy=
array([[ 3,  1,  8,  3,  0,  3,  6,  8, 20, 20, 34],
       [ 2,  7,  0,  7,  6,  5,  8,  2, 18, 18, 18]])>