# Map-based Memory Validation

The idea of a location map is to remember sensory data. For our simple navigator, sensory data is like smell: our trees distributed through the environment diffuse scents indicating their identity. At a given point, the sense of smell detects a combination of data from nearby trees.

## A simple chemical sensor

The 'scent' $z$ of a tree should be detected proprotionately to the distance from a tree. I model the decay of the scent via a diffusion process. The model for diffusion (the physical process, not the algorithm) is based on the heat equation $u(x,t)$ is $$\frac{1}{2} \frac{\partial u}{\partial t} = \sum_i \frac{\partial^2 u}{\partial x_i^2} = \Delta u,$$ where $\Delta u$ is called the _Laplacian_. The heat equation is solved by
$$u(x, t) = \frac{c}{\sqrt{t^d}} \exp\left(-\frac{1}{2t} \|x - \mu\|^2\right),$$ 
where $d$ is the dimension of the space, $c$ is any constant, and $\mu$ is the intial point where diffusion starts.

Here $u(x,t)$ is the diffused quantity. For clarity, $x$ is the location of the sensor, and $\mu$ is the location of the source tree. At $u(x, 0) = c$, this is a point mass at $\mu$. Our interest is to choose a virtual time $t$ such that controls the spread of the tree's sense such that (a) trees can be sensed from a distance, and (b) trees are still distinct [note that $u(x, \infty) = 0$]. For simplicity, we set $c = \sqrt{t^d}$ so that $u(\mu, t) = 1$ for all $t$; that is, the sensor outputs the vector with all ones at the source. 

To choose $t$, we set a maximum distance $M$ and a target sensory value $a$ such that the diffused scent has magnitude $a$ at distance $M$. That is, $u(x, t_M) = a$ when $\|x - \mu\| = M$, which reduces to $$u(x, t) = \exp\left(-\frac{M^2}{2t}\right) = a \quad\quad\implies\quad\quad t_M = \frac{M^2}{2 \log({1}/{a})}.$$

Finally, given a set of trees at locations $\{\ell_i\}$ with scent embeddings $z_i$, a sensor at position $x$ will read the sum of the diffused scents
$$
z = \sum_i z_i \exp\left(-\frac{1}{2t_M} \| x - \ell_i\|^2\right) = KZ
$$
for the kernel matrix $K_ij = u_{i}(x_j, t_M)$ and scent matrix $Z$ with $z_i$ as the $i^{th}$ row.

This sensor is implemented by the following function, with some wrapping to handle batch sizes and the possibility of multiple sensors.

In [1]:
import torch

def sense(self, source_locations: torch.Tensor, source_embeddings: torch.Tensor, sensor_positions: torch.Tensor, 
          M: float, a: float=0.1) -> torch.Tensor:
    # we expect source_locations to be a >2D tensor of shape (..., num_sources, dim)
    # we expect sensor_positions to be a >2D tensor of shape (..., num_sensors, dim) OR a 1D tensor of shape (..., dim)
    # we expect source_embeddings to be a >2D tensor of shape (..., num_sources, embed_dim)

    # special case for a single sensor
    single_sensor = sensor_positions.ndim < source_locations.ndim
    if single_sensor:
        sensor_positions = sensor_positions[..., None, :]
    
    assert source_locations.ndim == sensor_positions.ndim == source_embeddings.ndim
   
    # compute the virtual time
    t_M = M**2 / 2 / math.log(1/a)

    # compute the distance between the sensor and the source trees
    # This will yield a tensor of shape (..., num_sensors, num_sources)
    distances = torch.cdist(sensor_positions, source_locations)

    # now compute the kernel for each sensor-source pair as shape (..., num_sensors, num_sources)
    kernel = torch.exp(-0.5 * (distances).pow(2) / t_M)

    # now compute the embedding for each sensor as shape (..., num_sensors, embed_dim)
    if kernel.ndim == 2:
        embedding = torch.mm(kernel, source_embeddings)
    else:
        embedding = torch.bmm(kernel, source_embeddings)

    if single_sensor:
        embedding = embedding.squeeze(-2)

    return embedding

## A Memory to Remember Senses by Location

The purpose of the memory is to remember what would be sensed in a given location that was visited. In a continuous setting, we will never visit the exact same location twice, so we do not want a memory _per se_. Instead, we want an interpolator that can predict the expected sensory value well.

Suppose, then that our agent has visited a sequence of locations $\ell_t$, observing $z_t$ at each step, yielding a sequence of pairs $\{(\ell_t, z_t)\}$. Given a new location $\ell$, we want to estimate $\hat{z} = f(\ell)$ provided that for all $t$, $z_t \approx f(\ell_t)$. But _this is just a regression_! Our "memory" is not really a memory; it is a regression model trained from the dataset of visited points.

Our memory, then is a regression function $f$ trained on the visited points. However, we need a model that can be rapidly trained, because the memory needs to be immediately available from timestep to timestep. As a first approach, we can simply interpolate with an attention kernel.

In our case, our location estimates $\ell_t$ are generated by the agent and come with error, which we model as a Gaussian with diagonal covariance matrix (_i.e._, independent variation in each location dimension). Thus to each $\ell_t \in \mathbb{R}^d$ we associate a vector of deviations $\sigma_t \in \mathbb{R}^d$, and we want to regress $\hat{z} = f(\ell, \sigma)$. We can compute a location affinity kernel $k(\ell, \ell_t)$ between the inputs $\ell$ and $\sigma$ based on the $\sigma$-scaled distance as 
$$
\log k(\ell, \ell_t) = \quad-\frac{1}{2}\left\|\frac{\ell - \ell_t}{\sqrt{\sigma^2 + \sigma_t^2}}\right\|^2 
\quad-\sum_i \log | \sigma_i^2 + \sigma_{t,i}^2 | 
\quad-\frac{d}{2}\log 2\pi
$$
where logs make the relationships easier to see. Vector division is componentwise, and $k(\ell, \ell_t)$ is just the density function of a Gaussian $\mathcal{N}\left(\ell_t, \textrm{diag}\left(\sigma^2 + \sigma_t^2\right)\right)$ -- the variance combines the measurement error on both $\ell$ and $\ell_t$ and represents the variance of $\ell + \ell_t$.

Next, we can take a softmax over $\log k$ to get a set of affinity weights $w_t$ that will weight our dataset examples according to their closeness to the query point $\ell$, accounting for measurement error:
$$
w_t = \textrm{softmax} \left(\log k(\ell, \ell_t)\right) = \frac{k(\ell, \ell_t)}{\sum_s k(\ell, \ell_s)}
$$

From here, we can regress directly on the dataset to obtain the sensor estimate $\hat{z}$ by
$$
\hat{z} = \sum_t w_t z_t,
$$
which estimates the sensor output as a weighted average of the past sensor values.

In [None]:
import math

def read_memory(self, query_location: torch.Tensor, query_deviation: torch.Tensor, 
                memory_locations: torch.Tensor, memory_deviation: torch.Tensor, 
                memory_values: torch.Tensor) -> torch.Tensor:
    # we expect query_location to be a tensor of shape (..., num_queries, dim) (but num_queries can be 1 or missing)
    # we expect query_deviation to be a tensor of shape (..., num_queries, dim) (but num_queries can be 1 or missing)
    # we expect memory_locations to be a tensor of shape (..., num_keys, dim)
    # we expect memory_deviation to be a tensor of shape (..., num_keys, dim)
    # we expect memory_values to be a tensor of shape (..., num_keys, embed_dim)

    single_query = query_location.ndim < memory_locations.ndim
    if single_query:
        query_locations = query_location[..., None, :]
        query_deviations = query_deviation[..., None, :]
    
    assert query_locations.ndim == query_deviations.ndim == memory_locations.ndim == memory_deviation.ndim == memory_values.ndim

    # compute the combined variance, which has shape (..., num_queries, num_keys)
    variance = query_deviation**2 + memory_deviation**2
    log_k = (
        - 0.5 * ((query_location - memory_locations).pow(2) / variance).sum(dim=-1) 
        - torch.log(variance).sum(dim=-1)
        - 0.5 * math.log(2 * math.pi) * variance.shape[-1]
    )

    # the location affinity weights have shape (..., num_queries, num_keys)
    w = torch.softmax(log_k, dim=-1)

    hat_z = torch.bmm(w, memory_values)

    if single_query:
        hat_z = hat_z.squeeze(-2)

    return hat_z

Now, you might notice that this kernel looks very similar to dot product attention, and then you might ask whether we could recast it to make use of efficient tools for handling long-context attention, such as flash attention. The answer is that _you could_, but you would be changing the topology of the location space in so doing, and you would have to work that change all the way through the math. We might do that later. For now, the clarity of keeping our space as $\mathbb{R}^d$ is preferable.