In [1]:
import numpy
import scipy
import torch

from torchegranate.distributions import *

numpy.random.seed(0)
numpy.set_printoptions(suppress=True)

%load_ext watermark
%watermark -m -n -p numpy,scipy,torch,pomegranate

numpy      : 1.23.4
scipy      : 1.9.3
torch      : 1.12.1
pomegranate: 0.14.8

Compiler    : GCC 11.2.0
OS          : Linux
Release     : 4.15.0-197-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit



### Normal w/ Diagonal Covariance Distributions

In [2]:
n, d = 100000, 500

X = torch.randn(n, d)
Xn = X.numpy()

mus = torch.randn(d)
covs = torch.abs(torch.randn(d))
stds = torch.sqrt(covs)

In [3]:
%timeit Normal(mus, covs, covariance_type='diag').log_probability(X)
%timeit torch.distributions.Normal(mus, stds).log_prob(X).sum(dim=-1)
%timeit scipy.stats.norm.logpdf(Xn, mus, stds).sum(axis=1)

143 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
227 ms ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.12 s ± 18.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Normal w/ Full Covariance Distribution

In [4]:
d0 = Normal().fit(X)

mu, cov = d0.means, d0.covs

In [5]:
%timeit Normal(mu, cov).log_probability(X)
%timeit torch.distributions.MultivariateNormal(mu, cov).log_prob(X).sum(dim=-1)
%timeit scipy.stats.multivariate_normal.logpdf(Xn, mu, cov).sum(axis=-1)

211 ms ± 19.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
205 ms ± 22.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
765 ms ± 36.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Exponential Distribution

In [6]:
X = torch.abs(torch.randn(n, d))
Xn = X.numpy()

means = torch.abs(torch.randn(d))

In [7]:
%timeit Exponential(means).log_probability(X)
%timeit torch.distributions.Exponential(means).log_prob(X)
%timeit scipy.stats.expon.logpdf(X, means)

150 ms ± 1.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
89 ms ± 3.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.36 s ± 86.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Gamma Distribution

In [8]:
shapes = torch.abs(torch.randn(d))
rates = torch.abs(torch.randn(d))

In [9]:
%timeit Gamma(shapes, rates).log_probability(X)
%timeit torch.distributions.Gamma(shapes, rates).log_prob(X)
%timeit scipy.stats.gamma.logpdf(X, shapes, rates)

270 ms ± 9.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
250 ms ± 30.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.67 s ± 75.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Bernoulli Distribution

In [10]:
X = torch.tensor(numpy.random.choice(2, size=(n, d)), dtype=torch.float32)
probs = torch.mean(X, dim=0)

In [11]:
%timeit Bernoulli(probs).log_probability(X)
%timeit torch.distributions.Bernoulli(probs).log_prob(X)
%timeit scipy.stats.bernoulli.logpmf(X, probs)

181 ms ± 8.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
419 ms ± 20.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.78 s ± 66.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
