In [1]:
using Pkg
Pkg.activate(".")

[32m[1m  Activating[22m[39m project at `~/PhD/GaussianProcessNode`


In [2]:
using RxInfer, ReactiveMP, GraphPPL, Rocket

In [3]:
using SpecialFunctions

In [4]:
function freq_ratio(λ_prev, λ_present)
    f_prev = λ_prev ./ sum(λ_prev)
    f_present = λ_present ./ sum(λ_present)
    γ = f_present ./ f_prev
    return γ
end

freq_ratio (generic function with 1 method)

In [5]:
#create MvLogNormal node 
struct MvLogNormalMeanCovariance end 

@node MvLogNormalMeanCovariance Stochastic [ y, μ, Σ ]

@rule MvLogNormalMeanCovariance(:y, Marginalisation) (m_μ::PointMass, m_Σ::PointMass) = begin
   return MvLogNormal(m_μ, m_Σ) 
end

In [6]:
#create element-wise Poisson node 
struct ElementwisePoisson end 

@node ElementwisePoisson Stochastic [ y, x ]

@rule ElementwisePoisson(:y, Marginalisation) (q_x::Vector{GammaDistributionsFamily},) = begin
    α_x, β_x = Distributions.params.(q_x)
    λ = exp.(digamma.(α_x) .- log.(β_x))
    return Poisson.(λ)
end

@rule ElementwisePoisson(:x, Marginalisation) (q_y::Vector{Poisson},)= begin
    λ_y = mean.(q_y)
    return GammaShapeRate.(1 .+ λ_y, ones(length(λ_y)))
end

@rule ElementwisePoisson(:x, Marginalisation) (m_y::PointMass,)= begin
    λ_y = mean.(q_y)
    
    return GammaShapeRate.(1 .+ λ_y, ones(length(λ_y)))
end

In [7]:
#create element-wise Lognormal node 
struct ElementwiseLogNormal end 

@node ElementwiseLogNormal Stochastic [ y, μ, w ]

@rule ElementwiseLogNormal(:y, Marginalisation) (q_μ::MultivariateNormalDistributionsFamily, q_w::Vector{GammaDistributionsFamily},) = begin
    mean_μ = mean(q_μ)
    mean_w = mean.(q_w)
    σ = sqrt.(inv.(mean_w))
    return LogNormal.(mean_μ, σ)
end

@rule ElementwiseLogNormal(:μ, Marginalisation) (q_y::Vector{LogNormal}, q_w::Vector{GammaDistributionsFamily},) = begin
    μ_y = Distributions.params.(q_y)
    mean_w = mean.(q_w)

    return MvNormalMeanPrecision(μ_y, Diagonal(mean_w))
end

@rule ElementwiseLogNormal(:w, Marginalisation) (q_y::Vector{LogNormal}, q_μ::MultivariateNormalDistributionsFamily,) = begin
    μ_y, σ_y = Distributions.params.(q_y)
    mean_μ,cov_μ = mean_cov(q_μ)
    var_μ = diag(cov_μ)

    return GammaShapeRate.(3/2*ones(length(μ_y)),0.5.*(mean_μ^2 .+ var_μ .- 2 .* mean_μ.*μ_y .+ μ_y.^2 .+ σ_y.^2))
end

In [8]:
struct ElementwiseGammaShapeRate end 

@node ElementwiseGammaShapeRate Stochastic [ y, α, β ]

@rule ElementwiseGammaShapeRate(:y, Marginalisation) (m_α::PointMass, m_β::PointMass,) = begin
    return GammaShapeRate.(mean.(m_α), mean.(m_β))
end

@rule ElementwiseGammaShapeRate(:y, Marginalisation) (q_α::PointMass, q_β::PointMass, ) = begin 
    @show mean.(q_α)
    return GammaShapeRate.(mean.(q_α), mean.(q_β))
end

In [9]:
function ReactiveMP.prod(::ProdAnalytical,left::Vector{GammaDistributionsFamily}, right::Vector{GammaDistributionsFamily})
    T = promote_samplefloattype(left[1], right[1])
    return GammaShapeRate.(shape.(left) .+ shape.(right) .- one(T), rate.(left) .+ rate.(right))
end

In [10]:
@model function pop_dynamics(n_time, n_organ)
    y_data = datavar(Vector{Float64},n_time)
    λ = randomvar(n_time)
    s = randomvar()
    w = randomvar()
    γ = randomvar(n_time)

    #prior 
    s ~ MvNormalMeanCovariance(zeros(n_organ), diageye(n_organ))
    w ~ ElementwiseGammaShapeRate(zeros(n_organ), diageye(n_organ)) # this is the inverse of σ

    λ_0 ~ ElementwiseGammaShapeRate(0.01*ones(n_organ), 0.01*ones(n_organ))
    λ_prev = λ_0 
    #consider each time step
    for t=1:n_time 
        λ[t] ~ ElementwiseGammaShapeRate(0.01*ones(n_organ), 0.01*ones(n_organ)) #prior 
        y_data[t] ~ ElementwisePoisson(λ[t])
        γ[t] ~ freq_ratio(λ_prev,λ[t])
        γ[t] ~ ElementwiseLogNormal(s,w)
        λ_prev = λ[t]
    end
end

In [11]:
@meta function pop_dynamics_meta()
    freq_ratio() -> Linearization()
end

pop_dynamics_meta (generic function with 1 method)

In [12]:
@constraints function pop_dynamics_constraints()
    q(γ,s,w) = q(γ)q(s)q(w)
end

pop_dynamics_constraints (generic function with 1 method)

In [13]:
y_data = [50 20;68 23;81 35]
num_time = 3;
num_organ = 2;

In [14]:
nits =2

iresult = inference(
    model = pop_dynamics(num_time,num_organ),
    iterations = nits, 
    data  = (y = y_data,),
    meta = pop_dynamics_meta(),
    constraints = pop_dynamics_constraints(),
    returnvars = (s = KeepLast(),)
)

LoadError: MethodError: no method matching length(::PointMass{Vector{Float64}})
[0mClosest candidates are:
[0m  length([91m::Union{Base.KeySet, Base.ValueIterator}[39m) at abstractdict.jl:58
[0m  length([91m::Union{ZMQ._Message, Base.RefValue{ZMQ._Message}}[39m) at ~/.julia/packages/ZMQ/lrABE/src/_message.jl:31
[0m  length([91m::Union{LinearAlgebra.Adjoint{T, <:Union{StaticArraysCore.StaticArray{Tuple{var"#s2"}, T, 1} where var"#s2", StaticArraysCore.StaticArray{Tuple{var"#s3", var"#s4"}, T, 2} where {var"#s3", var"#s4"}}}, LinearAlgebra.Diagonal{T, <:StaticArraysCore.StaticArray{Tuple{var"#s13"}, T, 1} where var"#s13"}, LinearAlgebra.Hermitian{T, <:StaticArraysCore.StaticArray{Tuple{var"#s10", var"#s11"}, T, 2} where {var"#s10", var"#s11"}}, LinearAlgebra.LowerTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s18", var"#s19"}, T, 2} where {var"#s18", var"#s19"}}, LinearAlgebra.Symmetric{T, <:StaticArraysCore.StaticArray{Tuple{var"#s7", var"#s8"}, T, 2} where {var"#s7", var"#s8"}}, LinearAlgebra.Transpose{T, <:Union{StaticArraysCore.StaticArray{Tuple{var"#s2"}, T, 1} where var"#s2", StaticArraysCore.StaticArray{Tuple{var"#s3", var"#s4"}, T, 2} where {var"#s3", var"#s4"}}}, LinearAlgebra.UnitLowerTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s24", var"#s25"}, T, 2} where {var"#s24", var"#s25"}}, LinearAlgebra.UnitUpperTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s21", var"#s22"}, T, 2} where {var"#s21", var"#s22"}}, LinearAlgebra.UpperTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s15", var"#s16"}, T, 2} where {var"#s15", var"#s16"}}, StaticArraysCore.StaticArray{Tuple{var"#s25"}, T, 1} where var"#s25", StaticArraysCore.StaticArray{Tuple{var"#s1", var"#s3"}, T, 2} where {var"#s1", var"#s3"}, StaticArraysCore.StaticArray{<:Tuple, T}} where T[39m) at ~/.julia/packages/StaticArrays/VLqRb/src/abstractarray.jl:1
[0m  ...