In [1]:
import Pkg; Pkg.activate(".."); Pkg.instantiate();

[32m[1m  Activating[22m[39m project at `c:\Users\matti\Documents\TUe\Thesis\ThesisParallelMP`


In [2]:
using Distributed

In [3]:
addprocs(8);

In [4]:
@everywhere using RxInfer, Random, StatsPlots, LinearAlgebra, ProgressMeter, PGFPlotsX, Revise

In [5]:
@everywhere includet("../distributed.jl")

In [6]:
pgfplotsx()

Plots.PGFPlotsXBackend()

In [7]:
@everywhere begin
    n = 200  # Number of coin flips
    p = 0.75 # Bias of a coin

    distribution = Bernoulli(p)
    dataset      = float.(rand(Bernoulli(p), n))
    prior = Beta(2.0, 6.0)
end

In [8]:
@everywhere begin
    @model function coin_model(n, prior, prod_strategy, pipeline, meta, prod_constraint)

        y = datavar(Float64, n)
        θ = randomvar() where { prod_strategy = prod_strategy, prod_constraint = prod_constraint }

        θ ~ prior

        for i in 1:n
            y[i] ~ Bernoulli(θ) where { pipeline = pipeline, meta = meta  }
        end

    end
end

In [9]:
@everywhere begin
    struct MyCustomRule
        size::Int
    end

    function factorial(n)
        if n == 0 || n == 1
            return 1
        else
            return factorial(n - 1) + factorial(n - 2)
        end
    end

    @rule Bernoulli(:p, Marginalisation) (q_out::PointMass, meta::MyCustomRule) = begin
        @logscale -log(2)
        r = mean(q_out)
        s = factorial(meta.size) / 10^8
        return Beta(one(r) + r, 2one(r) - r + s)
    end
end

In [10]:
struct MyCustomProd
    size::Int
end

function RxInfer.BayesBase.prod(custom::MyCustomProd, left, right)
    r = mean(left)
	s = factorial(custom.size) / 10^8
    return prod(GenericProd(), Beta(one(r) + r, 2one(r) - r + s), right)
end

function RxInfer.BayesBase.prod(custom::MyCustomProd, left::Missing, right)
    return right
end

In [23]:
result_full = infer(
    model = coin_model(length(dataset), prior, nothing, nothing, MyCustomRule(1), nothing),
    data  = (y = dataset, ),
    iterations = 1
)

Inference results:
  Posteriors       | available for (θ)


In [24]:
result_full.posteriors[:θ]

1-element Vector{Beta{Float64}}:
 Beta{Float64}(α=139.0, β=69.00000200000001)

In [14]:
channel = RemoteChannel(1)

RemoteChannel{Channel{Any}}(1, 1, 210)

In [25]:
res_prod = Vector{Any}(undef, 1)
start = time_ns()

infer(
    model = coin_model(Int(length(dataset) / 2), prior, nothing, nothing, MyCustomRule(1), nothing),
    data  = (y = dataset[1:Int(length(dataset) / 2)], ),
    callbacks = (
        after_iteration = (model, iteration) -> begin
            put!(channel, ReactiveMP.getdata(Rocket.getrecent(model.variables.random[1].marginal)))
            return false
        end,
    ),
    iterations = 1
)

results_split = infer(
    model = coin_model(Int(length(dataset) / 2), prior, nothing, nothing, MyCustomRule(1), nothing),
    data  = (y = dataset[Int(length(dataset) / 2)+1:end], ),
    callbacks = (
        after_iteration = (model, iteration) -> begin
            res_prod[1] = prod(GenericProd(), take!(channel), ReactiveMP.getdata(Rocket.getrecent(model.variables.random[1].marginal)))
            return false
        end,
    ),
    iterations = 1
)

println("Time: ", (time_ns() - start) / 1e9)

Time: 0.2612629


In [26]:
res_prod[1]

Beta{Float64}(α=140.0, β=74.00000200000015)

In [32]:
res_prod = Vector{Any}(undef, 1)
start = time_ns()

@spawnat 2 infer(
    model = coin_model(Int(length(dataset) / 2), prior, nothing, nothing, MyCustomRule(1), nothing),
    data  = (y = dataset[1:Int(length(dataset) / 2)], ),
    callbacks = (
        after_iteration = (model, iteration) -> begin
            put!(channel, ReactiveMP.getdata(Rocket.getrecent(model.variables.random[1].marginal)))
            return false
        end,
    ),
    iterations = 1
)

results_split = infer(
    model = coin_model(Int(length(dataset) / 2), prior, nothing, nothing, MyCustomRule(1), nothing),
    data  = (y = dataset[Int(length(dataset) / 2)+1:end], ),
    callbacks = (
        after_iteration = (model, iteration) -> begin
            res_prod[1] = prod(GenericProd(), take!(channel), ReactiveMP.getdata(Rocket.getrecent(model.variables.random[1].marginal)))
            return false
        end,
    ),
    iterations = 1
)

println("Time: ", (time_ns() - start) / 1e9)

Time: 0.185384


In [33]:
res_prod[1]

Beta{Float64}(α=140.0, β=74.00000200000015)