In [3]:
import numpy as np
import minterpy as mp
import matplotlib.pyplot as plt

from newton_eval import (
    eval_driver_base_cpu,
    eval_driver_numba_cpu,
    eval_driver_numba_cpu_par,
    eval_driver_numba_gpu,
)

In [4]:
spatial_dimension = 6
poly_degree = 5
lp_degree = 2.0

In [5]:
mi = mp.MultiIndexSet.from_degree(
    spatial_dimension=spatial_dimension,
    poly_degree=poly_degree,
    lp_degree=lp_degree,
)
len(mi)

3819

In [22]:
num_points = 10000

In [23]:
xx_test = -1 + 2 * np.random.rand(num_points, spatial_dimension)

In [8]:
nwt_coeffs = np.random.rand(len(mi))
exponents = mi.exponents
grd = mp.Grid(mi)
gen_points = grd.generating_points

## Base implementation

In [9]:
yy_base = eval_driver_base_cpu(xx_test, nwt_coeffs, exponents, gen_points)

In [11]:
%%timeit
eval_driver_base_cpu(xx_test, nwt_coeffs, exponents, gen_points)

58.8 ms ± 2.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Numba implementation

In [12]:
yy_numba = eval_driver_numba_cpu(xx_test, nwt_coeffs, exponents, gen_points)

In [13]:
assert np.allclose(yy_base, yy_numba)

In [14]:
%%timeit
eval_driver_numba_cpu(xx_test, nwt_coeffs, exponents, gen_points)

216 µs ± 7.33 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Numba parallel implementation

In [24]:
yy_numba_cpu_par = eval_driver_numba_cpu_par(xx_test, nwt_coeffs, exponents, gen_points)

In [16]:
assert np.allclose(yy_numba, yy_numba_cpu_par)

In [25]:
%%timeit
eval_driver_numba_cpu_par(xx_test, nwt_coeffs, exponents, gen_points)

26.2 ms ± 626 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Numba GPU implementation

In [26]:
yy_numba_gpu = eval_driver_numba_gpu(xx_test, nwt_coeffs, exponents, gen_points)



In [28]:
assert np.allclose(yy_numba_cpu_par, yy_numba_gpu)

In [29]:
%%timeit
eval_driver_numba_gpu(xx_test, nwt_coeffs, exponents, gen_points)

8.39 ms ± 248 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [30]:
%%timeit
eval_driver_numba_gpu(xx_test, nwt_coeffs, exponents, gen_points, threads_per_block=512)



13.7 ms ± 43.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
