In [None]:
using Pkg;Pkg.activate("../..");Pkg.instantiate();
using RxInfer,ReactiveMP,GraphPPL,Rocket, LinearAlgebra, Distributions, Random
Random.seed!(666)

[32m[1m  Activating[22m[39m project at `~/biaslab/repos/EpistemicMessagePassing`


In [None]:
# Load all the helpers
include("transition_mixture/transition_mixture.jl")
include("transition_mixture/marginals.jl")
include("transition_mixture/in.jl")
include("transition_mixture/out.jl")
include("transition_mixture/switch.jl")
include("goal_observation.jl")
include("helpers.jl")

In [None]:
# We need to make pointmass constraints for discrete vars by hand
import RxInfer.default_point_mass_form_constraint_optimizer
import RxInfer.PointMassFormConstraint

function default_point_mass_form_constraint_optimizer(
    ::Type{Univariate},
    ::Type{Discrete},
    constraint::PointMassFormConstraint,
    distribution
)

    out = zeros( length(probvec(distribution)))
    out[argmax(probvec(distribution))] = 1.

    PointMass(out)
end

In [None]:
# Create the model
@model function t_maze(A,D,B1,B2,B3,B4,T)

    z_0 ~ Categorical(D)

    z = randomvar(T)
    switch = randomvar(T)

    c = datavar(Vector{Float64}, T)
    z_prev = z_0

    for t in 1:T
        switch[t] ~ Categorical(fill(1. /4. ,4))
        z[t] ~ TransitionMixture(z_prev,switch[t], B1,B2,B3,B4)
        c[t] ~ GoalObservation(z[t], A) where {pipeline = GeneralizedPipeline(vague(Categorical, 8)) }
        z_prev = z[t]
    end
end;

In [None]:
#Pointmass constraints
@constraints function pointmass_q()
    q(switch) :: PointMass
end

# Node constraints
@meta function t_maze_meta()
    GoalObservation(c,z) -> GeneralizedMeta()
end

In [None]:
# Configure experiment
T =2; # Planning horizon
its = 10; #Number of inference iterations to run
initmarginals = ( z = [Categorical(fill(1. /8. ,8)) for t in 1:T], ) # Initial marginals

A,B,C,D = constructABCD(0.9,[2.0 for t in 1:T],T); # Generate the matrices we need

In [None]:
# Run inference
result = inference(model = t_maze(A,D,B[1],B[2],B[3],B[4],T),
                   data= (c = C,),
                   initmarginals = initmarginals,
                   meta= t_maze_meta(),
                   iterations=its,
                  )

In [None]:
# Inspect results
println("Posterior controls as T=1, ", probvec.(result.posteriors[:switch][end][1]), "\n")
println("Posterior controls as T=2, ", probvec.(result.posteriors[:switch][end][2]))


In [None]:
# Repeat experiments with pointmass constraints
result = inference(model = t_maze(A,D,B[1],B[2],B[3],B[4],T),
                   data= (c = C,),
                   initmarginals = initmarginals,
                   meta= t_maze_meta(),
                   constraints=pointmass_q(),
                   iterations=its,
                  )

println("Posterior controls as T=1, ",probvec(result.posteriors[:switch][end][1]), "\n")
println("Posterior controls as T=2, ",probvec(result.posteriors[:switch][end][2]), "\n")