In [1]:
from tinygrad import Tensor, nn, TinyJit, Device
import numpy as np

# import os
# os.environ['JIT'] = '2'
Device.DEFAULT = "GPU"

In [2]:
corpus = open("shakespeare.txt").read()

In [3]:
print(len(corpus), "chars")
print(corpus[:1000])

1115394 chars
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst 

In [4]:
vocab = sorted(list(set(corpus)))
vocab_size = len(vocab)
vocab_size, "".join(vocab)

(65, "\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")

In [5]:
new_line_char = "\n"

In [6]:
encode = lambda s: [vocab.index(c) for c in s]
decode = lambda l: "".join([vocab[i] for i in l])

decode(encode("hello"))

'hello'

In [7]:
data = Tensor(encode(corpus))
split = int(0.9 * len(data))
train_data = data[:split]
test_data = data[split:]

In [8]:
def get_batch(data: Tensor, batch_size, block_size):
  indices = Tensor.randint((batch_size,), high=len(data) - block_size).reshape(
    (batch_size, 1)
  ) + Tensor.arange(block_size)
  return data[indices], data[indices + 1]

In [9]:
x, y = get_batch(train_data, batch_size=4, block_size=8)
x.shape, y.shape

((4, 8), (4, 8))

In [10]:
print([decode(row) for row in x.numpy()])
print([decode(row) for row in y.numpy()])

['YORK:\nLi', 'at doth ', 'and my s', 'els have']
['ORK:\nLit', 't doth b', 'nd my su', 'ls have ']


In [11]:
class Attention:
  def __init__(self, embed_size: int, n_heads: int, head_size: int) -> None:
    self.n_heads = n_heads
    self.head_size = head_size
    bound = 1 / (self.head_size**0.5)
    self.queries = Tensor.uniform(
      n_heads, embed_size, self.head_size, low=-bound, high=bound
    )
    self.keys = Tensor.uniform(
      n_heads, embed_size, self.head_size, low=-bound, high=bound
    )
    self.values = Tensor.uniform(
      n_heads, embed_size, self.head_size, low=-bound, high=bound
    )

  def __call__(self, x: Tensor) -> Tensor:
    B, T, C = x.shape

    x = x.unsqueeze(1).expand((B, self.n_heads, T, C))

    Q = x @ self.queries  # (B, n_heads, T, head_size)
    K = x @ self.keys  # (B, n_heads, T, head_size)
    V = x @ self.values  # (B, n_heads, T, head_size)
    attented_embeds = Tensor.scaled_dot_product_attention(
      Q, K, V, attn_mask=Tensor.ones((T, T), requires_grad=False).tril()
    )  # noqa: F401, (B, n_heads, T, head_size)
    concatenated_embeds = attented_embeds.reshape((B, T, self.n_heads * self.head_size))  # noqa: F401, (B, T, n_heads * head_size)
    return concatenated_embeds


class TransformerBlock:
  def __init__(self, embed_size: int, n_heads: int, head_size: int) -> None:
    self.attn = Attention(embed_size, n_heads, head_size)
    self.out_proj = nn.Linear(n_heads * head_size, embed_size)

  def __call__(self, x: Tensor) -> Tensor:
    return x.sequential([self.attn, self.out_proj, Tensor.gelu])


class Transformer:
  def __init__(
    self,
    block_size: int,
    vocab_size: int,
    embed_size: int,
    n_layers: int,
    n_heads: int,
    head_size: int,
  ) -> None:
    self.block_size = block_size
    self.vocab_size = vocab_size
    self.token_embed = nn.Embedding(vocab_size, embed_size)
    self.pos_embed = nn.Embedding(block_size, embed_size)
    self.h = [TransformerBlock(embed_size, n_heads, head_size) for _ in range(n_layers)]
    self.lm_head = nn.Linear(embed_size, vocab_size)

  def __call__(self, x: Tensor) -> Tensor:
    assert len(x.shape) == 2 and x.shape[1] == self.block_size
    B, T = x.shape
    embed = self.token_embed(x) + self.pos_embed(Tensor.arange(T))
    logits = embed.sequential(self.h + [self.lm_head])
    return logits

  def loss(self, x: Tensor, y: Tensor) -> Tensor:
    logits = self(x)
    loss = logits.sparse_categorical_crossentropy(y)
    return logits, loss

  def generate(self, x: Tensor, n: int = 500) -> Tensor:
    assert len(x.shape) == 1 and x.shape[0] == self.block_size
    x = x.unsqueeze(0)
    for _ in range(n):
      logits = self(x[:, -self.block_size :])
      p = logits[:, -1].softmax().squeeze(0)
      next_token = np.random.choice(self.vocab_size, p=p.numpy())
      x = x.cat(Tensor([[next_token]]), dim=1)
    return x.squeeze(0)

In [12]:
block_size = 16 # 32
embed_size = 512
n_heads = 4
head_size = embed_size // n_heads
transformer = Transformer(
  block_size=block_size,
  vocab_size=vocab_size,
  embed_size=embed_size,
  n_layers=2, # 4
  n_heads=n_heads,
  head_size=head_size,
)
sum(p.numel() for p in nn.state.get_parameters(transformer))

2172993

In [13]:
optim = nn.optim.AdamW(nn.state.get_parameters(transformer))
batch_size = 128


@TinyJit
@Tensor.train()
def train_step():
  optim.zero_grad()
  x_samples, y_samples = get_batch(train_data, batch_size, block_size)
  _, loss = transformer.loss(x_samples, y_samples)
  loss.backward()
  optim.step()
  return loss

In [14]:
losses = []
for step in range(1, 2501):
  loss = train_step().item()
  losses.append(loss)
  if step == 1 or step % 250 == 0:
    with Tensor.inference_mode():
      x_samples, y_samples = get_batch(test_data, batch_size, block_size)
      acc = (transformer(x_samples).argmax(axis=-1) == y_samples).mean().item()
      print(f"step {step}, loss {loss:.2f}, acc {acc*100.:.2f}%")

step 1, loss 4.17, acc 3.66%
step 250, loss 1.40, acc 46.53%
step 500, loss 1.10, acc 48.14%
step 750, loss 0.99, acc 56.05%
step 1000, loss 0.83, acc 60.16%
step 1250, loss 0.80, acc 59.91%
step 1500, loss 0.79, acc 59.86%
step 1750, loss 0.79, acc 59.91%
step 2000, loss 0.78, acc 59.81%
step 2250, loss 0.77, acc 60.35%
step 2500, loss 0.77, acc 60.35%


In [15]:
text = decode(transformer.generate(data[:block_size]).numpy())
print("\033[92m" + text[:block_size] + "\033[0m" + text[block_size:])

[92mFirst Citizen:
B[0mut isinin far my ving has bune aremy bagond sos tame! the of nd, mand's afout qurad bearlle youcantem! I his rasiases and that thadsof win nor soonmeruproin tor, ang wh m this sagess svand's.
Then you pow thand.

PARTUS:
HTRY say me dearscous sad; ied, forathem wous.

MESN RY:
Fit your abtiexs slnge meads notlat thou see makiln thys his tips

FORE:
Cof berk titer otat ruse that ponthourstoru d:
gourveadoth mout king cond
besty hourinstses soveprimaststread tise my thowseveve with apawh to gughll
