In [1]:
from tinygrad import Tensor, nn, TinyJit, Device
import numpy as np

In [2]:
# import os
# os.environ['JIT'] = '2'
Device.DEFAULT = "GPU"

In [3]:
corpus = open("shakespeare.txt").read()

In [4]:
print(len(corpus), "chars")
print(corpus[:1000])

1115394 chars
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst 

In [5]:
vocab = sorted(list(set(corpus)))
vocab_size = len(vocab)
vocab_size, "".join(vocab)

(65, "\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")

In [6]:
new_line_char = "\n"

In [7]:
encode = lambda s: [vocab.index(c) for c in s]
decode = lambda l: "".join([vocab[i] for i in l])

decode(encode("hello"))

'hello'

In [8]:
data = Tensor(encode(corpus))
split = int(0.9 * len(data))
train_data = data[:split]
test_data = data[split:]

In [9]:
def get_batch(data: Tensor, batch_size, block_size):
  indices = Tensor.randint((batch_size,), high=len(data) - block_size).reshape(
    (batch_size, 1)
  ) + Tensor.arange(block_size)
  return data[indices], data[indices + 1]

In [10]:
x, y = get_batch(train_data, batch_size=4, block_size=8)
x.shape, y.shape

((4, 8), (4, 8))

In [11]:
print([decode(row) for row in x.numpy()])
print([decode(row) for row in y.numpy()])

[' Bolingb', 'ks. What', 'IUS:\nI t', ' house,\n']
['Bolingbr', "s. What'", 'US:\nI te', 'house,\nA']


In [12]:
import math
class CausalSelfAttention:
  def __init__(self, n_embd: int, n_head:int, block_size:int):
    assert n_embd % n_head == 0
    # key, query, value projections for all heads, but in a batch
    self.c_attn = nn.Linear(n_embd, 3 * n_embd)
    # output projection
    self.c_proj = nn.Linear(n_embd, n_embd)
    # regularization
    self.n_head = n_head
    self.n_embd = n_embd
    # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
    self.bias = Tensor.ones(1, 1, block_size, block_size).tril()
    self.bias.requires_grad = False

  def __call__(self, x:Tensor):
    B, T, C = x.shape
    qkv = self.c_attn(x)
    q, k, v = qkv.split(self.n_embd, dim=2)
    k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
    q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
    v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

    # manual implementation of attention
    att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
    att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
    att = att.softmax()
    y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    y = y.transpose(1, 2).view(B, T, C) # re-assemble all head outputs side by side
    # output projection
    y = self.c_proj(y)
    return y

In [13]:
class Attention:
  def __init__(self, embed_size: int, n_heads: int, head_size: int) -> None:
    self.n_heads = n_heads
    self.head_size = head_size
    bound = 1 / (self.head_size**0.5)
    self.queries = Tensor.uniform(
      n_heads, embed_size, self.head_size, low=-bound, high=bound
    )
    self.keys = Tensor.uniform(
      n_heads, embed_size, self.head_size, low=-bound, high=bound
    )
    self.values = Tensor.uniform(
      n_heads, embed_size, self.head_size, low=-bound, high=bound
    )
    self.proj = nn.Linear(n_heads * head_size, embed_size)

  def __call__(self, x: Tensor) -> Tensor:
    B, T, C = x.shape

    x = x.unsqueeze(1).expand((B, self.n_heads, T, C))

    Q = x @ self.queries  # (B, n_heads, T, head_size)
    K = x @ self.keys  # (B, n_heads, T, head_size)
    V = x @ self.values  # (B, n_heads, T, head_size)
    mask = Tensor.ones((T, T), requires_grad=False).tril()
    attented_embeds = Tensor.scaled_dot_product_attention(
      Q, K, V, attn_mask=mask
    )  # (B, n_heads, T, head_size)
    concatenated_embeds = attented_embeds.reshape(
      (B, T, self.n_heads * self.head_size)
    )  # (B, T, n_heads * head_size)
    out = self.proj(concatenated_embeds)  # (B, T, C)
    return out

In [14]:
class MLP:
  def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None:
    self.fc1 = nn.Linear(input_size, hidden_size)
    self.fc2 = nn.Linear(hidden_size, output_size)

  def __call__(self, x: Tensor) -> Tensor:
    return x.sequential([self.fc1, Tensor.gelu, self.fc2])


class TransformerBlock:
  def __init__(self, embed_size: int, n_heads: int, head_size: int) -> None:
  # def __init__(self, embed_size: int, n_heads: int, block_size: int) -> None:
    self.ln1 = nn.LayerNorm(embed_size)
    self.attn = Attention(embed_size, n_heads, head_size)
    # self.attn = CausalSelfAttention(embed_size, n_heads, block_size)
    self.ln2 = nn.LayerNorm(embed_size)
    self.mlp = MLP(embed_size, 4 * embed_size, embed_size)

  def __call__(self, x: Tensor) -> Tensor:
    x = x + self.attn(self.ln1(x))
    x = x + self.mlp(self.ln2(x))
    return x


class Transformer:
  def __init__(
    self,
    block_size: int,
    vocab_size: int,
    embed_size: int,
    n_layers: int,
    n_heads: int,
    head_size: int,
  ) -> None:
    self.block_size = block_size
    self.vocab_size = vocab_size
    self.token_embed = nn.Embedding(vocab_size, embed_size)
    self.pos_embed = nn.Embedding(block_size, embed_size)
    self.h = [TransformerBlock(embed_size, n_heads, head_size) for _ in range(n_layers)]
    # self.h = [TransformerBlock(embed_size, n_heads, block_size) for _ in range(n_layers)]
    self.lm_head = nn.Linear(embed_size, vocab_size)

  def __call__(self, x: Tensor) -> Tensor:
    assert len(x.shape) == 2 and x.shape[1] == self.block_size
    B, T = x.shape
    embed = self.token_embed(x) + self.pos_embed(Tensor.arange(T))
    logits = embed.sequential(self.h + [self.lm_head])
    return logits

  def loss(self, x: Tensor, y: Tensor) -> Tensor:
    logits = self(x)
    loss = logits.sparse_categorical_crossentropy(y)
    return logits, loss

  def generate(self, x: Tensor, n: int = 500) -> Tensor:
    assert len(x.shape) == 1 and x.shape[0] == self.block_size
    x = x.unsqueeze(0)
    for _ in range(n):
      logits = self(x[:, -self.block_size :])
      p = logits[:, -1].softmax().squeeze(0)
      next_token = np.random.choice(self.vocab_size, p=p.numpy())
      x = x.cat(Tensor([[next_token]]), dim=1)
    return x.squeeze(0)

In [15]:
block_size = 32  # 16
embed_size = 512
n_heads = 4  # 4
head_size = embed_size // n_heads
transformer = Transformer(
  block_size=block_size,
  vocab_size=vocab_size,
  embed_size=embed_size,
  n_layers=8,  # 4
  n_heads=n_heads,
  head_size=head_size,
)
sum(p.numel() for p in nn.state.get_parameters(transformer))

25289793

In [16]:
optim = nn.optim.AdamW(nn.state.get_parameters(transformer))
batch_size = 16  # 128


@TinyJit
@Tensor.train()
def train_step():
  optim.zero_grad()
  x_samples, y_samples = get_batch(train_data, batch_size, block_size)
  _, loss = transformer.loss(x_samples, y_samples)
  loss.backward()
  optim.step()
  return loss

In [17]:
losses = []
for step in range(1, 1001):
  loss = train_step().item()
  losses.append(loss)
  if step == 1 or step % 100 == 0:
    with Tensor.inference_mode():
      x_samples, y_samples = get_batch(test_data, batch_size, block_size)
      acc = (transformer(x_samples).argmax(axis=-1) == y_samples).mean().item()
      print(f"step {step}, loss {loss:.2f}, acc {acc*100.:.2f}%")

step 1, loss 4.34, acc 13.09%
step 100, loss 0.80, acc 85.35%
step 200, loss 0.13, acc 97.66%
step 300, loss 0.09, acc 98.24%
step 400, loss 0.08, acc 97.66%
step 500, loss 0.10, acc 97.46%
step 600, loss 0.09, acc 98.05%
step 700, loss 0.09, acc 97.27%
step 800, loss 0.10, acc 98.05%
step 900, loss 0.10, acc 97.46%
step 1000, loss 0.13, acc 97.66%


In [18]:
text = decode(transformer.generate(data[:block_size], n=1000).numpy())
print("\033[92m" + text[:block_size] + "\033[0m" + text[block_size:])

[92mFirst Citizen:
Before we proceed[0m,
Ahe .e my heo ? o me s y th, fhste?

Aw hef nare cher hald Ye s ic ded w your P the ro d verthervel cd I dovelf d the hithered d  coer der der hio hhin he kis des sts drow heced deecker be ste sed-l ded ces he
A?OA,ONOOKIO:
Asucl
A the IOcea uithe ke thowwowwh kaed eravetherthe,
WheUo fre I thas st he the s nof se dor certher

WIN:
W chivathe the he pe r ver ghougherald s theathe ce the nond dd,
 he vieele be herd b ath fooy frd fo  no
Wheds.
ARK:AR3OAA I I:
A in wutcer ditht td y herd he hest we
A mevert hu
AncHere ait lfd wr fso-f bkers,
B hin kerbe di s? se ceul ver Bal thir shers cer ther
Thel fid ce pe bey than heowhem tow bt bichcer fove fececthe kerrd cerde of  nove shel Jde
WONGA:
A mumy waow f m ul d w  nce ther  haot detI

YowI
:A s rond ha herle churd bee n in heu arp dr he sheng,
W heace py hhel her I cithe kel
A dnve
An they sis sovaRI:
A d beAHHELLNABO:: thia mer he l ch fo- fke wh.

ANHAABBABBBABOINBBBABBBBBBBBBABBOA:OOABAOBBBEB