In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from tinygrad import Tensor, nn, TinyJit
import numpy as np
import random
import math

In [3]:
def build_dataset():
  data = []
  for i in range(100):
    for j in range(100):
      s = i + j
      data.append([i // 10, i % 10, j // 10, j % 10, s // 100, (s // 10) % 10, s % 10])
  random.shuffle(data)
  data = Tensor(data)
  X_train = data[:8000, :-1]
  Y_train = data[:8000, 1:]
  X_test = data[8000:, :-1]
  Y_test = data[8000:, 1:]
  return X_train, Y_train, X_test, Y_test

In [4]:
X_train, Y_train, X_test, Y_test = build_dataset()

X_train, Y_train, X_test, Y_test

(<Tensor <LB METAL (8000, 6) int ShapeTracker(views=(View(shape=(8000, 6), strides=(7, 1), offset=0, mask=None, contiguous=False),))> on METAL with grad None>,
 <Tensor <LB METAL (8000, 6) int ShapeTracker(views=(View(shape=(8000, 6), strides=(7, 1), offset=1, mask=None, contiguous=False),))> on METAL with grad None>,
 <Tensor <LB METAL (2000, 6) int ShapeTracker(views=(View(shape=(2000, 6), strides=(7, 1), offset=56000, mask=None, contiguous=False),))> on METAL with grad None>,
 <Tensor <LB METAL (2000, 6) int ShapeTracker(views=(View(shape=(2000, 6), strides=(7, 1), offset=56001, mask=None, contiguous=False),))> on METAL with grad None>)

In [5]:
class Attention:
  def __init__(self, n_heads, embed_size) -> None:
    self.n_heads = n_heads
    self.internal_embed_size = embed_size * 2
    bound = 1 / math.sqrt(self.internal_embed_size)
    self.keys = Tensor.uniform(
      n_heads, embed_size, self.internal_embed_size, low=-bound, high=bound
    )
    self.queries = Tensor.uniform(
      n_heads, embed_size, self.internal_embed_size, low=-bound, high=bound
    )
    self.values = Tensor.uniform(
      n_heads, embed_size, self.internal_embed_size, low=-bound, high=bound
    )
    self.linear = Tensor.uniform(
      n_heads * self.internal_embed_size, embed_size, low=-bound, high=bound
    )

  def __call__(self, x: Tensor) -> Tensor:
    B, T, C = x.shape

    x = x.unsqueeze(1).expand((B, self.n_heads, T, C))

    K = x @ self.keys
    Q = x @ self.queries
    V = x @ self.values

    # dot_attn = Q @ K.transpose(-2, -1)
    # scaled_dot_attn = dot_attn / math.sqrt(self.attn_embed_size)
    # masked_scaled_dot_attn = scaled_dot_attn#.tril().where(scaled_dot_attn, float("-inf"))
    # attn_scores = masked_scaled_dot_attn.softmax()

    # ret = attn_scores @ V
    # return ret

    ret = Tensor.scaled_dot_product_attention(
      K, Q, V, attn_mask=Tensor.ones((T, T)).tril()
    )
    ret = ret.reshape((B, T, self.n_heads * self.internal_embed_size))
    ret = ret @ self.linear
    ret = ret.gelu()

    return ret


class Transformer:
  def __init__(self, vocab_size, embed_size, n_layers, n_heads) -> None:
    self.token_embed = nn.Embedding(vocab_size, embed_size)
    self.h = [Attention(n_heads, embed_size) for _ in range(n_layers)]
    self.linear = nn.Linear(embed_size, vocab_size)

  def __call__(self, x: Tensor, y: Tensor = None) -> Tensor:
    logits = x.sequential([self.token_embed, *self.h, self.linear])
    logits = logits[:, -1 if y is None else Tensor.arange(y.shape[1]), :]
    loss = y is None or logits.sparse_categorical_crossentropy(y)
    return logits, loss

In [6]:
model = Transformer(vocab_size=10, embed_size=128, n_layers=2, n_heads=2)
sum(p.numel() for p in nn.state.get_parameters(model))

526858

In [7]:
optim = nn.optim.AdamW(nn.state.get_parameters(model))
batch_size = 128


@TinyJit
@Tensor.train()
def train_step():
  optim.zero_grad()
  samples = Tensor.randint(batch_size, high=X_train.shape[0])
  X_samples, Y_samples = X_train[samples], Y_train[samples]
  _, loss = model(X_samples, Y_samples)
  loss.backward()
  optim.step()
  return loss

In [8]:
for step in range(1, 3001):
  loss = train_step()
  if step == 1 or step % 250 == 0:
    with Tensor.inference_mode():
      Y_hat, _ = model(X_test)
      acc = (Y_hat.argmax(axis=-1) == Y_test[:, -1]).mean().item()
      print(f"step {step}, loss {loss.item():.2f}, acc {acc*100.:.2f}%")

step 1, loss 2.30, acc 10.45%
step 250, loss 0.66, acc 18.20%
step 500, loss 0.44, acc 33.05%
step 750, loss 0.22, acc 61.00%
step 1000, loss 0.06, acc 95.20%
step 1250, loss 0.03, acc 97.30%
step 1500, loss 0.01, acc 98.55%
step 1750, loss 0.02, acc 98.40%
step 2000, loss 0.02, acc 97.35%
step 2250, loss 0.00, acc 99.40%
step 2500, loss 0.01, acc 98.85%
step 2750, loss 0.02, acc 98.65%
step 3000, loss 0.02, acc 99.35%


In [9]:
predictions, _ = model(X_test)
true_labels = Y_test[:, -1]
incorrect_mask = predictions.argmax(axis=-1) != true_labels
incorrect_indices = np.where(incorrect_mask.numpy())[0].tolist()

for index in incorrect_indices:
  print(
    f"Example: {X_test[index].tolist()}, Prediction: {predictions[index].argmax().item()}, True: {true_labels[index].item()}"
  )

Example: [8, 9, 1, 0, 0, 9], Prediction: 8, True: 9
Example: [2, 4, 0, 3, 0, 2], Prediction: 8, True: 7
Example: [2, 8, 0, 4, 0, 3], Prediction: 0, True: 2
Example: [6, 7, 0, 8, 0, 7], Prediction: 4, True: 5
Example: [0, 6, 0, 0, 0, 0], Prediction: 2, True: 6
Example: [2, 0, 0, 6, 0, 2], Prediction: 2, True: 6
Example: [1, 0, 8, 9, 0, 9], Prediction: 8, True: 9
Example: [4, 8, 0, 8, 0, 5], Prediction: 2, True: 6
Example: [3, 8, 3, 8, 0, 7], Prediction: 5, True: 6
Example: [8, 5, 0, 5, 0, 9], Prediction: 3, True: 0
Example: [5, 2, 0, 6, 0, 5], Prediction: 7, True: 8
Example: [5, 7, 0, 6, 0, 6], Prediction: 2, True: 3
Example: [4, 6, 0, 3, 0, 4], Prediction: 0, True: 9
