In [None]:
import tensorflow as tf
import keras
from keras import layers, activations, losses, optimizers

import time
import numpy as np
import matplotlib.pyplot as plt

# Load and Process Data

In [None]:
with open("../data/input.txt", "r", encoding="utf-8") as f:
    text = f.read()

Tokens are chars, so the vocab size is the number of unique chars

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

char2int = {c: i for i, c in enumerate(chars)}
int2char = {i: c for i, c in enumerate(chars)}


def encode(s: str) -> list[int]:
    return [char2int[c] for c in s if c in char2int]


def decode(y: list[int] | np.ndarray | tf.Tensor) -> str:
    return "".join([int2char[int(i)] for i in y if int(i) in int2char])

The input text is encoded as an `Tensor`, then split into training and validation
splits

In [None]:
full_data = np.array(encode(text), dtype=np.int64)

val_size = len(full_data) // 10

train_data = full_data[val_size:]
val_data = full_data[:val_size]

### Convert data into blocks

$x_i = [d_i, d_{i + 1}, ..., d_{i + b}]$

$y_i = [d_{i + 1}, d_{i + 2}, ..., d_{i + b + 1}]$

In [None]:
def block_data(data, block_size):
    n_blocks = len(data) - block_size - 1
    x = np.stack([data[i:i + block_size] for i in range(n_blocks)])
    y = np.stack([data[i:i + block_size] for i in range(1, n_blocks + 1)])
    return x, y

### Generate random batches for dataset

In [None]:
def batch_iterate(x, y, batch_size):
    permutation = np.random.permutation(y.shape[0])
    for s in range(0, y.shape[0], batch_size):
        idxs = permutation[s:s + batch_size]
        yield tf.convert_to_tensor(x[idxs]), tf.convert_to_tensor(y[idxs])

# Model

In [None]:
class MLP(keras.Layer):
    def __init__(self, dropout, use_bias=True):
        super().__init__()
        self.dropout = dropout
        self.use_bias = use_bias

    
    def build(self, input_shape):
        self.c_fc = layers.Dense(
            4 * input_shape[-1], 
            activation=activations.gelu,
            use_bias=self.use_bias
        )
        self.c_proj = layers.Dense(
            input_shape[-1], 
            activation=activations.gelu,
            use_bias=self.use_bias
        )
        self.dropout = layers.Dropout(self.dropout)


    def call(self, x):
        x = self.c_fc(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

In [None]:
class Block(keras.Layer):
    def __init__(self, num_heads, dropout, use_bias=True):
        super().__init__()
        self.num_heads = num_heads
        self.dropout = dropout
        self.use_bias = use_bias


    def build(self, input_shape):
        self.ln_1 = layers.LayerNormalization(epsilon=1e-5, center=self.use_bias)
        self.attn = layers.MultiHeadAttention(
            num_heads=self.num_heads,
            key_dim=input_shape[-1] // self.num_heads,
            value_dim=input_shape[-1],
            dropout=self.dropout,
            use_bias=self.use_bias,
            output_shape=input_shape[-1:]
        )
        self.ln_2 = layers.LayerNormalization(center=self.use_bias)
        self.mlp = MLP(dropout=self.dropout, use_bias=self.use_bias)


    def call(self, x):
        x = self.ln_1(x)
        x = x + self.attn(x, x, use_causal_mask=True)
        x = x + self.mlp(self.ln_2(x))
        return x

In [None]:
class GenerativeTransformer(keras.Layer):
    def __init__(
        self, vocab_size, block_size, embedding_size,
        num_heads, num_layers, dropout, use_bias=True,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding_size = embedding_size
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.dropout_rate = dropout
        self.use_bias = use_bias


    def build(self, _):
        self.wte = layers.Embedding(self.vocab_size, self.embedding_size)
        self.wpe = layers.Embedding(self.block_size, self.embedding_size)
        self.dropout = layers.Dropout(self.dropout_rate)
        block_args = (self.num_heads, self.dropout_rate, self.use_bias)
        self.h = [Block(*block_args) for _ in range(self.num_layers)]
        self.ln_f = layers.LayerNormalization(epsilon=1e-5, center=self.use_bias)
        self.lm_head = layers.Dense(
            self.vocab_size, 
            activation=activations.gelu,
            use_bias=False,
        )

        # self.lm_head.weights[0] = self.wte.weights[0]

    
    def call(self, x_idx):
            _, T = x_idx.shape

            # assert T <= self.block_size, \
            #     f"cannot forward sequence of length {T}, block size is only {self.block_size}"
            
            pos = tf.range(0, T, dtype=tf.int64)

            tok_emb = self.wte(x_idx) # shape (B, T, C)
            pos_emb = self.wpe(pos) # shape (T, C)

            # (B, T, C) + (T, C) = (B, T, C)
            # elementwise addition for each batch
            x = self.dropout(tok_emb + pos_emb)
            for blk in self.h:
                x = blk(x)
            x = self.ln_f(x)
            x = self.lm_head(x)
            return x
        

# Training

In [None]:
EVAL_INTERVAL = 2500
LOG_INTERVAL = 500

BLOCK_SIZE = 32
BATCH_SIZE = 16

MAX_ITERS = 10000

MAX_LR = 1e-4
WARMUP_ITERS = 100
LR_DECAY_ITERS = 2500
MIN_LR = 1e-5

### Convert data to blocks

$x_i = [d_i, d_{i + 1}, ..., d_{i + b}]$

$y_i = [d_{i + 1}, d_{i + 2}, ..., d_{i + b + 1}]$

In [None]:
x_train, y_train = block_data(train_data, BLOCK_SIZE)
x_val, y_val = block_data(val_data, BLOCK_SIZE)

### Initialize model and optimizer

In [None]:
inputs = keras.Input(shape=(32,), dtype=tf.int64)
outputs = GenerativeTransformer(
    vocab_size=vocab_size, 
    block_size=BLOCK_SIZE, 
    embedding_size=640,
    num_heads=4, 
    num_layers=4, 
    dropout=0.0, 
    use_bias=True,
)(inputs)

model = keras.Model(inputs, outputs)

model.compile(loss=losses.sparse_categorical_crossentropy)

optimizer = optimizers.AdamW(learning_rate=1e-4)

Estimate loss from tensors $x, y$

In [None]:
def evaluate_loss(x, y, max_iters=100):
    loss_sum = 0
    cnt = 0
    for i, (bx, by) in enumerate(batch_iterate(x, y, BATCH_SIZE)):
        if i >= max_iters:
            break
    
        logits = model(bx)
        logits = tf.reshape(logits, (-1, logits.shape[-1]))
        by = tf.reshape(by, (-1,))
        loss = model.loss(by, logits)
        loss_sum += loss.numpy()[0] * len(x)
        cnt += len(x)
    return loss_sum / cnt

### Change learning rate over time

$\eta_i = \begin{cases}
    \frac{\eta \cdot i}{N_{\text{warmup}}} & i < N_{\text{warmup}} \\
    \eta_{\text{min}} + \left(
        \frac{1}{2} + \frac{1}{2}\cos\left(
            \pi \frac{N_{\text{warmup}} \cdot i}{N_{\text{decay}} - N_{\text{warmup}}}
        \right)
    \right)(\eta_0 - \eta_{\text{min}}) & N_{\text{warmup}} \leq i < N_{\text{decay}} \\
    \eta_{\text{min}} & N_{\text{decay}} \leq i
\end{cases}$

In [None]:
def get_lr(iter_num: int) -> float:
    if iter_num < WARMUP_ITERS: 
        return MAX_LR * iter_num / WARMUP_ITERS 
    
    if iter_num > LR_DECAY_ITERS:
        return MIN_LR
    
    decay_ratio = (iter_num - WARMUP_ITERS) / (LR_DECAY_ITERS - WARMUP_ITERS)
    assert 0 <= decay_ratio and decay_ratio <= 1
    coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio))
    return MIN_LR + coeff * (MAX_LR - MIN_LR)

In [None]:
plt.plot([get_lr(i) for i in range(1, MAX_ITERS + 1)])
plt.xlabel("Iteration")
plt.ylabel("Learning Rate")
plt.show()

### Crossentropy loss:

$l(x, y, \theta) = -\sum_i y_i \log(f(x_i, \theta))$

<br>

### Train Step with Adam Optimizer

$g_t = \nabla_{\theta_{t - 1}} l(x, y, )$

$\alpha = \eta \frac{\sqrt{1 - \beta_2^t}}{1 - \beta_1^t}$

$m_t = \beta_1 m_{t - 1} + (1 - \beta_1)g_t$

$m_t = \beta_2 v_{t - 1} + (1 - \beta_2)g_t^2$

$\theta_t = \theta_{t - 1} - \alpha \frac{m_t}{\sqrt{v_t} + \epsilon}$

In [None]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x)
        logits = tf.reshape(logits, (-1, logits.shape[-1]))
        y = tf.reshape(y, (-1,))
        loss = model.loss(y, logits)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply(grads, model.trainable_variables)
    return loss

In [None]:
i = 1
t0 = time.time()
best_val_loss = float('inf')

while True:
    if i > MAX_ITERS:
        break
    
    for x, y in batch_iterate(x_train, y_train, batch_size=BATCH_SIZE):
        if i > MAX_ITERS:
            break

        optimizer.learning_rate = get_lr(i)
        loss = train_step(x, y)

        if i % LOG_INTERVAL == 0:
            t1 = time.time()
            dt = t1 - t0
            t0 = t1
            print(f"[{i:4}] loss: {loss.numpy()[0]:.3f}, time: {dt:.3f}s")
        
        if i % EVAL_INTERVAL == 0:
            train_loss = evaluate_loss(x_train, y_train)
            val_loss = evaluate_loss(x_val, y_val)
            print(f"    train loss: {train_loss:.3f}, val loss: {val_loss:.3f}")

        i += 1

# 