Skip to content

Commit

Permalink
update classification rand tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bensadeghi committed Sep 17, 2020
1 parent 971ebb9 commit 7f1d1bc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions test/classification/low_precision.jl
Expand Up @@ -5,7 +5,7 @@ Random.seed!(16)

n,m = 10^3, 5;
features = Array{Any}(undef, n, m);
features[:,:] = randn(n, m);
features[:,:] = rand(n, m);
features[:,1] = round.(Int32, features[:,1]); # convert a column of 32bit integers
weights = rand(-1:1,m);
labels = round.(Int32, features * weights);
Expand Down Expand Up @@ -51,7 +51,7 @@ model, coeffs = build_adaboost_stumps(labels, features, n_iterations);
preds = apply_adaboost_stumps(model, coeffs, features);
cm = confusion_matrix(labels, preds)
@test typeof(preds) == Vector{Int32}
@test cm.accuracy > 0.2
@test cm.accuracy > 0.7

println("\n##### nfoldCV Classification Tree #####")
n_folds = Int32(3)
Expand Down Expand Up @@ -88,12 +88,12 @@ accuracy = nfoldCV_forest(
min_samples_leaf,
min_samples_split,
min_purity_increase)
@test mean(accuracy) > 0.6
@test mean(accuracy) > 0.7

println("\n##### nfoldCV Adaboosted Stumps #####")
n_iterations = Int32(15)
accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations)
@test mean(accuracy) > 0.1
@test mean(accuracy) > 0.7


# Test Int8 labels, and Float16 features
Expand Down
4 changes: 2 additions & 2 deletions test/classification/random.jl
Expand Up @@ -120,11 +120,11 @@ cm = confusion_matrix(labels, preds)
println("\n##### nfoldCV Classification Tree #####")
nfolds = 3
pruning_purity = 1.0
max_depth = 3
max_depth = 5
accuracy = nfoldCV_tree(labels, features, nfolds, pruning_purity, max_depth; rng=10, verbose=false)
accuracy2 = nfoldCV_tree(labels, features, nfolds, pruning_purity, max_depth; rng=10)
accuracy3 = nfoldCV_tree(labels, features, nfolds, pruning_purity, max_depth; rng=5)
@test mean(accuracy) > 0.6
@test mean(accuracy) > 0.7
@test accuracy == accuracy2
@test accuracy != accuracy3

Expand Down

0 comments on commit 7f1d1bc

Please sign in to comment.