In [1]:
# https://arxiv.org/pdf/2410.01201
# Were RNNs All We Needed?
%pylab inline
from tinygrad import Tensor, fetch, nn, TinyJit
from typing import Tuple
from tqdm import trange
base = "https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/"
X_train = Tensor(fetch(base+"tiny_shakespeare_train.bin"))[0x400:].bitcast('uint16').to(None)
X_test = Tensor(fetch(base+"tiny_shakespeare_val.bin"))[0x400:].bitcast('uint16').to(None)
import tiktoken
enc = tiktoken.get_encoding("gpt2")
print(X_train.max().item(), enc.decode(X_train[0:10].numpy()))

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib
50256 Under the canopy.

<|endoftext|>Third Servingman:


In [2]:
def g(x:Tensor) -> Tensor:
  return (x>=0).detach().where(x+0.5, x.sigmoid())
def log_g(x:Tensor) -> Tensor:
  return (x>=0).detach().where((x.relu()+0.5).log(), (-x).softplus())

def logcumsumexp(x, dim): return x.log().cumsum(dim).exp()
def parallel_scan_log(log_coeffs, log_values):
  # log_coeffs: (batch_size, seq_len, input_size)
  # log_values: (batch_size, seq_len+1, input_size)
  a_star = log_coeffs.cumsum(-1).pad2d((0,0,1,0))
  # NOTE: typo in paper here
  log_h0_plus_b_star = logcumsumexp(log_values - a_star, dim=1)
  log_h = a_star + log_h0_plus_b_star
  return log_h.exp()

# TODO: we need shape types (einsum like) for Tensors
class MinGRU:
  def __init__(self, input_size, hidden_size):
    self.linear_z = nn.Linear(input_size, hidden_size)
    self.linear_h = nn.Linear(input_size, hidden_size)

  def single(self, x_t:Tensor, h_prev:Tensor) -> Tensor:
    # x_t: (batch_size, input_size)
    # h_prev: (batch_size, hidden_size)
    z_t = self.linear_z(x_t).sigmoid()
    h_tilde = g(self.linear_h(x_t))
    h_t = (1-z_t)*h_prev + z_t*h_tilde
    return h_t

  def __call__(self, x:Tensor, h_0:Tensor):
    # x: (batch_size, seq_len, input_size)
    # h_0: (batch_size, 1, hidden_size)
    k = self.linear_z(x)
    log_z = -(-k).softplus()
    log_coeffs = -k.softplus()
    log_h_0 = log_g(h_0)
    log_tilde_h = log_g(self.linear_h(x))
    h = parallel_scan_log(log_coeffs, Tensor.cat(log_h_0, log_z + log_tilde_h, dim=1))
    return h

tokens = 50257
hidden_size = 384
class Model:
  def __init__(self):
    self.embedding = nn.Embedding(tokens, 384)
    self.gru = MinGRU(384, hidden_size)
    self.mlp1 = nn.Linear(384, 384)
    self.mlp2 = nn.Linear(384, tokens)
    self.reset()

  def reset(self): self.h_prev = Tensor.zeros(1, hidden_size, requires_grad=False)
  def get_logits(self, x:Tensor) -> Tuple[Tensor, Tensor]:
    # TODO: fix bug in embedding shape
    h_t = self.gru.single(self.embedding(x)[:, 0], self.h_prev)
    logits = self.mlp2(self.mlp1(self.h_prev).relu())
    return logits, h_t
      
  def single(self, x:Tensor) -> Tensor:
    logits, self.h_prev = self.get_logits(x)
    return logits[0].exp().softmax(-1).multinomial()
      
  def __call__(self, x:Tensor):
    h_prev = Tensor.zeros(x.shape[0], 1, hidden_size)
    pc = self.gru(self.embedding(x), h_prev)[:, 1:]
    return self.mlp2(self.mlp1(pc).relu())

# TODO: it seems like there's an issue with the learning rate and a mean somewhere, this is 100x bigger than the paper
model = Model()
optim = nn.optim.AdamW(nn.state.get_parameters(model), lr=1e-1)
losses = []

In [None]:
batch_size, seq_length = 32, 200

@TinyJit
@Tensor.train()
def train_step() -> Tensor:
  sel = Tensor.randint((batch_size,), low=0, high=X_train.shape[0]-seq_length)
  X = X_train[sel.reshape(-1, 1)+Tensor.arange(seq_length+1).reshape(1, -1)]
  optim.zero_grad()
  loss = model(X[:, :-1]).sparse_categorical_crossentropy(X[:, 1:]).backward()
  optim.step()
  return loss, sel[:3]

for i in (t:=trange(100)):
  loss, sel = train_step()
  losses.append(loss.item())
  t.set_description(f"loss: {losses[-1]:.4f} {sel.tolist()}")

loss: 6.5902 [13650, 166037, 220854]:  85%|███████████████████████████████████████████████████▊         | 85/100 [02:08<00:22,  1.47s/it]

In [None]:
model.reset()
arr = enc.encode("hello")
for i in range(10): arr.append(model.single(Tensor([arr[-1]])).item())
print(arr, enc.decode(arr))

In [None]:
plot(losses)