# Policy Inference for the T-Maze Navigation Task

This demo is based upon Part I and II of the "Realising Synthetic Active Inference Agents" paper series.

In [2]:
using Pkg
Pkg.activate("../..")

using RxAIF
using RxInfer

include("../../src/fixes.jl")

[32m[1m  Activating[22m[39m project at `c:\Simulations\RxAIF`


## Define the Simulation Parameters

In [3]:
include("helpers.jl")

# Set simulation parameters
(A, B, C, D_0) = constructABCD(0.9, 2.0)
r              = [1, 0] # Reward at position 2
x_0            = zeros(8)
x_0[1:2]       = r # Start from position 1

# Environmental parameters
env = (
    x_0 = x_0,
    A   = A,
    B   = B
)

# Model parameters
params = (
    T = 2,
    A = A,
    B = B,
    C = C
)

# Model prior statistics
stats = Dict(
    :D_t_min => D_0 # Initial state belief
)
;

## Define the Generative Model and Constraints

In [4]:
# Define the regulator model
@model function t_maze_plan(tau, params, stats, c)
    x_t_min ~ Categorical(stats[:D_t_min]) # State prior

    x_k_min = x_t_min
    for k=1:tau
        u[k] ~ Categorical(ones(4)./4)
        x[k] ~ TransitionMixture(x_k_min, u[k], B[1], B[2], B[3], B[4]) # TODO: params.B[1] gives error
        c[k] ~ GoalObservation(x[k], params.A) where {
            meta         = GeneralizedMeta(), 
            dependencies = GeneralizedPipeline(vague(Categorical,8))} # With breaker message

        x_k_min = x[k] # Reset for next slice
    end
end

# Define constraints on the variational distributions
@constraints function structured()
    q(u) :: PointMassFormConstraint()
end

# Initialize variational distributions
@initialization function init_marginals()
    q(x) = Categorical(softmax(rand(8)))
end
;

In [5]:
# Define the regulator model
@model function t_maze_estimate(params, stats, y_t, u_t)
    x_t_min ~ Categorical(stats[:D_t_min]) # State prior
    x_t     ~ Transition(x_t_min, B[u_t]) # TODO: params.B[u_t] doesn't work
    y_t     ~ Transition(x_t, params.A)
end
;

## Execute the Perception-Action Cycle

In [6]:
include("helpers.jl")
include("environment.jl")
include("agent.jl")

(execute, observe)    = initializeWorld(env) # Let there be a world
(plan, act, estimate) = initializeAgent(params, stats) # Let there be an agent

a = Vector{Int64}(undef, params.T) # Actions per time
o = Vector{Vector}(undef, params.T) # Observations (one-hot) per time
for t=1:params.T
            plan(t)
    a[t]  = act()
            execute(a[t])
    o[t]  = observe()
            estimate(o[t], a[t])
end


## Results

In [7]:
a

2-element Vector{Int64}:
 4
 2