Navigation Menu

Skip to content

Commit

Permalink
Fix bayesian classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
erelsgl committed Jan 30, 2017
1 parent 5ddd122 commit ee958aa
Showing 1 changed file with 50 additions and 40 deletions.
90 changes: 50 additions & 40 deletions classifiers/bayesian/bayesian.js
Expand Up @@ -2,9 +2,9 @@ var _ = require("underscore")._;

/**
* A multi-class single-label Bayes classifier.
*
*
* @author Erel Segal-Halevi based on code by Heather Arthur (https://github.com/harthur/classifier)
*
*
* @param options
*/
var Bayesian = function(options) {
Expand Down Expand Up @@ -37,7 +37,7 @@ Bayesian.prototype = {

/**
* Tell the classifier that the given document belongs to the given category.
* @param document [string] a training sample - a feature-value hash: {feature1: value1, feature2: value2, ...}
* @param document [string] a training sample - a feature-value hash: {feature1: value1, feature2: value2, ...}
* @param category [string] the correct category of this sample.
*/
trainOnline: function(document, category) {
Expand All @@ -47,12 +47,12 @@ Bayesian.prototype = {
/**
* Train the classifier with all the given documents.
* @param data an array with objects of the format: {input: sample1, output: category1}
* where sample1 is a feature-value hash: {feature1: value1, feature2: value2, ...}
* where sample1 is a feature-value hash: {feature1: value1, feature2: value2, ...}
*/
trainBatch: function(data) {
this.incDocCounts(data);
},

/**
* Ask the classifier what category the given document belongs to.
* @param document a hash {feature1: value1, feature2: value2, ...}
Expand All @@ -64,7 +64,7 @@ Bayesian.prototype = {
throw new Error("document should be a feature-value hash, but it is "+JSON.stringify(document));
}
var probs = this.getProbsSync(document);

var max = this.bestMatch(probs);
if (explain>0) {
return {
Expand All @@ -82,28 +82,29 @@ Bayesian.prototype = {
* A subroutine used for classification.
* Gets the probabilities of the words in the given sentence.
* @param document a hash {feature1: value1, feature2: value2, ...}
* Values are numeric and represent number of occurences.
*/
getProbsSync: function(document) {
var words = Object.keys(document); // an array with the unique words in the text, for example: [ 'free', 'watches' ]
var cats = this.getCats(); // a hash with the possible categories: { 'cat1': 1, 'cat2': 1 }
var counts = this.getWordCounts(words, cats); // For each word encountered during training, the counts of times it occurred in each category.
var probs = this.getCatProbs(cats, words, counts); // The probabilities that the given document belongs to each of the categories, i.e.: { 'cat1': 0.1875, 'cat2': 0.0625 }

var counts = this.getWordCounts(Object.keys(document), cats); // For each word encountered during training, the counts of times it occurred in each category.

var probs = this.getCatProbs(cats, document, counts); // The probabilities that the given document belongs to each of the categories, i.e.: { 'cat1': 0.1875, 'cat2': 0.0625 }

if (this.normalizeOutputProbabilities) {
var sum = _(probs).reduce(function(memo, num) { return memo + num; }, 0);
for (var cat in probs)
probs[cat] = probs[cat]/sum;
}

var pairs = _.pairs(probs); // pairs of [category,probability], for all categories that appeared in the training set.
//console.dir(pairs);
if (pairs.length==0) {
return {category: this.default, probability: 0};
}
pairs.sort(function(a,b) { // sort by decreasing prob
pairs.sort(function(a,b) { // sort by decreasing prob
return b[1]-a[1];
});

return pairs;
},

Expand All @@ -115,23 +116,23 @@ Bayesian.prototype = {
bestMatch: function(pairs) {
var maxCategory = pairs[0][0];
var maxProbability = pairs[0][1];

if (pairs.length>1) {
var nextProbability = pairs[1][1];
var threshold = this.thresholds[maxCategory] || this.globalThreshold;
if (nextProbability * threshold > maxProbability)
maxCategory = this.default; // not greater than other category by enough
if (this.calculateRelativeProbabilities)
if (this.calculateRelativeProbabilities)
maxProbability /= nextProbability;
}

return {
category: maxCategory,
probability: maxProbability
};
};
},


toJSON: function(callback) {
return this.backend.toJSON(callback);
},
Expand All @@ -143,55 +144,64 @@ Bayesian.prototype = {
getCats: function(callback) {
return this.backend.getCats(callback);
},



/*
*
* Internal functions (should be private):
*
* Internal functions (should be private):
*
*/
wordProb: function(word, cat, cats, counts) {

wordProb: function(word, cat, cats, wordCounts) {
// times word appears in a doc in this cat / docs in this cat
var prob = (counts[cat] || 0) / cats[cat];
var probWordGivenCat = (wordCounts[cat] || 0) / cats[cat];

// get weighted average with assumed so prob won't be extreme on rare words
var total = _(cats).reduce(function(sum, p, cat) {
return sum + (counts[cat] || 0);
var totalWordCount = _(cats).reduce(function(sum, p, cat) {
return sum + (wordCounts[cat] || 0);
}, 0, this);
return (this.weight * this.assumed + total * prob) / (this.weight + total);
// get weighted average with assumed so prob won't be extreme on rare words
var modifiedProbGivenCat = (this.weight * this.assumed + totalWordCount * probWordGivenCat) / (this.weight + totalWordCount);

//console.log("word="+word+" cat="+cat+" probWordGivenCat="+probWordGivenCat+" totalWordCount="+totalWordCount+" modifiedProbGivenCat="+modifiedProbGivenCat)
return modifiedProbGivenCat
},

getCatProbs: function(cats, words, counts) {
getCatProbs: function(cats, document, counts) {
var numDocs = _(cats).reduce(function(sum, count) {
return sum + count;
}, 0);
}, 0); // total number of training samples in all categories

var probs = {};
_(cats).each(function(catCount, cat) {
var catProb = (catCount || 0) / numDocs;
var catPriorProb = (catCount || 0) / numDocs;

var docProb = _(words).reduce(function(prob, word) {
// The probability to see a document is the product
// of the probability to see each word in the document.
var docProb = _(Object.keys(document)).reduce(function(prob, word) {
var wordCounts = counts[word] || {};
return prob * this.wordProb(word, cat, cats, wordCounts);
var probWordGivenCat = this.wordProb(word, cat, cats, wordCounts);
var probWordsGivenCat = Math.pow(probWordGivenCat, document[word]);
//console.log("probWordGivenCat="+probWordGivenCat+" probWordsGivenCat="+probWordsGivenCat+" document[word]="+document[word])
return prob * probWordsGivenCat;
}, 1, this);
//console.log("docProb="+docProb)

// the probability this doc is in this category
probs[cat] = catProb * docProb;
probs[cat] = catPriorProb * docProb;
}, this);
return probs;
},


getWordCounts: function(words, cats, callback) {
return this.backend.getWordCounts(words, cats, callback);
},

/**
* Increment the feature counts.
* Increment the feature counts.
* @param data an array with objects of the format: {input: sample1, output: class1}
* where sample1 is a feature-value hash: {feature1: value1, feature2: value2, ...}
* where sample1 is a feature-value hash: {feature1: value1, feature2: value2, ...}
*/
incDocCounts: function(samples, callback) {
// accumulate all the pending increments
Expand Down

0 comments on commit ee958aa

Please sign in to comment.