In [1]:
import numpy as np
import jax.numpy as jnp
import jax.tree_util as jtu
import jax.experimental.sparse as jsparse
from jax import nn, vmap, jit, block_until_ready
from functools import partial

from pymdp.utils import init_A_and_D_from_spec, get_sample_obs, generate_agent_specs_from_parameter_sets

# Hybrid
from pymdp.utils import apply_padding_batched
from pymdp.maths import compute_log_likelihoods_padded, deconstruct_lls

# Hybrid block
from pymdp.utils import build_block_diag_A, preprocess_A_for_block_diag, prepare_obs_for_block_diag, concatenate_observations_block_diag
from pymdp.maths import compute_log_likelihoods_block_diag, deconstruct_log_likelihoods_block_diag

from pymdp.algos import run_factorized_fpi_hybrid # For hybrid and hybrid block

# End2end padded
from pymdp.utils import apply_A_end2end_padding_batched, apply_obs_end2end_padding_batched
from pymdp.maths import compute_log_likelihood_per_modality_end2end2_padded
from pymdp.algos import run_factorized_fpi_end2end_padded

In [2]:
# Define coordinated parameter sets
# (num_factors, num_modalities, state_dim_upper_limit, obs_dim_upper_limit, dim_sampling_type, label)
parameter_sets = [
    (5, 5, 5, 5, 'uniform', 'low'),
    (10, 10, 10, 10, 'uniform', 'medium'),
    (25, 25, 25, 25, 'uniform', 'high'),
    # (125, 125, 125, 125, 'uniform', 'extreme'),  # Uncomment to include extreme cases
]

# Generate agent specs without dumping to file
specs = generate_agent_specs_from_parameter_sets(
    parameter_sets,
    num_agents_per_set=1,
    output_file=None  # Don't save to file
)

spec = specs['arbitrary dependencies'][1]
spec

{'num_factors': 10,
 'num_modalities': 10,
 'num_states': [3, 6, 2, 9, 7, 4, 7, 5, 7, 5],
 'num_obs': [4, 2, 4, 9, 7, 8, 4, 5, 7, 4],
 'A_dependencies': [[1],
  [4],
  [5, 6],
  [0, 8],
  [0, 2, 7],
  [3, 5],
  [6],
  [2, 4, 7, 9],
  [0, 2, 7, 8],
  [7]],
 'metadata': {'num_factors': 'medium',
  'num_modalities': 'medium',
  'state_dim_upper_limit': 'medium',
  'obs_dim_upper_limit': 'medium',
  'dim_sampling_type': 'uniform'}}

In [3]:
num_iter = 8
batch_size = 4
A_sparsity_level = None # E.g., 0.8 for 80% sparsity

A, D = init_A_and_D_from_spec(
    spec['num_obs'],
    spec['num_states'],
    spec['A_dependencies'],
    A_sparsity_level=A_sparsity_level,
    batch_size=batch_size
)

obs = get_sample_obs(spec['num_obs'], batch_size=batch_size)
o_vec = [nn.one_hot(o, spec['num_obs'][m]) for m, o in enumerate(obs)]

# place where this happens is important!
# o_vec = jtu.tree_map(lambda x: x[-1], o_vec)



### Original method imported directly from PyMDP

In [4]:
from pymdp.inference import update_posterior_states

infer_states_orig_pymdp = vmap(
    partial(
        update_posterior_states,
        A_dependencies=spec['A_dependencies'],
        num_iter=num_iter,
        method='fpi'
    )
)

In [5]:
qs = infer_states_orig_pymdp(A, None, o_vec, None, D)
[q.shape for q in qs], qs

([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.0911239 , 0.70235616, 0.20651992]],
  
         [[0.20210403, 0.22807154, 0.56982446]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.5558248 , 0.44417518]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.38312683, 0.6168732 ]],
  
         [[0.5388057 , 0.46119428]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16

### Hybrid method

In [6]:
def infer_states_hybrid(obs_padded, A_padded, D, A_shapes, A_dependencies, num_iter):
    lls_padded = compute_log_likelihoods_padded(obs_padded, A_padded)
    log_likelihoods = deconstruct_lls(lls_padded, A_shapes)
    return vmap(partial(run_factorized_fpi_hybrid, A_dependencies=A_dependencies, num_iter=num_iter))(log_likelihoods, D)

In [7]:
A_padded = apply_padding_batched(A)
A_shapes = [a.shape for a in A]

if A_sparsity_level is not None:
    A_padded = jsparse.BCOO.fromdense(A_padded, n_batch=1)

# obs preprocessing
obs_padded = apply_padding_batched(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))

In [8]:
qs = infer_states_hybrid(obs_padded, A_padded, D, A_shapes, A_dependencies=spec['A_dependencies'], num_iter=num_iter)
[q.shape for q in qs], qs

([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.0911239 , 0.70235616, 0.20651992]],
  
         [[0.20210403, 0.22807154, 0.56982446]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.5558248 , 0.44417518]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.38312683, 0.6168732 ]],
  
         [[0.5388057 , 0.46119428]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16

In [9]:
# JIT
apply_padding_batched_jit = jit(partial(apply_padding_batched))
infer_states_hybrid_jit = jit(partial(infer_states_hybrid, A_shapes=A_shapes, A_dependencies=spec['A_dependencies'], num_iter=num_iter))
obs_padded = apply_padding_batched_jit(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))

qs = infer_states_hybrid_jit(obs_padded, A_padded, D)
[q.shape for q in qs], qs

([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.0911239 , 0.70235616, 0.20651992]],
  
         [[0.20210403, 0.22807154, 0.56982446]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.5558248 , 0.44417518]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.38312683, 0.6168732 ]],
  
         [[0.5388057 , 0.46119428]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16

### Hybrid Block method

In [10]:
# Infer states hybrid block
def infer_states_hybrid_block(obs, A_big, D, state_shapes, cuts, A_dependencies, num_iter, use_einsum=False):
    """Hybrid inference using block diagonal approach for log-likelihood computation."""
    log_likelihoods = compute_log_likelihoods_block_diag(A_big, obs, state_shapes, cuts, use_einsum=use_einsum)
    return vmap(partial(run_factorized_fpi_hybrid, A_dependencies=A_dependencies, num_iter=num_iter))(log_likelihoods, D)

In [11]:
# Create a copy with moved axes for block diagonal method (don't modify original A)
A_moveaxis = [jnp.moveaxis(a, 1, -1) for a in A]
# Preprocess A matrices for block diagonal approach
A_big, state_shapes, cuts = preprocess_A_for_block_diag(A_moveaxis)

if A_sparsity_level is not None:
    A_big = jsparse.BCOO.fromdense(A_big, n_batch=1)

obs_tmp = jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec)
obs_big = concatenate_observations_block_diag(obs_tmp)

In [12]:
qs = infer_states_hybrid_block(obs_big, A_big, D, 
    state_shapes=state_shapes, cuts=cuts, A_dependencies=spec['A_dependencies'], 
    num_iter=num_iter, use_einsum=False
)

[q.shape for q in qs], qs

([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.0911239 , 0.70235616, 0.20651992]],
  
         [[0.20210403, 0.22807154, 0.56982446]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.5558248 , 0.44417518]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.38312683, 0.6168732 ]],
  
         [[0.5388057 , 0.46119428]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16

In [13]:
# JIT
use_einsum=False
concatenate_observations_block_diag_jit = jit(partial(concatenate_observations_block_diag)) # just for obs
infer_states_hybrid_block_jit = jit(partial(infer_states_hybrid_block, state_shapes=state_shapes, cuts=cuts, A_dependencies=spec['A_dependencies'], num_iter=num_iter, use_einsum=use_einsum))
obs_big = concatenate_observations_block_diag_jit(obs_tmp) # add padding of obs before running the infer states

qs = infer_states_hybrid_block_jit(obs_big, A_big, D)
[q.shape for q in qs], qs

([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.0911239 , 0.70235616, 0.20651992]],
  
         [[0.20210403, 0.22807154, 0.56982446]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.3751645 , 0.27490953, 0.34992597]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.5558248 , 0.44417518]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.38312683, 0.6168732 ]],
  
         [[0.5388057 , 0.46119428]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16

### End2End padded method

In [14]:
def infer_states_end2end_padded(A_padded, obs_padded, D, A_dependencies, max_obs_dim, max_state_dim, num_iter, sparsity=None):
    lls_padded = compute_log_likelihood_per_modality_end2end2_padded(obs_padded, A_padded, sparsity=sparsity)
    return run_factorized_fpi_end2end_padded(lls_padded, D, A_dependencies, max_obs_dim, max_state_dim, num_iter)

In [15]:
A_padded = apply_A_end2end_padding_batched(A)

if A_sparsity_level is not None:
    A_padded = jsparse.BCOO.fromdense(A_padded)

max_obs_dim = A_padded.shape[2]
max_state_dim = max(A_padded.shape[3:])

# obs preprocessing
obs_padded = apply_obs_end2end_padding_batched(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec), max_obs_dim)

In [16]:
qs = infer_states_end2end_padded(A_padded, obs_padded, D, spec['A_dependencies'], max_obs_dim, max_state_dim, num_iter, sparsity='ll_only')

[q.shape for q in qs], qs

([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.09112392, 0.70235604, 0.20652005]],
  
         [[0.202104  , 0.22807139, 0.56982464]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.37516478, 0.27490932, 0.34992588]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.55582505, 0.44417495]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.3831266 , 0.6168734 ]],
  
         [[0.5388056 , 0.4611944 ]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16

In [17]:
# JIT
apply_obs_padding_batched_jit = jit(partial(apply_obs_end2end_padding_batched, max_obs_dim=max_obs_dim))
infer_states_partially_padded_jit = jit(partial(infer_states_end2end_padded, A_dependencies=spec['A_dependencies'], max_obs_dim=max_obs_dim, max_state_dim=max_state_dim, num_iter=num_iter, sparsity='ll_only'))
obs_padded = apply_obs_padding_batched_jit(jtu.tree_map(lambda x: jnp.squeeze(x, 1), o_vec))

qs = infer_states_partially_padded_jit(A_padded, obs_padded, D)
[q.shape for q in qs], qs

([(4, 1, 3),
  (4, 1, 6),
  (4, 1, 2),
  (4, 1, 9),
  (4, 1, 7),
  (4, 1, 4),
  (4, 1, 7),
  (4, 1, 5),
  (4, 1, 7),
  (4, 1, 5)],
 [Array([[[0.09112392, 0.70235604, 0.20652005]],
  
         [[0.202104  , 0.22807139, 0.56982464]],
  
         [[0.3876841 , 0.3879837 , 0.2243322 ]],
  
         [[0.37516478, 0.27490932, 0.34992588]]], dtype=float32),
  Array([[[0.15824087, 0.183674  , 0.26761115, 0.09957846, 0.20257226,
           0.08832329]],
  
         [[0.23064205, 0.02032886, 0.13066185, 0.01698621, 0.2857863 ,
           0.31559473]],
  
         [[0.19148783, 0.11801518, 0.15054731, 0.13707514, 0.17550918,
           0.22736534]],
  
         [[0.16028336, 0.27957812, 0.16871531, 0.12887648, 0.1418996 ,
           0.12064715]]], dtype=float32),
  Array([[[0.55582505, 0.44417495]],
  
         [[0.549685  , 0.450315  ]],
  
         [[0.3831266 , 0.6168734 ]],
  
         [[0.5388056 , 0.4611944 ]]], dtype=float32),
  Array([[[0.07716177, 0.10925393, 0.07744848, 0.1075874 , 0.16