In [8]:
using RxInfer, LinearAlgebra

In [19]:
@rule Categorical(:out, Marginalisation) (m_p::Dirichlet,) = begin
    return Categorical(normalize(mean(m_p), 1))
end
@rule Categorical(:p, Marginalisation) (m_out::Categorical, q_out::PointMass) = begin
    return Dirichlet(probvec(q_out) .+ one(eltype(probvec(q_out))))
end

In [15]:
?@call_rule

No documentation found.

`ReactiveMP.@call_rule` is a macro.

```
# 1 method for macro "@call_rule":
[1] var"@call_rule"(__source__::LineNumberNode, __module__::Module, fform, args) in ReactiveMP at d:\OneDrive - TU Eindhoven\phd\Projects\Packages\ReactiveMP.jl\src\rule.jl:406
```


In [10]:
struct EnforceMarginalFunctionalDependency <: ReactiveMP.AbstractNodeFunctionalDependenciesPipeline
    edge :: Symbol
end

function ReactiveMP.message_dependencies(::EnforceMarginalFunctionalDependency, nodeinterfaces, nodelocalmarginals, varcluster, cindex, iindex)
    return ReactiveMP.message_dependencies(ReactiveMP.DefaultFunctionalDependencies(), nodeinterfaces, nodelocalmarginals, varcluster, cindex, iindex)
end

function ReactiveMP.marginal_dependencies(enforce::EnforceMarginalFunctionalDependency, nodeinterfaces, nodelocalmarginals, varcluster, cindex, iindex)
    default = ReactiveMP.marginal_dependencies(ReactiveMP.DefaultFunctionalDependencies(), nodeinterfaces, nodelocalmarginals, varcluster, cindex, iindex)
    index   = ReactiveMP.findnext(i -> name(i) === enforce.edge, nodeinterfaces, 1)
    if index === iindex 
        return default
    end
    vmarginal = ReactiveMP.getmarginal(ReactiveMP.connectedvar(nodeinterfaces[index]), IncludeAll())
    loc = ReactiveMP.FactorNodeLocalMarginal(-1, index, enforce.edge)
    ReactiveMP.setstream!(loc, vmarginal)
    # Find insertion position (probably might be implemented more efficiently)
    insertafter = sum(first(el) < iindex ? 1 : 0 for el in default; init = 0)
    return ReactiveMP.TupleTools.insertafter(default, insertafter, (loc, ))
end

In [11]:
function RxInfer.default_point_mass_form_constraint_optimizer(::Type{Univariate}, ::Type{Discrete}, constraint::RxInfer.PointMassFormConstraint, distribution)

    # fetch probvec
    p = probvec(distribution)

    # create new probvec
    p_new = zeros(length(p))
    p_new[argmax(p)] = 1

    return PointMass(p_new)
end

In [12]:
@model function model_issue()
    y = datavar(Vector{Float64})

    α ~ Dirichlet(0.01 .* ones(3))
    z_old ~ Categorical(α) where { pipeline = EnforceMarginalFunctionalDependency(:out) }
    z_new ~ Transition(z_old, diagm(ones(3))) where { pipeline = EnforceMarginalFunctionalDependency(:in) }
    y ~ Transition(z_new, diagm(ones(3)))

    return y, z_new, z_old, α

end

In [13]:
@constraints function constraints_issue()
    q(z_old) :: PointMass
end

constraints_issue (generic function with 1 method)

In [20]:
results_combination = inference(
    model = model_issue(), 
    data  = ( y = [1.0, 0.0, 0.0], ),
    constraints = constraints_issue(),
    returnvars = ( α=KeepLast(), ),
)

Inference results:
  Posteriors       | available for (α)


In [21]:
results_combination.posteriors[:α]

Marginal(Dirichlet{Float64, Vector{Float64}, Float64}(alpha=[1.0099999999999998, 0.010000000000000009, 0.010000000000000009]))