Usable implementation of Emerging Symbol Binding Network (ESBN), in Pytorch. They propose to have the main recurrent neural network interact with the input image representations only through a set of memory key / values.
The input image representation are cast as memory values, and are explicitly bound to memory keys that are generated by the network. The network generates the memory keys after getting a sum of all previous memory keys weighted by the similarity of the incoming representation to the set of memory values in storage.
This decoupling / indirection of sensory to abstract processing allows the network to outperform all previous approaches, including transformers.
import torch
from esbn_pytorch import ESBN
model = ESBN(
value_dim = 64,
key_dim = 64,
hidden_dim = 512,
output_dim = 4
)
images = torch.randn(3, 2, 3, 32, 32) # (n, b, c, h, w)
model(images) # (3, 2, 4) # (n, b, o)
@misc{webb2020emergent,
title={Emergent Symbols through Binding in External Memory},
author={Taylor W. Webb and Ishan Sinha and Jonathan D. Cohen},
year={2020},
eprint={2012.14601},
archivePrefix={arXiv},
primaryClass={cs.AI}
}