Skip to content

Commit

Permalink
Merge pull request #33 from bhmm/test_patho
Browse files Browse the repository at this point in the history
debug
  • Loading branch information
franknoe committed Feb 13, 2016
2 parents b37b01b + 39e961a commit 88abf76
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
18 changes: 7 additions & 11 deletions bhmm/output_models/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,6 @@ def sample(self, observations, prior=None):
>>> nobs = 1000
>>> output_model = GaussianOutputModel(nstates=nstates, means=[-1, 0, 1], sigmas=[0.5, 1, 2])
>>> observations = [ output_model.generate_observations_from_state(state_index, nobs) for state_index in range(nstates) ]
>>> weights = [ np.zeros([nobs,nstates], np.float32).T for _ in range(nstates) ]
Update output parameters by sampling.
Expand All @@ -412,16 +411,13 @@ def sample(self, observations, prior=None):
# Skip update if no observations.
if nsamples_in_state == 0:
logger().warn('Warning: State %d has no observations.' % state_index)
continue

# Sample new mu.
self.means[state_index] = np.random.randn()*self.sigmas[state_index]/np.sqrt(nsamples_in_state) + np.mean(observations_in_state)

# Sample new sigma.
# This scheme uses the improper Jeffreys prior on sigma^2, P(mu, sigma^2) \propto 1/sigma
chisquared = np.random.chisquare(nsamples_in_state-1)
sigmahat2 = np.mean((observations_in_state - self.means[state_index])**2)
self.sigmas[state_index] = np.sqrt(sigmahat2) / np.sqrt(chisquared / nsamples_in_state)
if nsamples_in_state > 0: # Sample new mu.
self.means[state_index] = np.random.randn()*self.sigmas[state_index]/np.sqrt(nsamples_in_state) + np.mean(observations_in_state)
if nsamples_in_state > 1: # Sample new sigma
# This scheme uses the improper Jeffreys prior on sigma^2, P(mu, sigma^2) \propto 1/sigma
chisquared = np.random.chisquare(nsamples_in_state-1)
sigmahat2 = np.mean((observations_in_state - self.means[state_index])**2)
self.sigmas[state_index] = np.sqrt(sigmahat2) / np.sqrt(chisquared / nsamples_in_state)

return

Expand Down
2 changes: 1 addition & 1 deletion bhmm/tests/test_bhmm_patho.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_2state_nonrev_step(self):
sampled = bhmm.bayesian_hmm([obs], mle, reversible=False, nsample=2000,
p0_prior='mixed', transition_matrix_prior='mixed')
assert np.all(sampled.transition_matrix_std[0] > 0)
assert np.allclose(sampled.transition_matrix_std[1], [0, 0])
assert np.max(np.abs(sampled.transition_matrix_std[1])) < 1e-3

def test_2state_rev_2step(self):
obs = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0], dtype=int)
Expand Down

0 comments on commit 88abf76

Please sign in to comment.