In [1]:
using Pkg
Pkg.activate(".")
# Pkg.instantiate()

[32m[1m  Activating[22m[39m project at `c:\Simulations\EpistemicMessagePassing\src\Part2\Rx`


In [2]:
using RxInfer, LinearAlgebra

# T-maze layout
# [2| |3]
#   | |
#   |1|
#   |4|

In [3]:
include("goal_observation.jl")

# Define the generative model
@model function t_maze(A_s, D_s, t)
    u = datavar(Matrix{Int64}, 2) # Policy for evaluations
    z = randomvar(2) # Latent states
    c = datavar(Vector{Float64}, 2) # Goal prior statistics

    z_0 ~ Categorical(D_s) # State prior
    A ~ MatrixDirichlet(A_s) # Observation matrix prior

    z_k_min = z_0
    for k=1:2
        z[k] ~ Transition(z_k_min, u[k])
        c[k] ~ GoalObservation(z[k], A) where {
            meta=BetheMeta(), 
            pipeline=BethePipeline()}

        z_k_min = z[k] # Reset for next slice
    end
end

# Define constraints on the variational density
@constraints function structured()
    q(z_0, z, A) = q(z_0, z)q(A)
end

structured (generic function with 1 method)

In [4]:
# Define experimental setting
α = 0.9; c = 2.0 # Reward probability and utility
S = 100 # Number of trials
seed = 1234 # Randomizer seed

include("helpers.jl")
include("environment.jl")
include("agent.jl")

(A, B, C, D) = constructABCD(α, c)
(A_0, D_0) = constructPriors() # Construct prior statistics for A and D
;

In [5]:
function infer(pol)
    data = (u = B[[pol...]],
            c = [C, C])

    initmarginals = (A   = MatrixDirichlet(asym(A_0)),
                     z_0 = Categorical(asym(8)),
                     z   = [Categorical(asym(8)),
                            Categorical(asym(8))])

    return inference(model         = t_maze(A_0, D_0, 1),
                     constraints   = structured(), 
                     data          = data,
                     initmarginals = initmarginals,
                     iterations    = 50,
                     free_energy   = true)
end


infer (generic function with 1 method)

In [6]:
infer((2,4))

Inference results:
  Posteriors       | available for (A, z_0, z)
  Free Energy:     | Real[13.7652, 10.6717, 9.85415, 9.21749, 8.9837, 8.98158, 8.98158, 8.98158, 8.98158, 8.98158  …  8.98158, 8.98158, 8.98158, 8.98158, 8.98158, 8.98158, 8.98158, 8.98158, 8.98158, 8.98158]
