In [1]:
#| default_exp context

# Context

In [2]:
#| export

from jax import numpy as jnp
from simple_pytree import Pytree

from jaxcmr.math import normalize_magnitude
from jaxcmr.typing import Array, Float, Float_


class TemporalContext(Pytree):
    """Temporal context representation for memory search models.

    Attributes:
        state: the current state of the context.
        initial_state: the initial state of the context.
        next_outlist_unit: the index of the next out-of-list context unit.
        outlist_contexts: the out-of-list context representations.
    """

    def __init__(self, item_count: int, size: int):
        """Create a new temporal context model.

        Args:
            item_count: the number of items in the context model.
            size: the size of the context representation.
        """
        self.size = size
        self.zeros = jnp.zeros(size)
        self.state = self.zeros.at[0].set(1)
        self.initial_state = self.zeros.at[0].set(1)
        self.next_outlist_unit = item_count + 1

    def integrate(
        self,
        context_input: Float[Array, " context_feature_units"],
        drift_rate: Float_,
    ) -> "TemporalContext":
        """Returns context after integrating input representation, preserving unit length.

        Args:
            context_input: the input representation to be integrated into the contextual state.
            drift_rate: The drift rate parameter.
        """
        context_input = normalize_magnitude(context_input)
        rho = jnp.sqrt(
            1 + jnp.square(drift_rate) * (jnp.square(self.state * context_input) - 1)
        ) - (drift_rate * (self.state * context_input))
        return self.replace(
            state=normalize_magnitude((rho * self.state) + (drift_rate * context_input))
        )

    @classmethod
    def init(cls, item_count: int) -> "TemporalContext":
        """Initialize a new context model.

        Args:
            item_count: the number of items in the context model.
        """
        return cls(item_count, item_count + 1)


## Tests

Test that TemporalContext initializes correctly and integrates input.

In [3]:
# Set up the test parameters
drift_rate = 0.3
item_count = 10
size = item_count + 2
context = TemporalContext(item_count, size)

# initial state should be 1.0 at the first element, and 0.0 elsewhere
assert context.state[0] == 1.0
assert jnp.all(context.state[1:] == 0.0)

context_input = jnp.zeros(size).at[-1].set(1)
new_context = context.integrate(context_input, drift_rate)

# last element is now non-zero; rest are still 0.0, except for the first element
assert new_context.state[-1] > 0.0 
assert jnp.all(new_context.state[1:-1] == 0.0)
assert new_context.state[0] > 0.0

# final state vector is unit length
assert jnp.isclose(jnp.linalg.norm(new_context.state), 1.0, atol=1e-6)

# test that the initial state is preserved
assert jnp.all(new_context.initial_state == context.state)