In [1]:
from tinygrad import Tensor, nn, TinyJit, Device
import numpy as np

In [2]:
Device.DEFAULT = "GPU"

In [3]:
corpus = open("shakespeare.txt").read()

In [4]:
print(len(corpus), "chars")
print(corpus[:1000])

1115394 chars
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst 

In [5]:
vocab = sorted(list(set(corpus)))
vocab_size = len(vocab)
vocab_size, "".join(vocab)

(65, "\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")

In [6]:
new_line_char = "\n"

In [7]:
encode = lambda s: [vocab.index(c) for c in s]
decode = lambda l: "".join([vocab[i] for i in l])

decode(encode("hello"))

'hello'

In [8]:
data = Tensor(encode(corpus))
split = int(0.9 * len(data))
train_data = data[:split]
test_data = data[split:]

In [9]:
def get_batch(data: Tensor, batch_size, block_size):
  indices = Tensor.randint((batch_size,), high=len(data) - block_size).reshape(
    (batch_size, 1)
  ) + Tensor.arange(block_size)
  return data[indices], data[indices + 1]

In [10]:
x, y = get_batch(train_data, batch_size=4, block_size=8)
x.shape, y.shape

((4, 8), (4, 8))

In [11]:
print([decode(row) for row in x.numpy()])
print([decode(row) for row in y.numpy()])

['ather ha', 'e in arm', 'if mysel', 't seeing']
['ther had', ' in arms', 'f myself', ' seeing,']


In [12]:
class Attention:
  def __init__(self, embed_size: int, head_size: int) -> None:
    self.head_size = head_size
    self.query = nn.Linear(embed_size, self.head_size, bias=False)
    self.key = nn.Linear(embed_size, self.head_size, bias=False)
    self.value = nn.Linear(embed_size, self.head_size, bias=False)

  def __call__(self, x: Tensor) -> Tensor:
    B, T, C = x.shape

    Q = self.query(x)
    K = self.key(x)
    dot_attn = Q @ K.transpose(-2, -1)
    scaled_dot_attn: Tensor = dot_attn / (self.head_size**0.5)
    mask = Tensor.ones((T, T), requires_grad=False).tril()
    masked_scaled_dot_attn = scaled_dot_attn.masked_fill(mask == 0, float("-inf"))
    attn_scores = masked_scaled_dot_attn.softmax()

    V = self.value(x)
    attented_embeds = attn_scores @ V
    return attented_embeds

class Transformer:
  def __init__(self, block_size: int, vocab_size: int, embed_size: int) -> None:
    self.block_size = block_size
    self.vocab_size = vocab_size
    self.token_embed = nn.Embedding(vocab_size, embed_size)
    self.attn = Attention(embed_size, embed_size)
    self.lm_head = nn.Linear(embed_size, vocab_size)

  def __call__(self, x: Tensor) -> Tensor:
    logits = x.sequential([self.token_embed, self.attn, self.lm_head])
    return logits

  def loss(self, x: Tensor, y: Tensor) -> Tensor:
    logits = self(x)
    loss = logits.sparse_categorical_crossentropy(y)
    return logits, loss

  def generate(self, x: Tensor, n: int = 500) -> Tensor:
    assert len(x.shape) == 1 and x.shape[0] == self.block_size
    x = x.unsqueeze(0)
    for _ in range(n):
      logits = self(x[:, -self.block_size:])
      p = logits[:, -1, :].softmax().squeeze(0)
      next_token = np.random.choice(self.vocab_size, p=p.numpy())
      x = x.cat(Tensor([[next_token]]), dim=1)
    return x.squeeze(0)

In [13]:
block_size = 256
embed_size = 512
n_heads = 4
head_size = embed_size // n_heads
transformer = Transformer(block_size=block_size, vocab_size=vocab_size, embed_size=embed_size)
sum(p.numel() for p in nn.state.get_parameters(transformer))

853057

In [14]:
optim = nn.optim.AdamW(nn.state.get_parameters(transformer))
batch_size = 128


@TinyJit
@Tensor.train()
def train_step():
  optim.zero_grad()
  x_samples, y_samples = get_batch(train_data, batch_size, block_size)
  _, loss = transformer.loss(x_samples, y_samples)
  loss.backward()
  optim.step()
  return loss

In [15]:
losses = []
for step in range(1, 3001):
  loss = train_step().item()
  losses.append(loss)
  if step == 1 or step % 500 == 0:
    with Tensor.inference_mode():
      x_samples, y_samples = get_batch(test_data, batch_size, block_size)
      acc = (transformer(x_samples).argmax(axis=-1) == y_samples).mean().item()
      print(f"step {step}, loss {loss:.2f}, acc {acc*100.:.2f}%")

step 1, loss 4.17, acc 9.31%
step 500, loss 2.51, acc 26.10%
step 1000, loss 2.47, acc 26.66%
step 1500, loss 2.47, acc 26.96%
step 2000, loss 2.48, acc 27.00%
step 2500, loss 2.45, acc 26.72%
step 3000, loss 2.43, acc 26.90%


In [16]:
text = decode(transformer.generate(data[:block_size]).numpy())
print("\033[92m" + text[:block_size] + "\033[0m" + text[block_size:])

[92mFirst Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
[0m
Clende no hat all ad ide y d gie,
Dot un's br; iar.
RDoussiro ice.
MEed Er GOUSdein my Lat is beeioderr,
WANoas ke, VO: he nup; y, n Qt IUENI'smoundengat yofilo.
Or aipof trkeshis t yorey wasseeal towir'sanslsorsomas hmalie, t t ou ck ceref gid Vis t be ancedang ak to sallemotondeis gothe's, O:
AN y su than,
S:
GS:
ABELUDRUCht pre t t
Iour t welur mmy?
Weng mf RGRtith o this hounshil ft.
A,
e I wh whor st ary othore s. sin leranan the flen bequlath;
A mano myove thend prdistorusinththou hes g p
