Skip to content

Commit

Permalink
Merge f4ff4d1 into 2a2be1d
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Jun 29, 2021
2 parents 2a2be1d + f4ff4d1 commit 3da8052
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 64 deletions.
76 changes: 29 additions & 47 deletions docs/Manifest.toml
Expand Up @@ -23,9 +23,9 @@ version = "0.10.0"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "d659e42240c2162300b321f05173cab5cc40a5ba"
git-tree-sha1 = "be770c08881f7bb928dfd86d1ba83798f76cf62a"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.10.4"
version = "0.10.9"

[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
Expand All @@ -35,9 +35,9 @@ version = "0.11.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.30.0"
version = "3.31.0"

[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
Expand All @@ -54,9 +54,9 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.4"

[[DataAPI]]
git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d"
git-tree-sha1 = "ee400abb2298bd13bfc3df1c412ed228061a2385"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.6.0"
version = "1.7.0"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
Expand All @@ -77,21 +77,15 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[Distances]]
deps = ["LinearAlgebra", "Statistics", "StatsAPI"]
git-tree-sha1 = "abe4ad222b26af3337262b8afb28fab8d215e9f8"
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.10.3"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Distributions]]
deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
git-tree-sha1 = "013020ec9a5cdf1dd454eba3704dbffa69d3047e"
git-tree-sha1 = "2733323e5c02a9d7f48e7a3c4bc98d764fb704da"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.3"
version = "0.25.6"

[[DocStringExtensions]]
deps = ["LibGit2"]
Expand All @@ -101,9 +95,9 @@ version = "0.8.5"

[[Documenter]]
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "5acbebf1be22db43589bc5aa1bb5fcc378b17780"
git-tree-sha1 = "47f13b6305ab195edb73c86815962d84e31b0f48"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.27.0"
version = "0.27.3"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
Expand Down Expand Up @@ -214,28 +208,16 @@ uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
version = "0.7.2"

[[MLJBase]]
deps = ["CategoricalArrays", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "MLJScientificTypes", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "c841d75dcd7dad3e3faee3a49efaf533a2c8d1df"
deps = ["CategoricalArrays", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "e1996657b66ba5c3a1bdbf73835640460958712d"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
version = "0.18.11"
version = "0.18.13"

[[MLJModelInterface]]
deps = ["Random", "ScientificTypes", "StatisticalTraits"]
git-tree-sha1 = "cafa0e923ce1ae659a4b4cb8eb03c98b916f0d4d"
deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
git-tree-sha1 = "55c785a68d71c5fd7b64b490e0d9ab18cf13a04c"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
version = "1.1.0"

[[MLJModels]]
deps = ["CategoricalArrays", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJBase", "MLJModelInterface", "MLJScientificTypes", "OrderedCollections", "Parameters", "Pkg", "REPL", "Random", "Requires", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "6a430717810ca3ef7ba182235f07c634ede4c412"
uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
version = "0.14.7"

[[MLJScientificTypes]]
deps = ["CategoricalArrays", "ColorTypes", "Dates", "PersistenceDiagramsBase", "PrettyTables", "ScientificTypes", "StatisticalTraits", "Tables"]
git-tree-sha1 = "59ef6602733869cc695de7e2524f75359ba1930f"
uuid = "2e2323e0-db8b-457b-ae0d-bdfb3bc63afd"
version = "0.4.8"
version = "1.1.1"

[[MLJXGBoostInterface]]
deps = ["MLJModelInterface", "Tables", "XGBoost"]
Expand Down Expand Up @@ -351,12 +333,6 @@ git-tree-sha1 = "5f6c21241f0f655da3952fd60aa18477cf96c220"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.1.0"

[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.3"

[[Rmath]]
deps = ["Random", "Rmath_jll"]
git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f"
Expand All @@ -373,9 +349,15 @@ version = "0.3.0+0"
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[ScientificTypes]]
git-tree-sha1 = "b4e89a674804025c4a5843e35e562910485690c2"
deps = ["CategoricalArrays", "ColorTypes", "Dates", "PersistenceDiagramsBase", "PrettyTables", "ScientificTypesBase", "StatisticalTraits", "Tables"]
git-tree-sha1 = "345e33061ad7c49c6e860e42a04c62ecbea3eabf"
uuid = "321657f4-b219-11e9-178b-2701a2544e81"
version = "1.1.2"
version = "2.0.0"

[[ScientificTypesBase]]
git-tree-sha1 = "3f7ddb0cf0c3a4cff06d9df6f01135fa5442c99b"
uuid = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
version = "1.0.0"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand Down Expand Up @@ -404,10 +386,10 @@ uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.5.1"

[[StatisticalTraits]]
deps = ["ScientificTypes"]
git-tree-sha1 = "2d882a163c295d5d754e4102d92f4dda5a1f906b"
deps = ["ScientificTypesBase"]
git-tree-sha1 = "5114841829816649ecc957f07f6a621671e4a951"
uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9"
version = "1.1.0"
version = "2.0.0"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
Expand Down Expand Up @@ -452,9 +434,9 @@ version = "1.0.1"

[[Tables]]
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"]
git-tree-sha1 = "aa30f8bb63f9ff3f8303a06c604c8500a69aa791"
git-tree-sha1 = "8ed4a3ea724dac32670b062be3ef1c1de6773ae8"
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "1.4.3"
version = "1.4.4"

[[Tar]]
deps = ["ArgTools", "SHA"]
Expand Down
2 changes: 0 additions & 2 deletions docs/Project.toml
Expand Up @@ -2,7 +2,6 @@
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
InferenceDiagnostics = "be115224-59cd-429b-ad48-344e309966f0"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -11,6 +10,5 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Documenter = "0.27"
InferenceDiagnostics = "0.1"
MLJBase = "0.18"
MLJModels = "0.14"
MLJXGBoostInterface = "0.1"
julia = "1.3"
8 changes: 3 additions & 5 deletions src/rstar.jl
Expand Up @@ -30,9 +30,7 @@ is returned (algorithm 2).
# Examples
```jldoctest rstar; setup = :(using Random; Random.seed!(100))
julia> using MLJBase, MLJModels, Statistics
julia> XGBoost = @load XGBoostClassifier verbosity=0;
julia> using MLJBase, MLJXGBoostInterface, Statistics
julia> samples = fill(4.0, 300, 2);
Expand All @@ -43,7 +41,7 @@ One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the
probabilistic classifier.
```jldoctest rstar
julia> distribution = rstar(XGBoost(), samples, chain_indices);
julia> distribution = rstar(XGBoostClassifier(), samples, chain_indices);
julia> isapprox(mean(distribution), 1; atol=0.1)
true
Expand All @@ -54,7 +52,7 @@ Deterministic classifiers can also be derived from probabilistic classifiers by
predicting the mode. In MLJ this corresponds to a pipeline of models.
```jldoctest rstar
julia> @pipeline XGBoost name = XGBoostDeterministic operation = predict_mode;
julia> @pipeline XGBoostClassifier name = XGBoostDeterministic operation = predict_mode;
julia> value = rstar(XGBoostDeterministic(), samples, chain_indices);
Expand Down
2 changes: 0 additions & 2 deletions test/rstar/Project.toml
Expand Up @@ -3,7 +3,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
InferenceDiagnostics = "be115224-59cd-429b-ad48-344e309966f0"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand All @@ -12,6 +11,5 @@ Distributions = "0.25"
InferenceDiagnostics = "0.1"
MLJBase = "0.18"
MLJLIBSVMInterface = "0.1"
MLJModels = "0.14"
MLJXGBoostInterface = "0.1"
julia = "1.3"
15 changes: 7 additions & 8 deletions test/rstar/runtests.jl
Expand Up @@ -2,16 +2,15 @@ using InferenceDiagnostics

using Distributions
using MLJBase
using MLJModels
using MLJLIBSVMInterface
using MLJXGBoostInterface

using Test

XGBoost = @load XGBoostClassifier verbosity = 0
@pipeline XGBoost name = XGBoostDeterministic operation = predict_mode
SVC = @load SVC verbosity = 0
@pipeline XGBoostClassifier name = XGBoostDeterministic operation = predict_mode

@testset "rstar.jl" begin
classifiers = (XGBoost(), XGBoostDeterministic(), SVC())
classifiers = (XGBoostClassifier(), XGBoostDeterministic(), SVC())
N = 1_000

@testset "examples (classifier = $classifier)" for classifier in classifiers
Expand All @@ -21,7 +20,7 @@ SVC = @load SVC verbosity = 0

# Mean of the statistic should be focused around 1, i.e., the classifier does not
# perform better than random guessing.
if classifier isa MLJModels.Deterministic
if classifier isa MLJBase.Deterministic
@test dist isa Float64
else
@test dist isa LocationScale
Expand All @@ -38,7 +37,7 @@ SVC = @load SVC verbosity = 0

# Mean of the statistic should be closte to 1, i.e., the classifier does not perform
# better than random guessing.
if classifier isa MLJModels.Deterministic
if classifier isa MLJBase.Deterministic
@test dist isa Float64
else
@test dist isa LocationScale
Expand All @@ -58,7 +57,7 @@ SVC = @load SVC verbosity = 0

# Mean of the statistic should be close to 2, i.e., the classifier should be able to
# learn an almost perfect decision boundary between chains.
if classifier isa MLJModels.Deterministic
if classifier isa MLJBase.Deterministic
@test dist isa Float64
else
@test dist isa LocationScale
Expand Down

0 comments on commit 3da8052

Please sign in to comment.