New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hidden Markov model with Gaussian emissions not training properly #400
Comments
Howdy It would be easiest for me to debug if you could provide the data as well. Can you also check the version of pomegranate that you have? Taking a look at the results when you toggle |
I'm on 0.7.3. |
Can you try also upgrading to 0.9.0 as well? |
The data file is attached: The error persists unchanged with 0.9.0
We see that when the data and regime centroids are projected onto the 0th and 3rd axes, there is no relation to the region the points lie in: |
@joshdorrington Those are beautiful plots that accomplish Edward Tufte's goals of self-explanation. Could you include a pointer to the code for both of them? |
I agree, those plots are super nice. I'll take a look at it soon, I'm currently travelling. If you need an immediate solution, can you try the code below? There might be some simple typos, but the gist should be the same as from pomegranate import *
from sklearn.cluster import KMeans
data=np.fromfile("np_input_file.txt").reshape([1,200001,6])
cov = numpy.eye(6)
mu = KMeans(3).fit_predict(data)
s1 = State(MultivariateGaussianDistribution(mu[0], cov))
s2 = State(MultivariateGaussianDistribution(mu[1], cov))
s3 = State(MultivariateGaussianDistribution(mu[2], cov))
model=HiddenMarkovModel()
model.add_states(s1, s2, s3)
model.add_transition(model.start, s1, 0.33)
model.add_transition(model.start, s2, 0.33)
model.add_transition(model.start, s3, 0.34)
model.add_transition(s1, s1, 0.33)
model.add_transition(s1, s2, 0.33)
model.add_transition(s1, s3, 0.34)
model.add_transition(s2, s1, 0.33)
model.add_transition(s2, s2, 0.33)
model.add_transition(s2, s3, 0.34)
model.add_transition(s3, s1, 0.33)
model.add_transition(s3, s2, 0.33)
model.add_transition(s3, s3, 0.34)
model.fit(data, verbose=True) |
@jkleckner , thanks for your kind comment. Sorry, I'm not exactly clear about what you'd like me to link to; the plotting code, or the state fitting code? I implemented a slight variant of your proposed code @jmschrei, just to tidy up a couple of little bugs, but it's basically the same:
And this returns a full output of:
However when plotted the same problem persists. Here we see that the K means centers were actually pretty close to where the hmmlearn HMM ended up: |
@joshdorrington It would be helpful to see both the fit and the plot. If they are too big to put inline you could always create a gist. When the pomegranate fit is working, it would also be useful to see the two side-by-side for comparison. Note that there is a notebook example that does this in [1] where the speed of them is compared. [1] https://github.com/jmschrei/pomegranate/blob/master/tutorials/Tutorial_0_pomegranate_overview.ipynb |
@joshdorrington, Could it be that you are getting a mismatched expectation of which indices correspond between the models? When I run it, the numbers looking at the plots (code would be better) appear that 0 and 3 indices might correspond to what you have? Note that the kmeans process will produce a non-deterministic order under some circumstances for the clusters.
|
@jkleckner I have made a gist with my full fitting and plotting included for both programs: As hmmlearn still has no python 3 compatibility they are separate files. Your results do look right, unfortunately I can't reproduce them! I don't think the indices are the problem, as I don't make any assumptions about which states are which, I only filter on a subspace of the data; presumably the order of the dimensions in the fitted states will always correspond to the data. |
@joshdorrington In trying to reproduce this my own build environment is now not working. I have been using conda which has been quite reliable for me but I think you have hit something incompatible with current revisions. What package manager do you use and can you upload the package versions? Specifically |
I also use Anaconda, here is the full list of packages and versions: Key details are conda v4.50, python v3.60 and pomegranate v0.90 |
When I run the first script, I think I'm getting the answers that you'd expect. from pomegranate import *
import numpy as np
data=np.fromfile("np_input_file.txt").reshape([1,200001,6])
Gaussmodel=HiddenMarkovModel.from_samples(MultivariateGaussianDistribution, n_components=3, X=data, verbose=True)
print Gaussmodel.states[0].distribution.parameters[0] # [0.7826816683883048, 0.24117942443206677, -0.26013475663367996, -0.3822771140018437, -0.1924771863914723, 0.2848611308798533]
print Gaussmodel.states[1].distribution.parameters[0] # [0.9011424396640376, 0.18619544314987144, -0.0415208854330756, -0.5671348934410917, -0.22473978132127376, 0.07526019711117092]
print Gaussmodel.states[2].distribution.parameters[0] # [0.8524057518122062, 0.20865981112786372, -0.2155342911203444, -0.44643304424702357, -0.17002630608375832, 0.24277978252249166] I noticed a bug in your gist, which is that you reshaped your data to be 200001 by 6, which is why you might've gotten LinAlgErrors or NaN improvements. When I comment that line out, I get the same improvement scores as before, and the following plot (after adding in a plt.show command at the end) I am on pomegranate v0.9.0, though I am running py2.7 as well. |
I just ran the same code on python 3.6 and got the same results as above. Can you try commenting out the reshape command and running again? |
I have been plunged into a new level of confusion, as like you I was suddenly seeing correct states being fitted with the same code I was using before. I have now identified what seems to be the problem but I have absolutely no idea why. If I import the data into python from a .csv with pandas.read_csv the training doesnt work, but if I import it from a .np file using np.fromfile it does. I present the test code I have used to check this below. The data when passed to pomegranate is in the same data format, in the same type of array, and contains the same values, so I see no reason why this should happen: the csv data file (suffixed.txt for github reasons): the comparison gist: https://gist.github.com/joshdorrington/e5a1fc3c651ee6b2d5812f31e13903c4 When I run this gist I get incorrect results for pd_gaussmodel as posted before, and correct results for the np_gaussmodel |
The bug is due to a difference in format between numpy arrays and pandas data frames. Since dataframes store data contiguously by column, not row, it's essentially stored as a transposed numpy array. When you extract the values, instead of numpy creating a new array where each row is contiguous, it just returns an array that's been transposed. For numpy operations this is fine, because they all operate using the appropriate strides whether the array is transposed or not. You can see the actual difference between the two arrays below: print pd_data.strides #(1600008, 8, 1600008)
print np_data.strides #(9600048, 48, 8) If you add in I should do more error checking to assure that it is a row contiguous array before training starts to stop this type of situation from arising in the future. tl;dr The underlying cython code behind pomegranate assumes that it is row-contiguous, like most numpy arrays are, not column contiguous, like pandas dataframes are. |
@jmschrei thanks very much, I'm glad we could get to the bottom of this! |
I am trying to fit a time series of a chaotic system with a 3 state hidden markov model with gaussian emissions. I format the data as a numpy array called input_data, and run the code below:
`
`
Unfortunately the returned states of the model are not accurate, all components of the gaussian means are nearly identical and strictly descending absolute value:
`
The data set has no such symmetry to it, and in fact these mean values are in some cases totally outside the range of values present in the data, so I do not think it is a problem with the kmeans initialisation of the states.
The text was updated successfully, but these errors were encountered: