In [1]:
from typing import Optional, Tuple

from tinygrad import Tensor
import tinygrad.nn as nn

from ntm.models import LSTM

In [None]:
hidden_size = 256
input_size = 8
max_len = 4
min_len = 1
model = LSTM(input_size=input_size, hidden_size=hidden_size, n_layers=3)

def forward_pass(x: Tensor):
  y_pred = x.zeros_like().contiguous()
  hc = None
  # Input the sequence
  for x_i in x:
    _, hc = model(x_i, hc)
  # Gather the predictions
  y_pred = []
  for _ in range(x.shape[0]):
    x_i = x_i.zeros_like()
    y_i, hc = model(x_i, hc)
    y_pred.append(y_i)
  y_pred = Tensor.stack(*y_pred, dim=0)
  return y_pred

@Tensor.train()
def training_step(x):
  y_pred = forward_pass(x)
  loss = y_pred.binary_crossentropy_logits(x)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  return loss

optimizer = nn.optim.Adam(nn.state.get_parameters(model))

for _ in range(100):
  seq_len = Tensor.randint(low=min_len, high=max_len+1).numpy().item()
  x = Tensor.randint(seq_len, 1, input_size, low=0, high=2) # 8-bit sequence of 10 elements
  loss = training_step(x)
  print(loss.numpy())

0.68392295
0.6876003
0.6962129
0.6852082
0.6890796
0.6991979
0.69248724
0.6829794
0.69878805
0.68077886
0.69751924
0.6972185
0.6910206
0.70567983
0.6732774
0.70693517
0.70131236
0.687511
0.6993559
0.6903388
0.68158615
0.69182384
0.68551916
0.69145334
0.6950525
0.7000911
0.6828847
0.6899609
0.68365616
0.68913096
0.6958917
0.6879776
0.6880743
0.6897537
0.68254286
0.68647
0.69978476
0.6813303
0.6847129
0.70226675
0.6972097
0.6889583
0.69685644
0.65376073
0.6944856
0.6808703
0.6753951
0.69685644
0.68312067
0.73156047
0.6668258
0.69805276
0.69371414
0.71088326
0.6806292
0.6451657
0.7244389
0.6711885
0.7146112
0.68803746
0.6943719
0.6816721
0.7030916
0.6801763
0.6832842
0.67175996
0.6853652
0.68910766
0.69515663
0.690769
0.6956691
0.69255424
0.67198247
0.68541086
0.69701576
0.69230837
0.677212
0.680216
0.67325497
0.67846775
0.6574165
0.688899
0.70163625
0.6710381
0.6791639
0.70744157
0.63206637
0.6843507
0.6692172
0.68245864
0.62461305
0.69439375
0.67607635
0.6828108
0.6079571
