From ee958aaa4bacd2197fc098921b7962cda1c57dcb Mon Sep 17 00:00:00 2001 From: Erel Segal-Halevi cron Date: Mon, 30 Jan 2017 10:37:16 +0200 Subject: [PATCH] Fix bayesian classifier --- classifiers/bayesian/bayesian.js | 90 ++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 40 deletions(-) diff --git a/classifiers/bayesian/bayesian.js b/classifiers/bayesian/bayesian.js index 262b5c7..044f18a 100644 --- a/classifiers/bayesian/bayesian.js +++ b/classifiers/bayesian/bayesian.js @@ -2,9 +2,9 @@ var _ = require("underscore")._; /** * A multi-class single-label Bayes classifier. - * + * * @author Erel Segal-Halevi based on code by Heather Arthur (https://github.com/harthur/classifier) - * + * * @param options */ var Bayesian = function(options) { @@ -37,7 +37,7 @@ Bayesian.prototype = { /** * Tell the classifier that the given document belongs to the given category. - * @param document [string] a training sample - a feature-value hash: {feature1: value1, feature2: value2, ...} + * @param document [string] a training sample - a feature-value hash: {feature1: value1, feature2: value2, ...} * @param category [string] the correct category of this sample. */ trainOnline: function(document, category) { @@ -47,12 +47,12 @@ Bayesian.prototype = { /** * Train the classifier with all the given documents. * @param data an array with objects of the format: {input: sample1, output: category1} - * where sample1 is a feature-value hash: {feature1: value1, feature2: value2, ...} + * where sample1 is a feature-value hash: {feature1: value1, feature2: value2, ...} */ trainBatch: function(data) { this.incDocCounts(data); }, - + /** * Ask the classifier what category the given document belongs to. * @param document a hash {feature1: value1, feature2: value2, ...} @@ -64,7 +64,7 @@ Bayesian.prototype = { throw new Error("document should be a feature-value hash, but it is "+JSON.stringify(document)); } var probs = this.getProbsSync(document); - + var max = this.bestMatch(probs); if (explain>0) { return { @@ -82,28 +82,29 @@ Bayesian.prototype = { * A subroutine used for classification. * Gets the probabilities of the words in the given sentence. * @param document a hash {feature1: value1, feature2: value2, ...} + * Values are numeric and represent number of occurences. */ getProbsSync: function(document) { - var words = Object.keys(document); // an array with the unique words in the text, for example: [ 'free', 'watches' ] var cats = this.getCats(); // a hash with the possible categories: { 'cat1': 1, 'cat2': 1 } - var counts = this.getWordCounts(words, cats); // For each word encountered during training, the counts of times it occurred in each category. - var probs = this.getCatProbs(cats, words, counts); // The probabilities that the given document belongs to each of the categories, i.e.: { 'cat1': 0.1875, 'cat2': 0.0625 } - + var counts = this.getWordCounts(Object.keys(document), cats); // For each word encountered during training, the counts of times it occurred in each category. + + var probs = this.getCatProbs(cats, document, counts); // The probabilities that the given document belongs to each of the categories, i.e.: { 'cat1': 0.1875, 'cat2': 0.0625 } + if (this.normalizeOutputProbabilities) { var sum = _(probs).reduce(function(memo, num) { return memo + num; }, 0); for (var cat in probs) probs[cat] = probs[cat]/sum; } - + var pairs = _.pairs(probs); // pairs of [category,probability], for all categories that appeared in the training set. //console.dir(pairs); if (pairs.length==0) { return {category: this.default, probability: 0}; } - pairs.sort(function(a,b) { // sort by decreasing prob + pairs.sort(function(a,b) { // sort by decreasing prob return b[1]-a[1]; }); - + return pairs; }, @@ -115,23 +116,23 @@ Bayesian.prototype = { bestMatch: function(pairs) { var maxCategory = pairs[0][0]; var maxProbability = pairs[0][1]; - + if (pairs.length>1) { var nextProbability = pairs[1][1]; var threshold = this.thresholds[maxCategory] || this.globalThreshold; if (nextProbability * threshold > maxProbability) maxCategory = this.default; // not greater than other category by enough - if (this.calculateRelativeProbabilities) + if (this.calculateRelativeProbabilities) maxProbability /= nextProbability; } return { category: maxCategory, probability: maxProbability - }; + }; }, - - + + toJSON: function(callback) { return this.backend.toJSON(callback); }, @@ -143,55 +144,64 @@ Bayesian.prototype = { getCats: function(callback) { return this.backend.getCats(callback); }, - - - + + + /* * - * Internal functions (should be private): - * + * Internal functions (should be private): + * */ - - wordProb: function(word, cat, cats, counts) { + + wordProb: function(word, cat, cats, wordCounts) { // times word appears in a doc in this cat / docs in this cat - var prob = (counts[cat] || 0) / cats[cat]; + var probWordGivenCat = (wordCounts[cat] || 0) / cats[cat]; - // get weighted average with assumed so prob won't be extreme on rare words - var total = _(cats).reduce(function(sum, p, cat) { - return sum + (counts[cat] || 0); + var totalWordCount = _(cats).reduce(function(sum, p, cat) { + return sum + (wordCounts[cat] || 0); }, 0, this); - return (this.weight * this.assumed + total * prob) / (this.weight + total); + // get weighted average with assumed so prob won't be extreme on rare words + var modifiedProbGivenCat = (this.weight * this.assumed + totalWordCount * probWordGivenCat) / (this.weight + totalWordCount); + + //console.log("word="+word+" cat="+cat+" probWordGivenCat="+probWordGivenCat+" totalWordCount="+totalWordCount+" modifiedProbGivenCat="+modifiedProbGivenCat) + return modifiedProbGivenCat }, - getCatProbs: function(cats, words, counts) { + getCatProbs: function(cats, document, counts) { var numDocs = _(cats).reduce(function(sum, count) { return sum + count; - }, 0); + }, 0); // total number of training samples in all categories var probs = {}; _(cats).each(function(catCount, cat) { - var catProb = (catCount || 0) / numDocs; + var catPriorProb = (catCount || 0) / numDocs; - var docProb = _(words).reduce(function(prob, word) { + // The probability to see a document is the product + // of the probability to see each word in the document. + var docProb = _(Object.keys(document)).reduce(function(prob, word) { var wordCounts = counts[word] || {}; - return prob * this.wordProb(word, cat, cats, wordCounts); + var probWordGivenCat = this.wordProb(word, cat, cats, wordCounts); + var probWordsGivenCat = Math.pow(probWordGivenCat, document[word]); + //console.log("probWordGivenCat="+probWordGivenCat+" probWordsGivenCat="+probWordsGivenCat+" document[word]="+document[word]) + return prob * probWordsGivenCat; }, 1, this); + //console.log("docProb="+docProb) // the probability this doc is in this category - probs[cat] = catProb * docProb; + probs[cat] = catPriorProb * docProb; }, this); return probs; }, - - + + getWordCounts: function(words, cats, callback) { return this.backend.getWordCounts(words, cats, callback); }, /** - * Increment the feature counts. + * Increment the feature counts. * @param data an array with objects of the format: {input: sample1, output: class1} - * where sample1 is a feature-value hash: {feature1: value1, feature2: value2, ...} + * where sample1 is a feature-value hash: {feature1: value1, feature2: value2, ...} */ incDocCounts: function(samples, callback) { // accumulate all the pending increments