Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HMM: calculate likelihood for data stream with/without pre-calculated emission probability #2142

Merged
merged 75 commits into from Jan 12, 2021

Conversation

aabghari
Copy link
Contributor

@aabghari aabghari commented Jan 6, 2020

I am introducing new APIs to address the following scenarios:

  1. There are cases such as audio or video live stream in which one would like to calculate HMM likelihood as data is coming.

  2. You can also have the emission probabilities pre-calculated by using a neural network trained on individual states for example.

@rcurtin
Copy link
Member

rcurtin commented Jan 15, 2020

Hey there @aabghari, thanks for writing this up. Really nice to see improvements to the HMM code. I think we should discuss what the API for this should look like first though. So let me make sure that I understand the task correctly:

The goal here is to provide an interface so that a user can do streaming HMMs; e.g., pass in one time step at a time, and get a log-likelihood of the full sequence up to that point. If I understood right, this is the "user facing" method to do that:

  double LogLikelihood(size_t t,
                       const arma::vec &data,
                       double &logScale,
                       arma::vec& prevForwardLogProb,
                       arma::vec& forwardLogProb) const;

I think maybe we could match this a little bit better to the existing API. For instance, I think it would be cleaner if the API for using HMMs in this streaming sense were just some extra parameters to the way users normally use LogLikelihood(). Consider this: one would currently get the log-likelihood of a sequence like below:

double loglik = hmm.LogLikelihood(dataSeq);

and in order to use it in a streaming sense, they should do something like...

// this is kind of pseudocode
double currentLoglik = 0.0;
while (there are more time steps)
{
  arma::vec step = get next time step...
  currentLoglik = hmm.LogLikelihood(... something ...);

  // do some other actions perhaps...
}

Ideally, we'd want that call to hmm.LogLikelihood() to look as similar as possible to the other call in non-streaming mode, and we'd want there to be minimal extra parameters and overhead.

Since all we need to predict the next time step is the previous log-probabilities of each state, we could use this signature:

  double LogLikelihood(const arma::mat& data,
                       arma::vec& stateLogProbs,
                       const arma::vec& inputStateLogProbs = arma::vec()) const;

So, that's basically exactly the same as the existing LogLikelihood(); but now you can pass in input state probabilities and get back output state probabilities. You'd then use this like this:

double currentLoglik = 0.0;
arma::vec stateLogProbs;
while (there are more time steps)
{
  arma::vec step = get next time step(s)...
  currentLoglik = hmm.LogLikelihood(step, stateLogProbs, stateLogProbs);

  // do some other actions perhaps...
}

In this way we can now also take multiple steps at once, and it matches the existing API quite closely, so that the HMM object feels natural to use both in a streaming and non-streaming setting. Let me know what you think of the idea. 👍

@mlpack-bot
Copy link

mlpack-bot bot commented Feb 14, 2020

This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍

@mlpack-bot mlpack-bot bot added the s: stale label Feb 14, 2020
@aabghari
Copy link
Contributor Author

Hey there @aabghari, thanks for writing this up. Really nice to see improvements to the HMM code. I think we should discuss what the API for this should look like first though. So let me make sure that I understand the task correctly:

The goal here is to provide an interface so that a user can do streaming HMMs; e.g., pass in one time step at a time, and get a log-likelihood of the full sequence up to that point. If I understood right, this is the "user facing" method to do that:
Yes, you are right.

  double LogLikelihood(size_t t,
                       const arma::vec &data,
                       double &logScale,
                       arma::vec& prevForwardLogProb,
                       arma::vec& forwardLogProb) const;

I think maybe we could match this a little bit better to the existing API. For instance, I think it would be cleaner if the API for using HMMs in this streaming sense were just some extra parameters to the way users normally use LogLikelihood(). Consider this: one would currently get the log-likelihood of a sequence like below:

double loglik = hmm.LogLikelihood(dataSeq);

and in order to use it in a streaming sense, they should do something like...

// this is kind of pseudocode
double currentLoglik = 0.0;
while (there are more time steps)
{
  arma::vec step = get next time step...
  currentLoglik = hmm.LogLikelihood(... something ...);

  // do some other actions perhaps...
}

Ideally, we'd want that call to hmm.LogLikelihood() to look as similar as possible to the other call in non-streaming mode, and we'd want there to be minimal extra parameters and overhead.

I agree.

Since all we need to predict the next time step is the previous log-probabilities of each state, we could use this signature:

  double LogLikelihood(const arma::mat& data,
                       arma::vec& stateLogProbs,
                       const arma::vec& inputStateLogProbs = arma::vec()) const;

I disagree with this API suggestion. We need to keep track of loglikelihood value somewhere. The loglikelihood value is the accumulation over time. We cannot do this inside the API. Every time the API gets called for a new time step, we need to pass the loglikelihood value up to that time to the API and get the updated value. This is what the logScale does in my original suggestion, however logScale may not be a good naming choice, it should have been logLikelihood perhaps. Another factor we need to consider here is to indicate to the API whether or not this call is the first one, since the calculation for loglikelihood for time zero is different from others (you need to use Initial matrix and also init the state probs at t=0), this is done by passing t to the API.

So, that's basically exactly the same as the existing LogLikelihood(); but now you can pass in input state probabilities and get back output state probabilities. You'd then use this like this:

double currentLoglik = 0.0;
arma::vec stateLogProbs;
while (there are more time steps)
{
  arma::vec step = get next time step(s)...
  currentLoglik = hmm.LogLikelihood(step, stateLogProbs, stateLogProbs);

  // do some other actions perhaps...
}

In this way we can now also take multiple steps at once, and it matches the existing API quite closely, so that the HMM object feels natural to use both in a streaming and non-streaming setting. Let me know what you think of the idea. +1

@mlpack-bot mlpack-bot bot removed the s: stale label Feb 19, 2020
@rcurtin
Copy link
Member

rcurtin commented Mar 4, 2020

Hey @aabghari, sorry for the slow response.

I disagree with this API suggestion. We need to keep track of loglikelihood value somewhere. The loglikelihood value is the accumulation over time. We cannot do this inside the API. Every time the API gets called for a new time step, we need to pass the loglikelihood value up to that time to the API and get the updated value. This is what the logScale does in my original suggestion, however logScale may not be a good naming choice, it should have been logLikelihood perhaps. Another factor we need to consider here is to indicate to the API whether or not this call is the first one, since the calculation for loglikelihood for time zero is different from others (you need to use Initial matrix and also init the state probs at t=0), this is done by passing t to the API.

I'm wondering if maybe I don't understand the full details of what you're hoping to do with the changes, so please point it out if I've misunderstood or overlooked anything. :)

The log-likelihood of a whole sequence is just the sum of log-likelihoods of each individual observation in that sequence. The log-likelihood of each individual observation, for an HMM, also depends on probabilities for the current internal state of the HMM. (And if this is the first observation in the sequence, then the probabilities for the internal state of the HMM are the initial probabilities.)

So, the idea would be that the compute the log-likelihood of a part of a sequence, the accumulation would actually happen outside of the API, like this:

// We want to compute the log-likelihood of elements 10 to 20 only.
double logLikelihood = 0.0;
arma::vec stateLogProbs;
for (size_t i = 0; i < sequence.length(); ++i)
{
  double stepLogLikelihood = hmm.LogLikelihood(sequence[i], stateLogProbs, stateLogProbs);
  // Only add the step's log-likelihood if it's in the range of the sequence that we care about.
  if (i >= 10 && i <= 20)
    logLikelihood += stepLogLikelihood;
}

And for a streaming case you might do...

double currentLoglik = 0.0;
arma::vec stateLogProbs;
while (there are more time steps)
{
  arma::vec step = get next time step(s)...
  // Add to the current log-likelihood.
  currentLoglik += hmm.LogLikelihood(step, stateLogProbs, stateLogProbs);

  std::cout << "after " << s << " steps, the log-likelihood is " << currentLoglik << std::endl;
  // do some other actions perhaps...
}

Maybe I overlooked something? I think that still covers the use cases you originally proposed. 👍

@mlpack-bot
Copy link

mlpack-bot bot commented Apr 3, 2020

This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍

@mlpack-bot mlpack-bot bot added the s: stale label Apr 3, 2020
@aabghari
Copy link
Contributor Author

aabghari commented Apr 9, 2020

My initial design was to calculate the accumulated likelihood inside the API that's why I have this double &logScale argument which actually keeps the accumulation. If you want to have the accumulation done outside the API, I can get rid of this argument.
I agree that I know the internal state but that's not enough if someone wants to reuse the hmm object for different data stream. Assume I have one stream coming and I calculate the likelihood for this stream. Once I have another stream coming, I need a way to restart the internal state and calculate the likelihood. I indicate the start of a new stream to the API by passing t=0 to the API. Although this can be replaced with a boolean flag too. Another way of dealing with several data streams without the restart flag would be to destroy existing hmm object and creating a new one, so we always start from state zero, but I assume this is not efficient at all. I prefer to have a flag to restart the internal state.

@mlpack-bot mlpack-bot bot removed the s: stale label Apr 9, 2020
@mlpack-bot
Copy link

mlpack-bot bot commented May 9, 2020

This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍

@mlpack-bot mlpack-bot bot added the s: stale label May 9, 2020
@rcurtin
Copy link
Member

rcurtin commented May 14, 2020

I'm sorry this has sat for so long. Everything is overwhelming...

Personally I think accumulation outside the API is a little bit more flexible, and it helps us ensure that no internal state is changed during the course of a prediction. That helps us not have objects in weird state situations. Ideally, we should be able to mark this overload of LogLikelihood() as const---intuitively, the operation of computing the log-likelihood should not affect the actual HMM object itself, and I think it's possible to achieve that while still solving the use cases we're trying to here. :)

For starting a new stream, actually what you could do in the example I proposed above is just call .clear() on stateLogProbs. i.e., if stateLogProbs is an empty vector, then this is t=0; otherwise, it's some element of the previous stream. In fact, you could juggle arbitrarily many streams simply by maintaining a different stateLogProbs vector for each stream. Since there is no internal state, there's no need to reset anything. It should even be thread-safe!

Again, sorry it took so long for this simple response---let me know if I overlooked something! 👍

@mlpack-bot mlpack-bot bot removed the s: stale label May 14, 2020
aabghari and others added 18 commits December 7, 2020 10:17
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Copy link
Member

@rcurtin rcurtin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @aabghari! There is a merge conflict but it should be easy to resolve. I found a few other small style issues, which I had hoped would be easy to add suggestions for but I can't do multiline suggestions on a phone, so there are lots of them. Sorry. :(

If you want to merge those suggestions and fix the merge conflict, feel free; if not, I'll do it during merge (when I am not using a phone :)).

Thank you for this nice support! 💯

src/mlpack/methods/hmm/hmm.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/hmm/hmm.hpp Show resolved Hide resolved
src/mlpack/methods/hmm/hmm.hpp Show resolved Hide resolved
src/mlpack/methods/hmm/hmm_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/hmm/hmm_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/hmm/hmm_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/hmm/hmm_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/hmm/hmm_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/hmm/hmm_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/hmm/hmm_impl.hpp Outdated Show resolved Hide resolved
Copy link

@mlpack-bot mlpack-bot bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second approval provided automatically after 24 hours. 👍

@rcurtin
Copy link
Member

rcurtin commented Jan 11, 2021

I went ahead and committed the suggestions and merged master; when I see that the tests pass, I'll go ahead and merge. Thanks again @aabghari! Great to finally merge this in. :)

@aabghari
Copy link
Contributor Author

I was going to apply the changes but you did anyway. Thanks for merging.

@rcurtin
Copy link
Member

rcurtin commented Jan 12, 2021

Looks like some of the builds had problems but they don't appear to be related to the changes here.

@rcurtin rcurtin merged commit 287c3bc into mlpack:master Jan 12, 2021
This was referenced Oct 14, 2022
@rcurtin rcurtin mentioned this pull request Oct 23, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants