# Named Distributions API

In this notebook we'll give some example uses of the named distribution api
designed for easier querying and construction of complicated A and B tensors.

The distribution objects allow for giving semantically sensible names to axes
and indices within a tensor. These can be made interactively in code or an 
entire set of A and B tensors can be compiled from a structured model
description.

Below is an example of how to build a distribution from code for a model
conisting of a single observation modality "observation" consiting of the
possible observations {A, B, C, D}. A hidden state "state" consisting of the
values {A, B, C, D} and controls "control" {up, down}.

In [1]:
import numpy as np
from jax import numpy as jnp

np.set_printoptions(precision=2, suppress=True)

from pymdp.jax.agent import Agent
from pymdp.jax.distribution import Distribution, compile_model

action = jnp.array([1])
action = jnp.broadcast_to(action, (1, 1))

observation = jnp.array([0])
observation = jnp.broadcast_to(observation, (1, 1))

qs_init = jnp.array([1.0, 0.0, 0.0, 0.0])
qs_init = jnp.broadcast_to(qs_init, (1, 1, 4))

policies = jnp.array([[0, 0, 0, 0], [1, 1, 1, 1]])
policies = jnp.expand_dims(policies, -1)

In [2]:
observations = ["A", "B", "C", "D"]
states = ["A", "B", "C", "D"]
controls = ["up", "down"]

data = np.zeros((len(observations), len(states)))
A = Distribution({"observations": observations}, {"states": states}, data)

A["A", "A"] = 1.0 
A["B", "B"] = 1.0
A["C", "C"] = 1.0
A["D", "D"] = 1.0


data = np.zeros((len(states), len(states), len(controls)))
B = Distribution({"states": states}, {"states": states, "controls": controls}, data)

B["B", "A", "up"] = 1.0
B["C", "B", "up"] = 1.0
B["D", "C", "up"] = 1.0
B["D", "D", "up"] = 1.0

B["A", "A", "down"] = 1.0
B["A", "B", "down"] = 1.0
B["B", "C", "down"] = 1.0
B["C", "D", "down"] = 1.0

C = jnp.array([0.0, 0.0, 0.0, 1.0])

agent = Agent([A], [B], [C], policies=policies, apply_batch=True)
print(f"goal state: {states[jnp.argmax(C)]}")

prior, _ = agent.infer_empirical_prior(action, [qs_init])
qs = agent.infer_states([observation], None, prior, None)
print(f"initial state: {states[jnp.argmax(qs[0])]}")

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
print(f"action taken: {controls[action[0][0]]}")

goal state: D
initial state: A
action taken: up


### Using configs
Alternatively you can use a model description to just generate the shape of the
A's and the B's in one go. 

In [3]:
model = {
    "observations": {
        "o1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"]},
    },
    "controls": {"c1": {"elements": ["up", "down"]}},
    "states": {
        "s1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"], "controlled_by": ["c1"]},
    },
}

As, Bs = compile_model(model)

As[0]["A", "A"] = 1.0
As[0]["B", "B"] = 1.0
As[0]["C", "C"] = 1.0
As[0]["D", "D"] = 1.0

Bs[0]["B", "A", "up"] = 1.0
Bs[0]["C", "B", "up"] = 1.0
Bs[0]["D", "C", "up"] = 1.0
Bs[0]["D", "D", "up"] = 1.0

Bs[0]["A", "A", "down"] = 1.0
Bs[0]["A", "B", "down"] = 1.0
Bs[0]["B", "C", "down"] = 1.0
Bs[0]["C", "D", "down"] = 1.0

Cs = [jnp.array([0.0, 0.0, 0.0, 1.0])]
agent = Agent(As, Bs, Cs, policies=policies, apply_batch=True)
print(f"goal state: {states[jnp.argmax(Cs[0])]}")

prior, _ = agent.infer_empirical_prior(action, [qs_init])
qs = agent.infer_states([observation], None, prior, None)
print(f"initial state: {states[jnp.argmax(qs[0])]}")

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
print(f"action taken: {controls[action[0][0]]}")

goal state: D
initial state: A
action taken: up


In [4]:
model = {
    "observations": {
        "temperature": {"elements": ["low", "medium", "high", "very high"], "depends_on": ["operating_state"]},
        "humidity": {"elements": ["low", "medium", "high", "very high"], "depends_on": ["maintenance_state"]},
        "pressure": {"elements": ["low", "medium", "high", "very high"], "depends_on": ["power_state"]},
        "vibration": {
            "elements": ["none", "low", "medium", "high"],
            "depends_on": ["operating_state", "maintenance_state"],
        },
    },
    "controls": {
        "temperature_control": {"elements": ["off", "low", "medium", "high"]},
        "humidity_control": {"elements": ["off", "low", "medium", "high"]},
        "pressure_control": {"elements": ["off", "low", "medium", "high"]},
    },
    "states": {
        "operating_state": {
            "elements": ["idle", "running", "overload"],
            "depends_on": ["operating_state"],
            "controlled_by": ["temperature_control"],
        },
        "maintenance_state": {
            "elements": ["regular", "alert", "critical"],
            "depends_on": ["maintenance_state"],
            "controlled_by": ["humidity_control"],
        },
        "power_state": {
            "elements": ["low", "normal", "high"],
            "depends_on": ["power_state"],
            "controlled_by": ["pressure_control"],
        },
    },
}

As, Bs = compile_model(model)

As[0]["low", "idle"] = 1.0
As[0]["medium", "running"] = 1.0
As[0]["low", "overload"] = 1.0

As[1]["low", "regular"] = 1.0
As[1]["low", "alert"] = 1.0
As[1]["high", "critical"] = 1.0

As[2]["low", "low"] = 1.0
As[2]["medium", "low"] = 1.0
As[2]["high", "high"] = 1.0

Bs[0]["running", "idle", "low"] = 1.0
Bs[0]["overload", "running", "medium"] = 1.0
Bs[0]["overload", "overload", "high"] = 1.0

Bs[0]["idle", "idle", "off"] = 1.0
Bs[0]["idle", "running", "off"] = 1.0
Bs[0]["running", "overload", "off"] = 1.0
Bs[0]["running", "running", "off"] = 1.0

Bs[1]["alert", "regular", "low"] = 1.0
Bs[1]["critical", "alert", "medium"] = 1.0
Bs[1]["critical", "critical", "high"] = 1.0

Bs[1]["regular", "regular", "off"] = 1.0
Bs[1]["regular", "alert", "off"] = 1.0
Bs[1]["alert", "critical", "off"] = 1.0
Bs[1]["alert", "alert", "off"] = 1.0

Bs[2]["normal", "low", "low"] = 1.0
Bs[2]["high", "normal", "medium"] = 1.0
Bs[2]["high", "high", "high"] = 1.0

Bs[2]["low", "low", "off"] = 1.0
Bs[2]["low", "normal", "off"] = 1.0
Bs[2]["normal", "high", "off"] = 1.0
Bs[2]["normal", "normal", "off"] = 1.0

agent = Agent(As, Bs, apply_batch=True)