Abstract interface of ScikitLearn.jl
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Failed to load latest commit information.
docs Removed the MacroTools dependency Mar 8, 2016
src Use Statistics Aug 27, 2018
test bring code to julia-0.7 Aug 10, 2018
.travis.yml Test on 1.0 Aug 27, 2018
LICENSE Added LICENSE Apr 19, 2016
README.md README... Jun 29, 2016
REQUIRE bump REQUIRE and travis to julia 0.7 Aug 10, 2018



This package exposes the scikit-learn interface. Packages that implement this interface can be used in conjunction with ScikitLearn.jl (pipelines, cross-validation, hyperparameter tuning, ...)

This is an intentionally slim package (~100 LOC, no dependencies). That way, ML libraries can import ScikitLearnBase without dragging along all of ScikitLearn's dependencies.


The docs contain an overview of the API and a more thorough specification.

There are two implementation strategies for an existing machine learning package:

  • Create a new type that wraps the existing type. The new type can usually be written entirely on top of the existing codebase (i.e. without modifying it). This gives more implementation freedom, and a more consistent interface amongst the various ScikitLearn.jl models. Here's an example from DecisionTree.jl
  • Use the existing type. This requires less code, and is usually better when the model type already contains the hyperparameters / fitting arguments.


For models with simple hyperparameters, it boils down to this:

import ScikitLearnBase

type NaiveBayes
    # The model hyperparameters (not learned from data)

    # The parameters learned from data
    # A constructor that accepts the hyperparameters as keyword arguments
    # with sensible defaults
    NaiveBayes(; bias=0.0f0) = new(bias)

# This will define `clone`, `set_params!` and `get_params` for the model
ScikitLearnBase.@declare_hyperparameters(NaiveBayes, [:bias])

# NaiveBayes is a classifier
ScikitLearnBase.is_classifier(::NaiveBayes) = true   # not required for transformers

function ScikitLearnBase.fit!(model::NaiveBayes, X, y)
    # X should be of size (n_sample, n_feature)
    .... # modify model.counts here
    return model

function ScikitLearnBase.predict(model::NaiveBayes, X)
    .... # returns a vector of predicted classes here

Models with more complex hyperparameter specifications should implement clone, get_params and set_params! explicitly instead of using @declare_hyperparameters.

More examples of PRs that implement the interface: GaussianMixtures.jl, GaussianProcesses.jl, DecisionTree.jl, LowRankModels.jl

Note: if the model performs unsupervised learning, implement transform instead of predict.

Once your library implements the API, file an issue/PR to add it to the list of models.