# Named Distributions API

In this notebook we'll give some example uses of the named distribution api
designed for easier querying and construction of complicated A and B tensors.

The distribution objects allow for giving semantically sensible names to axes
and indices within a tensor. These can be made interactively in code or an 
entire set of A and B tensors can be compiled from a structured model
description.

Below is an example of how to build a distribution from code for a model
conisting of a single observation modality "observation" consiting of the
possible observations {A, B, C, D}. A hidden state "state" consisting of the
values {A, B, C, D} and controls "control" {up, down}.

In [1]:
import numpy as np
import jax.tree_util as jtu
from jax import numpy as jnp

np.set_printoptions(precision=2, suppress=True)

from pymdp.jax.agent import Agent
from pymdp.jax.distribution import Distribution, compile_model

observations = ["A", "B", "C", "D"]
states = ["A", "B", "C", "D"]
controls = ["up", "down"]

data = np.zeros((len(observations), len(states)))
A = Distribution({"observations": observations}, {"states": states}, data)

A["A", "A"] = 1.0 
A["B", "B"] = 1.0
A["C", "C"] = 1.0
A["D", "D"] = 1.0

data = np.zeros((len(states), len(states), len(controls)))
B = Distribution({"states": states}, {"states": states, "controls": controls}, data)

B["B", "A", "up"] = 1.0
B["C", "B", "up"] = 1.0
B["D", "C", "up"] = 1.0
B["D", "D", "up"] = 1.0

B["A", "A", "down"] = 1.0
B["A", "B", "down"] = 1.0
B["B", "C", "down"] = 1.0
B["C", "D", "down"] = 1.0


C = Distribution({"observations": observations})
C["D"] = 1.0

D = Distribution({"states": states})
D["A"] = 1.0


Now we can use these A,B,C tensors to create an agent, and infer states and actions

In [3]:
agent = Agent([A], [B], [C], [D])
print(f"goal state: {states[jnp.argmax(agent.C[0])]}")

# infer state given action and observation
action = jnp.array([1])
action = jnp.broadcast_to(action, (1, 1))

observation = jnp.array([0])
observation = jnp.broadcast_to(observation, (1, 1))

# qs needs a time dimension for infer_empirical_prior, so expand dims of D
qs_init = jtu.tree_map(lambda x: jnp.expand_dims(x, 0), agent.D)
prior, _ = agent.update_empirical_prior(action, qs_init)
qs = agent.infer_states([observation], prior)
print(f"initial state: {states[jnp.argmax(qs[0])]}")

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
print(f"action taken: {controls[action[0][0]]}")

goal state: D
initial state: A
action taken: up


### Using configs
Alternatively you can use a model description to just generate the shape of the
A's and the B's in one go. 

In [5]:
model_description = {
    "observations": {
        "o1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"]},
    },
    "controls": {"c1": {"elements": ["up", "down"]}},
    "states": {
        "s1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"], "controlled_by": ["c1"]},
    },
}

model = compile_model(model_description)

model.A["o1"]["A", "A"] = 1.0
model.A["o1"]["B", "B"] = 1.0
model.A["o1"]["C", "C"] = 1.0
model.A["o1"]["D", "D"] = 1.0

model.B["s1"]["B", "A", "up"] = 1.0
model.B["s1"]["C", "B", "up"] = 1.0
model.B["s1"]["D", "C", "up"] = 1.0
model.B["s1"]["D", "D", "up"] = 1.0

model.B["s1"]["A", "A", "down"] = 1.0
model.B["s1"]["A", "B", "down"] = 1.0
model.B["s1"]["B", "C", "down"] = 1.0
model.B["s1"]["C", "D", "down"] = 1.0

model.C["o1"]["D"] = 1.0
agent = Agent(**model, apply_batch=True)
print(f"goal state: {states[jnp.argmax(agent.C[0])]}")

prior, _ = agent.update_empirical_prior(action, qs_init)
qs = agent.infer_states([observation], prior)
print(f"initial state: {states[jnp.argmax(qs[0])]}")

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
print(f"action taken: {controls[action[0][0]]}")

goal state: D
initial state: A


action taken: up


In [6]:
model_description = {
    "observations": {
        "temperature": {"elements": ["low", "medium", "high", "very high"], "depends_on": ["operating_state"]},
        "humidity": {"elements": ["low", "medium", "high", "very high"], "depends_on": ["maintenance_state"]},
        "pressure": {"elements": ["low", "medium", "high", "very high"], "depends_on": ["power_state"]},
        "vibration": {
            "elements": ["none", "low", "medium", "high"],
            "depends_on": ["operating_state", "maintenance_state"],
        },
    },
    "controls": {
        "temperature_control": {"elements": ["off", "low", "medium", "high"]},
        "humidity_control": {"elements": ["off", "low", "medium", "high"]},
        "pressure_control": {"elements": ["off", "low", "medium", "high"]},
    },
    "states": {
        "operating_state": {
            "elements": ["idle", "running", "overload"],
            "depends_on": ["operating_state"],
            "controlled_by": ["temperature_control"],
        },
        "maintenance_state": {
            "elements": ["regular", "alert", "critical"],
            "depends_on": ["maintenance_state"],
            "controlled_by": ["humidity_control"],
        },
        "power_state": {
            "elements": ["low", "normal", "high"],
            "depends_on": ["power_state"],
            "controlled_by": ["pressure_control"],
        },
    },
}

model = compile_model(model_description)

model.A["temperature"]["low", "idle"] = 1.0
model.A["temperature"]["medium", "running"] = 1.0
model.A["temperature"]["low", "overload"] = 1.0

model.A["humidity"]["low", "regular"] = 1.0
model.A["humidity"]["low", "alert"] = 1.0
model.A["humidity"]["high", "critical"] = 1.0

model.A["pressure"]["low", "low"] = 1.0
model.A["pressure"]["medium", "low"] = 1.0
model.A["pressure"]["high", "high"] = 1.0

model.A["vibration"]["low", "idle", "regular"] = 1.0
model.A["vibration"]["medium", "running", "regular"] = 1.0
model.A["vibration"]["high", "running", "critical"] = 1.0
model.A["vibration"]["high", "overload", "alert"] = 1.0

model.B["operating_state"]["overload", "running", "medium"] = 1.0
model.B["operating_state"]["overload", "overload", "high"] = 1.0
model.B["operating_state"]["idle", "idle", "off"] = 1.0
model.B["operating_state"]["idle", "running", "off"] = 1.0
model.B["operating_state"]["running", "idle", "low"] = 1.0
model.B["operating_state"]["running", "overload", "off"] = 1.0
model.B["operating_state"]["running", "running", "off"] = 1.0

model.B["maintenance_state"]["alert", "regular", "low"] = 1.0
model.B["maintenance_state"]["alert", "critical", "off"] = 1.0
model.B["maintenance_state"]["alert", "alert", "off"] = 1.0
model.B["maintenance_state"]["critical", "alert", "medium"] = 1.0
model.B["maintenance_state"]["critical", "critical", "high"] = 1.0
model.B["maintenance_state"]["regular", "regular", "off"] = 1.0
model.B["maintenance_state"]["regular", "alert", "off"] = 1.0

model.B["power_state"]["low", "low", "off"] = 1.0
model.B["power_state"]["low", "normal", "off"] = 1.0
model.B["power_state"]["normal", "high", "off"] = 1.0
model.B["power_state"]["normal", "normal", "off"] = 1.0
model.B["power_state"]["normal", "low", "low"] = 1.0
model.B["power_state"]["high", "normal", "medium"] = 1.0
model.B["power_state"]["high", "high", "high"] = 1.0

agent = Agent(**model, apply_batch=True)