Skip to content

Commit

Permalink
Merge 5d442d8 into 579062f
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Feb 6, 2020
2 parents 579062f + 5d442d8 commit 81a8d51
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
7 changes: 5 additions & 2 deletions src/classification/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ function build_tree(
min_samples_leaf = 1,
min_samples_split = 2,
min_purity_increase = 0.0;
weights::Union{Nothing,AbstractVector{U}} = nothing,
loss = util.entropy :: Function,
rng = Random.GLOBAL_RNG) where {S, T}
rng = Random.GLOBAL_RNG) where {S, T, U <: Integer}

if max_depth == -1
max_depth = typemax(Int)
Expand All @@ -93,7 +94,7 @@ function build_tree(
t = treeclassifier.fit(
X = features,
Y = labels,
W = nothing,
W = weights,
loss = loss,
max_features = Int(n_subfeatures),
max_depth = Int(max_depth),
Expand Down Expand Up @@ -195,6 +196,7 @@ function build_forest(
min_samples_leaf = 1,
min_samples_split = 2,
min_purity_increase = 0.0;
weights = nothing,
rng = Random.GLOBAL_RNG) where {S, T}

if n_trees < 1
Expand Down Expand Up @@ -229,6 +231,7 @@ function build_forest(
min_samples_leaf,
min_samples_split,
min_purity_increase,
weights = (weights === nothing ? nothing : weights[inds]),
loss = loss,
rng = rngs)
end
Expand Down
4 changes: 2 additions & 2 deletions src/classification/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ module treeclassifier
function fit(;
X :: Matrix{S},
Y :: Vector{T},
W :: Union{Nothing, Vector{U}},
W :: Union{Nothing, AbstractVector{U}},
loss=util.entropy :: Function,
max_features :: Int,
max_depth :: Int,
Expand All @@ -318,7 +318,7 @@ module treeclassifier

n_samples, n_features = size(X)
list, Y_ = util.assign(Y)
if W == nothing
if W === nothing
W = fill(1, n_samples)
end

Expand Down
5 changes: 4 additions & 1 deletion src/regression/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ function build_tree(
min_samples_leaf = 5,
min_samples_split = 2,
min_purity_increase = 0.0;
weights = nothing,
rng = Random.GLOBAL_RNG) where {S, T <: Float64}

if max_depth == -1
Expand All @@ -35,7 +36,7 @@ function build_tree(
t = treeregressor.fit(
X = features,
Y = labels,
W = nothing,
W = weights,
max_features = Int(n_subfeatures),
max_depth = Int(max_depth),
min_samples_leaf = Int(min_samples_leaf),
Expand All @@ -56,6 +57,7 @@ function build_forest(
min_samples_leaf = 5,
min_samples_split = 2,
min_purity_increase = 0.0;
weights = nothing,
rng = Random.GLOBAL_RNG) where {S, T <: Float64}

if n_trees < 1
Expand Down Expand Up @@ -86,6 +88,7 @@ function build_forest(
min_samples_leaf,
min_samples_split,
min_purity_increase,
weights = (weights === nothing ? nothing : weights[inds]),
rng = rngs)
end

Expand Down
12 changes: 8 additions & 4 deletions src/scikitlearnAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ get_classes(dt::DecisionTreeClassifier) = dt.classes
[:pruning_purity_threshold, :max_depth, :min_samples_leaf,
:min_samples_split, :min_purity_increase, :rng])

function fit!(dt::DecisionTreeClassifier, X, y)
function fit!(dt::DecisionTreeClassifier, X, y, weights=nothing)
n_samples, n_features = size(X)
dt.root = build_tree(
y, X,
Expand All @@ -58,6 +58,7 @@ function fit!(dt::DecisionTreeClassifier, X, y)
dt.min_samples_leaf,
dt.min_samples_split,
dt.min_purity_increase;
weights = weights,
rng = dt.rng)

dt.root = prune_tree(dt.root, dt.pruning_purity_threshold)
Expand Down Expand Up @@ -136,7 +137,7 @@ end
[:pruning_purity_threshold, :min_samples_leaf, :n_subfeatures,
:max_depth, :min_samples_split, :min_purity_increase, :rng])

function fit!(dt::DecisionTreeRegressor, X::Matrix, y::Vector)
function fit!(dt::DecisionTreeRegressor, X::Matrix, y::Vector, weights=nothing)
n_samples, n_features = size(X)
dt.root = build_tree(
float.(y), X,
Expand All @@ -145,6 +146,7 @@ function fit!(dt::DecisionTreeRegressor, X::Matrix, y::Vector)
dt.min_samples_leaf,
dt.min_samples_split,
dt.min_purity_increase;
weights = weights,
rng = dt.rng)
dt.pruning_purity_threshold
dt.root = prune_tree(dt.root, dt.pruning_purity_threshold)
Expand Down Expand Up @@ -213,7 +215,7 @@ get_classes(rf::RandomForestClassifier) = rf.classes
:min_samples_leaf, :min_samples_split, :min_purity_increase,
:rng])

function fit!(rf::RandomForestClassifier, X::Matrix, y::Vector)
function fit!(rf::RandomForestClassifier, X::Matrix, y::Vector, weights=nothing)
n_samples, n_features = size(X)
rf.ensemble = build_forest(
y, X,
Expand All @@ -224,6 +226,7 @@ function fit!(rf::RandomForestClassifier, X::Matrix, y::Vector)
rf.min_samples_leaf,
rf.min_samples_split,
rf.min_purity_increase;
weights = weights,
rng = rf.rng)
rf.classes = sort(unique(y))
rf
Expand Down Expand Up @@ -297,7 +300,7 @@ end
# since it'll change throughout fitting, but it works
:max_depth, :rng])

function fit!(rf::RandomForestRegressor, X::Matrix, y::Vector)
function fit!(rf::RandomForestRegressor, X::Matrix, y::Vector, weights=nothing)
n_samples, n_features = size(X)
rf.ensemble = build_forest(
float.(y), X,
Expand All @@ -308,6 +311,7 @@ function fit!(rf::RandomForestRegressor, X::Matrix, y::Vector)
rf.min_samples_leaf,
rf.min_samples_split,
rf.min_purity_increase;
weights = weights,
rng = rf.rng)
rf
end
Expand Down

0 comments on commit 81a8d51

Please sign in to comment.