In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from tinygrad import Tensor, nn, TinyJit
import numpy as np

In [3]:
def build_dataset():
  data = []
  for i in range(100):
    for j in range(100):
      s = i + j
      data.append([i // 10, i % 10, j // 10, j % 10, s // 100, (s // 10) % 10, s % 10])
  np.random.shuffle(data)
  data = Tensor(data)
  X_train = data[:8000, :-1]
  Y_train = data[:8000, 1:]
  X_test = data[8000:, :-1]
  Y_test = data[8000:, 1:]
  return X_train, Y_train, X_test, Y_test

In [4]:
X_train, Y_train, X_test, Y_test = build_dataset()

X_train, Y_train, X_test, Y_test

(<Tensor <LB METAL (8000, 6) int ShapeTracker(views=(View(shape=(8000, 6), strides=(7, 1), offset=0, mask=None, contiguous=False),))> on METAL with grad None>,
 <Tensor <LB METAL (8000, 6) int ShapeTracker(views=(View(shape=(8000, 6), strides=(7, 1), offset=1, mask=None, contiguous=False),))> on METAL with grad None>,
 <Tensor <LB METAL (2000, 6) int ShapeTracker(views=(View(shape=(2000, 6), strides=(7, 1), offset=56000, mask=None, contiguous=False),))> on METAL with grad None>,
 <Tensor <LB METAL (2000, 6) int ShapeTracker(views=(View(shape=(2000, 6), strides=(7, 1), offset=56001, mask=None, contiguous=False),))> on METAL with grad None>)

In [5]:
class Attention:
  def __init__(self, embed_size, n_heads, head_size) -> None:
    self.n_heads = n_heads
    self.head_size = head_size
    bound = 1 / (self.head_size**0.5)
    self.queries = Tensor.uniform(
      n_heads, embed_size, self.head_size, low=-bound, high=bound
    )
    self.keys = Tensor.uniform(
      n_heads, embed_size, self.head_size, low=-bound, high=bound
    )
    self.values = Tensor.uniform(
      n_heads, embed_size, self.head_size, low=-bound, high=bound
    )

  def __call__(self, x: Tensor) -> Tensor:
    B, T, C = x.shape

    x = x.unsqueeze(1).expand((B, self.n_heads, T, C))

    Q = x @ self.queries  # (B, n_heads, T, head_size)
    K = x @ self.keys  # (B, n_heads, T, head_size)
    dot_attn = Q @ K.transpose(-2, -1)  # (B, n_heads, T, T)
    scaled_dot_attn: Tensor = dot_attn / (self.head_size**0.5)  # (B, n_heads, T, T)
    mask = Tensor.ones((T, T), requires_grad=False).tril()  # (T, T)
    masked_scaled_dot_attn = scaled_dot_attn.masked_fill(mask == 0, float("-inf"))  # noqa: F401, (B, n_heads, T, T)
    attn_scores = masked_scaled_dot_attn.softmax()  # (B, n_heads, T, T)

    V = x @ self.values  # (B, n_heads, T, head_size)
    attented_embeds = attn_scores @ V  # (B, n_heads, T, head_size)
    concatenated_embeds = attented_embeds.reshape((B, T, self.n_heads * self.head_size))  # noqa: F401, (B, T, n_heads * head_size)
    return concatenated_embeds


class TransformerBlock:
  def __init__(self, embed_size: int, n_heads: int, head_size: int) -> None:
    self.attn = Attention(embed_size, n_heads, head_size)
    self.out_proj = nn.Linear(n_heads * head_size, embed_size)

  def __call__(self, x: Tensor) -> Tensor:
    return x.sequential([self.attn, self.out_proj, Tensor.gelu])


class Transformer:
  def __init__(self, vocab_size, embed_size, n_layers, n_heads, head_size) -> None:
    self.token_embed = nn.Embedding(vocab_size, embed_size)
    self.h = [TransformerBlock(embed_size, n_heads, head_size) for _ in range(n_layers)]
    self.lm_head = nn.Linear(embed_size, vocab_size)

  def forward(self, x: Tensor) -> Tensor:
    logits = x.sequential([self.token_embed, *self.h, self.lm_head])
    return logits

  def loss(self, x: Tensor, y: Tensor) -> Tensor:
    logits = self.forward(x)
    loss = logits.sparse_categorical_crossentropy(y)
    return logits, loss

  def __call__(self, x: Tensor) -> Tensor:
    logits = self.forward(x)
    return logits[:, -1, :]

In [6]:
embed_size = 512
n_heads = 4
head_size = embed_size // n_heads
model = Transformer(
  vocab_size=10, embed_size=embed_size, n_layers=2, n_heads=n_heads, head_size=head_size
)
sum(p.numel() for p in nn.state.get_parameters(model))

2108426

In [7]:
optim = nn.optim.AdamW(nn.state.get_parameters(model))
batch_size = 128


@TinyJit
@Tensor.train()
def train_step():
  optim.zero_grad()
  samples = Tensor.randint(batch_size, high=X_train.shape[0])
  X_samples, Y_samples = X_train[samples], Y_train[samples]
  _, loss = model.loss(X_samples, Y_samples)
  loss.backward()
  optim.step()
  return loss

In [8]:
for step in range(1, 1001):
  loss = train_step()
  if step == 1 or step % 250 == 0:
    with Tensor.inference_mode():
      acc = (model(X_test).argmax(axis=-1) == Y_test[:, -1]).mean().item()
      print(f"step {step}, loss {loss.item():.2f}, acc {acc*100.:.2f}%")

step 1, loss 2.30, acc 9.85%
step 250, loss 0.01, acc 98.70%
step 500, loss 0.02, acc 99.50%
step 750, loss 0.00, acc 100.00%
step 1000, loss 0.00, acc 100.00%


In [10]:
predictions = model(X_test).argmax(axis=-1)
true_labels = Y_test[:, -1]
incorrect_mask = predictions != true_labels
incorrect_indices = np.where(incorrect_mask.numpy())[0].tolist()

for index in incorrect_indices:
  print(
    f"Example: {X_test[index].tolist()}, Prediction: {predictions[index].item()}, True: {true_labels[index].item()}"
  )

Example: [0, 3, 0, 3, 0, 0], Prediction: 3, True: 6
Example: [0, 1, 0, 9, 0, 1], Prediction: 9, True: 0
Example: [6, 0, 6, 0, 1, 2], Prediction: 6, True: 0
Example: [0, 1, 0, 1, 0, 0], Prediction: 1, True: 2
Example: [0, 5, 0, 0, 0, 0], Prediction: 0, True: 5
