Skip to content

Commit

Permalink
add CV tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bensadeghi committed Sep 17, 2020
1 parent 6d73208 commit 971ebb9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 19 deletions.
41 changes: 22 additions & 19 deletions test/classification/adult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ preds = apply_tree(model, features)
cm = confusion_matrix(labels, preds)
@test cm.accuracy > 0.99

features = string.(features)
labels = string.(labels)

n_subfeatures = 3
n_trees = 5
model = build_forest(labels, features, n_subfeatures, n_trees)
Expand All @@ -23,24 +26,24 @@ preds = apply_adaboost_stumps(model, coeffs, features);
cm = confusion_matrix(labels, preds);
@test cm.accuracy > 0.8

# println("\n##### 3 foldCV Classification Tree #####")
# pruning_purity = 0.9
# nfolds = 3
# accuracy = nfoldCV_tree(labels, features, nfolds, pruning_purity);
# @test mean(accuracy) > 0.8

# println("\n##### 3 foldCV Classification Forest #####")
# n_subfeatures = 2
# n_trees = 10
# n_folds = 3
# partial_sampling = 0.5
# accuracy = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees, partial_sampling)
# @test mean(accuracy) > 0.8

# println("\n##### nfoldCV Classification Adaboosted Stumps #####")
# n_iterations = 15
# nfolds = 3
# accuracy = nfoldCV_stumps(labels, features, nfolds, n_iterations);
# @test mean(accuracy) > 0.8
println("\n##### 3 foldCV Classification Tree #####")
pruning_purity = 0.9
nfolds = 3
accuracy = nfoldCV_tree(labels, features, nfolds, pruning_purity; verbose=false);
@test mean(accuracy) > 0.8

println("\n##### 3 foldCV Classification Forest #####")
n_subfeatures = 2
n_trees = 10
n_folds = 3
partial_sampling = 0.5
accuracy = nfoldCV_forest(labels, features, n_folds, n_subfeatures, n_trees, partial_sampling; verbose=false)
@test mean(accuracy) > 0.8

println("\n##### nfoldCV Classification Adaboosted Stumps #####")
n_iterations = 15
n_folds = 3
accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations; verbose=false);
@test mean(accuracy) > 0.8

end # @testset
13 changes: 13 additions & 0 deletions test/regression/digits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,17 @@ model = build_forest(
preds = apply_forest(model, X)
@test R2(Y, preds) > 0.8

println("\n##### 3 foldCV Regression Tree #####")
n_folds = 5
r2 = nfoldCV_tree(Y, X, n_folds; verbose=false);
@test mean(r2) > 0.6

println("\n##### 3 foldCV Regression Forest #####")
n_subfeatures = 2
n_trees = 10
n_folds = 5
partial_sampling = 0.5
r2 = nfoldCV_forest(Y, X, n_folds, n_subfeatures, n_trees, partial_sampling; verbose=false)
@test mean(r2) > 0.6

end # @testset

0 comments on commit 971ebb9

Please sign in to comment.