# Julia vs. Jax vs. TensorFlow

In [1]:
from jax import jit, grad

In [2]:
import tensorflow as tf

In [3]:
%load_ext julia.magic

Initializing Julia interpreter. This may take some time...


In [4]:
%%julia
using Zygote, BenchmarkTools

## Simple function

### Jax

In [3]:
@jit
def f(x):
    for i in range(1,100):
        x = x / (x + 1)
    return x

In [4]:
# First evaluation where JAX is compiling the code
%time f(1.)

CPU times: user 94.8 ms, sys: 7.41 ms, total: 102 ms
Wall time: 113 ms




DeviceArray(0.01, dtype=float32)

In [5]:
%timeit f(2.)

140 µs ± 7.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [6]:
# Same thing for the gradients we need to compile them
g = jit(grad(f))

In [7]:
%time g(1.)

CPU times: user 392 ms, sys: 0 ns, total: 392 ms
Wall time: 382 ms


DeviceArray(0.0001, dtype=float32)

In [8]:
%timeit g(2.)

148 µs ± 4.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### Julia

In [8]:
%%julia

function f(x)
    for i in 1:100
        x = x / (x + 1)
    end
    return x
end;

In [9]:
%%julia
@btime f(1.);

  332.023 ns (0 allocations: 0 bytes)


In [10]:
%%julia
@btime gradient(f, 1.);

  51.269 μs (1564 allocations: 51.27 KiB)


### TensorFlow

This is surely not that fair since you're mostly measuring overhead, but at least it gives you a sense of the overhead everytime you go from Python <-> TF, which is non-existant in Julia, and probably much less in a `@jit`'ed Jax function. A TF expert should probably improve this code. 

In [9]:
@tf.function
def f(x):
    for i in range(1,100):
        x = x / (x + 1)
    return x

In [10]:
@tf.function
def gradf(x):
    with tf.GradientTape() as t:
        t.watch(x)
        out = f(x)
    return t.gradient(out, x)

In [11]:
x = tf.convert_to_tensor(1.0)

In [12]:
# First evaluation might be slow
%time f(x)

CPU times: user 882 ms, sys: 0 ns, total: 882 ms
Wall time: 925 ms


0.010000002

In [14]:
%timeit f(x)

276 µs ± 30.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [15]:
# Same thing for gradients
%time gradf(x)

CPU times: user 4.16 s, sys: 0 ns, total: 4.16 s
Wall time: 4.18 s


<tf.Tensor: shape=(), dtype=float32, numpy=9.9999976e-05>

In [16]:
%timeit gradf(x)

1.08 ms ± 32.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
