In [None]:
# Nonlinear MAX model definition
@model [default_factorisation=MeanField()] function nonlinear_max_prediction(model_flow::FlowModel, params, h_prior, w_prior, y_prev, u)
    
    dim = length(h_prior[1])
    # initialize variables
    
    y_pred       = datavar(Float64)
    
    del = nr_params(model_flow)
    params_flow = params[1:del]
    params_f = params[del+1:end-2]
    params_w = params[end-1:end]

    θ   = randomvar() where {form_constraint = FixedMarginalConstraint(MvNormalMeanPrecision(zeros(dim), ReactiveMP.huge*diageye(dim)))}  
    w   = randomvar() where {form_constraint = FixedMarginalConstraint(GammaShapeRate(melu(params_w[1]), melu(params_w[2])))}
    h_0 = randomvar() where {form_constraint = FixedMarginalConstraint(MvNormalMeanPrecision(h_prior[1], h_prior[2]))}

    sigmoid_pred = NN(y_prev, u, params_f)

    # compile flow model
    Flow_meta  = FlowMeta(compile(model_flow, params_flow)) # default: FlowMeta(model, Linearization())
    h_0 ~ MvNormalMeanPrecision(h_prior[1], h_prior[2])
    θ   ~ MvNormalMeanPrecision(zeros(dim), ReactiveMP.huge*diageye(dim))
    w   ~ GammaShapeRate(melu(params_w[1]), melu(params_w[2])) where {q=MeanField()}
    
    # specify transformed latent value
    
    AR_meta = ARMeta(Multivariate, dim, ARsafe())

    # specify observations
    
    ar_node, h ~ AR(h_0, θ, w) where {q = q(y, x)q(γ)q(θ), meta = AR_meta}
    
    y_lat_1 ~ Flow(h) where { meta = Flow_meta, q = FullFactorisation() }
    
    y_lat_2 ~ dot(y_lat_1, ones(dim))
    
    y_lat_3 ~ y_lat_2 + sigmoid_pred

    y_node, y ~ NormalMeanPrecision(y_lat_3, w) where { q = q(y, y_lat_3)q(w) }
    
    y_pred ~ NormalMeanPrecision(y, 1e-12)
    
    return h, h_0, θ, w, y_lat_1, y_lat_2, y_lat_3, y, y_node, y_pred
end

In [None]:
function nonlinear_max_prediction(observation_prev::T, control::T, model_flow::FlowModel, params; 
    h_prior=(ones(2), diageye(2)), w_prior=(1.0, 1.0), vmp_its = 50) where T<:Float64

    # define model
    model, (h, h_0, θ, w, y_lat_1, y_lat_2, y_lat_3, y, y_node, y_pred) = nonlinear_max_prediction(model_flow, params, h_prior, w_prior,
                                                                        observation_prev, control,
                                                                        options = (limit_stack_depth = 500, ))

    h_buffer = nothing
    h0_buffer = nothing
    y_out_buffer = nothing
    w_buffer = nothing
    y_pred_buffer = nothing

    h_sub = subscribe!(getmarginal(h), (x) -> h_buffer = x)
    h0_sub = subscribe!(getmarginal(h_0), (x) -> h0_buffer = x)
    y_pred_sub = subscribe!(getmarginal(y), (x) -> y_pred_buffer = x)
    w_sub = subscribe!(getmarginal(w), (x) -> w_buffer = x)

    fe_buffer = Vector()

    # subscribe
    fe_sub = subscribe!(score(BetheFreeEnergy(), model), (f) -> push!(fe_buffer, f))

    setmarginal!(w, GammaShapeRate(w_prior[1], w_prior[2]))
    setmarginal!(θ, MvNormalMeanPrecision(zeros(dim), ReactiveMP.huge*diageye(dim)))
    setmarginal!(y, NormalMeanPrecision(0.0, 1.0))

    for _ in 1:vmp_its
    ReactiveMP.update!(y_pred, 0.0)
    end

    # return the marginal values
    return fe_buffer, h_buffer, h0_buffer, w_buffer, y_pred_buffer

end

In [None]:
# TODO: Wrong, must be fixed

predictions = []
h_pred = (mean(h[end]), precision(h[end]))
h0_pred = (mean(h[end]), precision(h[end]))
w_pred = (shape(w), rate(w))
rw = []
for i in 1:length(X_test)
    fe, h_pred, h0_pred, w_pred, y_pred = nonlinear_max_prediction(X_test[i], U_test[i], model_flow, inf_params, h_prior=h_pred, w_prior=w_pred, vmp_its=1)
    w_pred = (shape(w_pred), rate(w_pred))
    push!(predictions, y_pred)
    h_pred = (mean(h_pred), precision(h_pred))
#     push!(rw, [h_pred[1], h_pred[2]])
end