Skip to content

Commit

Permalink
Move check_input to utils to avoid code-duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
barucden committed May 17, 2021
1 parent c0184f4 commit c516bea
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 65 deletions.
33 changes: 1 addition & 32 deletions src/classification/tree.jl
Expand Up @@ -225,37 +225,6 @@ module treeclassifier
node.r = NodeMeta{S}(features, region[ind+1:end], node.depth+1)
end

function check_input(
X :: AbstractMatrix{S},
Y :: AbstractVector{T},
W :: AbstractVector{U},
max_features :: Int,
max_depth :: Int,
min_samples_leaf :: Int,
min_samples_split :: Int,
min_purity_increase :: Float64) where {S, T, U}
n_samples, n_features = size(X)
if length(Y) != n_samples
throw("dimension mismatch between X and Y ($(size(X)) vs $(size(Y))")
elseif length(W) != n_samples
throw("dimension mismatch between X and W ($(size(X)) vs $(size(W))")
elseif max_depth < -1
throw("unexpected value for max_depth: $(max_depth) (expected:"
* " max_depth >= 0, or max_depth = -1 for infinite depth)")
elseif n_features < max_features
throw("number of features $(n_features) is less than the number "
* "of max features $(max_features)")
elseif max_features < 0
throw("number of features $(max_features) must be >= zero ")
elseif min_samples_leaf < 1
throw("min_samples_leaf must be a positive integer "
* "(given $(min_samples_leaf))")
elseif min_samples_split < 2
throw("min_samples_split must be at least 2 "
* "(given $(min_samples_split))")
end
end

function _fit(
X :: AbstractMatrix{S},
Y :: AbstractVector{Int},
Expand Down Expand Up @@ -321,7 +290,7 @@ module treeclassifier
W = fill(1, n_samples)
end

check_input(
util.check_input(
X, Y, W,
max_features,
max_depth,
Expand Down
33 changes: 1 addition & 32 deletions src/regression/tree.jl
Expand Up @@ -228,37 +228,6 @@ module treeregressor
node.r = NodeMeta{S}(features, region[ind+1:end], node.depth + 1)
end

function check_input(
X :: AbstractMatrix{S},
Y :: AbstractVector{T},
W :: AbstractVector{U},
max_features :: Int,
max_depth :: Int,
min_samples_leaf :: Int,
min_samples_split :: Int,
min_purity_increase :: Float64) where {S, T, U}
n_samples, n_features = size(X)
if length(Y) != n_samples
throw("dimension mismatch between X and Y ($(size(X)) vs $(size(Y))")
elseif length(W) != n_samples
throw("dimension mismatch between X and W ($(size(X)) vs $(size(W))")
elseif max_depth < -1
throw("unexpected value for max_depth: $(max_depth) (expected:"
* " max_depth >= 0, or max_depth = -1 for infinite depth)")
elseif n_features < max_features
throw("number of features $(n_features) is less than the number "
* "of max features $(max_features)")
elseif max_features < 0
throw("number of features $(max_features) must be >= zero ")
elseif min_samples_leaf < 1
throw("min_samples_leaf must be a positive integer "
* "(given $(min_samples_leaf))")
elseif min_samples_split < 2
throw("min_samples_split must be at least 2 "
* "(given $(min_samples_split))")
end
end

function _fit(
X :: AbstractMatrix{S},
Y :: AbstractVector{Float64},
Expand Down Expand Up @@ -318,7 +287,7 @@ module treeregressor
W = fill(1.0, n_samples)
end

check_input(
util.check_input(
X,
Y,
W,
Expand Down
34 changes: 33 additions & 1 deletion src/util.jl
Expand Up @@ -3,7 +3,7 @@

module util

export gini, entropy, zero_one, q_bi_sort!, hypergeometric
export gini, entropy, zero_one, q_bi_sort!, hypergeometric, check_input

function assign(Y :: AbstractVector{T}, list :: AbstractVector{T}) where T
dict = Dict{T, Int}()
Expand Down Expand Up @@ -297,5 +297,37 @@ module util
end
end

function check_input(
X :: AbstractMatrix{S},
Y :: AbstractVector{T},
W :: AbstractVector{U},
max_features :: Int,
max_depth :: Int,
min_samples_leaf :: Int,
min_samples_split :: Int,
min_purity_increase :: Float64) where {S, T, U}
n_samples, n_features = size(X)
if length(Y) != n_samples
throw("dimension mismatch between X and Y ($(size(X)) vs $(size(Y))")
elseif length(W) != n_samples
throw("dimension mismatch between X and W ($(size(X)) vs $(size(W))")
elseif max_depth < -1
throw("unexpected value for max_depth: $(max_depth) (expected:"
* " max_depth >= 0, or max_depth = -1 for infinite depth)")
elseif n_features < max_features
throw("number of features $(n_features) is less than the number "
* "of max features $(max_features)")
elseif max_features < 0
throw("number of features $(max_features) must be >= zero ")
elseif min_samples_leaf < 1
throw("min_samples_leaf must be a positive integer "
* "(given $(min_samples_leaf))")
elseif min_samples_split < 2
throw("min_samples_split must be at least 2 "
* "(given $(min_samples_split))")
end
end


end

0 comments on commit c516bea

Please sign in to comment.