# Simple Tests
> some simple tests to check whether the fundamental functionality is working

The tests include functions with multiple inputs

In [None]:
from javiche import jaxit
from jax import grad

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
@jaxit()
def plus(A, B):
  """adds two numbers/arrays"""
  return A+B

@jaxit()
def times(A, B):
  """multiplies two numbers/arrays"""
  return A*B

@jaxit()
def square(A):
  """squares number/array"""
  return A**2

In [None]:
assert grad(plus, argnums=[0,1])(2.0,2.0) == (1,1)

In [None]:
assert grad(times, argnums=[0,1])(2.0,3.0) == (3,2)

In [None]:
assert grad(square, argnums=0)(2.0) == 4

In [None]:
import time
import timeit

## Caching
We can also use caching to avoid recalculating the result of a function with the same input parameters.

In [None]:
@jaxit(cache=True)
def cubed(A):
  """computes A to the power of three"""
  time.sleep(1)
  return A**3

In [None]:
%timeit -r 1 -n 1 cubed(3)

1.01 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [None]:
%timeit -r 1 -n 1 cubed(3)

342 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


At the same time we maintain the ability to calculate gradients on it using jax

In [None]:
%timeit -r 1 -n 1 grad(cubed)(2.0)

1.55 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [None]:
%timeit -r 1 -n 1 grad(cubed)(3.0)

1.01 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [None]:
assert grad(cubed, argnums=0)(2.0) == 12

Note that this functionality relies on all inputs to the function being hashable.
Additionally also JAX Arrays and numpy arrays are supported