diff --git a/src/init_strategies/strategies_gamma.jl b/src/init_strategies/strategies_gamma.jl index 95bfa89..1fb6a81 100644 --- a/src/init_strategies/strategies_gamma.jl +++ b/src/init_strategies/strategies_gamma.jl @@ -16,7 +16,8 @@ Silverman, Bernard W. Density estimation for statistics and data analysis. Routl struct RuleOfThumbSilverman <: InitializationStrategyGamma end function rule_of_thumb_silverman(data::Array{T,2}) where T <: Real - return (size(data, 2) * (size(data, 1) + 2) / 4.0)^(-1.0 / (size(data,1) + 4.0)) + s = (size(data, 2) * (size(data, 1) + 2) / 4.0)^(-1.0 / (size(data,1) + 4.0)) + return 1 / (2 * s^2) end calculate_gamma(model, strategy::RuleOfThumbSilverman) = rule_of_thumb_silverman(model.data) @@ -32,7 +33,8 @@ Scott, David W. Multivariate density estimation: theory, practice, and visualiza struct RuleOfThumbScott <: InitializationStrategyGamma end function rule_of_scott(data::Array{T,2}) where T <: Real - return size(data, 2)^(-1.0/(size(data, 1) + 4)) + s = size(data, 2)^(-1.0/(size(data, 1) + 4)) + return 1 / (2 * s^2) end calculate_gamma(model, strategy::RuleOfThumbScott) = rule_of_scott(model.data) diff --git a/test/init_strategies/init_strategies_test.jl b/test/init_strategies/init_strategies_test.jl index ff5e8f9..5a256d6 100644 --- a/test/init_strategies/init_strategies_test.jl +++ b/test/init_strategies/init_strategies_test.jl @@ -6,16 +6,10 @@ pools = fill(:U, size(dummy_data, 2)) model = SVDD.VanillaSVDD(dummy_data) - @testset "RuleOfThumbSilverman" begin - # see https://docs.scipy.org/doc/scipy-0.19.0/reference/generated/scipy.stats.gaussian_kde.html - expected = (n * (d + 2) / 4.0)^(-1.0 / (d + 4)) - @test expected == SVDD.calculate_gamma(model, SVDD.RuleOfThumbSilverman()) - end - - @testset "RuleOfThumbScott" begin - # see https://docs.scipy.org/doc/scipy-0.19.0/reference/generated/scipy.stats.gaussian_kde.html - expected = n^(-1.0 / (d + 4)) - @test expected == SVDD.calculate_gamma(model, SVDD.RuleOfThumbScott()) + for s in [:RuleOfThumbSilverman, :RuleOfThumbScott] + @testset "$s" begin + @test SVDD.calculate_gamma(model, SVDD.eval(s)()) > 0 + end end @testset "TaxErrorEstimate" begin