diff --git a/src/run.jl b/src/run.jl index 19ad9a5..e380d53 100644 --- a/src/run.jl +++ b/src/run.jl @@ -66,6 +66,9 @@ function active_learn(experiment::Dict{Symbol, Any}, data::Array{T, 2}, labels:: set_data!(model, train_data) set_pools!(model, labelmap(train_pools)) + classify_precision = get(experiment[:param], :classify_precision, SVDD.OPT_PRECISION) + debug(LOGGER, "Classify precision: $classify_precision") + debug(LOGGER, "Start active learning cycle with $(experiment[:param][:num_al_iterations]) queries.") for i in 0:experiment[:param][:num_al_iterations] info(LOGGER, "Iteration $(i)") @@ -89,7 +92,7 @@ function active_learn(experiment::Dict{Symbol, Any}, data::Array{T, 2}, labels:: test_data, _, test_indices = get_test(split_strategy, data, pools) debug(LOGGER, "[TEST] Testing by predicting $(format_observations(test_data)) observations.") predictions = SVDD.predict(model, test_data) - push_evaluation!(res.al_history, i, predictions, labels[test_indices]) + push_evaluation!(res.al_history, i, predictions, labels[test_indices], classify_precision) debug(LOGGER, "[TEST] Testing done.") if i < experiment[:param][:num_al_iterations] @@ -220,14 +223,14 @@ function push_evaluation_cm!(al_history, i, cm) end end -function push_evaluation!(al_history::MVHistory, i, predictions::Vector{Vector{Float64}}, labels) - cm = ConfusionMatrix(SVDD.classify(predictions, Val(:Global)), labels) +function push_evaluation!(al_history::MVHistory, i, predictions::Vector{Vector{Float64}}, labels, classify_precision) + cm = ConfusionMatrix(SVDD.classify(predictions, Val(:Global), opt_precision=classify_precision), labels) push_evaluation_cm!(al_history, i, cm) return nothing end -function push_evaluation!(al_history::MVHistory, i, predictions, labels) - cm = ConfusionMatrix(SVDD.classify.(predictions), labels) +function push_evaluation!(al_history::MVHistory, i, predictions, labels, classify_precision) + cm = ConfusionMatrix(SVDD.classify.(predictions, opt_precision=classify_precision), labels) push_evaluation_cm!(al_history, i, cm) push!(al_history, :auc, i, roc_auc(predictions, labels)) for k in [0.01, 0.02, 0.05, 0.1, 0.2]