Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

Loading…

Fix infinite recursion. #8

Merged
merged 2 commits into from

2 participants

@rustyio
require 'decisiontree'

labels = [:a, :b, :c]
data = [
  ["a1", "b0", "c0", "RED"],
  ["a1", "b1", "c1", "RED"],
  ["a1", "b1", "c0", "BLUE"],
  ["a1", "b0", "c1", "BLUE"]
]

tree = DecisionTree::ID3Tree.new(labels, data, "RED", :discrete)
tree.train

The preceding sample causes infinite recursion.

When all attributes result in the same fitness score, the code continuously chooses to partition on the first attribute. As you can see in the example above, this does nothing to partition the data.

While it may not be the provably best approach, this pull request solves the problem by checking if all attributes result in the same fitness score, and if so, choosing a best attribute randomly.

@igrigorik igrigorik merged commit 7bf93a3 into from
@igrigorik
Owner

Nice, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Commits on Jan 15, 2013
  1. @rustyio

    Fix infinite recursion.

    rustyio authored
  2. @rustyio

    Only compare scores.

    rustyio authored
This page is out of date. Refresh to see the latest.
Showing with 21 additions and 1 deletion.
  1. +6 −1 lib/decisiontree/id3_tree.rb
  2. +15 −0 spec/id3_spec.rb
View
7 lib/decisiontree/id3_tree.rb
@@ -69,9 +69,14 @@ def id3_train(data, attributes, default, used={})
# return classification if all examples have the same classification
return data.first.last if data.classification.uniq.size == 1
- # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
+ # Choose best attribute:
+ # 1. enumerate all attributes
+ # 2. Pick best attribute
+ # 3. If attributes all score the same, then pick a random one to avoid infinite recursion.
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
max = performance.max { |a,b| a[0] <=> b[0] }
+ min = performance.min { |a,b| a[0] <=> b[0] }
+ max = performance.shuffle.first if max[0] == min[0]
best = Node.new(attributes[performance.index(max)], max[1], max[0])
best.threshold = nil if @type == :discrete
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
View
15 spec/id3_spec.rb
@@ -74,4 +74,19 @@
Then { tree.predict([2, "blue"]).should == "not angry" }
end
+ describe "infinite recursion case" do
+ Given(:labels) { [:a, :b, :c] }
+ Given(:data) do
+ [
+ ["a1", "b0", "c0", "RED"],
+ ["a1", "b1", "c1", "RED"],
+ ["a1", "b1", "c0", "BLUE"],
+ ["a1", "b0", "c1", "BLUE"]
+ ]
+ end
+ Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "RED", :discrete) }
+ When { tree.train }
+ Then { tree.predict(["a1","b0","c0"]).should == "RED" }
+ end
+
end
Something went wrong with that request. Please try again.