# Exercise 1 – Naive Bayes Classification
(10 points)

Implement a Naive Bayes classifier by finalizing the two given classes. The `BayesianLearner` class acts like a builder for the `BayesianClassifier` instances. That means that the learner gets the set of classes during its creation and the `learnExample` method of the learner is called once for each document of the training set. Internally, the learner should gather all statistics that are necessary for the classifier when processing the training examples.
After the learner saw all training documents, the `createClassifier` method is called which creates an instance of the `BayesianClassifier` class and initializes it with the statistics gathered before. 
The classification itself is carried out by the `classify` method which takes an unknown document and assigns it one of the classes learned before.

#### Hints

- Please do not forget to preprocess your documents. What exactly the preprocessing does is up to you.
- The evaluation will measure the accuracy of your classifier.
- The evaluation in the hidden tests has three stages. 
  1. Your solution will get 4 points as soon as it is better than the baselines. The baselines are:
     - For each class, a classifier that always returns this class.
     - A random guesser that returns a random class.
  2. If your solution has an accuracy >= 0.7, you will get 3 more points.
  3. If your solution has an accuracy >= 0.8, you will get 3 more points.
- You can download the [single-class-train.tsv](https://hobbitdata.informatik.uni-leipzig.de/teaching/SNLP/classification/single-class-train.tsv) file. It comprises one document per line. The first word is the class, followed by a tab character (`\t`). The remaining content of the line is the text of the document.

#### Notes

- You are free to use a different IDE to develop your solution. However, you have to copy the solution into this notebook to submit it.
- Do not add additional external libraries.
- Interface
  - You can use _[TAB]_ for autocompletion and _[SHIFT]_+_[TAB]_ for code inspection.
  - Use _Menu_ -> _View_ -> _Toggle Line Numbers_ for debugging.
  - Check _Menu_ -> _Help_ -> _Keyboard Shortcuts_.
- Known issues
  - All global variables will be set to void after an import.
  - Missing spaces arround `%` (Modulo) can cause unexpected errors so please make sure that you have added spaces around every `%` character.
- Finish
  - Save your solution by clicking on the _disk icon_.
  - Make sure that all necessary imports are listed at the beginning of your cell.
  - Run a final check of your solution by
    - click on _restart the kernel, then re-run the whole notebook_ (the fast forward arrow in the tool bar)
    - wait fo the kernel to restart and execute all cells (all executable cells should have numbers in front of them instead of a `[*]`) 
    - Check all executed cells for errors. If an exception is thrown, please check your code. Note that although the error might look cryptic, until now we never encounter that an exception was caused without a valid reason inside of the submitted code. A good way to check the code is to copy the solution into a new class in your favorite IDE and check
      - errors reported by the IDE
      - imports the IDE adds to your code which might be missing in your submission.
  - Finally, choose _Menu_ -> _File_ -> _Close and Halt_.
  - Do not forget to _Submit_ your solution in the _Assignments_ view.

In [1]:
// package NaiveBayesian;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

// YOUR CODE HERE

/**
 * Classifier implementing naive Bayes classification.
 */
public class BayesianClassifier {
	// YOUR CODE HERE
	List<String> vocab;
	int N_docs = 0;
	List<Integer> Nclass;
	double[] priorProbs;
	Map<String, int[]> tokenFreqMap;
	List<String> classList;
	Map<String, Integer> docSize;

	BayesianClassifier(List<String> vocab, int N_docs, List<Integer> Nclass, Map<String, int[]> tokenFreqMap,
			List<String> classList, Map<String, Integer> docSize) {
		this.vocab = vocab;
		this.N_docs = N_docs;
		this.Nclass = Nclass;
		this.tokenFreqMap = tokenFreqMap;
		this.classList = classList;
		this.docSize = docSize;
		// This method will store probabilities in the map.

		priorProbs = new double[Nclass.size()];

		for (int i = 0; i < priorProbs.length; i++) {
			priorProbs[i] = (double) Nclass.get(i) / N_docs;
		}

	}

	String[] stopwords = { "i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours",
			"yourself", "yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its",
			"itself", "they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that",
			"these", "those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having",
			"do", "does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while",
			"of", "at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before",
			"after", "above", "below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again",
			"further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", "each",
			"few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than",
			"too", "very", "s", "t", "can", "will", "just", "don", "should", "now" };

	List<String> stopWordsList = Arrays.asList(stopwords);

	public String preprocessing(String text) {
		double[][] condProbArray = new double[vocab.size()][classList.size()];

		String[] tokens = text.toLowerCase().replaceAll("[^ a-zA-Z0-9]", "").split(" ");
		double score[] = new double[classList.size()];
		
		for (String token : tokens) {

			// check if the token is not a stop word
			if (!stopWordsList.contains(token)) {
				// condiProb for that token:
				// freq(token,class:C)/freq(length(text:class:c)
				int indexToken = 0, freqToken = 0;
				boolean flag = false;

				// calculate the index if the token is in vocab
				if (vocab.contains(token)) {
					indexToken = vocab.indexOf(token);
					flag = true; // that token is in vocab
				}
				double condProb = 0.0;

				int i = 0;
				// calculate the score of each class for this token.
				for (String c : classList) {
					int size = docSize.get(c);
					int[] wordList = tokenFreqMap.get(c);
					/*if(docSizeClass[i] == 0){
						int sum=0;
						for(int j=0;j<wordList.length;j++){
							 sum += wordList[i];
						}
						docSizeClass[i]=sum;
					}*/
					
					
					double logToken;

					if (flag) {

						condProb = condProbArray[indexToken][classList.indexOf(c)];
						if (condProb > 0) {
							// as the condprob is already there, calculate its
							// log
							logToken = Math.log(condProb);
							score[i] += logToken;
							i++;
							continue;
						}
						// if the condProb is not there, calculate
						else {
							// our aim here is to calculate only the freqToken
							// so it can be zero if the token is not in vocab.
							freqToken = wordList[indexToken];
							condProbArray[indexToken][classList.indexOf(c)] = (double) (freqToken + 1)
									/ (size + vocab.size());
							logToken = Math.log(condProbArray[indexToken][classList.indexOf(c)]);
							score[i] += logToken;
							i++;
							continue;
						}
					}

					// here if token is not in vocab then freq would be zero.

					// use this freq for condProbArray
					else {

						condProb = (double) (freqToken + 1) / (size + vocab.size());
						logToken = Math.log(condProb);
						score[i] += logToken;
						i++;
						continue;
					}
				}

			}

		}

		for (int i = 0; i < score.length; i++) {
			score[i] += Math.log(priorProbs[i]);
		}

		double maxScore = score[0];
		int indexMax = 0;
		for (int i = 1; i < score.length; i++) {

			if (score[i] > maxScore) {
				maxScore = score[i];
				indexMax = i;
			}

		}
		return classList.get(indexMax);
	}

	/**
	 * Classifies the given document and returns the class name.
	 */
	public String classify(String text) {
		String clazz = null;

		clazz = preprocessing(text);

		// YOUR CODE HERE
		return clazz;
	}
}

/**
 * Learner (or Builder) class for a naive Bayes classifier.
 */
class BayesianLearner {
	// YOUR CODE HERE
	List<String> vocab = new ArrayList();
	int N_docs = 0;
	Map<String, String> classText = new HashMap<>();

	List<String> classList = new ArrayList<>();
	List<Integer> Nclass = new ArrayList<>();

	/**
	 * Constructor taking the set of classes the classifier should be able to
	 * distinguish.
	 */
	public BayesianLearner(Set<String> classes) {
		/*
		 * in classlist, we store each class at an index. Thus we can count no.
		 * of documents in the preprocessing step only.
		 */
		for (String clas : classes) {
			classList.add(clas);
			Nclass.add(0);
		}

	}

	String[] stopwords = { "i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours",
			"yourself", "yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its",
			"itself", "they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that",
			"these", "those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having",
			"do", "does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while",
			"of", "at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before",
			"after", "above", "below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again",
			"further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", "each",
			"few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than",
			"too", "very", "s", "t", "can", "will", "just", "don", "should", "now" };

	List<String> stopWordsList = Arrays.asList(stopwords);

	/*
	 * This method will update a map with the key being class and value being a
	 * vector with term frequency. The size of vector is equal to vacab size
	 */
	public void preprocessing(String clazz, String text) {
		// First a term from text is added to vocab. Now tempVect will represent
		// the text in numerical form. So, add 1 to the vector too. 1 is set to
		// index wrt vocab.
		// List<Integer> tempVect = new ArrayList<Integer>();
		String tempVect = new String();
		if (classText.containsKey(clazz)) {
			tempVect = classText.get(clazz);
		}

		// String[] tokens = text.toLowerCase().replaceAll("[^ a-zA-Z0-9]",
		// "").split(" ");
		// Performance measure test with the preprocessed textual content
		String text_processed = text.toLowerCase().replaceAll("[^ a-zA-Z0-9]", "");

		// Performance measure: test with unprocessed textual content so that
		// term freq. has real textual count.
		String text_unprocessed = text;

		classText.put(clazz, tempVect.concat(text_processed));

	}

	List<String> vocab1 = new ArrayList<>();

	void addVocab(String text) {
		String tokens[] = text.split(" ");
		for (String token : tokens) {
			if (!stopWordsList.contains(token)) {
				if (!vocab1.contains(token)) {
					vocab1.add(token);
				}
			}
		}
	}

	/**
	 * The method used to learn the training examples. It takes the name of the
	 * class as well as the text of the training document.
	 */
	public void learnExample(String clazz, String text) {
		// YOUR CODE HERE
		int index = classList.indexOf(clazz);
		int count = 0;
		// NClass refer to the number of documents for a class(i) at index i.
		if (Nclass.size() <= index) {
			Nclass.set(index, 1);
		} else {
			count = Nclass.get(index) + 1;
			Nclass.set(index, count);
		}

		// updates the count of documents for total
		N_docs++;

		preprocessing(clazz, text);

	}

	Map<String, int[]> classFreqMap = new HashMap<>();
	Map<String, Integer> docSizeClass = new HashMap<>();
	
	void prepareArray(String text, String clazz) {
		String[] tokens = text.split(" ");
		int size = tokens.length;
//		docSizeClass.put(clazz, size);
		int[] wordArray = new int[vocab1.size()];
		int sum=0;
		for (String token : tokens) {
			// check if the token is not in stop word
			if (!stopWordsList.contains(token)) {
				// find the index of token
				int indexToken = vocab1.indexOf(token);
				// add this token to the count present in the array at the index
				// of Vocab
				wordArray[indexToken] += 1;
				sum++;
			}
		}
		docSizeClass.put(clazz, sum);
		// store the class and word freq array to the map.
		classFreqMap.put(clazz, wordArray);
	}

	/**
	 * Creates a BayesianClassifier instance based on the statistics gathered
	 * from the training example.
	 */

	public BayesianClassifier createClassifier() {

		/*
		 * Now conditional prob. is needed to be calculated.
		 * :=(freq(term->V,class:c)+1)/(length(text,class:c)+V.length)
		 */
		for (String clazz : classText.keySet()) {
			String text = classText.get(clazz);
			// prepare the vocab list by iterating each class
			addVocab(text);
		}

		// for each class, prepare a new array with tokens from the list and
		// store into a new map.
		// right now our text is already preprocessed.
		
		for (String clazz : classText.keySet()) {
			String text = classText.get(clazz);
			// prepare the vocab list by iterating each class
			prepareArray(text, clazz);
		}
		

		BayesianClassifier classifier = new BayesianClassifier(vocab1, N_docs, Nclass, classFreqMap, classList,docSizeClass);

		// YOUR CODE HERE
		return classifier;
	}
}
// This line should make sure that compile errors are directly identified when
// executing this cell
// (the line itself does not produce any meaningful result)
new BayesianLearner(new HashSet<>(Arrays.asList("good","bad")));
System.out.println("compiled");

compiled


# Evaluation

- Run the following cell to test your implementation.
- You can ignore the cells afterwards.

In [2]:
%maven org.junit.jupiter:junit-jupiter-api:5.3.1
%maven commons-io:commons-io:2.6
import org.junit.jupiter.api.Assertions;
import org.opentest4j.AssertionFailedError;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import java.util.Map.Entry;

/**
 * Simple method for reading classification examples from a file as a list of (class, text) pairs.
 */
public static List<String[]> readClassData(String filename) throws IOException {
    return FileUtils.readLines(new File(filename), "utf-8").stream().map(s -> s.split("\t"))
            .filter(s -> s.length > 1).collect(Collectors.toList());
}

/**
 * Method for cecking the given classifier. The method expects training and evaluation data.
 * The data should have a String array for each document in which the first cell of the
 * array contains the class while the second cell contains the text of the document. During
 * the check, some statistics like the accuracies of different baseline classifiers are
 * printed. Finally, the calculated accuracy is returned.
 *
 * @param trainingCorpus the data that is used for training the classifier. 
 * @param evaluationCorpus the data that is used for evaluating the classifier. 
 * @param minAccuracy minimum accuracy the classifier should achieve.
 * @return the accuracy achieved by the classifier
 */
public static double checkClassifier(List<String[]> trainingCorpus, List<String[]> evaluationCorpus,
        double minAccuracy) {
    double accuracy = 0;
    try {
        System.out.print("Training corpus size: ");
        System.out.println(trainingCorpus.size());
        System.out.print("Eval. corpus size   : ");
        System.out.println(evaluationCorpus.size());
        // Determine the classes
        Set<String> classes = Arrays.asList(trainingCorpus, evaluationCorpus).stream().flatMap(l -> l.stream())
                .map(d -> d[0]).distinct().collect(Collectors.toSet());
        // Determine the number of instances per class in the evaluation set
        Map<String, Long> evalClassCounts = evaluationCorpus.stream()
                .collect(Collectors.groupingBy(d -> d[0], Collectors.counting()));
        for (String clazz : classes) {
            if (!evalClassCounts.containsKey(clazz)) {
                evalClassCounts.put(clazz, 0L);
            }
        }

        // Determine the expected accuracies of the baselines
        Map<String, Double> accForClassGuessers = new HashMap<>();
        for (Entry<String, Long> e : evalClassCounts.entrySet()) {
            accForClassGuessers.put(e.getKey(), e.getValue() / (double) evaluationCorpus.size());
        }
        double accRandomGuesser = 1.0 / accForClassGuessers.size();

        // Train the classifier
        long time1 = System.currentTimeMillis();
        BayesianLearner learner = new BayesianLearner(classes);
        for (String[] trainingExample : trainingCorpus) {
            learner.learnExample(trainingExample[0], trainingExample[1]);
        }
        BayesianClassifier classifier = learner.createClassifier();
        time1 = System.currentTimeMillis() - time1;
        System.out.println("Training took       : " + time1 + "ms");

        // Classify the evaluation corpus
        long time2 = System.currentTimeMillis();
        int tp = 0, errors = 0, id = 0;
        String result;
        List<String[]> fpDetails = new ArrayList<>();
        for (String[] evalExample : evaluationCorpus) {
            result = classifier.classify(evalExample[1]);
            if (evalExample[0].equals(result)) {
                ++tp;
            } else {
                ++errors;
                fpDetails.add(new String[] { Integer.toString(id), evalExample[0], result });
            }
            ++id;
        }
        time2 = System.currentTimeMillis() - time2;
        System.out.println("Classification took : " + time2 + "ms");
        accuracy = tp / (double) (tp + errors);

        System.out.println("Baseline classifiers: ");
        for (Entry<String, Double> baseResult : accForClassGuessers.entrySet()) {
            System.out.println(String.format("Always %-13s: %-7.5f", baseResult.getKey(), baseResult.getValue()));
        }
        System.out.println(String.format("Random guesser      : %-7.5f", accRandomGuesser));
        System.out.println(String.format("Your solution       : %-7.5f (%d tp, %d errors)", accuracy, tp, errors));
        if (fpDetails.size() > 0) {
            System.out.println("  Wrong classifications are:");
            for (int i = 0; i < Math.min(fpDetails.size(), 20); ++i) {
                System.out.print("    id=");
                System.out.print(fpDetails.get(i)[0]);
                System.out.print(" expected=");
                System.out.print(fpDetails.get(i)[1]);
                System.out.print(" result=");
                System.out.println(fpDetails.get(i)[2]);
            }
            if (fpDetails.size() > 20) {
                System.out.println("    ...");
            }
        }

        // Make sure that the students solution is better than all baselines
        for (Entry<String, Double> baseResult : accForClassGuessers.entrySet()) {
            if (baseResult.getValue() >= accuracy) {
                StringBuilder builder = new StringBuilder();
                builder.append("Your solution is not better than a classifier that always chooses the \"");
                builder.append(baseResult.getKey());
                builder.append("\" class.");
                Assertions.fail(builder.toString());
            }
        }
        if (accRandomGuesser >= accuracy) {
            Assertions.fail("Your solution is not better than a random guesser.");
        }
        if ((minAccuracy > 0) && (minAccuracy > accuracy)) {
            Assertions.fail("Your solution did not reach the expected accuracy of " + minAccuracy);
        }
        System.out.println("Test successfully completed.");
    } catch (AssertionFailedError e) {
        throw e;
    } catch (Throwable e) {
        System.err.println("Your solution caused an unexpected error:");
        throw e;
    }
    return accuracy;
}

/*
 * Test case 1: a simple example corpus which is easy to do by hand.
 */
System.out.println("---------- Simple example corpus ----------");
List<String[]> exampleCorpusTrain = Arrays.asList(
        new String[] {"chess", "white king, black rook, black queen, white pawn, black knight, white bishop." },
        new String[] {"history", "knight person granted honorary title knighthood" },
        new String[] {"history", "knight order eligibility, knighthood, head of state, king, prelate, middle ages." },
        new String[] {"chess", "Defense knight pawn opening game opponent." },
        new String[] {"literature", "Knights Round Table. King Arthur. literary cycle Matter of Britain."}
        );
List<String[]> exampleCorpusTest = Arrays.asList(
        new String[] {"history", "Knighthood Middle Ages." },
        new String[] {"chess", "player king knight opponent king checkmate game draw." },
        // document with unknown words
        new String[] {"literature", "britain king arthur. Sir Galahad." }
        );

double accuracy = checkClassifier(exampleCorpusTrain, exampleCorpusTest, 0);


/*
 * Test case 2: a more complex example on real-world data.
 */
System.out.println();
System.out.println("---------- Larger example corpus ----------");
List<String[]> classificationData =readClassData("/srv/distribution/single-class-train.tsv");
accuracy = checkClassifier(classificationData.subList(0, 750), classificationData.subList(750, classificationData.size()), 0);

---------- Simple example corpus ----------
Training corpus size: 5
Eval. corpus size   : 3
Training took       : 1ms
Classification took : 0ms
Baseline classifiers: 
Always literature   : 0.33333
Always chess        : 0.33333
Always history      : 0.33333
Random guesser      : 0.33333
Your solution       : 1.00000 (3 tp, 0 errors)
Test successfully completed.

---------- Larger example corpus ----------
Training corpus size: 750
Eval. corpus size   : 260
Training took       : 4039ms
Classification took : 1270ms
Baseline classifiers: 
Always gold         : 0.03462
Always money-fx     : 0.19231
Always trade        : 0.16923
Always interest     : 0.11538
Always coffee       : 0.06154
Always money-supply : 0.10000
Always ship         : 0.05000
Always sugar        : 0.06538
Always crude        : 0.21154
Random guesser      : 0.11111
Your solution       : 0.85385 (222 tp, 38 errors)
  Wrong classifications are:
    id=5 expected=money-fx result=interest
    id=10 expected=money-supply resul

In [3]:
// Ignore this cell

In [4]:
// Ignore this cell

In [5]:
// Ignore this cell