# Learning HMM parameters from annotated data
(10 points)

You should finish the class `HMMLearner` which should be enabled to learn the statistics for a Hidden Markov Model from a set of Strings comprising observation sequences annotated with the hidden states. For each sequence the method `processSequence` is called which should be used to gather the necessary statistics for the HMM. After all sequences have been given to the class, the method `buildViterbi` is used to create a `ViterbiAlgorithm` class which is already known from Exercise 1. The tests for the learning are based on the assumption that a correctly learned HMM would create a correct instance of the `ViterbiAlgorithm` class, i.e., the paths determined by the viterbi algorithm are correct.

To ease the processing of the sequences, the constructor of the `HMMLearner` takes the number of states (i.e., the size of $Q$) and the size of the observation vocabulary (i.e., the size of $V$).

#### Example

The sequences comprises of annotated tokens. The tokens are separated by a single whitespace. Every token comprises an observation and a state connected by a `_` character.

A tagged sequence will look like the following line:
```
1_COLD 2_HOT 2_HOT 1_COLD
```
The sequence starts with the state `COLD` which is emitting a `1`. After that two times the `2` is observer emitted by the state `HOT` before again a `1` is emitted by a `COLD` state.

#### Hints

- For this task, it is **not necessary to implement Expectation Maximization** as we are learning from data that already has been annotated.
- The [example dataset](https://hobbitdata.informatik.uni-leipzig.de/teaching/SNLP/HMM/icecream-sequences.txt) can be downloaded.
- Make sure that the tests take less than 1:00 minutes (the evaluation has a max runtime of 5 minutes per file including the hidden tests)
- Observations **do not have to be** numbers. Please take this into account for your implementation.

#### Notes

- You are free to use a different IDE to develop your solution. However, you have to copy the solution into this notebook to submit it.
- Do not add additional external libraries.
- Interface
  - You can use _[TAB]_ for autocompletion and _[SHIFT]_+_[TAB]_ for code inspection.
  - Use _Menu_ -> _View_ -> _Toggle Line Numbers_ for debugging.
  - Check _Menu_ -> _Help_ -> _Keyboard Shortcuts_.
- Known issues
  - All global variables will be set to void after an import.
  - Missing spaces arround `%` (Modulo) can cause unexpected errors so please make sure that you have added spaces around every `%` character.
- Finish
  - Save your solution by clicking on the _disk icon_.
  - Make sure that all necessary imports are listed at the beginning of your cell.
  - Run a final check of your solution by
    - click on _restart the kernel, then re-run the whole notebook_ (the fast forward arrow in the tool bar)
    - wait fo the kernel to restart and execute all cells (all executable cells should have numbers in front of them instead of a `[*]`) 
    - Check all executed cells for errors. If an exception is thrown, please check your code. Note that although the error might look cryptic, until now we never encounter that an exception was caused without a valid reason inside of the submitted code. A good way to check the code is to copy the solution into a new class in your favorite IDE and check
      - errors reported by the IDE
      - imports the IDE adds to your code which might be missing in your submission.
  - Finally, choose _Menu_ -> _File_ -> _Close and Halt_.
  - Do not forget to _Submit_ your solution in the _Assignments_ view.

In [3]:
// YOUR CODE HERE

/**
 * Simple structure for storing a sequence of states and its probability.
 */
public static class StateSequence {
    
	/**
	 * The sequence of states.
	 */
	public final String[] states;
	/**
	 * The logarithm of the probability of this sequence.
	 */
	public final double logProbability;

	public StateSequence(String[] states, double logProbability) {
		this.states = states;
		this.logProbability = logProbability;
	}

}

/**
 * A class implementing the Viterbi algorithm based on an Hidden Markov Model
 * with the given states, observations, transition matrix and emission matrix.
 * 
 * You may want to copy it from Exercise-1.
 */
public static class ViterbiAlgorithm {
   

	// YOUR CODE HERE

	/**
	 * Constructor.
	 */
	String states[];
	String observationVocab[];
	double[][] transitionMatrix;
	double emissionMatrix[][];

	public ViterbiAlgorithm(String[] states, String[] observationVocab, double[][] transitionMatrix,
			double[][] emissionMatrix) {
		// YOUR CODE HERE
		this.states = states;
		this.observationVocab = observationVocab;
		this.transitionMatrix = transitionMatrix;
		this.emissionMatrix = emissionMatrix;
	}

	/**
	 * Returns the sequence of states which has the highest probability to
	 * create the given sequence of observations.
	 * 
	 * @param observations
	 *            a sequence of observations
	 * @return the sequence of states
	 */
	public StateSequence getStateSequence(String[] observations) {
		double[][] viterbi = new double[states.length + 2][observations.length];
		int[][] backPointer = new int[states.length + 2][observations.length];
		List<String> observationVocabList = Arrays.asList(observationVocab);

		for (int s = 0; s < states.length; s++) {
			viterbi[s][0] = Math.log(transitionMatrix[0][s + 1])
					+ Math.log(emissionMatrix[s][observationVocabList.indexOf(observations[0])]);
		}

		for (int obs = 1; obs < observations.length; obs++) {
			// System.out.println("*******observation at "+obs+" is :
			// "+observations[obs]);
			for (int state = 0; state < states.length; state++) {
				// double temp = -Double.MAX_VALUE;
				// System.out.println("+++++++++The state is: " +states[state]);
				for (int state1 = 0; state1 < states.length; state1++) {
					// System.out.println("-----the previous state is "+state1);
					double temp1 = viterbi[state1][obs - 1] + Math.log(transitionMatrix[state1 + 1][state + 1])
							+ Math.log(emissionMatrix[state][observationVocabList.indexOf(observations[obs])]);

					// System.out.println("the value is: "+temp1);
					if (viterbi[state][obs] != 0) {
						if (temp1 > viterbi[state][obs]) {
							backPointer[state][obs] = state1;
							// System.out.println("the backpointer is:
							// "+states[backPointer[state][obs]]);
							viterbi[state][obs] = temp1;
						}
					} else {
						backPointer[state][obs] = state1;
						viterbi[state][obs] = temp1;
					}

				}
			}
		}

		double temp = -Double.MAX_VALUE;

		// System.out.println("*****Final******");
		for (int state = 0; state < states.length; state++) {
			viterbi[states.length][observations.length - 1] = viterbi[state][observations.length - 1]
					+ Math.log(transitionMatrix[state + 1][states.length + 1]);
			// System.out.println("the state is: "+states[state]+" final value
			// is: "+viterbi[states.length][observations.length-1]);
			if (viterbi[states.length][observations.length - 1] > temp) {
				temp = viterbi[states.length][observations.length - 1];
				backPointer[states.length][observations.length - 1] = state;
				// System.out.println("back pointer new one is:
				// "+states[state]);
			}
		}
		int state = backPointer[states.length][observations.length - 1];
		String[] stateT = new String[observations.length + 1];
		stateT[observations.length] = states[state];
		for (int obs = observations.length - 1; obs > 0; obs--) {

			stateT[obs] = states[backPointer[state][obs]];
			state = backPointer[state][obs];
		}

		String[] str = new String[observations.length];
		for (int i = 0; i < str.length; i++) {
			str[i] = stateT[i + 1];
		}

		StateSequence seq = new StateSequence(str, temp);

		return seq;

	}


}

public class HMMLearner {



	// YOUR CODE HERE
	int[][] countEmission;
	int[] countState;
	int[][] count;
	int numberOfStates;
	int sizeOfVocab;
	int[] startProb;
	int[] endProb;
	int countSequences;
	List<String> obsVocab = new ArrayList<>();

	// contains only H or C.(names of states)
	List<String> statesVocab = new ArrayList<>();

	public HMMLearner(int numberOfStates, int sizeOfVocab) {
		// YOUR CODE HERE
		count = new int[numberOfStates][numberOfStates];
		startProb = new int[numberOfStates];
		endProb = new int[numberOfStates];
		countState = new int[numberOfStates];
		countEmission = new int[numberOfStates][sizeOfVocab];
		this.sizeOfVocab = sizeOfVocab;
		countSequences = 0;
		this.numberOfStates = numberOfStates;

	}

	public void processSequence(String sequence) {
		// YOUR CODE HERE
		countSequences++;
		String[] sequenceToken = sequence.split(" ");
		List<Integer> stateSeq = new ArrayList<>();
		List<Integer> obsSeq = new ArrayList<>();

		for (String token : sequenceToken) {
			String[] intToken = token.split("_");
			if (!obsVocab.contains(intToken[0]))
				obsVocab.add(intToken[0]);
			if (!statesVocab.contains(intToken[1]))
				statesVocab.add(intToken[1]);

			// contains the real given observations
			obsSeq.add(obsVocab.indexOf(intToken[0]));

			// stateSeq contains the index of H or C
			stateSeq.add(statesVocab.indexOf(intToken[1]));
		}

		/*
		 * Transition Matrix
		 */

		// trying every combination for states in a vocabulary. This will
		// calculate the counts for every i and j combination.
		// This will also calculate the number of j's in the sequence so that
		// P(i|j) can be calculated.
		for (int i = 0; i < statesVocab.size(); i++) {
			int index1 = 0;
			int index2 = 1;
			for (int j = 0; j < statesVocab.size(); j++) {
				while (index2 < stateSeq.size()) {
					if (stateSeq.get(index1) == j && i == 0)
						countState[j] += 1;
					if (stateSeq.get(index1) == i && stateSeq.get(index2) == j)
						count[i][j] += 1;
					index1 += 1;
					index2 += 1;

				}
				if (index2 == stateSeq.size() && stateSeq.get(index1) == j && i == 0)
					countState[j] += 1;
				index1 = 0;
				index2 = 1;
			}

		}

		/*
		 * emission Matrix here parent loop will be observation value from vocab
		 * and inner loop will be on state from vocab.
		 */

		for (int indexState = 0; indexState < stateSeq.size(); indexState++) {

			// state represent the index in state vocab
			int state = stateSeq.get(indexState);

			// obs represent the index in observation vocab
			int obs = obsSeq.get(indexState);
			countEmission[state][obs] += 1;
		}

		/*
		 * Generate the starting probabilities
		 */

		int startStateIndex = stateSeq.get(0);
		startProb[startStateIndex] += 1;

		int endStateIndex = stateSeq.get(stateSeq.size() - 1);
		endProb[endStateIndex] += 1;

	}

	public ViterbiAlgorithm buildViterbi() {
		// Here we make use of countEmission, count, and countState

		// calculate transmissionMatrix; As it includes start and end state,
		// hence we start from 1.
		double[][] transitionMat = new double[numberOfStates + 2][numberOfStates + 2];
		for (int state = 1; state < numberOfStates+1; state++) {
			int stateSeqCount = countState[state-1];
			for (int stateInt = 1; stateInt < numberOfStates+1; stateInt++) {
				transitionMat[stateInt][state] = (double) count[stateInt - 1][state - 1] / stateSeqCount;
			}
		}
		// Complete the transition matrix with start and end state.
		for (int state = 1; state <= numberOfStates; state++) {
			transitionMat[0][state] = (double) startProb[state-1] / countSequences;
			transitionMat[state][numberOfStates + 1] = (double) endProb[state-1] / countSequences;
		}

		// calculate emissionMatrix
		double emissionMatProb[][] = new double[numberOfStates][sizeOfVocab];
		for (int state = 0; state < numberOfStates; state++) {
			for (int obs = 0; obs < sizeOfVocab; obs++)
				emissionMatProb[state][obs] = (double) countEmission[state][obs] / countState[state];
		}
		String[] states = new String[numberOfStates];
		int i = 0;
		for (String s : statesVocab) {
			states[i] = s;
			i++;
		}

		String[] obs = new String[sizeOfVocab];
		i = 0;
		for (String s : obsVocab) {
			obs[i] = s;
			i++;
		}

		ViterbiAlgorithm viterbi = new ViterbiAlgorithm(states, obs, transitionMat, emissionMatProb);

		return viterbi;
	}

	public static void main(String[] args) {
		HMMLearner hmm = new HMMLearner(2, 3);// (numbState,sizeObs)

		/*
		 * hmm.processSequence(
		 * "3_HOT 2_HOT 3_HOT 2_HOT 1_COLD 3_HOT 2_HOT 1_COLD 1_COLD 2_COLD 2_HOT 1_HOT 2_COLD 2_HOT 3_HOT 3_HOT 1_COLD 3_HOT"
		 * );
		 */

		hmm.processSequence("2_HOT 3_HOT 1_COLD 2_HOT 3_COLD");
		ViterbiAlgorithm vit = hmm.buildViterbi();

	}

}

// This line should make sure that compile errors are directly identified when executing this cell
// (the line itself does not produce any meaningful result)
new ViterbiAlgorithm(new String[0],new String[0],new double[2][2],new double[0][0]);
new HMMLearner(0,0);
System.out.println("compiled");

compiled


# Evaluation

- Run the following cell to test your implementation.
- You can ignore the cells afterwards.

In [4]:
%maven org.junit.jupiter:junit-jupiter-api:5.3.1
%maven commons-io:commons-io:2.6
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.LineIterator;
import org.junit.jupiter.api.Assertions;
import org.opentest4j.AssertionFailedError;

public static void checkHMMLearningViaViterbi(String filename, int numberOfStates, int sizeOfVocab,
        String[][] testObservations, String[][] expectedSequences) throws Throwable {
    LineIterator iterator = null;
    try {
        // Learn
        iterator = FileUtils.lineIterator(new File(filename), "utf-8");
        long time1 = System.currentTimeMillis();
        HMMLearner learner = new HMMLearner(numberOfStates, sizeOfVocab);
        while (iterator.hasNext()) {
            learner.processSequence(iterator.next());
        }
        ViterbiAlgorithm viterbi = learner.buildViterbi();
        time1 = System.currentTimeMillis() - time1;
        System.out.println("Learning took " + time1 + "ms");
        // Test
        time1 = System.currentTimeMillis();
        for (int i = 0; i < expectedSequences.length; i++) {
            StateSequence sequence = viterbi.getStateSequence(testObservations[i]);
            Assertions.assertArrayEquals(expectedSequences[i], sequence.states,
                    "The calculated sequence " + Arrays.toString(sequence.states)
                            + " does not match the expected sequence " + Arrays.toString(expectedSequences[i]));
            System.out.println("Test " + i + " passed");
        }
        time1 = System.currentTimeMillis() - time1;
        System.out.println("Testing took " + time1 + "ms");
    } catch (AssertionFailedError e) {
        throw e;
    } catch (Throwable e) {
        System.err.println("Your solution caused an unexpected error:");
        throw e;
    } finally {
        if (iterator != null) {
            iterator.close();
        }
    }
}

String[][] testObservations;
String[][] expectedSequences;

System.out.println("---------- Ice cream example ----------");
testObservations = new String[][] { { "3", "1", "3" }, { "3", "2", "1", "1" },
        { "1", "3", "3", "2", "3", "2", "1", "3", "1", "1", "1" }, new String[1000] };
expectedSequences = new String[][] { { "HOT", "HOT", "HOT" }, { "HOT", "HOT", "COLD", "COLD" },
        { "HOT", "HOT", "HOT", "HOT", "HOT", "HOT", "HOT", "HOT", "COLD", "COLD", "COLD" }, new String[1000] };
Arrays.fill(testObservations[3], "3");
Arrays.fill(expectedSequences[3], "HOT");
checkHMMLearningViaViterbi("/srv/distribution/icecream-sequences.txt", 2, 3, testObservations, expectedSequences);


---------- Ice cream example ----------
Learning took 4679ms
Test 0 passed
Test 1 passed
Test 2 passed
Test 3 passed
Testing took 6ms


In [None]:
// Ignore this cell

In [None]:
// Ignore this cell

In [None]:
// Ignore this cell