In [1]:
import numpy
import scipy
import torch

from sklearn.datasets import make_blobs

from torchegranate.distributions import *
from torchegranate.bayes_classifier import BayesClassifier

from sklearn.naive_bayes import GaussianNB, BernoulliNB

import matplotlib.pyplot as plt
import seaborn; seaborn.set_style('whitegrid')

numpy.random.seed(0)
numpy.set_printoptions(suppress=True)

%load_ext watermark
%watermark -m -n -p numpy,scipy,torch,pomegranate

numpy      : 1.23.4
scipy      : 1.9.3
torch      : 1.12.1
pomegranate: 0.14.8

Compiler    : GCC 11.2.0
OS          : Linux
Release     : 4.15.0-197-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit



### Gaussian Naive Bayes

In [2]:
n, d, k = 200000, 500, 50

X, y = make_blobs(n_samples=n, n_features=d, centers=k, cluster_std=0.75, random_state=0)

In [3]:
%timeit model_sklearn = GaussianNB().fit(X, y)
%timeit model_pom = BayesClassifier([Normal(covariance_type='diag') for i in range(k)]).fit(X, y)

787 ms ± 22.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
872 ms ± 16.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
model_sklearn = GaussianNB().fit(X, y)
model_pom = BayesClassifier([Normal(covariance_type='diag') for i in range(k)]).fit(X, y)

%timeit model_sklearn.predict(X)
%timeit model_pom.predict(X)

20.9 s ± 24.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
15.6 s ± 152 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Bernoulli Naive Bayes

In [5]:
n, d, k = 200000, 200, 25

X = numpy.random.choice(2, size=(n, d))
y = numpy.random.choice(k, size=(n,))

In [6]:
%timeit model_sklearn = BernoulliNB().fit(X, y)
%timeit model_pom = BayesClassifier([Bernoulli() for i in range(k)]).fit(X, y)

14 s ± 242 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
359 ms ± 905 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
model_sklearn = BernoulliNB().fit(X, y)
model_pom = BayesClassifier([Bernoulli() for i in range(k)]).fit(X, y)

%timeit model_sklearn.predict(X)
%timeit model_pom.predict(X)

628 ms ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.01 s ± 35.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
