Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

bayes

  • Loading branch information...
commit 789d6d91ef6347229c73e584f775584a02d53720 1 parent ec23813
Mischa Fierer authored

Showing 1 changed file with 54 additions and 0 deletions. Show diff stats Hide diff stats

  1. +54 0 6/docclass.rb
54 6/docclass.rb
@@ -11,6 +11,15 @@ def initialize(getfeatures, filename=nil)
11 11 @fc = Hash.new{|h,k| h[k] = Hash.new{|h1,k1| h1[k1] = 0}} #hash of hashes of 0
12 12 @cc = Hash.new{|h,k| h[k] = 0}
13 13 @getfeatures = getfeatures
  14 + @thresholds = Hash.new{|h,k| h[k] = 0}
  15 + end
  16 +
  17 + def setthreshold(cat, t)
  18 + @thresholds[cat] = t
  19 + end
  20 +
  21 + def getthreshold(cat)
  22 + @thresholds[cat] || 0
14 23 end
15 24
16 25 def incf(f, cat)
@@ -56,6 +65,29 @@ def weightedprob(f, cat, prf, weight=1.0, ap=0.5)
56 65
57 66 ((weight * ap ) + (totals * basicprob)) / (weight + totals)
58 67 end
  68 +
  69 + def classify(item, default='unknown')
  70 + probs = {}
  71 +
  72 + max = 0.0
  73 + best = 'hi'
  74 +
  75 + categories.each do |cat|
  76 + probs[cat] = prob(item,cat)
  77 + if probs[cat] > max
  78 + max = probs[cat]
  79 + best = cat
  80 + end
  81 + end
  82 +
  83 + probs.each do |cat, val|
  84 + next if cat == best
  85 + return default if probs[cat] * getthreshold(best) > probs[best]
  86 + end
  87 +
  88 + best
  89 + end
  90 +
59 91 end
60 92
61 93 def sampletrain(cl)
@@ -95,6 +127,7 @@ def test
95 127 test2
96 128 test3
97 129 test4
  130 + test5
98 131 end
99 132
100 133 def test1
@@ -136,4 +169,25 @@ def test4
136 169 puts("should be 0.05000...3")
137 170 end
138 171
  172 +def test5
  173 + cl = NaiveBayes.new(@@getfeatures)
  174 + sampletrain(cl)
  175 +
  176 + p cl.classify('quick rabbit')
  177 + puts "should be good"
  178 +
  179 + p cl.classify('quick money')
  180 + puts "should be bad"
  181 +
  182 + cl.setthreshold('bad', 3.0)
  183 +
  184 + p cl.classify('quick money')
  185 + puts "should be unknown"
  186 +
  187 + 10.times { sampletrain(cl) }
  188 +
  189 + p cl.classify('quick money')
  190 + puts "should be bad"
  191 +
  192 +end
139 193

0 comments on commit 789d6d9

Please sign in to comment.
Something went wrong with that request. Please try again.