In [2]:
import torch
from tqdm import tqdm as notebook_tqdm
from tiny_recursive_model import TinyRecursiveModel, MLPMixer1D, Trainer

trm = TinyRecursiveModel(
    dim = 16,
    num_tokens = 256,
    network = MLPMixer1D(
        dim = 16,
        depth = 2,
        seq_len = 256
    ),
)

# mock dataset

from torch.utils.data import Dataset
class MockDataset(Dataset):
    def __len__(self):
        return 16

    def __getitem__(self, idx):
        inp = torch.randint(0, 256, (256,))
        out = torch.randint(0, 256, (256,))
        return inp, out

mock_dataset = MockDataset()

# trainer

trainer = Trainer(
    trm,
    mock_dataset,
    epochs = 1,
    batch_size = 16,
    cpu = True
)

trainer()

# inference

pred_answer, exit_indices = trm.predict(
    torch.randint(0, 256, (1, 256)),
    max_deep_refinement_steps = 12,
    halt_prob_thres = 0.1
)

# save to collection of specialized networks for tool call

torch.save(trm.state_dict(), 'saved-trm.pt')


[1 (1 / 12)] loss: 5.711 | halt loss: 0.744
complete
