-
Notifications
You must be signed in to change notification settings - Fork 0
/
mv_gaussian.jl
98 lines (91 loc) · 4.12 KB
/
mv_gaussian.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
##############################################################################################################
# load packages
##############################################################################################################
cd(@__DIR__)
using Pkg
Pkg.activate("..")
using Revise
using NormalizingFlowsTutorials
using InvertibleNetworks: NetworkConditionalGlow
using Distributions
using LinearAlgebra
using Random
using PyPlot
Random.seed!(8744)
##############################################################################################################
# generate data
##############################################################################################################
# posterior inference on unseen observation
μ = [-1,1]
σ = 1
n_obs = 50
# what shape should this be?
data = reshape(rand(MvNormal(μ, σ * I(2)), n_obs), (1,2,n_obs,1))
##############################################################################################################
# generate training data
##############################################################################################################
function sample_prior()
μ = rand(Normal(0, 1), 2)
σ = rand(truncated(LogNormal(1, 1), 0, 10))
return [μ...,σ]
end
n_samples = 1000
n_parms = 4
n_train = 10000
# train using samples from joint distribution x,y ~ p(x,y) where x=[μ, σ] -> y = N(μ, σ)
# rows: μ, σ, y
x_train = mapreduce(x -> sample_prior(), hcat, 1:n_train)
# hack to make even number of parameters.
x_train = [x_train[1,:]'; x_train]
# what shape should this be?
# Dimensions: n_train, n_obs, MvNormal variables (2)
y_train = map(i -> rand(MvNormal(x_train[1:2,i],x_train[3,i] * I(2)), n_obs), 1:n_train)
##############################################################################################################
# sample prior distribution
##############################################################################################################
x_prior = mapreduce(x -> sample_prior(), hcat, 1:n_samples)'
##############################################################################################################
# train neural network
##############################################################################################################
n_epochs = 10
batch_size = 1000
n_batches = div(n_train, batch_size)
n_hidden = 32
n_multiscale = 3
n_coupling = 4
network = NetworkConditionalGlow(n_parms, n_obs, n_hidden, n_multiscale, n_coupling)
losses = train!(network, x_train, y_train; n_epochs, n_batches, batch_size, n_obs)
fig = figure()
plot(losses)
xlabel("iterations")
ylabel("loss")
fig
##############################################################################################################
# sample from posterior distribution
##############################################################################################################
x_post = sample_posterior(network, data; n_parms, n_samples)
##############################################################################################################
# plot results
##############################################################################################################
fig = figure()
subplot(1,2,1)
hist(x_prior[:,1];alpha=0.7,density=true,label="Prior")
hist(x_post[:,1];alpha=0.7,density=true,label="Posterior")
axvline(μ[1], color="k", linewidth=1,label="Ground truth")
xlabel(L"\mu_1"); ylabel("Density");
legend()
fig = figure()
subplot(1,2,1)
hist(x_prior[:,2];alpha=0.7,density=true,label="Prior")
hist(x_post[:,2];alpha=0.7,density=true,label="Posterior")
axvline(μ[2], color="k", linewidth=1,label="Ground truth")
xlabel(L"\mu_2"); ylabel("Density");
legend()
subplot(1,2,3)
hist(x_prior[:,3]; alpha=0.7,density=true,label="Prior")
hist(x_post[:,3]; alpha=0.7,density=true,label="Posterior")
axvline(σ, color="k", linewidth=1,label="Ground truth")
xlabel(L"\sigma"); ylabel("Density");
legend()
tight_layout()
fig