# Julia vs. Jax vs. TensorFlow

In [1]:
from jax import jit, grad

In [2]:
import tensorflow as tf

In [3]:
%load_ext julia.magic

Initializing Julia interpreter. This may take some time...


In [4]:
%%julia
using Zygote, BenchmarkTools

## Simple function

### Jax

In [5]:
@jit
def f(x):
    for i in range(1,100):
        x = x / (x + 1)
    return x

In [6]:
%timeit f(1.)



137 µs ± 7.57 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [7]:
%timeit grad(f)(1.)

19.9 ms ± 863 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Julia

In [8]:
%%julia

function f(x)
    for i in 1:100
        x = x / (x + 1)
    end
    return x
end;

In [9]:
%%julia
@btime f(1.);

  332.023 ns (0 allocations: 0 bytes)


In [10]:
%%julia
@btime gradient(f, 1.);

  51.269 μs (1564 allocations: 51.27 KiB)


### TensorFlow

This is surely not that fair since you're mostly measuring overhead, but at least it gives you a sense of the overhead everytime you go from Python <-> TF, which is non-existant in Julia, and probably much less in a `@jit`'ed Jax function. A TF expert should probably improve this code. 

In [11]:
def f(x):
    for i in range(1,100):
        x = x / (x + 1)
    return x

In [12]:
def gradf(x):
    with tf.GradientTape() as t:
        t.watch(x)
        out = f(x)
    return t.gradient(out, x)

In [13]:
x = tf.convert_to_tensor(1.0)

In [14]:
%timeit f(x).numpy()

10.3 ms ± 2.93 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
%timeit gradf(x).numpy()

45.3 ms ± 3.58 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
