In [1]:
import numpy as np
import torch
from hummingbird import convert_sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer

In [2]:
# We are going to use the breast cancer dataset from scikit-learn for this example.
X, y = load_breast_cancer(return_X_y=True)
nrows=15000
X = X[0:nrows]
y = y[0:nrows]
X_torch = torch.from_numpy(X).float() # We create a torch version of X to use for inference later.

In [3]:
# Create and train a random forest model.
model = RandomForestClassifier(n_estimators=10, max_depth=10)
model.fit(X, y)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
                       max_depth=10, max_features='auto', max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=10,
                       n_jobs=None, oob_score=False, random_state=None,
                       verbose=0, warm_start=False)

In [4]:
# Use Hummingbird to convert your scikit-learn model to PyTorch.
# In this specific case we force the tree implementation to use the GEMM strategy.
pytorch_model = convert_sklearn(model, extra_config = {"tree_implementation": "gemm"})

In [5]:
%%timeit -r 3

# Time for scikit-learn.
model.predict(X)

1.51 ms ± 4.43 µs per loop (mean ± std. dev. of 3 runs, 1000 loops each)


In [6]:
%%timeit -r 3

# Time for Hummingbird - By default CPU execution is used.
pytorch_model(X_torch)

4.21 ms ± 145 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)


In [7]:
%%timeit -r 3

# Time for Hummingbird - GPU. Note that you must have a GPU-enabled machine.
pytorch_model.to('cuda')
pytorch_model(X_torch.to('cuda'))

439 µs ± 543 ns per loop (mean ± std. dev. of 3 runs, 1000 loops each)


In [8]:
# Make sure Hummingbird output matches scikit-learn.
skl = model.predict_proba(X)
hum = pytorch_model(X_torch.to('cuda'))

np.testing.assert_allclose(skl, hum[1].data.to('cpu').numpy(), rtol=1e-6, atol=1e-6)