<a href="https://colab.research.google.com/github/ddrous/neuralhub/blob/main/benchmark_nan_to_num.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install equinox diffrax

In [2]:
import jax
import jax.numpy as jnp
# import equinox as eqx

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = 'false'

jax.devices()

[CudaDevice(id=0)]

## Run the example

In [5]:
### Benchmarks the cost of applying nan_to_num to our resutls
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController

vector_field = lambda t, y, args: -y
# vector_field = lambda t, y, args: 1/ (t-1)
ts = jnp.linspace(0, 0.1, 100)
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=ts)
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)


@jax.jit
def benchmark_1(ts):
    sol = diffeqsolve(term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=1, saveat=saveat,
                    stepsize_controller=stepsize_controller, max_steps=100, throw=False)
    return sol.ys

@jax.jit
def benchmark_2(ts):
    sol = diffeqsolve(term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=1, saveat=saveat,
                    stepsize_controller=stepsize_controller, max_steps=100, throw=False)
    return jnp.nan_to_num(sol.ys, nan=0., neginf=0., posinf=0.)


print("======= Clean function ======== ")
%time benchmark_1(ts).block_until_ready()
%timeit -r30 -n40 benchmark_1(ts).block_until_ready()

print("\n======= With NaNs to Nums ======== ")
%time benchmark_2(ts).block_until_ready()
%timeit -r30 -n40 benchmark_2(ts).block_until_ready()

CPU times: user 799 ms, sys: 30 ms, total: 829 ms
Wall time: 861 ms
3.38 ms ± 74.3 µs per loop (mean ± std. dev. of 30 runs, 40 loops each)

CPU times: user 403 ms, sys: 9 ms, total: 412 ms
Wall time: 441 ms
3.37 ms ± 93.8 µs per loop (mean ± std. dev. of 30 runs, 40 loops each)
