diff --git a/pymdp/learning.py b/pymdp/learning.py index bd694a6d..ec334f68 100644 --- a/pymdp/learning.py +++ b/pymdp/learning.py @@ -100,7 +100,7 @@ def update_state_likelihood_dirichlet( for factor in factors: dfdb = maths.spm_cross(qs[factor], qs_prev[factor]) - dfdb *= (B[factor][:, :, actions[factor]] > 0).astype("float") + dfdb *= (B[factor][:, :, int(actions[factor])] > 0).astype("float") qB[factor][:,:,int(actions[factor])] += (lr*dfdb) return qB