In [1]:
import os
import brunoflow as bf
import jax
from jax import numpy as jnp
import numpy as np
jax.config.update("jax_enable_x64", True)
print(f"Running JAX on {jax.devices()[0].device_kind}")

Running JAX on NVIDIA GeForce RTX 3050 Ti Laptop GPU


### Generating large arrays (jnp vs np)

In [2]:
%%timeit
x = bf.Node(jnp.zeros(shape=(10000, 1000)))

2.21 ms ± 3.78 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [3]:
%%timeit
x = bf.Node(np.zeros(shape=(10000, 1000)))

1.66 ms ± 294 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Calling exp on large (jnp vs np) array

In [4]:
x = bf.Node(jnp.zeros(shape=(10000, 1000)))

In [5]:
%%timeit
y = bf.exp(x)

4.4 ms ± 31.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
x = bf.Node(np.zeros(shape=(10000, 1000)))

In [7]:
%%timeit
y = bf.exp(x)

14 ms ± 550 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Calling backprop on node with large (jnp vs np) arrays

In [8]:
y = bf.exp(bf.Node(jnp.zeros(shape=(10000, 1000))))

In [9]:
%%timeit
y.backprop(values_to_compute=("grad",))

8.33 ms ± 206 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
%%timeit
y.backprop(values_to_compute=("grad", "abs_val_grad", "entropy"))

34.5 ms ± 340 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
y = bf.exp(bf.Node(np.zeros(shape=(10000, 1000))))

In [12]:
%%timeit
y.backprop(values_to_compute=("grad",))

7.89 ms ± 64 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
%%timeit
y.backprop(values_to_compute=("grad", "abs_val_grad", "entropy"))

34 ms ± 120 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [1]:
### Calling backprop on node with large (jnp vs np) arrays with bigger computation graphs

In [3]:
x = bf.Node(jnp.zeros(shape=(10000, 1000)))
y = bf.exp(x)
out = y + y / (x + y)


In [5]:
%%timeit
out.backprop(values_to_compute=("grad", "abs_val_grad", "entropy"))

164 ms ± 577 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
