forked from mattjj/pybasicbayes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
EM_demo.py
72 lines (50 loc) · 1.86 KB
/
EM_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from __future__ import division
import numpy as np
np.seterr(invalid='raise')
from matplotlib import pyplot as plt
import copy
from pybasicbayes import models, distributions
from util.text import progprint_xrange
# EM is really terrible! Here's a demo of how to do it on really easy data
### generate and plot the data
alpha_0=100.
obs_hypparams=dict(mu_0=np.zeros(2),sigma_0=np.eye(2),kappa_0=0.05,nu_0=5)
priormodel = models.Mixture(alpha_0=alpha_0,
components=[distributions.Gaussian(**obs_hypparams) for itr in range(6)])
data = priormodel.rvs(200)
del priormodel
plt.figure()
plt.plot(data[:,0],data[:,1],'kx')
plt.title('data')
min_num_components, max_num_components = (1,12)
num_tries_each = 5
### search over models using BIC as a model selection criterion
BICs = []
examplemodels = []
for idx, num_components in enumerate(progprint_xrange(min_num_components,max_num_components+1)):
theseBICs = []
for i in xrange(num_tries_each):
fitmodel = models.Mixture(
alpha_0=10000, # used for random initialization Gibbs sampling, big means use all components
components=[distributions.Gaussian(**obs_hypparams) for itr in range(num_components)])
fitmodel.add_data(data)
# use Gibbs sampling for initialization
for itr in xrange(100):
fitmodel.resample_model()
# use EM to fit a model
for itr in xrange(50):
fitmodel.EM_step()
theseBICs.append(fitmodel.BIC())
examplemodels.append(copy.deepcopy(fitmodel))
BICs.append(theseBICs)
plt.figure()
plt.errorbar(
x=np.arange(min_num_components,max_num_components+1),
y=[np.mean(x) for x in BICs],
yerr=[np.std(x) for x in BICs]
)
plt.xlabel('num components')
plt.ylabel('BIC')
examplemodels[np.argmin([np.mean(x) for x in BICs])].plot()
plt.title('a decent model')
plt.show()