# T-Maze Interactive Simulation

This notebook executes the action-perception loop for a discrete GFE-constrained SSM.

In [1]:
using Pkg
Pkg.activate("../..")
# Pkg.instantiate()

[32m[1m  Activating[22m[39m project at `c:\Simulations\EpistemicMessagePassing`


In [2]:
using LinearAlgebra
using ForneyLab
using Plots
using ForwardDiff: hessian
# using ProgressMeter

# T-maze layout
# [2| |3]
#   | |
#   |1|
#   |4|

include("factor_nodes/discrete_observation.jl")
include("update_rules/discrete_observation.jl")
;

In [3]:
using StatsFuns: gammainvcdf
import ForneyLab: ruleSPEqualityFnFactor, sample

# function sample(dist::Distribution{Univariate, Gamma})
#     (dist.params[:a] < 0.01) && return 0.0
#     gammainvcdf(dist.params[:a], 1/dist.params[:b], rand())
# end

# Custom update that outputs a Function message as result o6f Dirichlet-Function message product (instead of SampleList)
ruleSPEqualityFnFactor(msg_1::Message{<:Function}, msg_2::Message{<:Dirichlet}, msg_3::Nothing) = Message(prodDirFn!(msg_1.dist, msg_2.dist))
ruleSPEqualityFnFactor(msg_1::Message{<:Function}, msg_2::Nothing, msg_3::Message{<:Dirichlet}) = Message(prodDirFn!(msg_1.dist, msg_3.dist))
ruleSPEqualityFnFactor(msg_1::Nothing, msg_2::Message{<:Function}, msg_3::Message{<:Dirichlet}) = Message(prodDirFn!(msg_2.dist, msg_3.dist))
ruleSPEqualityFnFactor(msg_1::Message{<:Dirichlet}, msg_2::Message{<:Function}, msg_3::Nothing) = Message(prodDirFn!(msg_2.dist, msg_1.dist))
ruleSPEqualityFnFactor(msg_1::Message{<:Dirichlet}, msg_2::Nothing, msg_3::Message{<:Function}) = Message(prodDirFn!(msg_3.dist, msg_1.dist))
ruleSPEqualityFnFactor(msg_1::Nothing, msg_2::Message{<:Dirichlet}, msg_3::Message{<:Function}) = Message(prodDirFn!(msg_3.dist, msg_2.dist))

prodDirFn!(dist_fn::Distribution{MatrixVariate, Function}, dist_dir::Distribution{MatrixVariate, Dirichlet}) =
    Distribution(MatrixVariate, Function, log_pdf=(A)->logPdf(dist_dir, A)+dist_fn.params[:log_pdf](A))
;

In [4]:
import ForneyLab: sampleWeightsAndEntropy

# Edit: set to 50 samples
function sampleWeightsAndEntropy(x::Distribution, y::Distribution)
    n_samples = 50 # Number of samples is fixed
    samples = sample(x, n_samples)

    # Apply log-pdf functions to the samples
    log_samples_x = logPdf.([x], samples)
    log_samples_y = logPdf.([y], samples)

    # Extract the sample weights
    w_raw = exp.(log_samples_y) # Unnormalized weights
    w_sum = sum(w_raw)
    weights = w_raw./w_sum # Normalize the raw weights

    # Compute the separate contributions to the entropy
    H_y = log(w_sum) - log(n_samples)
    H_x = -sum( weights.*(log_samples_x + log_samples_y) )
    entropy = H_x + H_y

    # Inform next step about the proposal and integrand to be used in entropy calculation in smoothing
    logproposal = (samples) -> logPdf.([x], samples)
    logintegrand = (samples) -> logPdf.([y], samples)

    return (samples, weights, w_raw, logproposal, logintegrand, entropy)
end
;

# Algorithm for $t=1$

In [5]:
fg_t1 = FactorGraph()

u = Vector{Variable}(undef, 2)
x = Vector{Variable}(undef, 2)
y = Vector{Variable}(undef, 2)

# Slice k=0
@RV x_0 ~ Categorical(placeholder(:D_t_min, dims=(8,)))
@RV A ~ Dirichlet(placeholder(:A_s, dims=(16,8)))

# Slice k=1
@RV u[1]
@RV x[1] ~ Transition(x_0, u[1])
placeholder(u[1], :u, index=1, dims=(8,8))
DiscreteObservation(x[1], 
                    A,
                    placeholder(:C, dims=(16,), var_id=:C_1),
                    n_factors=8)
# Slice k=2
@RV u[2]
@RV x[2] ~ Transition(x[1], u[2])
placeholder(u[2], :u, index=2, dims=(8,8))
DiscreteObservation(x[2], 
                    A,
                    placeholder(:C, dims=(16,), var_id=:C_2),
                    n_factors=8)
# Algorithm
q_t1 = PosteriorFactorization([x_0; x], A, ids=[:X, :A])
algo_t1 = messagePassingAlgorithm(q_t1, id=:t1, free_energy=true)
code_t1 = algorithmSourceCode(algo_t1, free_energy=true)
eval(Meta.parse(code_t1))
;

# Algorithm for $t=2$

In [6]:
fg_t2 = FactorGraph()

u = Vector{Variable}(undef, 2)
x = Vector{Variable}(undef, 2)
y = Vector{Variable}(undef, 2)

# Slice k=0
@RV x_0 ~ Categorical(placeholder(:D_t_min, dims=(8,)))
@RV A ~ Dirichlet(placeholder(:A_s, dims=(16,8)))

# Slice k=1
@RV u[1]
@RV x[1] ~ Transition(x_0, u[1])
placeholder(u[1], :u, index=1, dims=(8,8))
@RV y[1] ~ Transition(x[1], A)
placeholder(y[1], :y, index=1, dims=(16,))

# Slice k=2
@RV u[2]
@RV x[2] ~ Transition(x[1], u[2])
placeholder(u[2], :u, index=2, dims=(8,8))
DiscreteObservation(x[2], 
                    A,
                    placeholder(:C, dims=(16,), var_id=:C_2),
                    n_factors=8)
# Algorithm
q_t2 = PosteriorFactorization([x_0; x], A, ids=[:X, :A])
algo_t2 = messagePassingAlgorithm(q_t2, id=:t2, free_energy=true)
code_t2 = algorithmSourceCode(algo_t2, free_energy=true)
eval(Meta.parse(code_t2))
;

# Algorithm for $t=3$ (Learning)

In [7]:
fg_t3 = FactorGraph()

u = Vector{Variable}(undef, 2)
x = Vector{Variable}(undef, 2)
y = Vector{Variable}(undef, 2)

# Slice k=0
@RV x_0 ~ Categorical(placeholder(:D_t_min, dims=(8,)))
@RV A ~ Dirichlet(placeholder(:A_s, dims=(16,8)))

# Slice k=1
@RV u[1]
@RV x[1] ~ Transition(x_0, u[1])
placeholder(u[1], :u, index=1, dims=(8,8))
@RV y[1] ~ Transition(x[1], A)
placeholder(y[1], :y, index=1, dims=(16,))

# Slice k=2
@RV u[2]
@RV x[2] ~ Transition(x[1], u[2])
placeholder(u[2], :u, index=2, dims=(8,8))
@RV y[2] ~ Transition(x[2], A)
placeholder(y[2], :y, index=2, dims=(16,))

# Algorithm
q_t3 = PosteriorFactorization([x_0; x], A, ids=[:X, :A])
algo_t3 = messagePassingAlgorithm(q_t3, id=:t3, free_energy=true)
code_t3 = algorithmSourceCode(algo_t3, free_energy=true)
eval(Meta.parse(code_t3))
;

## Action-Perception Loop

In [8]:
α = 0.9; c = 2.0 # Reward probability and utility
S = 20 # Number of simulations

include("environment.jl")
include("agent_2.jl")
include("helpers.jl")

(A, B, C, D) = constructABCD(α, c)
A_0 = constructAPrior() # Construct prior statistics for A

(reset, execute, observe) = initializeWorld(A, B, C, D) # Let there be a world
(infer, act) = initializeAgent(A_0, B, C, D) # Let there be a constrained agent

# Step through the experimental protocol
As = Vector{Matrix}(undef, S)
for s = 1:S
    G_ts = Vector{Any}(undef, 2)
    a = Vector{Int64}(undef, 2)
    o = Vector{Vector}(undef, 2)
    reset() # Reset world
    for t = 1:2
            G_ts[t] = infer(t, a, o)
               a[t] = act(G_ts[t])
                      execute(a[t])
        (o[t], r_t) = observe()
    end
    println("Session $s")
    println(round.(G_ts[1], digits=2))
    println(round.(G_ts[2], digits=2))
    println(a)
    As[s] = infer(3, a, o) # Learn
end
;

Session 1
[13.75 14.42 14.55 14.35; 14.35 144.27 144.27 144.27; 14.47 144.27 144.27 144.27; 14.31 14.03 14.18 14.15]
[18.28, 17.48, 17.64, 17.97]
[1, 2]
Session 2
[13.3 13.63 14.16 14.21; 13.52 144.27 144.27 144.27; 14.43 144.27 144.27 144.27; 14.23 13.44 14.25 14.13]
[13.07, 11.46, 12.32, 12.58]
[1, 2]
Session 3
[13.57 13.06 14.31 14.11; 13.01 144.27 144.27 144.27; 14.18 144.27 144.27 144.27; 13.93 12.94 14.16 13.95]
[10.89, 9.03, 10.04, 10.72]
[4, 2]
Session 4
[13.71 13.3 14.2 13.92; 13.61 144.27 144.27 144.27; 14.01 144.27 144.27 144.27; 14.07 13.09 14.08 13.79]
[10.35, 10.41, 10.46, 10.91]
[4, 1]
Session 5
[14.57 13.91 14.58 14.33; 13.81 144.27 144.27 144.27; 14.52 144.27 144.27 144.27; 14.4 13.21 14.04 13.76]
[10.08, 9.02, 9.61, 10.25]
[4, 2]
Session 6
[14.87 14.13 14.58 14.25; 14.19 144.27 144.27 144.27; 14.49 144.27 144.27 144.27; 14.23 13.6 13.95 13.55]
[9.88, 9.52, 9.34, 9.9]
[4, 3]
Session 7
[14.68 14.19 13.92 14.05; 14.25 144.27 144.27 144.27; 14.03 144.27 144.27 144.27; 14.

In [10]:
import Statistics: mean

mean(A::Matrix) = A./sum(A,dims=1)
round.(mean(As[S] - As[1]), digits=1) # Inspect observation probabilities

16×8 Matrix{Float64}:
 0.5  0.5  0.0  0.0  0.0  0.0  0.0  0.0
 0.5  0.5  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.9  0.4  0.0  0.0  0.0  0.0
 0.0  0.0  0.1  0.6  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.9  0.9  0.0  0.0
 0.0  0.0  0.0  0.0  0.1  0.1  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  1.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0