In [1]:
import os
import brunoflow as bf
import jax
from jax import numpy as jnp
import numpy as np
jax.config.update("jax_enable_x64", True)
print(f"Running JAX on {jax.devices()[0].device_kind}")

Running JAX on NVIDIA GeForce RTX 3050 Ti Laptop GPU


### Generating large arrays (jnp vs np)

In [2]:
%%timeit
x = bf.Node(jnp.zeros(shape=(10000, 1000)))

261 ms ± 26.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [3]:
%%timeit
x = bf.Node(np.zeros(shape=(10000, 1000)))

107 ms ± 5.53 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Calling exp on large (jnp vs np) array

In [4]:
x = bf.Node(jnp.zeros(shape=(10000, 1000)))

In [5]:
%%timeit
y = bf.exp(x)

135 ms ± 6.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
x = bf.Node(np.zeros(shape=(10000, 1000)))

In [7]:
%%timeit
y = bf.exp(x)

125 ms ± 8.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Calling backprop on node with large (jnp vs np) arrays

In [8]:
y = bf.exp(bf.Node(jnp.zeros(shape=(10000, 1000))))

In [9]:
%%timeit
y.backprop(values_to_compute=("grad",))

145 ms ± 7.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%%timeit
y.backprop(values_to_compute=("grad", "abs_val_grad", "entropy"))

531 ms ± 30.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
y = bf.exp(bf.Node(np.zeros(shape=(10000, 1000))))

In [12]:
%%timeit
y.backprop(values_to_compute=("grad",))

137 ms ± 3.74 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
%%timeit
y.backprop(values_to_compute=("grad", "abs_val_grad", "entropy"))

495 ms ± 5.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Calling backprop on node with large (jnp vs np) arrays with bigger computation graphs

In [2]:
x = bf.Node(jnp.zeros(shape=(1000, 1000)))
y = bf.exp(x)
out = y + y + y + y + y + y + y + y + y + y + y + y + y + y + y + y + y + y + y + y + y / (x + y)


In [3]:
%%timeit
out.backprop(values_to_compute=("grad",))

26.6 ms ± 498 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
