Skip to content

Commit

Permalink
Merge 2d51e31 into dda13d2
Browse files Browse the repository at this point in the history
  • Loading branch information
Eight1911 committed Jul 5, 2018
2 parents dda13d2 + 2d51e31 commit 3c8d788
Show file tree
Hide file tree
Showing 14 changed files with 537 additions and 412 deletions.
4 changes: 4 additions & 0 deletions src/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ convert(::Type{Node}, x::Leaf) = Node(0, nothing, x, Leaf(nothing,[nothing]))
promote_rule(::Type{Node}, ::Type{Leaf}) = Node
promote_rule(::Type{Leaf}, ::Type{Node}) = Node

function mean(l)
return sum(l) / length(l)
end

##############################
########## Includes ##########

Expand Down
96 changes: 44 additions & 52 deletions src/classification/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,79 +38,71 @@ end

################################################################################

function _split_neg_z1_loss(labels::Vector, features::Matrix, weights::Vector)
best = NO_BEST
best_val = -Inf
for i in 1:size(features,2)
domain_i = sort(unique(features[:,i]))
for thresh in domain_i[2:end]
cur_split = features[:,i] .< thresh
value = _neg_z1_loss(labels[cur_split], weights[cur_split]) + _neg_z1_loss(labels[(!).(cur_split)], weights[(!).(cur_split)])
if value > best_val
best_val = value
best = (i, thresh)
end
end
end
return best
end

function build_stump(labels::Vector, features::Matrix, weights=[0];
rng=Random.GLOBAL_RNG)
if weights == [0]
return build_tree(labels, features, 0, 1)
weights = nothing
end
S = _split_neg_z1_loss(labels, features, weights)
if S == NO_BEST
return Leaf(majority_vote(labels), labels)

t = treeclassifier.fit_zero_one(
X = features,
Y = labels,
W = weights,
max_features = size(features, 2),
max_depth = 1,
min_samples_leaf = 1,
min_samples_split = 2,
min_purity_increase = 0.0,
rng = rng)

function _convert(node::treeclassifier.NodeMeta, labels_list::Array, labels::Array)
if node.is_leaf
return Leaf(labels_list[node.label], labels[node.region])
else
left = _convert(node.l, labels_list, labels)
right = _convert(node.r, labels_list, labels)
return Node(node.feature, node.threshold, left, right)
end
end
id, thresh = S
left = features[:,id] .< thresh
l_labels = labels[left]
r_labels = labels[(!).(left)]
return Node(id, thresh,
Leaf(majority_vote(l_labels), l_labels),
Leaf(majority_vote(r_labels), r_labels))

return _convert(t.root, t.list, labels[t.labels])
end

function build_tree(labels::Vector, features::Matrix, n_subfeatures=0, max_depth=-1,
min_samples_leaf=1, min_samples_split=2, min_purity_increase=0.0;
min_samples_leaf=1, min_samples_split=2, min_purity_increase=0.0;
rng=Random.GLOBAL_RNG)
rng = mk_rng(rng)::Random.AbstractRNG
if max_depth < -1
error("Unexpected value for max_depth: $(max_depth) (expected: max_depth >= 0, or max_depth = -1 for infinite depth)")
end

if max_depth == -1
max_depth = typemax(Int64)
end

if n_subfeatures == 0
n_subfeatures = size(features, 2)
end
min_samples_leaf = Int64(min_samples_leaf)
min_samples_split = Int64(min_samples_split)
min_purity_increase = Float64(min_purity_increase)
t = treeclassifier.fit(
features, labels, n_subfeatures, max_depth,
min_samples_leaf, min_samples_split, min_purity_increase,
rng=rng)

function _convert(node :: treeclassifier.NodeMeta, labels :: Array)
rng = mk_rng(rng)::Random.AbstractRNG
t = treeclassifier.fit(
X = features,
Y = labels,
W = nothing,
max_features = n_subfeatures,
max_depth = max_depth,
min_samples_leaf = Int64(min_samples_leaf),
min_samples_split = Int64(min_samples_split),
min_purity_increase = Float64(min_purity_increase),
rng = rng)

function _convert(node::treeclassifier.NodeMeta, labels_list::Array, labels::Array)
if node.is_leaf
distribution = []
for i in 1:length(node.labels)
counts = node.labels[i]
for _ in 1:counts
push!(distribution, labels[i])
end
end
return Leaf(labels[node.label], distribution)
return Leaf(labels_list[node.label], labels[node.region])
else
left = _convert(node.l, labels)
right = _convert(node.r, labels)
left = _convert(node.l, labels_list, labels)
right = _convert(node.r, labels_list, labels)
return Node(node.feature, node.threshold, left, right)
end
end
return _convert(t.root, t.list)

return _convert(t.root, t.list, labels[t.labels])
end

function prune_tree(tree::LeafOrNode, purity_thresh=1.0)
Expand Down

0 comments on commit 3c8d788

Please sign in to comment.