Skip to content

Commit

Permalink
Remove ViterbiAlgorithmParams
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanholder committed Nov 13, 2016
1 parent dd4ad74 commit b7d90a3
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 86 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
.settings/
/deploy_snapshot.sh
/deploy_release.sh
/.idea
*.iml
*~
57 changes: 45 additions & 12 deletions src/main/java/com/bmw/hmm/ViterbiAlgorithm.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,55 @@ private static class ForwardStepResult<S, O, D> {

/**
* Need to construct a new instance for each sequence of observations.
* Does not keep the message history.
*/
public ViterbiAlgorithm() {
this(new ViterbiAlgorithmParams());
}
public ViterbiAlgorithm() { }

/**
* Need to construct a new instance for each sequence of observations.
* Whether to store intermediate forward messages
* (probabilities of intermediate most likely paths) for debugging.
* Default: false
* Must be called before processing is started.
*/
public ViterbiAlgorithm(ViterbiAlgorithmParams params) {
if (params.isKeepMessageHistory()) {
public ViterbiAlgorithm<S, O, D> setKeepMessageHistory(boolean keepMessageHistory) {
if (processingStarted()) {
throw new IllegalStateException("Processing has already started.");
}

if (keepMessageHistory) {
messageHistory = new ArrayList<>();
} else {
messageHistory = null;
}
if (params.isComputeSmoothingProbabilities()) {
return this;
}

/**
* Whether to compute smoothing probabilities using the {@link ForwardBackwardAlgorithm}
* for the states of the most likely sequence. Note that this significantly increases
* computation time and memory footprint.
* Default: false
* Must be called before processing is started.
*/
public ViterbiAlgorithm<S, O, D> setComputeSmoothingProbabilities(
boolean computeSmoothingProbabilities) {
if (processingStarted()) {
throw new IllegalStateException("Processing has already started.");
}

if (computeSmoothingProbabilities) {
forwardBackward = new ForwardBackwardAlgorithm<>();
} else {
forwardBackward = null;
}
return this;
}

/**
* Returns whether {@link #startWithInitialObservation(Object, Collection, Map)}
* or {@link #startWithInitialStateProbabilities(Collection, Map)} has already been called.
*/
public boolean processingStarted() {
return message != null;
}

/**
Expand Down Expand Up @@ -216,7 +249,7 @@ public void nextStep(O observation, Collection<S> candidates,
Map<S, Double> emissionLogProbabilities,
Map<Transition<S>, Double> transitionLogProbabilities,
Map<Transition<S>, D> transitionDescriptors) {
if (message == null) {
if (!processingStarted()) {
throw new IllegalStateException(
"startWithInitialStateProbabilities() or startWithInitialObservation() "
+ "must be called first.");
Expand Down Expand Up @@ -285,14 +318,14 @@ public boolean isBroken() {
}

/**
* @see ViterbiAlgorithmParams
* @see #setComputeSmoothingProbabilities(boolean)
*/
public boolean isComputeSmoothingProbabilities() {
return forwardBackward != null;
}

/**
* @see ViterbiAlgorithmParams
* @see #setKeepMessageHistory(boolean)
*/
public boolean isKeepMessageHistory() {
return messageHistory != null;
Expand Down Expand Up @@ -343,7 +376,7 @@ private boolean hmmBreak(Map<S, Double> message) {
*/
private void initializeStateProbabilities(O observation, Collection<S> candidates,
Map<S, Double> initialLogProbabilities) {
if (message != null) {
if (processingStarted()) {
throw new IllegalStateException("Initial probabilities have already been set.");
}

Expand Down
38 changes: 0 additions & 38 deletions src/main/java/com/bmw/hmm/ViterbiAlgorithmParams.java

This file was deleted.

79 changes: 43 additions & 36 deletions src/test/java/com/bmw/hmm/ViterbiAlgorithmTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,12 @@

package com.bmw.hmm;

import static java.lang.Math.log;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.junit.Test;

import com.bmw.hmm.SequenceState;
import com.bmw.hmm.Transition;
import com.bmw.hmm.ViterbiAlgorithm;
import java.util.*;

import static java.lang.Math.log;
import static org.junit.Assert.*;

public class ViterbiAlgorithmTest {

Expand Down Expand Up @@ -118,20 +108,20 @@ public void testComputeMostLikelySequence() {
emissionLogProbabilitiesForNoUmbrella.put(Rain.F, log(0.8));

final Map<Transition<Rain>, Double> transitionLogProbabilities = new LinkedHashMap<>();
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.T), log(0.7));
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.F), log(0.3));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.T), log(0.3));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.F), log(0.7));
transitionLogProbabilities.put(new Transition<>(Rain.T, Rain.T), log(0.7));
transitionLogProbabilities.put(new Transition<>(Rain.T, Rain.F), log(0.3));
transitionLogProbabilities.put(new Transition<>(Rain.F, Rain.T), log(0.3));
transitionLogProbabilities.put(new Transition<>(Rain.F, Rain.F), log(0.7));

final Map<Transition<Rain>, Descriptor> transitionDescriptors = new LinkedHashMap<>();
transitionDescriptors.put(new Transition<Rain>(Rain.T, Rain.T), Descriptor.R2R);
transitionDescriptors.put(new Transition<Rain>(Rain.T, Rain.F), Descriptor.R2S);
transitionDescriptors.put(new Transition<Rain>(Rain.F, Rain.T), Descriptor.S2R);
transitionDescriptors.put(new Transition<Rain>(Rain.F, Rain.F), Descriptor.S2S);
transitionDescriptors.put(new Transition<>(Rain.T, Rain.T), Descriptor.R2R);
transitionDescriptors.put(new Transition<>(Rain.T, Rain.F), Descriptor.R2S);
transitionDescriptors.put(new Transition<>(Rain.F, Rain.T), Descriptor.S2R);
transitionDescriptors.put(new Transition<>(Rain.F, Rain.F), Descriptor.S2S);

final ViterbiAlgorithm<Rain, Umbrella, Descriptor> viterbi =
new ViterbiAlgorithm<>(new ViterbiAlgorithmParams().setKeepMessageHistory(true).
setComputeSmoothingProbabilities(true));
new ViterbiAlgorithm<Rain, Umbrella, Descriptor>().setKeepMessageHistory(true).
setComputeSmoothingProbabilities(true);
viterbi.startWithInitialObservation(Umbrella.T, candidates,
emissionLogProbabilitiesForUmbrella);
viterbi.nextStep(Umbrella.T, candidates, emissionLogProbabilitiesForUmbrella,
Expand Down Expand Up @@ -190,6 +180,23 @@ public void testComputeMostLikelySequence() {
checkMessageHistory(expectedMessageHistory, actualMessageHistory);
}

@Test
public void testSetParams() {
final ViterbiAlgorithm<Rain, Umbrella, Descriptor> viterbi = new ViterbiAlgorithm<>();

assertFalse(viterbi.isKeepMessageHistory());
viterbi.setKeepMessageHistory(true);
assertTrue(viterbi.isKeepMessageHistory());
viterbi.setKeepMessageHistory(false);
assertFalse(viterbi.isKeepMessageHistory());

assertFalse(viterbi.isComputeSmoothingProbabilities());
viterbi.setComputeSmoothingProbabilities(true);
assertTrue(viterbi.isComputeSmoothingProbabilities());
viterbi.setComputeSmoothingProbabilities(false);
assertFalse(viterbi.isComputeSmoothingProbabilities());
}

private void checkMessageHistory(List<Map<Rain, Double>> expectedMessageHistory,
List<Map<Rain, Double>> actualMessageHistory) {
assertEquals(expectedMessageHistory.size(), actualMessageHistory.size());
Expand Down Expand Up @@ -298,19 +305,19 @@ public void testBreakAtSecondTransition() {
assertFalse(viterbi.isBroken());

Map<Transition<Rain>, Double> transitionLogProbabilities = new LinkedHashMap<>();
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.T), log(0.5));
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.F), log(0.5));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.T), log(0.5));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.F), log(0.5));
transitionLogProbabilities.put(new Transition<>(Rain.T, Rain.T), log(0.5));
transitionLogProbabilities.put(new Transition<>(Rain.T, Rain.F), log(0.5));
transitionLogProbabilities.put(new Transition<>(Rain.F, Rain.T), log(0.5));
transitionLogProbabilities.put(new Transition<>(Rain.F, Rain.F), log(0.5));
viterbi.nextStep(Umbrella.T, candidates, emissionLogProbabilities,
transitionLogProbabilities);
assertFalse(viterbi.isBroken());

transitionLogProbabilities = new LinkedHashMap<>();
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.T), log(0.0));
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.F), log(0.0));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.T), log(0.0));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.F), log(0.0));
transitionLogProbabilities.put(new Transition<>(Rain.T, Rain.T), log(0.0));
transitionLogProbabilities.put(new Transition<>(Rain.T, Rain.F), log(0.0));
transitionLogProbabilities.put(new Transition<>(Rain.F, Rain.T), log(0.0));
transitionLogProbabilities.put(new Transition<>(Rain.F, Rain.F), log(0.0));
viterbi.nextStep(Umbrella.T, candidates, emissionLogProbabilities,
transitionLogProbabilities);

Expand Down Expand Up @@ -338,10 +345,10 @@ public void testDeterministicCandidateOrder() {
emissionLogProbabilitiesForNoUmbrella.put(Rain.T, log(0.5));

final Map<Transition<Rain>, Double> transitionLogProbabilities = new LinkedHashMap<>();
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.T), log(0.5));
transitionLogProbabilities.put(new Transition<Rain>(Rain.F, Rain.F), log(0.5));
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.T), log(0.5));
transitionLogProbabilities.put(new Transition<Rain>(Rain.T, Rain.F), log(0.5));
transitionLogProbabilities.put(new Transition<>(Rain.F, Rain.T), log(0.5));
transitionLogProbabilities.put(new Transition<>(Rain.F, Rain.F), log(0.5));
transitionLogProbabilities.put(new Transition<>(Rain.T, Rain.T), log(0.5));
transitionLogProbabilities.put(new Transition<>(Rain.T, Rain.F), log(0.5));

final ViterbiAlgorithm<Rain, Umbrella, Descriptor> viterbi = new ViterbiAlgorithm<>();
viterbi.startWithInitialObservation(Umbrella.T, candidates,
Expand Down

0 comments on commit b7d90a3

Please sign in to comment.