# Resonator Networks

Resonator networks are a recent development in Vector Symbolic Architectures
which allow for efficient factorization of composite symbols into their 
component parts.

In [17]:
from __future__ import annotations
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

from typing import Callable, Dict, List, Any
import math
from functools import reduce

# Vocabulary

A *Vocabulary* is a collection of atomic vector symbols as well as operations
defined over them.

In [267]:
def init_normal(dim: int) -> np.ndarray:
    sd = 1.0 / math.sqrt(dim)
    v = np.random.normal(scale=sd, size=dim)
    v /= np.linalg.norm(v)
    return v


def cconv(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    from numpy.fft import fft, ifft

    return ifft(fft(x) * fft(y)).real


def spose(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    return x + y


class Vocabulary:
    """
    `Vocabulary`

    Codebook of atomic vectors of size `dim` and associated functions.

    Specify a `vector_gen` function to give a probability distribution
    sample, must be a function from `dim -> np.ndarray`

    For other functions, `bind` defaults to circular convolution,
    and `superpose` to element-wise addition.

    For similarity, pass argumen to `sim`; defaults to `np.dot`.

    For any other functions, pass dictionary to `funcs`.
    """

    dim: int
    symbols: Dict[str, np.ndarray]
    bind: Callable[[np.ndarray, np.ndarray], np.ndarray]
    superpose: Callable[[np.ndarray, np.ndarray], np.ndarray]
    funcs: dict[str, Callable[[Any], Any]]
    sim: Callable[[np.ndarray, np.ndarray], float]
    vector_gen: Callable[[int], np.ndarray] = init_normal

    def __init__(
        self,
        /,
        dim: int,
        symbols: Dict[str, np.ndarray] | str,
        vector_gen: Callable[[int], np.ndarray] = init_normal,
        bind: Callable[[np.ndarray, np.ndarray], np.ndarray] = cconv,
        superpose: Callable[[np.ndarray, np.ndarray], np.ndarray] = spose,
        sim: Callable[[np.ndarray, np.ndarray], float] = np.dot,
        funcs: Dict[str, Callable[[Any], Any]] = {},
    ) -> None:
        self.dim = dim
        self.bind = bind
        self.superpose = superpose
        self.funcs = funcs
        self.sim = sim
        self.vector_gen = vector_gen

        for key, value in self.funcs.items():
            self.key = value


        if isinstance(symbols, dict):
            self.symbols = symbols
        elif isinstance(symbols, str):
            symbol_names = symbols.split(sep=";")
            symbols = {}
            for name in symbol_names:
                symbols[name] = vector_gen(dim)
            self.symbols = symbols
        else:
            raise TypeError(symbols)

    def __getitem__(self, key: str) -> np.ndarray:
        return self.symbols[key]

    def __setitem__(self, key: str, item: np.ndarray) -> None:
        self.symbols[key] = item

In [268]:
dim = 10
symbols = ["blue", "red", "ball"]
test_vocab = Vocabulary(dim=dim, symbols=";".join(symbols))
print(list(test_vocab.symbols.keys()))

['blue', 'red', 'ball']


In [269]:
blue_ball = test_vocab.bind(test_vocab["blue"], test_vocab["ball"])
blue_ball

array([-0.65537026, -0.38525277, -0.37674104, -0.55765922, -0.48475549,
       -0.51773565, -0.39827577, -0.89296592, -0.46270542, -0.48700147])

In [21]:
test_vocab.sim(test_vocab["blue"], test_vocab["blue"])

np.float64(0.9999999999999998)

# Multiply-Add-Permute Codes

Here, we implement Multiply-Add-Permute Codes (Gayler 1998) as a vocabulary.

In [251]:
def map_gen(dim: int) -> np.ndarray:
    """
    `map_gen`

    MAP Code vector generator samples random vectors with elements
    in set [-1, 1].
    """

    op = np.vectorize(lambda _: np.random.choice(np.array([1, -1])))
    v = np.zeros(dim)
    return op(v)


map_bind = np.multiply


def _perm(x: np.ndarray) -> np.ndarray:
    v = np.zeros(x.size)
    for i in range(v.size):
        v[i] = x[i - 1 % v.size]
    return v


def map_perm(x: np.ndarray, ntimes: int) -> np.ndarray:
    res = x
    for _ in range(ntimes):
        res = _perm(res)

    return res


def map_sim(x: np.ndarray, y: np.ndarray) -> float:
    # return abs((x @ y) / x.size)
    # return abs(x / np.linalg.norm(x)) @ (y / np.linalg.norm(y))
    return np.dot(x, y.T) / x.size


dim = 600
symbols = ["a", "b", "c", "d", "e", "f", "g", "left", "right"]
map = Vocabulary(
    dim=dim,
    symbols=";".join(symbols),
    vector_gen=map_gen,
    bind=map_bind,
    sim=map_sim,
    funcs={"perm": map_perm},
)

In [252]:
map.sim(map["a"], map["a"])

np.float64(1.0)

In [253]:
map.funcs["perm"](map["a"], 1)

array([-1.,  1.,  1.,  1.,  1., -1.,  1.,  1.,  1.,  1., -1.,  1.,  1.,
        1.,  1., -1.,  1.,  1., -1.,  1.,  1., -1.,  1.,  1., -1., -1.,
       -1., -1., -1.,  1.,  1.,  1., -1., -1.,  1., -1.,  1., -1.,  1.,
       -1., -1.,  1., -1.,  1.,  1., -1.,  1.,  1.,  1., -1., -1., -1.,
       -1., -1., -1.,  1., -1.,  1.,  1., -1.,  1.,  1., -1., -1., -1.,
        1., -1., -1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1., -1., -1.,
       -1., -1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,  1.,  1., -1.,
       -1., -1., -1., -1.,  1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,
        1., -1.,  1.,  1.,  1.,  1., -1., -1.,  1.,  1.,  1., -1.,  1.,
        1.,  1., -1.,  1., -1., -1., -1., -1.,  1.,  1.,  1., -1.,  1.,
       -1., -1., -1.,  1.,  1., -1.,  1., -1., -1.,  1.,  1., -1., -1.,
        1., -1.,  1.,  1.,  1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1., -1.,  1.,  1.,  1., -1.,  1., -1., -1.,  1.,  1., -1.,
       -1., -1., -1., -1.,  1., -1.,  1., -1., -1., -1., -1.,  1

# Tree Representation

Here, we are going to reproduce the tree factoring task that was used for 
expounding on Resonator networks.

In [254]:
tree = map.superpose(
    map.bind(
        map["a"],
        map.bind(
            map["left"],
            map.bind(
                map.funcs["perm"](map["left"], 1),
                map.funcs["perm"](map["left"], 2),
            ),
        ),
    ),
    map.superpose(
        map.bind(
            map["b"],
            map.bind(
                map["left"],
                map.bind(
                    map.funcs["perm"](map["right"], 1),
                    map.funcs["perm"](map["left"], 2),
                ),
            ),
        ),
        map.superpose(
            map.bind(
                map["c"],
                map.bind(
                    map["right"],
                    map.bind(
                        map.funcs["perm"](map["right"], 1),
                        map.funcs["perm"](map["left"], 2),
                    ),
                ),
            ),
            map.superpose(
                map.bind(
                    map["d"],
                    map.bind(
                        map["right"],
                        map.bind(
                            map.funcs["perm"](map["right"], 1),
                            map.bind(
                                map.funcs["perm"](map["right"], 2),
                                map.funcs["perm"](map["left"], 3),
                            ),
                        ),
                    ),
                ),
                map.superpose(
                    map.bind(
                        map["e"],
                        map.bind(
                            map["right"],
                            map.bind(
                                map.funcs["perm"](map["right"], 1),
                                map.bind(
                                    map.funcs["perm"](map["right"], 2),
                                    map.funcs["perm"](map["right"], 3),
                                ),
                            ),
                        ),
                    ),
                    map.superpose(
                        map.bind(
                            map["f"],
                            map.bind(
                                map["left"],
                                map.bind(
                                    map.funcs["perm"](map["right"], 1),
                                    map.bind(
                                        map.funcs["perm"](map["right"], 2),
                                        map.bind(
                                            map.funcs["perm"](map["left"], 3),
                                            map.funcs["perm"](map["left"], 4),
                                        ),
                                    ),
                                ),
                            ),
                        ),
                        map.bind(
                            map["g"],
                            map.bind(
                                map["left"],
                                map.bind(
                                    map.funcs["perm"](map["right"], 1),
                                    map.bind(
                                        map.funcs["perm"](map["right"], 2),
                                        map.bind(
                                            map.funcs["perm"](map["left"], 3),
                                            map.funcs["perm"](map["right"], 4),
                                        ),
                                    ),
                                ),
                            ),
                        ),
                    ),
                ),
            ),
        ),
    ),
)

In [255]:
b_p_noise = map.bind(
    tree,
    map.bind(
        map["left"],
        map.bind(
            map.funcs["perm"](map["right"], 1),
            map.funcs["perm"](map["left"], 2),
        ),
    ),
)
b_p_noise

array([-7., -1.,  1.,  5.,  5., -1., -3.,  1.,  3., -1., -1.,  5., -3.,
        1., -1.,  5.,  3., -1., -3., -1., -1., -3.,  1.,  3.,  1., -1.,
        3., -1.,  3.,  1., -3., -5., -1.,  1.,  1., -5.,  1.,  3., -1.,
       -1., -3., -7.,  3., -1., -1.,  1.,  1.,  3.,  3.,  3.,  1.,  1.,
        1., -5., -5.,  1., -3., -1., -3.,  7., -1., -1., -1., -1., -1.,
        5., -5.,  1., -1.,  7.,  3., -5., -5., -1.,  5.,  1.,  3.,  1.,
        3., -3.,  3.,  3., -1., -1.,  1., -1., -1.,  3.,  3.,  3.,  1.,
       -1., -1.,  5.,  3.,  5., -3., -1., -1., -1., -3.,  5.,  1., -3.,
        3.,  1., -3.,  3., -3., -1.,  3., -1., -1., -1., -3., -1., -3.,
       -1.,  5.,  5.,  5., -1.,  1., -1.,  1.,  1., -1., -3.,  1., -5.,
        1., -1.,  1., -3., -3.,  3.,  5.,  5.,  3.,  1.,  1., -1.,  3.,
        1., -3.,  1., -1.,  3.,  3., -3., -1.,  3., -1.,  5.,  3., -1.,
        3., -3.,  1., -3.,  1., -5., -1.,  1., -3.,  1.,  1.,  1., -3.,
        3., -3.,  1.,  5.,  3., -1., -3., -1., -1.,  1., -1.,  1

In [256]:
map["b"]

array([-1, -1,  1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1,
       -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,  1,
       -1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1,  1,  1,
        1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1,
        1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1, -1, -1,
       -1, -1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1,  1, -1, -1,
       -1, -1, -1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1,  1,  1,
        1,  1, -1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,  1,
        1, -1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1,  1,
        1, -1,  1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1,  1,  1, -1,  1,
       -1,  1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,
       -1, -1, -1, -1,  1,  1,  1,  1,  1,  1,  1,  1, -1,  1, -1, -1,  1,
       -1,  1,  1,  1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1,  1,
       -1,  1,  1,  1,  1

In [257]:
map.sim(map["b"], b_p_noise)

np.float64(1.06)

# The Resonator Network

Given the tree example, the motivation for this problem not figuring out, given
some location within the tree, what the symbol is at that location.

Rather, the question is: given some label in the tree, give me directions
to find that label.

For example
$$
\text{tree} \odot c = \text{right} \odot \rho (\text{right}) \odot 
    \rho^2 (\text{right}) + \textit{noise}
$$
But how do we efficiently compare the factors on the right hand side without
exhaustively enumerating all traversals of the tree and compute similarity?

To do this, we need to establish first a maximum depth of the search, and this
determines the number of factors to be estimated.

For the tree above, we need at most $5$ factors, as this is as deep as the
deepest node. Each factor determines whether to go $\text{left}$, 
$\text{right}$, or $\text{stop}$. To indicate $\text{stop}$, we use $I$.

## A More Formal Definition

Let us have codebooks $X = [x_1, x_2, \ldots, x_D]$, 
$Y = [y_1, y_2, \ldots, y_D]$, and $Z = [z_1, z_2, \ldots, z_D]$. Given some
$s := x_{i^*} \odot y_{j^*} \odot z_{k^*}$, and codebooks $X, Y$ and $Z$,
the goal is to find $x_{i^*}, y_{j^*}, z_{k^*}$.

Let $\hat x, \hat y$, and $\hat z$ represent the estimate for each factor.
These vectors can be initialized to the superposition of all possible factors,
$$
\begin{align*}
\hat x (0) &= \sum^D_i x_i,\\
\hat y (0) &= \sum^D_j y_j, \\ 
\hat z (0) &= \sum^D_r z_r
\end{align*}
$$

A factor can be inferred on the basis of the other two; for example,
$$
\hat z (1) = s \odot \hat x (0) \odot \hat y (0)
$$
Because the binding of $\hat x (0) \odot \hat y (0)$ is the superposition
of all possible codes in the codebook, it represents, for example if $D = 100$,
then $D^2 = 10,000$.

The results of inference can be improved using clean-up memory which helps
reduce noise, and the process is built on cross-talk. We can, through
iterative application, arrive at good enough estimates.
$$
\begin{align*}
\hat x (t + 1) &= g (X X^T (s \odot \hat y (t) \odot \hat z (t))), \\
\hat y (t + 1) &= g (Y Y^T (s \odot \hat x (t) \odot \hat z (t))), \\
\hat z (t + 1) &= g (Z Z^T (s \odot \hat x (t) \odot \hat y (t)))
\end{align*}
$$
where $g$ is a function preventing run-away feedback, holding the values
of each vector at $\pm 1$.

The clean-up memory for $\hat x$ which is the matrix multiplication $XX^T$
with threshold function $g$, then this operation is equivalent to outer-product
Hebbian learning (Hopfield, 1982); except here, rather than directing feeding
the network back to itself, the result of the clean-up is sent to other 
parts of the network.


In [263]:
def resonator_tree(
    label: np.ndarray,
    tree: np.ndarray,
    max_depth: int,
    base_codebook: np.ndarray,
    iterations: int = 20,
) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
    """
    `resonator_tree`

    Given a label symbol `label`, tree `tree`, and maximum depth `max_depth`,
    find the location of that symbol in the tree.

    The process proceeds by `iterations` number of times, defaulting to `20`.
    """
    # codebook matrixes for each factor
    codebooks = []
    for i in range(max_depth):
        codebook_i = []
        for code in base_codebook:
            codebook_i.append(map.funcs["perm"](code, i))
        codebooks.append(np.array(codebook_i).T)

    # initialize each factor as the sum of all possible combinations
    factors = []
    for codebook in codebooks:
        init_est = np.sum(codebook, axis=1)
        factors.append(init_est)

    def g(vector: np.ndarray) -> np.ndarray:
        """
        `g`

        Stop runaway feedback of vector by holding value at `+/- 1`.
        """
        tmp = np.sign(vector)
        tmp[tmp == 0] = 1
        return tmp

    # run the neural network
    runs = [factors]
    s = map.bind(tree, label)
    for i in range(iterations):
        factors = runs[i]
        update_factors = []

        for j, factor in enumerate(factors):
            codebook = codebooks[j]
            other_factors = np.array(
                [ofact for ofact in factors if not np.array_equal(ofact, factor)]
            )
            fbinds = reduce(
                lambda x, y: map.bind(x, y),
                other_factors,
                np.ones(map.dim),
            )
            fs = map.bind(s, fbinds)
            cleanup = np.dot(codebook, codebook.T)
            # print(cleanup.shape)
            # fst = np.dot(codebook.T, fs)
            res = g(np.dot(cleanup, fs))
            update_factors.append(res)

        # Add them to the running similarities of iterations, and update factors
        runs.append(update_factors)

    return runs, codebooks

In [264]:
# first we add to our map codes a vector for `stop`
iterations = 20
map["stop"] = np.ones(map.dim)

# then we define the resonator network
label = map["c"]
max_depth = 5
base_codebook = np.array([map["left"], map["right"], map["stop"]])
runs, codebooks = resonator_tree(
    iterations=iterations,
    label=label,
    max_depth=5,
    tree=tree,
    base_codebook=base_codebook,
)

# Plotting Results

Here well plot the first, tenth, and final iteration of the resonator network
to show how the network arrives at the result.

In [265]:
import pandas as pd

In [266]:
similarities = {}
for i in range(len(runs[0])):
    similarities[f"left{i}"] = []
    similarities[f"right{i}"] = []
    similarities[f"stop{i}"] = []

for run in runs:
    for i, factor in enumerate(run):
        sims = []
        codebook = codebooks[i].T
        left, right, stop = tuple(codebook)
        similarities[f"left{i}"].append(map.sim(factor, left))
        similarities[f"right{i}"].append(map.sim(factor, right))
        similarities[f"stop{i}"].append(map.sim(factor, stop))

# print(np.dot(codebook, runs[19][2]) / codebook.size)

df = pd.DataFrame(similarities)
df

Unnamed: 0,left0,right0,stop0,left1,right1,stop1,left2,right2,stop2,left3,right3,stop3,left4,right4,stop4
0,1.053333,1.066667,1.026667,1.053333,1.066667,1.026667,1.053333,1.066667,1.026667,1.053333,1.066667,1.026667,1.053333,1.066667,1.026667
1,-0.49,0.463333,-0.516667,-0.49,0.463333,-0.516667,0.463333,-0.49,-0.53,-0.53,-0.516667,0.463333,-0.53,-0.516667,0.463333
2,0.516667,0.53,0.49,-0.463333,0.49,0.53,1.0,0.046667,0.006667,1.0,0.046667,0.006667,-1.0,-0.046667,-0.006667
3,-0.006667,-0.02,-1.0,-0.006667,-0.02,-1.0,0.046667,1.0,0.02,0.046667,1.0,0.02,0.006667,0.02,1.0
4,-1.0,-0.046667,-0.006667,-1.0,-0.046667,-0.006667,1.0,0.046667,0.006667,0.006667,0.02,1.0,-0.046667,-1.0,-0.02
5,0.006667,0.02,1.0,0.046667,1.0,0.02,1.0,0.046667,0.006667,0.006667,0.02,1.0,1.0,0.046667,0.006667
6,1.0,0.046667,0.006667,0.006667,0.02,1.0,-0.006667,-0.02,-1.0,-1.0,-0.046667,-0.006667,0.006667,0.02,1.0
7,-1.0,-0.046667,-0.006667,-0.006667,-0.02,-1.0,-0.463333,0.49,0.53,0.49,-0.463333,0.516667,-0.006667,-0.02,-1.0
8,-0.463333,0.49,0.53,0.49,-0.463333,0.516667,0.006667,0.02,1.0,1.0,0.046667,0.006667,-1.0,-0.046667,-0.006667
9,0.516667,0.53,0.49,-0.046667,-1.0,-0.02,0.516667,0.53,0.49,0.463333,-0.49,-0.53,-0.516667,-0.53,-0.49


It looks like, as expected, we're getting Hopfield like dynamics. But, the correct answer here is,
$$
\text{tree} \odot \text{c} = \text{right} \odot \rho (\text{right}) \odot \rho^2 (\text{left}) + \textit{noise}
$$

In [262]:
# IT WORKS

SyntaxError: invalid syntax (3287335677.py, line 1)