In [1]:
import json
import numpy as np
import jax.numpy as jnp
import jax.tree_util as jtu
import jax.experimental.sparse as jsparse
from jax import nn, vmap, jit, block_until_ready
from functools import partial

from pymdp.utils import init_agent_from_spec, get_sample_obs

# Hybrid
from pymdp.utils import apply_padding_batched
from pymdp.maths import compute_log_likelihoods_padded, deconstruct_lls

# Hybrid block
from pymdp.utils import build_block_diag_A, preprocess_A_for_block_diag, prepare_obs_for_block_diag, concatenate_observations_block_diag
from pymdp.maths import compute_log_likelihoods_block_diag, deconstruct_log_likelihoods_block_diag

from pymdp.algos import run_factorized_fpi_hybrid # For hybrid and hybrid block

# End2end padded
from pymdp.utils import apply_A_end2end_padding_batched, apply_obs_end2end_padding_batched
from pymdp.maths import compute_log_likelihood_per_modality_end2end2_padded
from pymdp.algos import run_factorized_fpi_end2end_padded

In [None]:
with open('random_agent_specs/agent_specs_latest_lighter_deps_filtered.json') as f:
    specs = json.load(f)

spec = specs['arbitrary dependencies'][0]
spec

{'num_factors': 5,
 'num_modalities': 5,
 'num_states': [4, 3, 4, 2, 4],
 'num_obs': [2, 2, 3, 4, 3],
 'A_dependencies': [[2], [3, 4], [0, 1, 2, 3], [0], [4]],
 'metadata': {'num_factors': 'low',
  'num_modalities': 'low',
  'state_dim_upper_limit': 'low',
  'obs_dim_upper_limit': 'low',
  'dim_sampling_type': 'uniform'}}

In [None]:
num_iter = 8
batch_size = 4
A_sparsity_level = None # E.g., 0.8 for 80% sparsity

A, D = init_agent_from_spec(
    spec['num_obs'],
    spec['num_states'],
    spec['A_dependencies'],
    A_sparsity_level=A_sparsity_level,
    batch_size=batch_size
)

obs = get_sample_obs(spec['num_obs'], batch_size=batch_size)
o_vec = [nn.one_hot(o, spec['num_obs'][m]) for m, o in enumerate(obs)]

# place where this happens is important!
# o_vec = jtu.tree_map(lambda x: x[-1], o_vec)



### Original method imported directly from PyMDP

In [4]:
from pymdp.inference import update_posterior_states

infer_states_orig_pymdp = vmap(
    partial(
        update_posterior_states,
        A_dependencies=spec['A_dependencies'],
        num_iter=num_iter,
        method='fpi'
    )
)

In [5]:
qs = infer_states_orig_pymdp(A, None, o_vec, None, D)
[q.shape for q in qs], qs

([(4, 1, 4), (4, 1, 3), (4, 1, 4), (4, 1, 2), (4, 1, 4)],
 [Array([[[0.453592  , 0.2891406 , 0.06482263, 0.19244476]],
  
         [[0.23495829, 0.08601432, 0.14110027, 0.5379271 ]],
  
         [[0.19091347, 0.13984956, 0.3721627 , 0.29707432]],
  
         [[0.3885364 , 0.22233072, 0.0110043 , 0.3781286 ]]], dtype=float32),
  Array([[[0.2603369 , 0.38066563, 0.35899743]],
  
         [[0.25421855, 0.38015884, 0.36562258]],
  
         [[0.41699758, 0.3009654 , 0.282037  ]],
  
         [[0.39347005, 0.37286237, 0.23366755]]], dtype=float32),
  Array([[[0.28605786, 0.13507262, 0.34174025, 0.23712924]],
  
         [[0.3640707 , 0.3019745 , 0.15490013, 0.17905463]],
  
         [[0.12120205, 0.14439411, 0.46568182, 0.2687221 ]],
  
         [[0.36508068, 0.22946823, 0.13614362, 0.26930743]]], dtype=float32),
  Array([[[0.46730146, 0.5326986 ]],
  
         [[0.5473364 , 0.45266357]],
  
         [[0.5823111 , 0.41768897]],
  
         [[0.54273903, 0.45726097]]], dtype=float32),
  Arra

### Hybrid method

In [6]:
def infer_states_hybrid(obs_padded, A_padded, D, A_shapes, A_dependencies, num_iter):
    lls_padded = compute_log_likelihoods_padded(obs_padded, A_padded)
    log_likelihoods = deconstruct_lls(lls_padded, A_shapes)
    return vmap(partial(run_factorized_fpi_hybrid, A_dependencies=A_dependencies, num_iter=num_iter))(log_likelihoods, D)

In [7]:
A_padded = apply_padding_batched(A)
A_shapes = [a.shape for a in A]

if A_sparsity_level is not None:
    A_padded = jsparse.BCOO.fromdense(A_padded, n_batch=1)

# obs preprocessing
obs_padded = apply_padding_batched(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))

In [8]:
qs = infer_states_hybrid(obs_padded, A_padded, D, A_shapes, A_dependencies=spec['A_dependencies'], num_iter=num_iter)
[q.shape for q in qs], qs

([(4, 1, 4), (4, 1, 3), (4, 1, 4), (4, 1, 2), (4, 1, 4)],
 [Array([[[0.453592  , 0.2891406 , 0.06482263, 0.19244476]],
  
         [[0.23495829, 0.08601432, 0.14110027, 0.5379271 ]],
  
         [[0.19091347, 0.13984956, 0.3721627 , 0.29707432]],
  
         [[0.3885364 , 0.22233072, 0.0110043 , 0.3781286 ]]], dtype=float32),
  Array([[[0.2603369 , 0.38066563, 0.35899743]],
  
         [[0.25421855, 0.38015884, 0.36562258]],
  
         [[0.41699758, 0.3009654 , 0.282037  ]],
  
         [[0.39347005, 0.37286237, 0.23366755]]], dtype=float32),
  Array([[[0.28605786, 0.13507262, 0.34174025, 0.23712924]],
  
         [[0.3640707 , 0.3019745 , 0.15490013, 0.17905463]],
  
         [[0.12120205, 0.14439411, 0.46568182, 0.2687221 ]],
  
         [[0.36508068, 0.22946823, 0.13614362, 0.26930743]]], dtype=float32),
  Array([[[0.46730146, 0.5326986 ]],
  
         [[0.5473364 , 0.45266357]],
  
         [[0.5823111 , 0.41768897]],
  
         [[0.54273903, 0.45726097]]], dtype=float32),
  Arra

In [9]:
# JIT
apply_padding_batched_jit = jit(partial(apply_padding_batched))
infer_states_hybrid_jit = jit(partial(infer_states_hybrid, A_shapes=A_shapes, A_dependencies=spec['A_dependencies'], num_iter=num_iter))
obs_padded = apply_padding_batched_jit(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))

qs = infer_states_hybrid_jit(obs_padded, A_padded, D)
[q.shape for q in qs], qs

([(4, 1, 4), (4, 1, 3), (4, 1, 4), (4, 1, 2), (4, 1, 4)],
 [Array([[[0.453592  , 0.2891406 , 0.06482263, 0.19244476]],
  
         [[0.23495829, 0.08601432, 0.14110027, 0.5379271 ]],
  
         [[0.19091347, 0.13984956, 0.3721627 , 0.29707432]],
  
         [[0.3885364 , 0.22233072, 0.0110043 , 0.3781286 ]]], dtype=float32),
  Array([[[0.2603369 , 0.38066563, 0.35899743]],
  
         [[0.25421855, 0.38015884, 0.36562258]],
  
         [[0.41699758, 0.3009654 , 0.282037  ]],
  
         [[0.39347005, 0.37286237, 0.23366755]]], dtype=float32),
  Array([[[0.28605786, 0.13507262, 0.34174025, 0.23712924]],
  
         [[0.3640707 , 0.3019745 , 0.15490013, 0.17905463]],
  
         [[0.12120205, 0.14439411, 0.46568182, 0.2687221 ]],
  
         [[0.36508068, 0.22946823, 0.13614362, 0.26930743]]], dtype=float32),
  Array([[[0.46730146, 0.5326986 ]],
  
         [[0.5473364 , 0.45266357]],
  
         [[0.5823111 , 0.41768897]],
  
         [[0.54273903, 0.45726097]]], dtype=float32),
  Arra

### Hybrid Block method

In [10]:
# Infer states hybrid block
def infer_states_hybrid_block(obs, A_big, D, state_shapes, cuts, A_dependencies, num_iter, use_einsum=False):
    """Hybrid inference using block diagonal approach for log-likelihood computation."""
    log_likelihoods = compute_log_likelihoods_block_diag(A_big, obs, state_shapes, cuts, use_einsum=use_einsum)
    return vmap(partial(run_factorized_fpi_hybrid, A_dependencies=A_dependencies, num_iter=num_iter))(log_likelihoods, D)

In [11]:
A = [jnp.moveaxis(a, 1, -1) for a in A]
# Preprocess A matrices for block diagonal approach
A_big, state_shapes, cuts = preprocess_A_for_block_diag(A)

if A_sparsity_level is not None:
    A_big = jsparse.BCOO.fromdense(A_big, n_batch=1)

obs_tmp = jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec)
obs_big = concatenate_observations_block_diag(obs_tmp)

In [12]:
qs = infer_states_hybrid_block(obs_big, A_big, D, 
    state_shapes=state_shapes, cuts=cuts, A_dependencies=spec['A_dependencies'], 
    num_iter=num_iter, use_einsum=False
)

[q.shape for q in qs], qs

([(4, 1, 4), (4, 1, 3), (4, 1, 4), (4, 1, 2), (4, 1, 4)],
 [Array([[[0.453592  , 0.2891406 , 0.06482263, 0.19244476]],
  
         [[0.23495829, 0.08601432, 0.14110027, 0.5379271 ]],
  
         [[0.19091347, 0.13984956, 0.3721627 , 0.29707432]],
  
         [[0.3885364 , 0.22233072, 0.0110043 , 0.3781286 ]]], dtype=float32),
  Array([[[0.2603369 , 0.38066563, 0.35899743]],
  
         [[0.25421855, 0.38015884, 0.36562258]],
  
         [[0.41699758, 0.3009654 , 0.282037  ]],
  
         [[0.39347005, 0.37286237, 0.23366755]]], dtype=float32),
  Array([[[0.28605786, 0.13507262, 0.34174025, 0.23712924]],
  
         [[0.3640707 , 0.3019745 , 0.15490013, 0.17905463]],
  
         [[0.12120205, 0.14439411, 0.46568182, 0.2687221 ]],
  
         [[0.36508068, 0.22946823, 0.13614362, 0.26930743]]], dtype=float32),
  Array([[[0.46730146, 0.5326986 ]],
  
         [[0.5473364 , 0.45266357]],
  
         [[0.5823111 , 0.41768897]],
  
         [[0.54273903, 0.45726097]]], dtype=float32),
  Arra

In [13]:
# JIT
use_einsum=False
concatenate_observations_block_diag_jit = jit(partial(concatenate_observations_block_diag)) # just for obs
infer_states_hybrid_block_jit = jit(partial(infer_states_hybrid_block, state_shapes=state_shapes, cuts=cuts, A_dependencies=spec['A_dependencies'], num_iter=num_iter, use_einsum=use_einsum))
obs_big = concatenate_observations_block_diag_jit(obs_tmp) # add padding of obs before running the infer states

qs = infer_states_hybrid_block_jit(obs_big, A_big, D)
[q.shape for q in qs], qs

([(4, 1, 4), (4, 1, 3), (4, 1, 4), (4, 1, 2), (4, 1, 4)],
 [Array([[[0.453592  , 0.2891406 , 0.06482263, 0.19244476]],
  
         [[0.23495829, 0.08601432, 0.14110027, 0.5379271 ]],
  
         [[0.19091347, 0.13984956, 0.3721627 , 0.29707432]],
  
         [[0.3885364 , 0.22233072, 0.0110043 , 0.3781286 ]]], dtype=float32),
  Array([[[0.2603369 , 0.38066563, 0.35899743]],
  
         [[0.25421855, 0.38015884, 0.36562258]],
  
         [[0.41699758, 0.3009654 , 0.282037  ]],
  
         [[0.39347005, 0.37286237, 0.23366755]]], dtype=float32),
  Array([[[0.28605786, 0.13507262, 0.34174025, 0.23712924]],
  
         [[0.3640707 , 0.3019745 , 0.15490013, 0.17905463]],
  
         [[0.12120205, 0.14439411, 0.46568182, 0.2687221 ]],
  
         [[0.36508068, 0.22946823, 0.13614362, 0.26930743]]], dtype=float32),
  Array([[[0.46730146, 0.5326986 ]],
  
         [[0.5473364 , 0.45266357]],
  
         [[0.5823111 , 0.41768897]],
  
         [[0.54273903, 0.45726097]]], dtype=float32),
  Arra

### End2End padded method

In [14]:
def infer_states_end2end_padded(A_padded, obs_padded, D, A_dependencies, max_obs_dim, max_state_dim, num_iter, sparsity=None):
    lls_padded = compute_log_likelihood_per_modality_end2end2_padded(obs_padded, A_padded, sparsity=sparsity)
    return run_factorized_fpi_end2end_padded(lls_padded, D, A_dependencies, max_obs_dim, max_state_dim, num_iter)

In [15]:
A_padded = apply_A_end2end_padding_batched(A)

if A_sparsity_level is not None:
    A_padded = jsparse.BCOO.fromdense(A_padded)

max_obs_dim = A_padded.shape[2]
max_state_dim = max(A_padded.shape[3:])

# obs preprocessing
obs_padded = apply_obs_end2end_padding_batched(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec), max_obs_dim)

In [16]:
qs = infer_states_end2end_padded(A_padded, obs_padded, D, spec['A_dependencies'], max_obs_dim, max_state_dim, num_iter, sparsity='ll_only')

[q.shape for q in qs], qs

([(4, 1, 4), (4, 1, 3), (4, 1, 4), (4, 1, 2), (4, 1, 4)],
 [Array([[[3.27321410e-01, 5.66468358e-01, 1.06210284e-01, 3.18200288e-16]],
  
         [[9.80049148e-02, 2.49357328e-01, 6.52637780e-01, 8.94657057e-16]],
  
         [[1.39479563e-01, 7.13981926e-01, 1.46538526e-01, 1.27322260e-16]],
  
         [[5.90120912e-01, 3.80194724e-01, 2.96843629e-02, 2.64849394e-15]]],      dtype=float32),
  Array([[[0.3094568 , 0.3524493 , 0.33809394]],
  
         [[0.37040284, 0.4546238 , 0.17497338]],
  
         [[0.29984298, 0.28486982, 0.41528726]],
  
         [[0.3355383 , 0.39806393, 0.26639774]]], dtype=float32),
  Array([[[6.3361436e-01, 3.6638564e-01, 1.7301976e-31, 1.7301976e-31]],
  
         [[3.7444645e-01, 6.2555355e-01, 1.8134088e-31, 1.8134088e-31]],
  
         [[2.3076543e-01, 7.6923460e-01, 1.5773009e-31, 1.5773009e-31]],
  
         [[6.2786764e-01, 3.7213239e-01, 1.7959364e-31, 1.7959364e-31]]],      dtype=float32),
  Array([[[0.4980732 , 0.50192684]],
  
         [[0.59721

In [18]:
# JIT
apply_obs_padding_batched_jit = jit(partial(apply_obs_end2end_padding_batched, max_obs_dim=max_obs_dim))
infer_states_partially_padded_jit = jit(partial(infer_states_end2end_padded, A_dependencies=spec['A_dependencies'], max_obs_dim=max_obs_dim, max_state_dim=max_state_dim, num_iter=num_iter, sparsity='ll_only'))
obs_padded = apply_obs_padding_batched_jit(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))

qs = infer_states_partially_padded_jit(A_padded, obs_padded, D)
[q.shape for q in qs], qs

([(4, 1, 4), (4, 1, 3), (4, 1, 4), (4, 1, 2), (4, 1, 4)],
 [Array([[[3.27321410e-01, 5.66468358e-01, 1.06210284e-01, 3.18200288e-16]],
  
         [[9.80049148e-02, 2.49357328e-01, 6.52637780e-01, 8.94657057e-16]],
  
         [[1.39479563e-01, 7.13981926e-01, 1.46538526e-01, 1.27322260e-16]],
  
         [[5.90120912e-01, 3.80194724e-01, 2.96843629e-02, 2.64849394e-15]]],      dtype=float32),
  Array([[[0.3094568 , 0.3524493 , 0.33809394]],
  
         [[0.37040284, 0.4546238 , 0.17497338]],
  
         [[0.29984298, 0.28486982, 0.41528726]],
  
         [[0.3355383 , 0.39806393, 0.26639774]]], dtype=float32),
  Array([[[6.3361436e-01, 3.6638564e-01, 1.7301976e-31, 1.7301976e-31]],
  
         [[3.7444645e-01, 6.2555355e-01, 1.8134088e-31, 1.8134088e-31]],
  
         [[2.3076543e-01, 7.6923460e-01, 1.5773009e-31, 1.5773009e-31]],
  
         [[6.2786764e-01, 3.7213239e-01, 1.7959364e-31, 1.7959364e-31]]],      dtype=float32),
  Array([[[0.4980732 , 0.50192684]],
  
         [[0.59721