In [1]:
from tinygrad import Tensor, nn, TinyJit, Device
import numpy as np
import math

In [2]:
corpus = open("shakespeare.txt").read()

In [3]:
print(len(corpus), "chars")
print(corpus[:1000])

1115394 chars
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst 

In [4]:
vocab = sorted(list(set(corpus)))
vocab_size = len(vocab)
vocab_size, "".join(vocab)

(65, "\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")

In [5]:
new_line_char = "\n"

In [6]:
encode = lambda s: [vocab.index(c) for c in s]
decode = lambda l: "".join([vocab[i] for i in l])

decode(encode("hello"))

'hello'

In [7]:
data = Tensor(encode(corpus))
split = int(0.9 * len(data))
train_data = data[:split]
test_data = data[split:]

In [8]:
def get_batch(data: Tensor, batch_size, block_size):
  indices = Tensor.randint((batch_size,), high=len(data) - block_size).reshape(
    (batch_size, 1)
  ) + Tensor.arange(block_size)
  return data[indices], data[indices + 1]

In [9]:
x, y = get_batch(train_data, batch_size=4, block_size=8)
x.shape, y.shape

((4, 8), (4, 8))

In [10]:
print([decode(row) for row in x.numpy()])
print([decode(row) for row in y.numpy()])

[' come do', 'h\nBut th', 'UMBERLAN', 'll money']
['come dow', '\nBut the', 'MBERLAND', 'l money;']


In [11]:
class Attention:
  def __init__(self, embed_size: int, n_heads: int, head_size: int) -> None:
    self.n_heads = n_heads
    self.head_size = head_size
    bound = 1 / math.sqrt(self.head_size)
    self.queries = Tensor.uniform(
      n_heads, embed_size, self.head_size, low=-bound, high=bound
    )
    self.keys = Tensor.uniform(
      n_heads, embed_size, self.head_size, low=-bound, high=bound
    )
    self.values = Tensor.uniform(
      n_heads, embed_size, self.head_size, low=-bound, high=bound
    )
    self.proj = nn.Linear(n_heads * head_size, embed_size)

  def __call__(self, x: Tensor) -> Tensor:
    B, T, C = x.shape

    x = x.unsqueeze(1).expand((B, self.n_heads, T, C))

    Q = x @ self.queries  # (B, n_heads, T, head_size)
    K = x @ self.keys  # (B, n_heads, T, head_size)
    V = x @ self.values  # (B, n_heads, T, head_size)

    attn = Q @ K.transpose(-2, -1) / math.sqrt(self.head_size)  # (B, n_heads, T, T)
    mask = Tensor.ones((T, T), requires_grad=False).tril()
    attn = attn.masked_fill(mask == 0, float("-inf"))  # (B, n_heads, T, T)
    attn = attn.softmax()  # (B, n_heads, T, T)

    y = attn @ V  # (B, n_heads, T, head_size)
    y = y.transpose(1, 2).reshape((B, T, -1))  # (B, T, n_heads * head_size)
    y = self.proj(y)  # (B, T, C)
    return y


class MLP:
  def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None:
    self.fc1 = nn.Linear(input_size, hidden_size)
    self.fc2 = nn.Linear(hidden_size, output_size)

  def __call__(self, x: Tensor) -> Tensor:
    return x.sequential([self.fc1, Tensor.gelu, self.fc2])


class TransformerBlock:
  def __init__(self, embed_size: int, n_heads: int, head_size: int) -> None:
    self.ln1 = nn.LayerNorm(embed_size)
    self.attn = Attention(embed_size, n_heads, head_size)
    self.ln2 = nn.LayerNorm(embed_size)
    self.mlp = MLP(embed_size, 4 * embed_size, embed_size)

  def __call__(self, x: Tensor) -> Tensor:
    x = x + self.attn(self.ln1(x))
    x = x + self.mlp(self.ln2(x))
    return x


class Transformer:
  def __init__(
    self,
    block_size: int,
    vocab_size: int,
    embed_size: int,
    n_layers: int,
    n_heads: int,
    head_size: int,
  ) -> None:
    self.block_size = block_size
    self.vocab_size = vocab_size
    self.token_embed = nn.Embedding(vocab_size, embed_size)
    self.pos_embed = nn.Embedding(block_size, embed_size)
    self.h = [TransformerBlock(embed_size, n_heads, head_size) for _ in range(n_layers)]
    self.ln = nn.LayerNorm(embed_size)
    self.lm_head = nn.Linear(embed_size, vocab_size, bias=False)

  def __call__(self, x: Tensor) -> Tensor:
    assert len(x.shape) == 2 and x.shape[1] == self.block_size
    B, T = x.shape
    embed = self.token_embed(x) + self.pos_embed(Tensor.arange(T))
    logits = embed.sequential(self.h + [self.ln, self.lm_head])
    return logits

  def loss(self, x: Tensor, y: Tensor) -> Tensor:
    logits = self(x)
    loss = logits.sparse_categorical_crossentropy(y)
    return logits, loss

  def generate(self, x: Tensor, n: int = 500) -> Tensor:
    assert len(x.shape) == 1 and x.shape[0] == self.block_size
    with Tensor.inference_mode():
      x = x.unsqueeze(0)
      for _ in range(n):
        logits = self(x[:, -self.block_size :])
        p = logits[:, -1].softmax().squeeze(0)
        next_token = np.random.choice(self.vocab_size, p=p.numpy())
        x = x.cat(Tensor([[next_token]]), dim=1)
      return x.squeeze(0)

In [12]:
block_size = 128
embed_size = 256
n_layers = 8
n_heads = 8
head_size = embed_size // n_heads
transformer = Transformer(
  block_size, vocab_size, embed_size, n_layers, n_heads, head_size
)
sum(p.numel() for p in nn.state.get_parameters(transformer))

6378496

In [13]:
optim = nn.optim.AdamW(nn.state.get_parameters(transformer))
batch_size = 64


@TinyJit
@Tensor.train()
def train_step():
  optim.zero_grad()
  x_samples, y_samples = get_batch(train_data, batch_size, block_size)
  _, loss = transformer.loss(x_samples, y_samples)
  loss.backward()
  optim.step()
  return loss

In [14]:
losses = []
for step in range(1, 2501):
  loss = train_step().item()
  losses.append(loss)
  if step == 1 or step % 100 == 0:
    with Tensor.inference_mode():
      x_samples, y_samples = get_batch(test_data, batch_size, block_size)
      acc = (transformer(x_samples).argmax(axis=-1) == y_samples).mean().item()
      print(f"step {step}, loss {loss:.2f}, acc {acc*100.:.2f}%")

step 1, loss 4.21, acc 15.15%
step 100, loss 2.54, acc 26.04%
step 200, loss 2.42, acc 28.61%
step 300, loss 2.21, acc 34.20%
step 400, loss 1.88, acc 42.11%
step 500, loss 1.74, acc 43.99%
step 600, loss 1.62, acc 46.55%
step 700, loss 1.53, acc 49.46%
step 800, loss 1.50, acc 49.50%
step 900, loss 1.45, acc 50.18%
step 1000, loss 1.43, acc 51.09%
step 1100, loss 1.36, acc 52.84%
step 1200, loss 1.31, acc 52.09%
step 1300, loss 1.33, acc 53.33%
step 1400, loss 1.32, acc 53.75%
step 1500, loss 1.27, acc 53.75%
step 1600, loss 1.23, acc 53.82%
step 1700, loss 1.23, acc 54.99%
step 1800, loss 1.25, acc 54.77%
step 1900, loss 1.18, acc 53.77%
step 2000, loss 1.21, acc 54.21%
step 2100, loss 1.19, acc 54.54%
step 2200, loss 1.17, acc 54.59%
step 2300, loss 1.18, acc 55.77%
step 2400, loss 1.13, acc 55.18%
step 2500, loss 1.14, acc 54.68%


In [15]:
text = decode(transformer.generate(data[:block_size], 1000).numpy())
print("\033[92m" + text[:block_size] + "\033[0m" + text[block_size:])

[92mFirst Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to [0mbe grieved
I merciful, mouth, and shall toke of reputation.

ROMEO:
No, be Ever, bear him to be married
That break and affecting her just.

GLOUCESTER:
Bless I stay, when Gloucester wings and be so,
For reputy but at an Ireland.

KING RICHARD III:
Away, then, be my merits! Upot live tears?

Clown:
My successive can yield as these highness of me;
My worship for Claudio. Hark! 'tis Cup rot,
What they may listore Claudio, when the sulder?

Second Citizen:
Why, my lord, I protetly day!

First Citizen:
Not stand not proves, come, nor own rember eyes
In march, for thy shoulders them for sleep,
A little lambs-ating the old light,
When he was speaking, mine honour, comes hither. Notive:

JOHN OF CARGARD II:
Why not knigh me in Rome,
Repent may complant them he strike
With successful plate confess; swift
Is seem at our hard it, if he cannot
Had kill'd the arms