# Viterbi Algorithm
(10 points)

Finish the implementation of the `ViterbiAlgorithm` class using the dynamic programming approach presented in the lecture. The class has a constructor taking the necessary information of a Hidden Markov Model and a `getStateSequence` method returning the most probable state of sequences for a given sequence of observations.

* The constructor takes the following parameters:
  * `String[] states` - an array containing the set of states $Q$ in the same order as in the transition matrix and the emission matrix.
  * `String[] observationVocab` - an array containing all possible observations $V$ in the order they have in the emission matrix.
  * `double[][] transitionMatrix` - the transition matrix $A$ where `transitionMatrix[i][j]` contains the probability $a_{ij}$ for a transition from $q_i$ to $q_j$ (since the `states` array starts at 0 and not at 1, `transitionMatrix[i][j]` gives the transition from `states[i - 1]` to `states[j - 1]`). Note that the transition matrix has two more states as described in the lecture slides. The start state always has the index `0` marks the start state in the matrix while the final state always has the index `transitionMatrix.length - 1`. 
  * `double[][] emissionMatrix` - the emission matrix $B$ containing the probabilities $b_{ik}$ that state $q_i$ emits the observation $v_k$.

* The `getStateSequence` method takes the following parameters:
  * `String[] observations` - an array containing the observations for which the most probable sequence of states should be returned.

* The method should return an instance of the `StateSequence` class. This class contains two following attributes:
  * `String[] states` - the sequence of states which emitted a certain sequence of observations.
  * `double logProbability` - the logarithm of the probability of the state sequence (including the emission probabilities for the sequence of observations).

#### Example

The visible tests rely on the ice cream example available in the lecture slides (see lecture slide 14 / page 22 in the PDF). The array of states and observations are given as follows:
```java
String[] states = new String[] { "HOT", "COLD" };
String[] observationVocab = new String[] { "1", "2", "3" };
```
double[][] transitionMatrix = new double[][] { { 0, 0.8, 0.2, 0 }, { 0, 0.6, 0.3, 0.1 }, { 0, 0.4, 0.5, 0.1 },
                { 0, 0, 0, 0 } };

The transition matrix $A$ taken from the automaton picture looks like the following table. (**Please note** that although the states `HOT` and `COLD` have the ids `0` and `1` in the `states` array, the have the ids `1` and `2` in the transition matrix.)

<table>
    <tr>
        <th>from \ to</th>
        <th><p align="center">start</p></th>
        <th><p align="center">HOT</p></th>
        <th><p align="center">COLD</p></th>
        <th><p align="center">end</p></th>
    </tr>
    <tr>
        <td><p align="center"><b>start</b></p></td>
        <td><p align="center">0</p></td>
        <td><p align="center">0.8</p></td>
        <td><p align="center">0.2</p></td>
        <td><p align="center">0</p></td>
    </tr>
    <tr>
        <td><p align="center"><b>HOT</b></p></td>
        <td><p align="center">0</p></td>
        <td><p align="center">0.6</p></td>
        <td><p align="center">0.3</p></td>
        <td><p align="center">0.1</p></td>
    </tr>
    <tr>
        <td><p align="center"><b>COLD</b></p></td>
        <td><p align="center">0</p></td>
        <td><p align="center">0.4</p></td>
        <td><p align="center">0.5</p></td>
        <td><p align="center">0.1</p></td>
    </tr>
    <tr>
        <td><p align="center"><b>end</b></p></td>
        <td><p align="center">0</p></td>
        <td><p align="center">0</p></td>
        <td><p align="center">0</p></td>
        <td><p align="center">0</p></td>
    </tr>
</table>

The emission matrix $B$ taken from the automaton picture looks like the following table. (Please note that this matrix does neither contain the start nor the final state since both are not emitting any observation.)

<table>
    <tr>
        <th>state \ observation</th>
        <th><p align="center">1</p></th>
        <th><p align="center">2</p></th>
        <th><p align="center">3</p></th>
    </tr>
    <tr>
        <td><p align="center"><b>HOT</b></p></td>
        <td><p align="center">0.2</p></td>
        <td><p align="center">0.4</p></td>
        <td><p align="center">0.4</p></td>
    </tr>
    <tr>
        <td><p align="center"><b>COLD</b></p></td>
        <td><p align="center">0.5</p></td>
        <td><p align="center">0.4</p></td>
        <td><p align="center">0.1</p></td>
    </tr>
</table>

#### Hints

- The input matrices will contain probabilities (not their logarithms)
- The test is separated into two different cells. The first cell will test some example observation sequences and compare it with an expected result. The second cell will do the same but for a larger, generated sequence.
- The number of states and observations are not limited to the ice cream example. For the hidden tests, we will use a different scenario with different observations and states.
- Make sure that the tests take less than 2:00 minutes (the evaluation has a max runtime of 5 minutes per file including the hidden tests)
- Observations **do not have to be** numbers. Please take this into account for your implementation.

#### Notes

- You are free to use a different IDE to develop your solution. However, you have to copy the solution into this notebook to submit it.
- Do not add additional external libraries.
- Interface
  - You can use _[TAB]_ for autocompletion and _[SHIFT]_+_[TAB]_ for code inspection.
  - Use _Menu_ -> _View_ -> _Toggle Line Numbers_ for debugging.
  - Check _Menu_ -> _Help_ -> _Keyboard Shortcuts_.
- Known issues
  - All global variables will be set to void after an import.
  - Missing spaces arround `%` (Modulo) can cause unexpected errors so please make sure that you have added spaces around every `%` character.
- Finish
  - Save your solution by clicking on the _disk icon_.
  - Make sure that all necessary imports are listed at the beginning of your cell.
  - Run a final check of your solution by
    - click on _restart the kernel, then re-run the whole notebook_ (the fast forward arrow in the tool bar)
    - wait fo the kernel to restart and execute all cells (all executable cells should have numbers in front of them instead of a `[*]`) 
    - Check all executed cells for errors. If an exception is thrown, please check your code. Note that although the error might look cryptic, until now we never encounter that an exception was caused without a valid reason inside of the submitted code. A good way to check the code is to copy the solution into a new class in your favorite IDE and check
      - errors reported by the IDE
      - imports the IDE adds to your code which might be missing in your submission.
  - Finally, choose _Menu_ -> _File_ -> _Close and Halt_.
  - Do not forget to _Submit_ your solution in the _Assignments_ view.

In [3]:
import java.util.Arrays;
import java.util.List;

class StateSequence {
	/**
	 * The sequence of states.
	 */
	public final String[] states;
	/**
	 * The logarithm of the probability of this sequence.
	 */
	public final double logProbability;

	public StateSequence(String[] states, double logProbability) {
		this.states = states;
		this.logProbability = logProbability;
	}
}

In [4]:
public class ViterbiAlgorithm {

	String states[];
	String observationVocab[];
	double[][] transitionMatrix;
	double emissionMatrix[][];

	public ViterbiAlgorithm(String[] states, String[] observationVocab, double[][] transitionMatrix,
			double[][] emissionMatrix) {
		// YOUR CODE HERE

		this.states = states;
		this.observationVocab = observationVocab;
		this.transitionMatrix = transitionMatrix;
		this.emissionMatrix = emissionMatrix;
	}

	public StateSequence getStateSequence(String[] observations) {

		double[][] viterbi = new double[states.length + 2][observations.length];
		int[][] backPointer = new int[states.length + 2][observations.length];
		List<String> observationVocabList = Arrays.asList(observationVocab);

		for (int s = 0; s < states.length; s++) {
			viterbi[s][0] = Math.log(transitionMatrix[0][s + 1])
					+ Math.log(emissionMatrix[s][observationVocabList.indexOf(observations[0])]);
		}

		for (int obs = 1; obs < observations.length; obs++) {
			// System.out.println("*******observation at "+obs+" is :
			// "+observations[obs]);
			for (int state = 0; state < states.length; state++) {
				// double temp = -Double.MAX_VALUE;
				// System.out.println("+++++++++The state is: " +states[state]);
				for (int state1 = 0; state1 < states.length; state1++) {
					// System.out.println("-----the previous state is "+state1);
					double temp1 = viterbi[state1][obs - 1] + Math.log(transitionMatrix[state1 + 1][state + 1])
							+ Math.log(emissionMatrix[state][observationVocabList.indexOf(observations[obs])]);

					// System.out.println("the value is: "+temp1);
					if(viterbi[state][obs]!=0){
					if (temp1 > viterbi[state][obs]) {
						backPointer[state][obs] = state1;
						// System.out.println("the backpointer is:
						// "+states[backPointer[state][obs]]);
						viterbi[state][obs] = temp1;
					}}
					else{
						backPointer[state][obs] = state1;
						viterbi[state][obs] = temp1;
					}

				}
			}
		}

		double temp = -Double.MAX_VALUE;

		// System.out.println("*****Final******");
		for (int state = 0; state < states.length; state++) {
			viterbi[states.length][observations.length - 1] = viterbi[state][observations.length - 1]
					+ Math.log(transitionMatrix[state + 1][states.length + 1]);
			// System.out.println("the state is: "+states[state]+" final value
			// is: "+viterbi[states.length][observations.length-1]);
			if (viterbi[states.length][observations.length - 1] > temp) {
				temp = viterbi[states.length][observations.length - 1];
				backPointer[states.length][observations.length - 1] = state;
				// System.out.println("back pointer new one is:
				// "+states[state]);
			}
		}
		int state = backPointer[states.length][observations.length - 1];
		String[] stateT = new String[observations.length + 1];
		stateT[observations.length] = states[state];
		for (int obs = observations.length - 1; obs > 0; obs--) {

			stateT[obs] = states[backPointer[state][obs]];
			state = backPointer[state][obs];
		}

		String[] str = new String[observations.length];
		for (int i = 0; i < str.length; i++) {
			str[i] = stateT[i + 1];
		}
		
		StateSequence seq = new StateSequence(str, temp);
		
		return seq;
	}

}
// This line should make sure that compile errors are directly identified when executing this cell
// (the line itself does not produce any meaningful result)
new ViterbiAlgorithm(new String[0],new String[0],new double[2][2],new double[0][0]);
System.out.println("compiled");

compiled


# Evaluation

- Run the following cell to test your implementation.
- You can ignore the cells afterwards.

In [5]:
%maven org.junit.jupiter:junit-jupiter-api:5.3.1
import org.junit.jupiter.api.Assertions;
import org.opentest4j.AssertionFailedError;
import java.util.Arrays;

/**
 * Simple structure representing a list of expected state sequences.
 * Note that in very rare cases more than one solution is possible.
 * This is why this class offers an array of state sequences.
 * It is not necessary to use this in the student's implementation!
 */
public static class ExpectedStateSequence {
    /**
     * An array of expected states (and their alternatives).
     */
    public final String[][] states;
    /**
     * The logarithm of the probability of this sequence.
     */
    public final double logProbability;

    public ExpectedStateSequence(double logProbability, String[]...states) {
        this.states = states;
        this.logProbability = logProbability;
    }
}

public static final double DELTA = 0.000001;

public static void checkViterbi(String[] states, double[][] transitionMatrix, String[] observationVocab,
        double[][] emissionMatrix, String[] observations, ExpectedStateSequence expectedSequence) {
    try {
        ViterbiAlgorithm viterbi = new ViterbiAlgorithm(states, observationVocab, transitionMatrix, emissionMatrix);
        long time1 = System.currentTimeMillis();
        StateSequence sequence = viterbi.getStateSequence(observations);
        time1 = System.currentTimeMillis() - time1;
        System.out.println("Viterbi took " + time1 + "ms");
        // Check whether the result state sequence matches one of the expected
        // sequences 
        int id = 0;
        while((id < expectedSequence.states.length) && (!Arrays.equals(sequence.states, expectedSequence.states[id]))) {
            ++id;
        }
        // If there is no expected squence that fits to the given result
        if(id >= expectedSequence.states.length) {
            StringBuilder message = new StringBuilder();
            message.append("The determined sequence ");
            message.append(Arrays.toString(sequence.states));
            message.append("\n does not match the expected state");
            if(expectedSequence.states.length > 1) {
                message.append("s ");
                for(int i = 0; i < expectedSequence.states.length; ++i) {
                    message.append('\n');
                    message.append(Arrays.toString(expectedSequence.states[i]));
                }
            } else {
                message.append(' ');
                message.append(Arrays.toString(expectedSequence.states[0]));
            }
            Assertions.fail(message.toString());
        }
        double diff = Math.abs(expectedSequence.logProbability - sequence.logProbability);
        Assertions.assertTrue(diff < DELTA, "The calculated probability (" + sequence.logProbability
                + ") does not match the expected probability (" + expectedSequence.logProbability + ").");
        System.out.println("Test passed");
    } catch (AssertionFailedError e) {
        throw e;
    } catch (Throwable e) {
        System.err.println("Your solution caused an unexpected error:");
        throw e;
    }
}

String observations[];
ExpectedStateSequence expectedSequence;
String[] states;
String[] sequence;
double[][] transitionMatrix;
String[] observationVocab;
double[][] emissionMatrix;

System.out.println("---------- Ice cream example ----------");
states = new String[] { "HOT", "COLD" };
transitionMatrix = new double[][] { { 0, 0.8, 0.2, 0 }, { 0, 0.6, 0.3, 0.1 }, { 0, 0.4, 0.5, 0.1 },
        { 0, 0, 0, 0 } };
observationVocab = new String[] { "1", "2", "3" };
emissionMatrix = new double[][] { { 0.2, 0.4, 0.4 }, { 0.5, 0.4, 0.1 } };
observations = new String[] { "3", "1", "3" };
expectedSequence = new ExpectedStateSequence(Math.log(0.0009216), new String[] { "HOT", "HOT", "HOT" });
checkViterbi(states, transitionMatrix, observationVocab, emissionMatrix, observations, expectedSequence);

observations = new String[] { "3", "2", "1", "1" };
expectedSequence = new ExpectedStateSequence(Math.log(0.000288), new String[] { "HOT", "HOT", "COLD", "COLD" });
checkViterbi(states, transitionMatrix, observationVocab, emissionMatrix, observations, expectedSequence);

observations = new String[] { "1", "3", "3", "2", "3", "2", "1", "3", "1", "1", "1" };
expectedSequence = new ExpectedStateSequence(Math.log(3.439853568E-9), 
        new String[] { "HOT", "HOT", "HOT", "HOT", "HOT", "HOT", "HOT", "HOT", "COLD", "COLD", "COLD" });
checkViterbi(states, transitionMatrix, observationVocab, emissionMatrix, observations, expectedSequence);


---------- Ice cream example ----------
Viterbi took 0ms
Test passed
Viterbi took 0ms
Test passed
Viterbi took 0ms
Test passed


In [6]:
System.out.println("---------- Ice cream example ----------");
observations = new String[1000];
Arrays.fill(observations, "3");
sequence = new String[1000];
Arrays.fill(sequence, "HOT");
expectedSequence = new ExpectedStateSequence(
        Math.log(0.8) + (999 * Math.log(0.6)) + (1000 * Math.log(0.4) + Math.log(0.1)), sequence);
checkViterbi(states, transitionMatrix, observationVocab, emissionMatrix, observations, expectedSequence);

---------- Ice cream example ----------
Viterbi took 1ms
Test passed


In [4]:
// Ignore this cell

In [5]:
// Ignore this cell