In [1]:
from tinygrad import Tensor, nn, TinyJit, Device
import numpy as np
import math

# https://github.com/tinygrad/tinygrad/issues/5408
# import os
# os.environ['JIT'] = '2'
Device.DEFAULT = "GPU"

In [2]:
corpus = open("shakespeare.txt").read()

In [3]:
print(len(corpus), "chars")
print(corpus[:1000])

1115394 chars
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst 

In [4]:
vocab = sorted(list(set(corpus)))
vocab_size = len(vocab)
vocab_size, "".join(vocab)

(65, "\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")

In [5]:
new_line_char = "\n"

In [6]:
encode = lambda s: [vocab.index(c) for c in s]
decode = lambda l: "".join([vocab[i] for i in l])

decode(encode("hello"))

'hello'

In [7]:
data = Tensor(encode(corpus))
split = int(0.9 * len(data))
train_data = data[:split]
test_data = data[split:]

In [8]:
def get_batch(data: Tensor, batch_size, block_size):
  indices = Tensor.randint((batch_size,), high=len(data) - block_size).reshape(
    (batch_size, 1)
  ) + Tensor.arange(block_size)
  return data[indices], data[indices + 1]

In [9]:
x, y = get_batch(train_data, batch_size=4, block_size=8)
x.shape, y.shape

((4, 8), (4, 8))

In [10]:
print([decode(row) for row in x.numpy()])
print([decode(row) for row in y.numpy()])

['\nHath ye', 'ich Capu', '-shrunk.', 'oo near,']
['Hath yet', 'ch Capul', 'shrunk.\n', 'o near, ']


In [11]:
class Attention:
  def __init__(self, n_embd: int, n_head: int):
    assert n_embd % n_head == 0
    # key, query, value projections for all heads, but in a batch
    self.c_attn = nn.Linear(n_embd, 3 * n_embd)
    # output projection
    self.c_proj = nn.Linear(n_embd, n_embd)
    # regularization
    self.n_head = n_head
    self.n_embd = n_embd

  def __call__(self, x: Tensor):
    B, T, C = x.shape
    qkv = self.c_attn(x)
    q, k, v = qkv.split(self.n_embd, dim=2)
    k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
    q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
    v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

    # manual implementation of attention
    att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
    mask = Tensor.ones(T, T).tril()
    att = att.masked_fill(mask == 0, float("-inf"))
    att = att.softmax()
    y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    y = y.transpose(1, 2).view(B, T, C)  # re-assemble all head outputs side by side
    # output projection
    y = self.c_proj(y)
    return y


class MLP:
  def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None:
    self.fc1 = nn.Linear(input_size, hidden_size)
    self.fc2 = nn.Linear(hidden_size, output_size)

  def __call__(self, x: Tensor) -> Tensor:
    return x.sequential([self.fc1, Tensor.relu, self.fc2])


class TransformerBlock:
  def __init__(self, embed_size: int, n_heads: int) -> None:
    self.ln1 = nn.LayerNorm(embed_size)
    self.attn = Attention(embed_size, n_heads)
    self.ln2 = nn.LayerNorm(embed_size)
    self.mlp = MLP(embed_size, 4 * embed_size, embed_size)

  def __call__(self, x: Tensor) -> Tensor:
    x = x + self.attn(self.ln1(x))
    x = x + self.mlp(self.ln2(x))
    return x


class Transformer:
  def __init__(
    self,
    block_size: int,
    vocab_size: int,
    embed_size: int,
    n_layers: int,
    n_heads: int,
  ) -> None:
    self.block_size = block_size
    self.vocab_size = vocab_size
    self.token_embed = nn.Embedding(vocab_size, embed_size)
    self.pos_embed = nn.Embedding(block_size, embed_size)
    self.h = [TransformerBlock(embed_size, n_heads) for _ in range(n_layers)]
    self.ln_f = nn.LayerNorm(embed_size)
    self.lm_head = nn.Linear(embed_size, vocab_size, bias=False)

  def __call__(self, x: Tensor) -> Tensor:
    assert len(x.shape) == 2 and x.shape[1] == self.block_size
    B, T = x.shape
    embed = self.token_embed(x) + self.pos_embed(Tensor.arange(T))
    logits = self.lm_head(self.ln_f(embed.sequential(self.h)))
    return logits

  def loss(self, x: Tensor, y: Tensor) -> Tensor:
    logits = self(x)
    loss = logits.sparse_categorical_crossentropy(y)
    return logits, loss

  def generate(self, x: Tensor, n: int = 500) -> Tensor:
    assert len(x.shape) == 1 and x.shape[0] == self.block_size
    x = x.unsqueeze(0)
    for _ in range(n):
      logits = self(x[:, -self.block_size :])
      p = logits[:, -1].softmax().squeeze(0)
      next_token = np.random.choice(self.vocab_size, p=p.numpy())
      x = x.cat(Tensor([[next_token]]), dim=1)
    return x.squeeze(0)

In [12]:
block_size = 32
embed_size = 128
n_heads = 4
transformer = Transformer(
  block_size=block_size,
  vocab_size=vocab_size,
  embed_size=embed_size,
  n_layers=4,
  n_heads=n_heads,
)
sum(p.numel() for p in nn.state.get_parameters(transformer))

814080

In [13]:
optim = nn.optim.AdamW(nn.state.get_parameters(transformer))
batch_size = 128


@TinyJit
@Tensor.train()
def train_step():
  optim.zero_grad()
  x_samples, y_samples = get_batch(train_data, batch_size, block_size)
  _, loss = transformer.loss(x_samples, y_samples)
  loss.backward()
  optim.step()
  return loss

In [14]:
losses = []
for step in range(1, 2001):
  loss = train_step().item()
  losses.append(loss)
  if step == 1 or step % 100 == 0:
    with Tensor.inference_mode():
      x_samples, y_samples = get_batch(test_data, batch_size, block_size)
      acc = (transformer(x_samples).argmax(axis=-1) == y_samples).mean().item()
      print(f"step {step}, loss {loss:.2f}, acc {acc*100.:.2f}%")

step 1, loss 4.31, acc 15.06%
step 100, loss 2.32, acc 32.30%
step 200, loss 2.02, acc 39.21%
step 300, loss 1.83, acc 42.92%
step 400, loss 1.69, acc 45.21%
step 500, loss 1.66, acc 44.24%
step 600, loss 1.63, acc 46.88%
step 700, loss 1.57, acc 47.88%
step 800, loss 1.53, acc 48.12%
step 900, loss 1.51, acc 48.88%
step 1000, loss 1.50, acc 52.22%
step 1100, loss 1.46, acc 51.15%
step 1200, loss 1.45, acc 49.88%
step 1300, loss 1.49, acc 51.54%
step 1400, loss 1.46, acc 51.05%
step 1500, loss 1.44, acc 50.32%
step 1600, loss 1.52, acc 51.51%
step 1700, loss 1.38, acc 50.81%
step 1800, loss 1.45, acc 51.27%
step 1900, loss 1.45, acc 51.73%
step 2000, loss 1.45, acc 51.54%


In [15]:
text = decode(transformer.generate(data[:block_size], 1000).numpy())
print("\033[92m" + text[:block_size] + "\033[0m" + text[block_size:])

[92mFirst Citizen:
Before we proceed[0m itself; the that good see,
That's envish you, leave you pierce thine her bloody cometry.
Thou shalt ones grace; and skull what woo more sapneeling by and they do no worrup in these feelss ensurnes to keep thete alse up.
What I may so as Tybalt thou Hereford, 'twixe tells allsion to her of warrant:
Come, let to me my son, I am now;
And crow she wave as lips; and to the king,
Which said I coung stange all: hence his marriage, never will, could all these
Must be sudden shall staff once hath foolish'd.

HENRY BOLINGS:
Calary as your paunt great frot, you seek your prosition of your edwar:
So all, pupon me ear by your true,
The professial tyrange of Will'st;
The hands, that stay Marcius that gases hath Roman;
For me leave you? herefore I never 'tis beg and
Diving that amazes we are pale judge!
If it is namedhes and that will
Here are of the own creature out upon our spirite
To him with youtestle lap the widow,
In wead set me to love a shepherd;
And 