# Complex action dependencies

In this notebook, we will show some examples of how to specify and run agents with complex action dependencies. Complex action dependencies refer to situations where a state variables depends on multiple actions or no action. These state transitions tensors have shapes of the form: `[state_dim, *prev_state_dims, *prev_action_dims]`. 

The general strategy for dealing with this is to flatten the `prev_action_dims` while initializing the agent so that the new B tensor shapes are `[state_dim, *prev_state_dims, math.prod(prev_action_dims)]`. If a state has no action dependency, the new B tensor will have shape `[state_dim, *prev_state_dims, 1]` where 1 stands for a dummy action. All computations will be done in the flattened B tensors and actions will be sampled in the flattened action dimensions. After a flattened action is sampled, one can convert it back to the original action dimensions by calling `agent.decode_multi_actions`. To flatten multi actions, for example from collected data, one can call `agent.encode_multi_actions`.

In [1]:
from pprint import pprint
import itertools
import numpy as np
from jax import numpy as jnp
from jax import tree_util as jtu

from pymdp.jax.agent import Agent
from pymdp.jax import distribution

## Multiple action dependencies
In this example, some states depend on multiple actions. 

In [2]:
model = {
    "observations": {
        "o1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"]},
    },
    "controls": {"c1": {"elements": ["up", "down"]}, "c2": {"elements": ["left", "right", "stay"]}},
    "states": {
        "s1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"], "controlled_by": ["c1", "c2"]},
        "s2": {"elements": ["A", "B", "C", "D"], "depends_on": ["s2"], "controlled_by": ["c1"]},
    },
}

B_action_dependencies = [
    [list(model["controls"].keys()).index(i) for i in s["controlled_by"]] 
    for s in model["states"].values()
]
num_controls = [len(c["elements"]) for c in model["controls"].values()]

As, Bs = distribution.compile_model(model)

# initialize tensor values
As[0]["A", "A"] = 1.0
As[0]["B", "B"] = 1.0
As[0]["C", "C"] = 1.0
As[0]["D", "D"] = 1.0

for i, state in enumerate(model["states"].keys()):
    controls = list(itertools.product(*[
        model["controls"][c]["elements"] for c in model["states"][state]["controlled_by"]
    ]))
    for control in controls:
        Bs[i][*["B", "A"], *control] = 1.0
        Bs[i][*["C", "B"], *control] = 1.0
        Bs[i][*["D", "C"], *control] = 1.0
        Bs[i][*["D", "D"], *control] = 1.0

agent = Agent(
    As, Bs,
    B_action_dependencies=B_action_dependencies,
    num_controls=num_controls,
)

# dummy history
action = agent.policies[np.random.randint(0, len(agent.policies))]
observation = [np.random.randint(0, d, size=(1, 1)) for d in agent.num_obs]
qs_hist = jtu.tree_map(lambda x: jnp.expand_dims(x, 0), agent.D)

prior, _ = agent.infer_empirical_prior(action, qs_hist)
qs = agent.infer_states(observation, None, prior, None)

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
action_multi = agent.decode_multi_actions(action)
action_reconstruct = agent.encode_multi_actions(action_multi)

print("A_dependencies", agent.A_dependencies)
print("B_dependencies", agent.B_dependencies)
print("B_action_dependencies", agent.B_action_dependencies)
print("original control dims", agent.num_controls_multi)
print("flattened control dims", agent.num_controls)
print("original B shapes", [a.data.shape for a in Bs])
print("flattened B shapes", [a.shape for a in agent.B])
print("B normalized", [jnp.isclose(a.data.sum(0), 1.).all() for a in Bs])
print("B flat normalized", [jnp.isclose(a.sum(1), 1.).all() for a in agent.B])

print("\n")
print("prior")
pprint([p.round(2) for p in prior])
print("post")
pprint([p.round(2) for p in qs])
print("action")
pprint(action)
print("action_multi")
pprint(action_multi)
print("action_reconstruct")
pprint(action_reconstruct)

A_dependencies [[0]]
B_dependencies [[0], [1]]
B_action_dependencies [[0, 1], [0]]
original control dims [2, 3]
flattened control dims [6, 2]
original B shapes [(4, 4, 2, 3), (4, 4, 2)]
flattened B shapes [(1, 4, 4, 6), (1, 4, 4, 2)]
B normalized [Array(True, dtype=bool), Array(True, dtype=bool)]
B flat normalized [Array(True, dtype=bool), Array(True, dtype=bool)]


prior
[Array([[0.  , 0.25, 0.25, 0.5 ]], dtype=float32),
 Array([[0.  , 0.25, 0.25, 0.5 ]], dtype=float32)]
post
[Array([[[0.5 , 0.12, 0.12, 0.25]]], dtype=float32),
 Array([[[0.  , 0.25, 0.25, 0.5 ]]], dtype=float32)]
action
Array([[0, 0]], dtype=int32)
action_multi
Array([[0, 0]], dtype=int32)
action_reconstruct
Array([[0, 0]], dtype=int32)


## No action dependency

In this example, some states do not depend on any action.

In [3]:
model = {
    "observations": {
        "o1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"]},
    },
    "controls": {"c1": {"elements": ["up", "down"]}, "c2": {"elements": ["left", "right", "stay"]}},
    "states": {
        "s1": {"elements": ["A", "B", "C", "D"], "depends_on": ["s1"], "controlled_by": ["c1", "c2"]},
        "s2": {"elements": ["A", "B", "C", "D"], "depends_on": ["s2"], "controlled_by": []},
    },
}

B_action_dependencies = [
    [list(model["controls"].keys()).index(i) for i in s["controlled_by"]] 
    for s in model["states"].values()
]
num_controls = [len(c["elements"]) for c in model["controls"].values()]

As, Bs = distribution.compile_model(model)

# initialize tensor values
As[0]["A", "A"] = 1.0
As[0]["B", "B"] = 1.0
As[0]["C", "C"] = 1.0
As[0]["D", "D"] = 1.0

for i, state in enumerate(model["states"].keys()):
    controls = list(itertools.product(*[
        model["controls"][c]["elements"] for c in model["states"][state]["controlled_by"]
    ]))
    for control in controls:
        Bs[i][*["B", "A"], *control] = 1.0
        Bs[i][*["C", "B"], *control] = 1.0
        Bs[i][*["D", "C"], *control] = 1.0
        Bs[i][*["D", "D"], *control] = 1.0

agent = Agent(
    As, Bs,
    B_action_dependencies=B_action_dependencies,
    num_controls=num_controls,
)

# dummy history
action = agent.policies[np.random.randint(0, len(agent.policies))]
observation = [np.random.randint(0, d, size=(1, 1)) for d in agent.num_obs]
qs_hist = jtu.tree_map(lambda x: jnp.expand_dims(x, 0), agent.D)

prior, _ = agent.infer_empirical_prior(action, qs_hist)
qs = agent.infer_states(observation, None, prior, None)

q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
action_multi = agent.decode_multi_actions(action)
action_reconstruct = agent.encode_multi_actions(action_multi)

print("A_dependencies", agent.A_dependencies)
print("B_dependencies", agent.B_dependencies)
print("B_action_dependencies", agent.B_action_dependencies)
print("original control dims", agent.num_controls_multi)
print("flattened control dims", agent.num_controls)
print("original B shapes", [a.data.shape for a in Bs])
print("flattened B shapes", [a.shape for a in agent.B])
print("B normalized", [jnp.isclose(a.data.sum(0), 1.).all() for a in Bs])
print("B flat normalized", [jnp.isclose(a.sum(1), 1.).all() for a in agent.B])

print("\n")
print("prior")
pprint([p.round(2) for p in prior])
print("post")
pprint([p.round(2) for p in qs])
print("action")
pprint(action)
print("action_multi")
pprint(action_multi)
print("action_reconstruct")
pprint(action_reconstruct)

A_dependencies [[0]]
B_dependencies [[0], [1]]
B_action_dependencies [[0, 1], []]
original control dims [2, 3]
flattened control dims [6, 1]
original B shapes [(4, 4, 2, 3), (4, 4)]
flattened B shapes [(1, 4, 4, 6), (1, 4, 4, 1)]
B normalized [Array(True, dtype=bool), Array(True, dtype=bool)]
B flat normalized [Array(True, dtype=bool), Array(True, dtype=bool)]


prior
[Array([[0.  , 0.25, 0.25, 0.5 ]], dtype=float32),
 Array([[0.  , 0.25, 0.25, 0.5 ]], dtype=float32)]
post
[Array([[[0., 0., 1., 0.]]], dtype=float32),
 Array([[[0.  , 0.25, 0.25, 0.5 ]]], dtype=float32)]
action
Array([[0, 0]], dtype=int32)
action_multi
Array([[0, 0]], dtype=int32)
action_reconstruct
Array([[0, 0]], dtype=int32)
