Browse files

added support for continuous and discrete attributes in the same dataset

  • Loading branch information...
1 parent 3e3badc commit d1dce9be91dac7cde4c14995f809f15ab9e66faf @superchris superchris committed Oct 26, 2012
Showing with 150 additions and 36 deletions.
  1. +19 −0 ..gemspec
  2. +17 −0 .gitignore
  3. +4 −0 Gemfile
  4. +4 −1 decisiontree.gemspec
  5. +39 −35 lib/decisiontree/id3_tree.rb
  6. +64 −0 spec/id3_spec.rb
  7. +3 −0 spec/spec_helper.rb
View
19 ..gemspec
@@ -0,0 +1,19 @@
+# -*- encoding: utf-8 -*-
+lib = File.expand_path('../lib', __FILE__)
+$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
+require './version'
+
+Gem::Specification.new do |gem|
+ gem.name = "."
+ gem.version = .::VERSION
+ gem.authors = ["Chris Nelson"]
+ gem.email = ["chris@gaslightsoftware.com"]
+ gem.description = %q{TODO: Write a gem description}
+ gem.summary = %q{TODO: Write a gem summary}
+ gem.homepage = ""
+
+ gem.files = `git ls-files`.split($/)
+ gem.executables = gem.files.grep(%r{^bin/}).map{ |f| File.basename(f) }
+ gem.test_files = gem.files.grep(%r{^(test|spec|features)/})
+ gem.require_paths = ["lib"]
+end
View
17 .gitignore
@@ -0,0 +1,17 @@
+*.gem
+*.rbc
+.bundle
+.config
+.yardoc
+Gemfile.lock
+InstalledFiles
+_yardoc
+coverage
+doc/
+lib/bundler/man
+pkg
+rdoc
+spec/reports
+test/tmp
+test/version_tmp
+tmp
View
4 Gemfile
@@ -0,0 +1,4 @@
+source 'https://rubygems.org'
+
+# Specify your gem's dependencies in ..gemspec
+gemspec
View
5 decisiontree.gemspec
@@ -47,7 +47,10 @@ Gem::Specification.new do |s|
"examples/simple.rb"
]
s.add_runtime_dependency "graphr"
-
+ s.add_development_dependency "rspec"
+ s.add_development_dependency "rspec-given"
+ s.add_development_dependency "pry"
+
if s.respond_to? :specification_version then
current_version = Gem::Specification::CURRENT_SPECIFICATION_VERSION
s.specification_version = 3
View
74 lib/decisiontree/id3_tree.rb
@@ -15,9 +15,9 @@ def self.load_from_file(filename)
end
end
-class Array
- def classification; collect { |v| v.last }; end
-
+class Array
+ def classification; collect { |v| v.last }; end
+
# calculate information entropy
def entropy
return 0 if empty?
@@ -51,28 +51,34 @@ def train(data=@data, attributes=@attributes, default=@default)
@tree = id3_train(data2, attributes, default)
end
-
- def id3_train(data, attributes, default, used={})
- # Choose a fitness algorithm
- case @type
- when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
+
+ def type(attribute)
+ @type.is_a?(Hash) ? @type[attribute.to_sym] : @type
+ end
+
+ def fitness_for(attribute)
+ case type(attribute)
+ when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)}
end
-
- return default if data.empty?
+ end
+
+ def id3_train(data, attributes, default, used={})
+ return default if data.empty?
# return classification if all examples have the same classification
return data.first.last if data.classification.uniq.size == 1
# Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
- performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) }
+ performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
max = performance.max { |a,b| a[0] <=> b[0] }
best = Node.new(attributes[performance.index(max)], max[1], max[0])
best.threshold = nil if @type == :discrete
- @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
+ @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
tree, l = {best => {}}, ['>=', '<']
-
- case @type
+
+ fitness = fitness_for(best.attribute)
+ case type(best.attribute)
when :continuous
data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i|
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
@@ -82,7 +88,7 @@ def id3_train(data, attributes, default, used={})
partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } }
partitions.each_with_index { |examples, i|
tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness)
- }
+ }
end
tree
@@ -96,32 +102,32 @@ def id3_continuous(data, attributes, attribute)
thresholds.pop
#thresholds -= used[attribute] if used.has_key? attribute
- gain = thresholds.collect { |threshold|
+ gain = thresholds.collect { |threshold|
sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
pos = (sp[0].size).to_f / data.size
neg = (sp[1].size).to_f / data.size
-
+
[data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold]
}.max { |a,b| a[0] <=> b[0] }
return [-1, -1] if gain.size == 0
gain
end
-
+
# ID3 for discrete label cases
def id3_discrete(data, attributes, attribute)
values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } }
remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i }
-
+
[data.classification.entropy - remainder, attributes.index(attribute)]
end
def predict(test)
- return (@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test))
+ descend(@tree, test)
end
- def graph(filename)
+ def graph(filename)
dgp = DotGraphPrinter.new(build_tree)
dgp.write_to_file("#{filename}.png", "png")
end
@@ -151,22 +157,20 @@ def build_rules(tree=@tree)
end
private
- def descend_continuous(tree, test)
+ def descend(tree, test)
attr = tree.to_a.first
return @default if !attr
- return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
- return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
- return descend_continuous(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
- return descend_continuous(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
- end
-
- def descend_discrete(tree, test)
- attr = tree.to_a.first
- return @default if !attr
- return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
- return descend_discrete(attr[1][test[@attributes.index(attr[0].attribute)]],test)
+ if type(attr.first.attribute) == :continuous
+ return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
+ return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
+ return descend(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
+ return descend(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
+ else
+ return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
+ return descend(attr[1][test[@attributes.index(attr[0].attribute)]],test)
+ end
end
-
+
def build_tree(tree = @tree)
return [] unless tree.is_a?(Hash)
return [["Always", @default]] if tree.empty?
@@ -282,7 +286,7 @@ def to_s
def predict(test)
@rules.each do |r|
- prediction = r.predict(test)
+ prediction = r.predict(test)
return prediction, r.accuracy unless prediction.nil?
end
return @default, 0.0
View
64 spec/id3_spec.rb
@@ -0,0 +1,64 @@
+require 'spec_helper'
+
+describe describe DecisionTree::ID3Tree do
+
+ describe "discrete attributes" do
+ Given(:labels) { ["hungry", "color"] }
+ Given(:data) do
+ [
+ ["yes", "red", "angry"],
+ ["no", "blue", "not angry"],
+ ["yes", "blue", "not angry"],
+ ["no", "red", "not angry"]
+ ]
+ end
+ Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :discrete) }
+ When { tree.train }
+ Then { tree.predict(["yes", "red"]).should == "angry" }
+ Then { tree.predict(["no", "red"]).should == "not angry" }
+ end
+
+ describe "discrete attributes" do
+ Given(:labels) { ["hunger", "happiness"] }
+ Given(:data) do
+ [
+ [8, 7, "angry"],
+ [6, 7, "angry"],
+ [7, 9, "angry"],
+ [7, 1, "not angry"],
+ [2, 9, "not angry"],
+ [3, 2, "not angry"],
+ [2, 3, "not angry"],
+ [1, 4, "not angry"]
+ ]
+ end
+ Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :continuous) }
+ When { tree.train }
+ Then { tree.graph("continuous") }
+ Then { tree.predict([7, 7]).should == "angry" }
+ Then { tree.predict([2, 3]).should == "not angry" }
+ end
+
+ describe "a mixture" do
+ Given(:labels) { ["hunger", "color"] }
+ Given(:data) do
+ [
+ [8, "red", "angry"],
+ [6, "red", "angry"],
+ [7, "red", "angry"],
+ [7, "blue", "not angry"],
+ [2, "red", "not angry"],
+ [3, "blue", "not angry"],
+ [2, "blue", "not angry"],
+ [1, "red", "not angry"]
+ ]
+ end
+ Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", color: :discrete, hunger: :continuous) }
+ When { tree.train }
+ Then { tree.graph("continuous") }
+ Then { tree.predict([7, "red"]).should == "angry" }
+ Then { tree.predict([2, "blue"]).should == "not angry" }
+ end
+
+
+end
View
3 spec/spec_helper.rb
@@ -0,0 +1,3 @@
+require 'rspec/given'
+require 'decisiontree'
+require 'pry'

0 comments on commit d1dce9b

Please sign in to comment.