In [1]:
from tinygrad import Tensor, nn, TinyJit
import numpy as np

In [2]:
corpus = open("shakespeare.txt").read()

In [3]:
print(len(corpus), "chars")
print(corpus[:1000])

1115394 chars
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst 

In [4]:
vocab = sorted(list(set(corpus)))
vocab_size = len(vocab)
vocab_size, "".join(vocab)

(65, "\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")

In [5]:
new_line_char = "\n"

In [6]:
encode = lambda s: [vocab.index(c) for c in s]
decode = lambda l: "".join([vocab[i] for i in l])

decode(encode("hello"))

'hello'

In [7]:
data = Tensor(encode(corpus))
split = int(0.9 * len(data))
train_data = data[:split]
test_data = data[split:]

In [8]:
def get_batch(data: Tensor, batch_size, block_size):
  indices = Tensor.randint((batch_size,), high=len(data) - block_size).reshape(
    (batch_size, 1)
  ) + Tensor.arange(block_size)
  return data[indices], data[indices + 1]

In [9]:
x, y = get_batch(train_data, batch_size=4, block_size=8)
x.shape, y.shape

((4, 8), (4, 8))

In [10]:
print([decode(row) for row in x.numpy()])
print([decode(row) for row in y.numpy()])

['ther boo', 'thy nobl', "'ll swea", 'ht in ju']
['her book', 'hy noble', 'll swear', 't in jus']


In [11]:
class Bigram:
  def __init__(self, vocab_size: int):
    assert vocab_size >= 1
    self.vocab_size = vocab_size
    self.embed = nn.Embedding(vocab_size, vocab_size)

  def __call__(self, x: Tensor) -> Tensor:
    assert len(x.shape) == 1
    return self.embed(x.reshape((-1, 1))).squeeze(1)

  def loss(self, logits: Tensor, y: Tensor) -> Tensor:
    assert (
      len(logits.shape) == 2
      and len(y.shape) == 1
      and logits.shape[0] == y.shape[0]
      and logits.shape[1] == self.vocab_size
    )
    return logits.sparse_categorical_crossentropy(y)

  def generate(self, x: Tensor, max_len=50):
    with Tensor.inference_mode():
      for _ in range(max_len):
        prev_x = x[-1:]
        p = self(prev_x).squeeze().softmax().numpy()
        next_x = Tensor([np.random.choice(vocab_size, p=p)])
        x = x.cat(next_x)
    return x

In [12]:
bigram = Bigram(vocab_size)

In [13]:
optim = nn.optim.AdamW(nn.state.get_parameters(bigram))
batch_size = 128


@TinyJit
@Tensor.train()
def train_step():
  optim.zero_grad()
  x_samples, y_samples = get_batch(train_data, batch_size, block_size=1)
  loss = bigram.loss(bigram(x_samples.reshape(-1)), y_samples.reshape(-1)).backward()
  optim.step()
  return loss

In [14]:
losses = []
for step in range(1, 10001):
  loss = train_step().item()
  losses.append(loss)
  if step == 1 or step % 1000 == 0:
    with Tensor.inference_mode():
      acc = (bigram(test_data[:-1]).argmax(axis=1) == test_data[1:]).mean().item()
      print(f"step {step}, loss {loss:.2f}, acc {acc*100.:.2f}%")

step 1, loss 4.19, acc 1.53%
step 1000, loss 3.40, acc 24.57%
step 2000, loss 3.03, acc 26.38%
step 3000, loss 2.70, acc 26.54%
step 4000, loss 2.72, acc 27.20%
step 5000, loss 2.52, acc 27.05%
step 6000, loss 2.62, acc 27.08%
step 7000, loss 2.54, acc 27.06%
step 8000, loss 2.62, acc 27.06%
step 9000, loss 2.41, acc 27.01%
step 10000, loss 2.32, acc 27.01%


In [18]:
print(
  decode(bigram.generate(Tensor([vocab.index(new_line_char)]), max_len=500).numpy())
)


Thy d hatlises acearJUn weanegD aXHenossbA wareh ias:
IZigre thy ber ugm askem m'simucjceres agery, antsaco avor.
MELYo;
Fr thaveELI waplat ke buthefost pd acoghes heithind,

che f Ye canco groues! aien stirsou ts tevenend themeresoru? Gupgsor:
AUKE hat, tS:
OLand be INCound, s ke dond to'd.
RIULayend nor seFath cos upp'stheae onve caker?
se t wous t.
HI t iooravathect wo,
Inggicul y BRDYod
I:

qul pitonarreicMan'suaumy nt th, thangar oome rta we I's t teathe cinr desha
Moake, a r. hos;
I iongha


In [16]:
B, T, C = 1000, 100, 200
x = Tensor.randint((B, T, C))


def using_cumsum():
  a = x.cumsum(axis=1)
  b = a / Tensor.arange(1, T + 1).reshape((T, 1))


def using_matmul():
  a = Tensor.ones((T, T)).tril() @ x
  b = a / Tensor.arange(1, T + 1).reshape((T, 1))


def using_softmax():
  a = Tensor.ones((T, T)).tril().where(0, float("-inf")).softmax()
  b = a @ x

In [17]:
import timeit

(
  np.mean(timeit.repeat(using_cumsum, repeat=10000, number=1)) * 1000,
  np.mean(timeit.repeat(using_matmul, repeat=10000, number=1)) * 1000,
  np.mean(timeit.repeat(using_softmax, repeat=10000, number=1)) * 1000,
)

(np.float64(0.38381759660478565),
 np.float64(0.4886868394010889),
 np.float64(0.44910617659606944))