In [1]:
import numpy as np
import timeit
import torch
from torchist import histogramdd, histogram

In [3]:
print('CPU')
print('---')

x = np.random.rand(100000)
xdd = np.random.rand(100000, 5)
edges10 = np.linspace(0.0, 1.0, 11) ** 1.5
edges100 = np.linspace(0.0, 1.0, 101) ** 1.5

x_t = torch.from_numpy(x)
xdd_t = torch.from_numpy(xdd)
edges10_t = torch.from_numpy(edges10)
edges100_t = torch.from_numpy(edges100)

## Correctness
hdd, _ = np.histogramdd(xdd, bins=10)
hdd_t = histogramdd(xdd_t, bins=10)

assert np.all(hdd == hdd_t.numpy())

hdd, _ = np.histogramdd(xdd, bins=[edges10] * 5)
hdd_t = histogramdd(xdd_t, edges=[edges10_t] * 5)

assert np.all(hdd == hdd_t.numpy())

## Speed
for key, f in {
    'np.histogram': lambda: np.histogram(x, bins=100),
    'np.histogramdd': lambda: np.histogramdd(xdd, bins=10),
    'np.histogram (non-uniform)': lambda: np.histogram(x, bins=edges100),
    'np.histogramdd (non-uniform)': lambda: np.histogramdd(xdd, bins=[edges10] * 5),
    'torchist.histogram': lambda: histogram(x_t, bins=100),
    'torchist.histogramdd': lambda: histogramdd(xdd_t, bins=10),
    'torchist.histogram (non-uniform)': lambda: histogram(x_t, edges=edges100_t),
    'torchist.histogramdd (non-uniform)': lambda: histogramdd(xdd_t, edges=[edges10_t] * 5),
}.items():
    time = timeit.timeit(f, number=100)
    print(key, ':', '{:.04f}'.format(time), 's')



print()
print('MPS')
print('----')

def _to_mps(x):
    return x.to(torch.device("mps"), dtype=torch.float32)

x_t = _to_mps(x_t)

xdd_t = _to_mps(xdd_t)
edges10_t = _to_mps(edges10_t)
edges100_t = _to_mps(edges100_t)

## Correctness
hdd, _ = np.histogramdd(xdd, bins=10)
hdd_t = histogramdd(xdd_t, bins=10)

hdd, _ = np.histogramdd(xdd, bins=[edges10] * 5)
hdd_t = histogramdd(xdd_t, edges=[edges10_t] * 5)

## Speed
for key, f in {
    'torchist.histogram': lambda: histogram(x_t, bins=100),
    'torchist.histogramdd': lambda: histogramdd(xdd_t, bins=10),
    'torchist.histogram (non-uniform)': lambda: histogram(x_t, edges=edges100_t),
    'torchist.histogramdd (non-uniform)': lambda: histogramdd(xdd_t, edges=[edges10_t] * 5),
}.items():
    start = torch.mps.Event(enable_timing=True)
    end = torch.mps.Event(enable_timing=True)

    start.record()
    for _ in range(100):
        f()
    end.record()

    torch.mps.synchronize()
    time = start.elapsed_time(end) / 1000  # ms -> s

    print(key, ':', '{:.04f}'.format(time), 's')



CPU
---
np.histogram : 0.0714 s
np.histogramdd : 2.0422 s
np.histogram (non-uniform) : 0.5791 s
np.histogramdd (non-uniform) : 1.5480 s
torchist.histogram : 0.0568 s
torchist.histogramdd : 0.2129 s
torchist.histogram (non-uniform) : 0.1049 s
torchist.histogramdd (non-uniform) : 0.4601 s

MPS
----
torchist.histogram : 921.3686 s
torchist.histogramdd : 1.7230 s
torchist.histogram (non-uniform) : 1.2039 s
torchist.histogramdd (non-uniform) : 1011.5633 s
