# T-Maze Goal Learning

In [None]:
using Pkg
Pkg.activate("..")
Pkg.instantiate()

In [None]:
using RxInfer, LinearAlgebra, Plots

## Generalized Agent

In [None]:
include("../goal_observation.jl")

# Define the generative model
@model function t_maze_generalized(A, C_s, D, x)
    u = datavar(Matrix{Int64}, 2) # Policy for evaluations
    z = randomvar(2) # Latent states
    c = randomvar(2) # Goal priors

    z_0 ~ Categorical(D) # State prior

    z_k_min = z_0
    for k=1:2
        z[k] ~ Transition(z_k_min, u[k])
        c[k] ~ GoalObservation(z[k], A) where {
            meta=GeneralizedMeta(x[k]), 
            pipeline=GeneralizedPipeline(vague(Categorical,8))} # With breaker message
        c[k] ~ Dirichlet(C_s[k])

        z_k_min = z[k] # Reset for next slice
    end
end

# Define constraints on the variational density
@constraints function structured()
    q(z_0, z, c) = q(z_0, z)q(c)
end

## Generalized Simulation

In [None]:
# Define experimental setting
α = 0.9 # Reward probability
S = 10 # Number of trials
seed = 666 # Randomizer seed

include("helpers.jl")
include("environment.jl")
include("goal_agent.jl")

(A, B, C, D) = constructABCD(α, 0.0)
C_0 = constructGoalPriors() # Construct prior statistics for C's

rs = generateGoalSequence(seed, S) # Sets random seed and returns reproducible goal sequence
(reset, execute, observe) = initializeWorld(A, B, C, D, rs) # Let there be a world
(infer, act) = initializeGoalAgent(A, B, C_0, D, # Let there be a constrained agent
                                   t_maze_model=t_maze_generalized)

# Step through the experimental protocol
Cs = Vector{Vector}(undef, S) # Posterior statistics for C's
Gs = [Vector{Matrix}(undef, 3) for s=1:S] # Free energy values per time
as = [Vector{Int64}(undef, 2) for s=1:S] # Actions per time
os = [Vector{Vector}(undef, 2) for s=1:S] # Observations (one-hot) per time
for s = 1:S
    reset(s) # Reset world
    for t=1:2
        (Gs[s][t], _) = infer(t, as[s], os[s])
             as[s][t] = act(t, Gs[s][t])
                        execute(as[s][t])
             os[s][t] = observe()
    end
    (Gs[s][3], Cs[s]) = infer(3, as[s], os[s]) # Learn at t=3
end
;

## Results

In [None]:
include("visualizations.jl")
plotLearnedGoals(C_0, C, Gs, S)
savefig("figures/GFE_C")