diff --git a/lib/classifier-reborn/bayes.rb b/lib/classifier-reborn/bayes.rb index 617a1ca..ce032e4 100644 --- a/lib/classifier-reborn/bayes.rb +++ b/lib/classifier-reborn/bayes.rb @@ -14,6 +14,7 @@ def initialize(*categories) categories.each { |category| @categories[CategoryNamer.prepare_name(category)] = Hash.new } @total_words = 0 @category_counts = Hash.new(0) + @category_word_count = Hash.new end # Provides a general training method for all categories specified in Bayes#new @@ -24,10 +25,12 @@ def initialize(*categories) # b.train "The other", "The other text" def train(category, text) category = CategoryNamer.prepare_name(category) - @category_counts[category] += 1 + @category_word_count[category] ||= 0 + @category_counts[category] += 1 Hasher.word_hash(text).each do |word, count| @categories[category][word] ||= 0 @categories[category][word] += count + @category_word_count[category] += count @total_words += count end end @@ -41,16 +44,21 @@ def train(category, text) # b.untrain :this, "This text" def untrain(category, text) category = CategoryNamer.prepare_name(category) + @category_word_count[category] ||= 0 @category_counts[category] -= 1 Hasher.word_hash(text).each do |word, count| if @total_words >= 0 orig = @categories[category][word] || 0 - @categories[category][word] ||= 0 - @categories[category][word] -= count + @categories[category][word] ||= 0 + @categories[category][word] -= count if @categories[category][word] <= 0 @categories[category].delete(word) count = orig end + + if @category_word_count[category] >= count + @category_word_count[category] -= count + end @total_words -= count end end @@ -62,13 +70,14 @@ def untrain(category, text) # The largest of these scores (the one closest to 0) is the one picked out by #classify def classifications(text) score = Hash.new - training_count = @category_counts.values.reduce(0, :+).to_f + word_hash = Hasher.word_hash(text) + training_count = @category_counts.values.reduce(:+).to_f @categories.each do |category, category_words| score[category.to_s] = 0 - total = category_words.values.reduce(0, :+) - Hasher.word_hash(text).each do |word, count| + total = (@category_word_count[category] || 1).to_f + word_hash.each do |word, count| s = category_words.has_key?(word) ? category_words[word] : 0.1 - score[category.to_s] += Math.log(s/total.to_f) + score[category.to_s] += Math.log(s/total) end # now add prior probability for the category s = @category_counts.has_key?(category) ? @category_counts[category] : 0.1