# T-Maze Inference for Planning with GBFE

In [None]:
using LinearAlgebra
using ForneyLab
using Plots

# T-maze layout
# [2| |3]
#   | |
#   |1|
#   |4|

include("factor_nodes/GFECategorical.jl")
include("update_rules/GFECategorical.jl")
;

┌ Info: Precompiling ForneyLab [9fc3f58a-c2cc-5bff-9419-6a294fefdca9]
└ @ Base loading.jl:1662
┌ Info: Precompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]
└ @ Base loading.jl:1662


# Model

In [None]:
T = 2

fg = FactorGraph()

u = Vector{Variable}(undef, T)
x = Vector{Variable}(undef, T)

@RV x_t_min ~ Categorical(placeholder(:D_t_min, dims=(8,)))

x_k_min = x_t_min
for k=1:T
    @RV u[k]
    @RV x[k] ~ Transition(x_k_min, u[k],id=:x_*k)

    placeholder(u[k], :u, index=k, dims=(8,8))
    GFECategorical(x[k], 
                   placeholder(:A, dims=(16,8), var_id=:A_*k), 
                   placeholder(:C, dims=(16,), index=k, var_id=:C_*k),
                   n_factors=8)
    
    x_k_min = x[k]
end
;

In [None]:
q = PosteriorFactorization(fg)
algo = messagePassingAlgorithm(x_t_min, free_energy=true)
code = algorithmSourceCode(algo, free_energy=true)
eval(Meta.parse(code))
;

In [None]:
println(code)

# Results

In [None]:
# Reward probability and utility, uncomment scenario of interest
α = 0.9; c = 2.0

include("environment.jl")
include("agent.jl")
include("helpers.jl")

(A, B, C, D) = constructABCD(α, c)
;

In [None]:
# Single policy
pi = [4, 2]

n_its = 10
G = zeros(n_its)

data = Dict(:u       => [B[pi[1]], B[pi[2]]],
            :A       => A,
            :C       => [C, C],
            :D_t_min => D)

marginals = Dict{Symbol, ProbabilityDistribution}(
    :x_1 => ProbabilityDistribution(Univariate, Categorical, p=ones(8)./8),
    :x_2 => ProbabilityDistribution(Univariate, Categorical, p=ones(8)./8))

messages = init()

for k=1:n_its
    step!(data, marginals, messages)
    G[k] = freeEnergy(data, marginals)
end
  
G = G./log(2) # Convert to bits
;


In [None]:
plot(1:n_its, G, color=:black, grid=true, linewidth=2, legend=false, xlabel="Iteration", ylabel="GFE [bits]")

In [None]:
# GBFE for all policies
GBFE = evaluatePoliciesGBFE(A, B, C, D)
plotResults(GBFE, clim=(15.0,60.0), dpi=300, highlight=minimum)
#savefig("GBFE_c_$(c)_a_$(α).png")