In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Hyperparameters
batch_size = 32
block_size = 8
max_iters = 3000
eval_interval = 300
learning_rate = 1e-2
eval_iters = 200

tf.random.set_seed(1337)

In [2]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = tf.constant(encode(text), dtype=tf.int64)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [4]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = tf.random.uniform((batch_size,), maxval=len(data) - block_size, dtype=tf.int64)
    x = tf.gather(data, ix[:, tf.newaxis] + tf.cast(tf.range(block_size), dtype=tf.int64))
    y = tf.gather(data, ix[:, tf.newaxis] + tf.cast(tf.range(1, block_size + 1), dtype=tf.int64))
    x, y = tf.convert_to_tensor(x, dtype=tf.int64), tf.convert_to_tensor(y, dtype=tf.int64)
    return x, y

In [5]:
@tf.function
def estimate_loss():
    out = {}
    model.trainable = False
    for split in ['train', 'val']:
        losses = tf.zeros(eval_iters, dtype=tf.float32)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses = tf.tensor_scatter_nd_add(losses, [[k]], [loss])
        out[split] = tf.reduce_mean(losses)
    model.trainable = True
    return out

In [6]:
class BigramLanguageModel(keras.Model):

    def __init__(self, vocab_size):
        super(BigramLanguageModel, self).__init__()
        self.token_embedding_table = layers.Embedding(vocab_size, vocab_size)

    def call(self, idx, targets=None):

        logits = self.token_embedding_table(idx)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = tf.reshape(logits, (B * T, C))
            targets = tf.reshape(targets, (B * T,))
            loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, logits=logits))

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in tf.range(max_new_tokens):
            logits, _ = self(idx)
            logits = logits[:, -1, :]
            probs = tf.nn.softmax(logits, axis=-1)
            idx_next = tf.random.categorical(tf.math.log(probs), num_samples=1, dtype=tf.int64)
            idx = tf.concat([idx, idx_next], axis=1)
        return idx

model = BigramLanguageModel(vocab_size)

In [7]:
optimizer = tf.optimizers.Adam(learning_rate)

for iter in tf.range(max_iters):

    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter.numpy()}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')

    with tf.GradientTape() as tape:
        logits, loss = model(xb, yb)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

step 0: train loss 4.1738, val loss 4.1731
step 300: train loss 2.6584, val loss 2.6725
step 600: train loss 2.5153, val loss 2.5308
step 900: train loss 2.4783, val loss 2.5034
step 1200: train loss 2.4696, val loss 2.5020
step 1500: train loss 2.4745, val loss 2.4959
step 1800: train loss 2.4649, val loss 2.4975
step 2100: train loss 2.4665, val loss 2.4937
step 2400: train loss 2.4672, val loss 2.4989
step 2700: train loss 2.4664, val loss 2.4839


In [8]:
# Generate from the model
context = tf.zeros((1, 1), dtype=tf.int64)
generated_text = decode(model.generate(context, max_new_tokens=500).numpy()[0])
print(generated_text)


Whar lse win havetsert I t?
FRat ef
OLor t I is amusse mze wQUK:
Southat out n f s ur's; y t are he, he hens ond fous ll t.

An kiongequt thil
Tbe?
Han nd y?

s sscouse s; wis y.



Sea we f the tit, s l, r n:
I ss rdeven l cre asst,
jughalllet aveanthe.
IUSoreanand ary bee h,
Whous,
QUSkiedwindin, remy! n h.
Gichomounepatou dshony wonceergis thom
ANEYo toms l ntesthig, lurt loumplapt is Sed pak'
Tat miore k. icchirobow yollersee
Anoururet s tha Te me ton f tod my wisthinis akikees
RWourel.
And 
