# Named Distributions API

In this notebook we'll give some example uses of the named distribution api
designed for easier querying and construction of complicated A and B tensors.

The distribution objects allow for giving semantically sensible names to axes
and indices within a tensor. These can be made interactively in code or an 
entire set of A and B tensors can be compiled from a structured model
description.

Below is an example of how to build a distribution from code for a model
conisting of a single observation modality "observation" consiting of the
possible observations {A, B, C, D}. A hidden state "state" consisting of the
values {A, B, C, D} and controls "control" {up, down}.

In [2]:
import numpy as np
import jax
from jax import numpy as jnp

from pymdp.jax import Distribution
from pymdp.jax.agent import Agent

np.set_printoptions(precision=2, suppress=True)


def get_task_info():
    policies = jnp.expand_dims(jnp.array([[0, 0, 0, 0], [1, 1, 1, 1]]), -1)
    C = jnp.zeros((1, 4))
    C = C.at[0, 3].set(1.0)
    action = jnp.array([[1]])
    qs = [jnp.zeros((1, 1, 4))]
    qs[0] = qs[0].at[0, 0, 0].set(1.0)
    observation = [jnp.array([[0]])]
    return policies, C, action, qs, observation

observations = ["A", "B", "C", "D"]
states = ["A", "B", "C", "D"]
controls = ["up", "down"]

data = np.zeros((len(observations), len(states)))
A = Distribution(data, {"observations": observations}, {"states": states})

A["A", "A"] = 1.0 
A["B", "B"] = 1.0
A["C", "C"] = 1.0
A["D", "D"] = 1.0


## Transition
# Similarily we can use the distributions to build a 
data = np.zeros((len(states), len(states), len(controls)))
B = Distribution(data, {"states": states}, {"states": states, "controls": controls})

B["B", "A", "up"] = 1.0
B["C", "B", "up"] = 1.0
B["D", "C", "up"] = 1.0
B["D", "D", "up"] = 1.0

B["A", "A", "down"] = 1.0
B["A", "B", "down"] = 1.0
B["B", "C", "down"] = 1.0
B["C", "D", "down"] = 1.0

policies = jnp.expand_dims(jnp.array([[0, 0, 0, 0], [1, 1, 1, 1]]), -1)

C = jnp.zeros((1, 4))
C = C.at[0, 0].set(1.0)

agent = Agent([A], [B], [C], policies=policies)

action = jnp.array([[0]])
qs = [jnp.zeros((1, 1, 4))]
qs[0] = qs[0].at[0, 0, 0].set(1.0)

observation = [jnp.array([[0]])]
prior, _ = agent.infer_empirical_prior(action, qs)

agent = Agent([A], [B], [C], policies=policies)
prior, _ = agent.infer_empirical_prior(action, qs)
qs = agent.infer_states(observation, None, prior, None)

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)

[Array([[[0., 0., 0., 1.]]], dtype=float32)]
[[0. 1.]]
[[1]]


### Using configs
Alternatively you can use a model description to just generate the shape of the
A's and the B's in one go. 

In [13]:
from pymdp.jax import distribution

model = {
    "observations": {
        "o1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"]},
    },
    "controls": {"c1": {"elements": ["up", "down"]}},
    "states": {
        "s1": {"elements": ["A", "B", "C", "D"], "depends_on_states": ["s1"], "depends_on_control": ["c1"]},
    },
}

As, Bs = distribution.compile_model(model)
print(Bs[0].data.shape)

As[0]["A", "A"] = 1.0
As[0]["B", "B"] = 1.0
As[0]["C", "C"] = 1.0
As[0]["D", "D"] = 1.0

Bs[0]["B", "A", "up"] = 1.0
Bs[0]["C", "B", "up"] = 1.0
Bs[0]["D", "C", "up"] = 1.0
Bs[0]["D", "D", "up"] = 1.0

Bs[0]["A", "A", "down"] = 1.0
Bs[0]["A", "B", "down"] = 1.0
Bs[0]["B", "C", "down"] = 1.0
Bs[0]["C", "D", "down"] = 1.0

policies, Cs, action, qs, observation = get_task_info()

agent = Agent(As, Bs, Cs, policies=policies)
prior, _ = agent.infer_empirical_prior(action, qs)
qs = agent.infer_states(observation, None, prior, None)

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
print(action)

(4, 4)


IndexError: list index out of range