# Julia vs. Jax vs. TensorFlow

In [1]:
import numpy as np
import jax.numpy as jnp

In [2]:
from jax import jit, grad, jacfwd, jacrev

In [3]:
import tensorflow as tf

In [4]:
%load_ext julia.magic

Initializing Julia interpreter. This may take some time...


In [5]:
%%julia
using Zygote, ForwardDiff, BenchmarkTools, LinearAlgebra

## Simple function

### Jax

In [6]:
@jit
def f(x):
    for i in range(1,100):
        x = x / (x + 1)
    return x

In [7]:
# First evaluation where JAX is compiling the code
%time f(1.)

CPU times: user 95.1 ms, sys: 8.67 ms, total: 104 ms
Wall time: 117 ms




DeviceArray(0.01, dtype=float32)

In [8]:
%timeit f(2.)

138 µs ± 6.06 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [9]:
# Same thing for the gradients we need to compile them
g = jit(grad(f))

In [10]:
%time g(1.)

CPU times: user 540 ms, sys: 14 ms, total: 554 ms
Wall time: 544 ms


DeviceArray(0.0001, dtype=float32)

In [11]:
%timeit g(2.)

149 µs ± 4.33 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### Julia

In [12]:
%%julia
function f(x)
    for i in 1:100
        x = x / (x + 1)
    end
    return x
end;

In [13]:
%%julia
@btime f(1.);

  332.207 ns (0 allocations: 0 bytes)


In [14]:
%%julia
@btime gradient(f, 1.);

  57.808 μs (1564 allocations: 51.27 KiB)


### TensorFlow

This is surely not that fair since you're mostly measuring overhead, but at least it gives you a sense of the overhead everytime you go from Python <-> TF, which is non-existant in Julia, and probably much less in a `@jit`'ed Jax function. A TF expert should probably improve this code. 

In [15]:
@tf.function
def f(x):
    for i in range(1,100):
        x = x / (x + 1)
    return x

In [16]:
@tf.function
def gradf(x):
    with tf.GradientTape() as t:
        t.watch(x)
        out = f(x)
    return t.gradient(out, x)

In [17]:
x = tf.convert_to_tensor(1.0)

In [18]:
# First evaluation might be slow
%time f(x)

CPU times: user 802 ms, sys: 188 ms, total: 990 ms
Wall time: 1.44 s


<tf.Tensor: shape=(), dtype=float32, numpy=0.010000002>

In [19]:
%timeit f(x)

1.15 ms ± 51.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [20]:
# Same thing for gradients
%time gradf(x)

CPU times: user 1.3 s, sys: 15.5 ms, total: 1.32 s
Wall time: 1.33 s


<tf.Tensor: shape=(), dtype=float32, numpy=9.9999976e-05>

In [21]:
%timeit gradf(x)

3.56 ms ± 62.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Jamie's HZPT model

For this model, we are computing its Jacobian matrix with shape (50, 3).

### Analytic

In [24]:
r = np.logspace(-1,2)

def hzpt_g_analytic(r,A0,R,R1h):
    A_grad = - np.exp(-r/R) / (4*np.pi*r*R**2) *(1 - (R/R1h)**2 * np.exp(-(R+R1h)*r/(R*R1h)))
    R_grad = A0*(-(r- 2*R)/R**2 + 2*np.exp(-r*(1/R + 1/R1h))*r /R1h**2) * np.exp(-r/R) / (4*np.pi*r*R**2)
    R1h_grad = A0* np.exp(-r*(1/R + 1/R1h))* R**2 *(r-2*R1h)/R1h**4 * np.exp(-r/R) / (4*np.pi*r*R**2)
    return np.array([A_grad,R_grad,R1h_grad])

%timeit hzpt_g_analytic(r, 750.,26.,2.)

50.9 µs ± 426 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### Jax

In [25]:
r = jnp.logspace(-1,2)

@jit
def hzpt_f_jax(r,A0,R,R1h):
    F2 = jnp.exp(-r/R) / (4*jnp.pi*r*R**2)
    return -A0 * F2 * (1 - (R/R1h)**2 * jnp.exp(-(R+R1h)*r/(R*R1h)))



In [26]:
@jit
def hzpt_jacfwd_jax(r, A0, R, R1h):
    return jacfwd(hzpt_f_jax, argnums=(1, 2, 3))(r, A0, R, R1h)

hzpt_jacfwd_jax(r, 750., 26., 2.)
%timeit hzpt_jacfwd_jax(r, 750., 26., 2.)

280 µs ± 12.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [27]:
r.shape

(50,)

In [28]:
hzpt_jacfwd_jax(r, 750., 26., 2.)[1].shape

(50,)

In [29]:
@jit
def hzpt_jacrev_jax(r, A0, R, R1h):
    return jacrev(hzpt_f_jax, argnums=(1, 2, 3))(r, A0, R, R1h)

hzpt_jacrev_jax(r, 750., 26., 2.)
%timeit hzpt_jacrev_jax(r, 750., 26., 2.)

260 µs ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## Julia

In [30]:
%%julia
const r = 10 .^ range(-1, 2, length=50);

In [31]:
%%julia
function hzpt_f_jax(r, A0, R, R1h)
    F2 = exp(-r/R) / (4π*r*R^2)
    return -A0 * F2 * (1 - (R/R1h)^2 * exp(-(R+R1h)*r/(R*R1h)))
end;

In [33]:
%%julia
@btime hzpt_f_jax.(r, 750., 26., 2.);

  1.136 μs (1 allocation: 496 bytes)


In [34]:
%%julia
@btime ForwardDiff.jacobian(((A0, R, R1h),) -> hzpt_f_jax.(r, A0, R, R1h), [750., 26., 2.]);

  2.911 μs (5 allocations: 3.47 KiB)


In [35]:
%%julia
jac = ForwardDiff.jacobian(((A0, R, R1h),) -> hzpt_f_jax.(r, A0, R, R1h), [750., 26., 2.])
size(jac)

(50, 3)

We can also try reverse-mode, where note I added the `norm` since we need the gradient of a scalar function. With only 3 arguments, the overhead of reverse-mode makes this significantly worse than forward-mode, although still faster than either of the Jax gradients. 

In [36]:
%%julia
@btime Zygote.gradient((A0, R, R1h) -> norm(hzpt_f_jax.(r, A0, R, R1h)), 750., 26., 2.)

  75.158 μs (1122 allocations: 68.70 KiB)


(0.3614060423837588, 0.26535983236863103, -261.8240445770739)