Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Merge pull request #8 from rustyio/fix-infinite-recursion

Fix infinite recursion.
  • Loading branch information...
commit 7bf93a3a5e49318259ba610d98ee33f30e887dcb 2 parents a714dce + b451b32
@igrigorik authored
Showing with 21 additions and 1 deletion.
  1. +6 −1 lib/decisiontree/id3_tree.rb
  2. +15 −0 spec/id3_spec.rb
View
7 lib/decisiontree/id3_tree.rb
@@ -69,9 +69,14 @@ def id3_train(data, attributes, default, used={})
# return classification if all examples have the same classification
return data.first.last if data.classification.uniq.size == 1
- # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
+ # Choose best attribute:
+ # 1. enumerate all attributes
+ # 2. Pick best attribute
+ # 3. If attributes all score the same, then pick a random one to avoid infinite recursion.
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
max = performance.max { |a,b| a[0] <=> b[0] }
+ min = performance.min { |a,b| a[0] <=> b[0] }
+ max = performance.shuffle.first if max[0] == min[0]
best = Node.new(attributes[performance.index(max)], max[1], max[0])
best.threshold = nil if @type == :discrete
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
View
15 spec/id3_spec.rb
@@ -74,4 +74,19 @@
Then { tree.predict([2, "blue"]).should == "not angry" }
end
+ describe "infinite recursion case" do
+ Given(:labels) { [:a, :b, :c] }
+ Given(:data) do
+ [
+ ["a1", "b0", "c0", "RED"],
+ ["a1", "b1", "c1", "RED"],
+ ["a1", "b1", "c0", "BLUE"],
+ ["a1", "b0", "c1", "BLUE"]
+ ]
+ end
+ Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "RED", :discrete) }
+ When { tree.train }
+ Then { tree.predict(["a1","b0","c0"]).should == "RED" }
+ end
+
end
Please sign in to comment.
Something went wrong with that request. Please try again.