# Sophisticated inference

This notebook demonstrates tree searching policies.

In [1]:
import jax.numpy as jnp
from jax import random as jr

key = jr.PRNGKey(0)

In [2]:
import networkx as nx
from pymdp.jax.envs import GraphEnv


def generate_connected_clusters(cluster_size=2, connections=2):
    edges = []
    connecting_node = 0
    while connecting_node < connections * cluster_size:
        edges += [(connecting_node, a) for a in range(connecting_node + 1, connecting_node + cluster_size + 1)]
        connecting_node = len(edges)
    graph = nx.Graph()
    graph.add_edges_from(edges)
    return graph, {
        "locations": [
            (f"hallway {i}" if len(list(graph.neighbors(loc))) > 1 else f"room {i}")
            for i, loc in enumerate(graph.nodes)
        ]
    }


graph, _ = generate_connected_clusters(cluster_size=3, connections=2)
env = GraphEnv(graph, object_locations=[3], agent_locations=[0])

In [3]:
from pymdp.jax.agent import Agent

A = [a.copy() for a in env.params["A"]]
B = [b.copy() for b in env.params["B"]]
A_dependencies = env.dependencies["A"]
B_dependencies = env.dependencies["B"]

C = [jnp.zeros(a.shape[:2]) for a in A]
C[1] = C[1].at[1].set(1.0)

D = [jnp.ones(b.shape[:2]) / b.shape[1] for b in B]

agent = Agent(A, B, C, D, None, None, None, A_dependencies=A_dependencies, B_dependencies=B_dependencies, policy_len=1)

In [4]:
keys = jr.split(key, 2)
key = keys[0]
obs, env = env.step(keys[1:])

In [5]:
print(obs)

[Array([[0]], dtype=int32), Array([[0]], dtype=int32)]


In [6]:
empirical_prior = agent.D

qs = agent.infer_states(
    observations=obs,
    past_actions=None,
    empirical_prior=empirical_prior,
    qs_hist=None,
)

In [7]:
print(qs)

[Array([[[1.0000000e+00, 1.2888965e-18, 1.2888965e-18, 1.2888965e-18,
         1.2888965e-18, 1.2888965e-18, 1.2888965e-18]]], dtype=float32), Array([[[3.1720716e-17, 1.4285715e-01, 1.4285715e-01, 1.4285715e-01,
         1.4285715e-01, 1.4285715e-01, 1.4285715e-01, 1.4285715e-01]]],      dtype=float32)]


In [8]:
print(qs[0].shape)

(1, 1, 7)


In [10]:
from pymdp.jax.planning import tree_search

tree_search(agent, qs, 3)

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (17 of them) had size 1, e.g. axis 0 of argument self.A[0] of type float32[1,7,7];
  * one axis had size 7: axis 0 of argument qs of type int32[7,1,2]