Load necessary libraries

In [17]:
using Gen
using GenViz
using Statistics
include("DistributionsBacked.jl")
include("IterDeep.jl")
using PyPlot

const my_inv_gamma = DistributionsBacked{Float64}((alpha, theta) -> Distributions.InverseGamma(alpha, theta), [true, true], true)

DistributionsBacked{Float64}(var"#23#24"(), Bool[1, 1], true)

Set up visualization server. (Careful: run only once)

In [2]:
server = VizServer(8092);

Space for declaring constants.

In [3]:
INV_GAMMA_PRIOR_CONSTANT = 0.1;

Canonical Rats data, from from section 6 of Gelfand et al (1990). See http://www.openbugs.net/Examples/Rats.html

In [4]:
ys_raw = ([151, 145, 147, 155, 135, 159, 141, 159, 177, 134,
    160, 143, 154, 171, 163, 160, 142, 156, 157, 152, 154, 139, 146,
    157, 132, 160, 169, 157, 137, 153, 199, 199, 214, 200, 188, 210,
    189, 201, 236, 182, 208, 188, 200, 221, 216, 207, 187, 203, 212,
    203, 205, 190, 191, 211, 185, 207, 216, 205, 180, 200, 246, 249,
    263, 237, 230, 252, 231, 248, 285, 220, 261, 220, 244, 270, 242,
    248, 234, 243, 259, 246, 253, 225, 229, 250, 237, 257, 261, 248,
    219, 244, 283, 293, 312, 272, 280, 298, 275, 297, 350, 260, 313,
    273, 289, 326, 281, 288, 280, 283, 307, 286, 298, 267, 272, 285,
    286, 303, 295, 289, 258, 286, 320, 354, 328, 297, 323, 331, 305,
    338, 376, 296, 352, 314, 325, 358, 312, 324, 316, 317, 336, 321,
    334, 302, 302, 323, 331, 345, 333, 316, 291, 324])
ys = reshape([Float64(y) for y = ys_raw],30,5)
xs = [8.0, 15.0, 22.0, 29.0, 36.0]

5-element Array{Float64,1}:
  8.0
 15.0
 22.0
 29.0
 36.0

These data are the weights of 30 rats, measured at 5 common age values in days. The model is a simple hierarchical/random effects one; a random intercept and random slope for each rat, plus an error term (presumably representing not measurement error, but just the random variation of individual growth curves).

In [5]:
for i in 1:30
    plot(xs,ys[i,:])
end
xlabel("Age (days)")
ylabel("Weight (g)")
title("Rat growth")

UndefVarError: UndefVarError: plot not defined

In [6]:
@gen function model_prematurely_optimized(xs::Vector,N::Int32,) #single objects across rats (the dim of size N)
    T = length(xs)
    xbar = mean(xs) #could be precomputed, but YKWTS about premature optimization...

    mu_alpha ~ normal(0, 100)
    mu_beta ~ normal(0, 100)
    sigmasq_y ~ my_inv_gamma(INV_GAMMA_PRIOR_CONSTANT, INV_GAMMA_PRIOR_CONSTANT)
    sigmasq_alpha ~ my_inv_gamma(INV_GAMMA_PRIOR_CONSTANT, INV_GAMMA_PRIOR_CONSTANT)
    sigmasq_beta ~ my_inv_gamma(INV_GAMMA_PRIOR_CONSTANT, INV_GAMMA_PRIOR_CONSTANT)

    alpha ~ normal(fill(mu_alpha,N), sqrt(sigmasq_alpha)) # vectorized
    beta ~ normal(fill(mu_beta,N), sqrt(sigmasq_beta))  # vectorized
    ys ~ normal([alpha[n] + beta[n] * (x[t] - xbar) for n in 1:N, t in 1:T], sqrt(sigmasq_y))
end
;

In [7]:
@gen function model(xs::Vector,N::Int32,)
    T = length(xs)
    xbar = mean(xs) #could be precomputed, but YKWTS about premature optimization...

    mu_alpha ~ normal(0, 100)
    mu_beta ~ normal(0, 10)
    sigmasq_y ~ my_inv_gamma(INV_GAMMA_PRIOR_CONSTANT, INV_GAMMA_PRIOR_CONSTANT)
    sigmasq_alpha ~ my_inv_gamma(INV_GAMMA_PRIOR_CONSTANT, INV_GAMMA_PRIOR_CONSTANT)
    sigmasq_beta ~ my_inv_gamma(INV_GAMMA_PRIOR_CONSTANT, INV_GAMMA_PRIOR_CONSTANT)

    alpha = Vector{Float64}(undef,N)
    beta = Vector{Float64}(undef,N)
    y = Vector{Vector{Float64}}(undef,N)
    for n in 1:N
        alpha[n] = ({:data => n => :alpha} ~ normal(mu_alpha, sqrt(sigmasq_alpha))) # vectorized
        beta[n] = ({:data => n => :beta} ~ normal(mu_beta, sqrt(sigmasq_beta)))  # vectorized
        y[n] = ({:data => n => :y} ~ broadcasted_normal([alpha[n] + beta[n] * (xs[t] - xbar) for t = 1:T],
                                            sqrt(sigmasq_y)))
    end
end
;

In [8]:
function make_constraints(ys::Array)
    (N,T) = size(ys)
    constraints = Gen.choicemap()
    for n in 1:N
        constraints[:data => n => :y] = ys[n,:]
    end
    constraints
end
;

In [9]:
function mcmc_inference(xs, ys, num_iters, update)
    N = size(ys)[1]
    observations = make_constraints(ys)
    (trace, _) = generate(model, (xs, N), observations)
    for iter=1:num_iters
        trace = update(trace,N,observations)
    end
    trace
end
;

In [10]:
function block_mh(tr,N,observations)
    (tr, _) = mh(tr, select(:mu_alpha, :mu_beta))
    (tr, _) = mh(tr, select(:sigmasq_alpha, :sigmasq_beta, :sigmasq_y))

    for n in 1:N
        (tr, _) = mh(tr, select(:data => n => :y,
                                :data => n => :alpha,:data => n => :beta))
    end
    tr
end
;

In [23]:
function simple_hmc(tr,N,observations)
    (tr, _) = hmc(tr, Gen.complement(select(observations)))
    tr
end
;

In [29]:
tr = mcmc_inference(xs, ys, 200, block_mh)
get_choices(tr)
;

In [19]:
print("mu_alpha: $(tr[:mu_alpha]), mu_beta: $(tr[:mu_beta]), sigma_y: $(sqrt(tr[:sigmasq_y]))")

mu_alpha: -62.08911913701273, mu_beta: -2.904699224028846, sigma_y: 10.424544392090809

In [20]:

function visualize_rats(trace,ys)
    assmt = Gen.get_choices(trace)
    (xs,N) = Gen.get_args(trace)
    Dict("slope" => assmt[:slope],
        "intercept" => assmt[:intercept],
        "inlier_std" => assmt[:noise],
        "y-coords" => [assmt[:data => i => :y] for i in 1:length(xs)],
        "outliers" => [assmt[:data => i => :is_outlier] for i in 1:length(xs)])
end;

In [30]:
tr2 = mcmc_inference(xs, ys, 200, simple_hmc)
get_choices(tr2)
;

In [27]:
print("mu_alpha: $(tr2[:mu_alpha]), mu_beta: $(tr2[:mu_beta]), sigma_y: $(sqrt(tr2[:sigmasq_y]))\n")

mu_alpha: 25.733551442646895, mu_beta: -0.7245588767483507, sigma_y: 28.082571338699566

In [28]:
for i = 1:5
    tr2 = mcmc_inference(xs, ys, 300, simple_hmc)
    print("mu_alpha: $(tr2[:mu_alpha]), mu_beta: $(tr2[:mu_beta]), sigma_y: $(sqrt(tr2[:sigmasq_y]))\n")
end
    

mu_alpha: 14.267175410432042, mu_beta: 14.392528336152559, sigma_y: 4.569658902596242
mu_alpha: -61.531685609791595, mu_beta: -4.125014033246265, sigma_y: 54.05637335084533


DomainError: DomainError with -0.06038205559423504:
log will only return a complex result if called with a complex argument. Try log(Complex(x)).