In [26]:
import jax
from jax import grad, jit
import jax.numpy as jnp
from functools import partial

In [5]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357]


In [6]:
def first_finite_differences(f, x, eps=1E-3):
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]


In [16]:
multiple_jit = jit(grad(jit(grad(jit(grad(sum_logistic))))))
one_jit = jit(grad(grad(grad(sum_logistic))))

In [25]:
# same result
print(multiple_jit(1.0))
print(one_jit(1.0))

-0.0353256
-0.0353256


`multiple_jit`, `one_jit`의 컴파일 시간을 측정해보면 `jit`이 중첩된 함수가 컴파일 시간이 더 긴 것으로 추정된다.

In [28]:
%%timeit

jit(grad(jit(grad(jit(grad(sum_logistic))))))(1.0)

53.5 ms ± 437 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [29]:
%%timeit

jit(grad(grad(grad(sum_logistic))))(1.0)

38.1 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


compile이 끝난 후에는 `make_jaxpr`에는 표현에 차이가 있지만 속도에는 큰 차이가 없다.

In [19]:
%%timeit

multiple_jit(1.0)

35.7 µs ± 403 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [20]:
%%timeit

one_jit(1.0)

35.4 µs ± 554 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [21]:
jax.make_jaxpr(multiple_jit)(1.0)

{ lambda ; a:f32[]. let
    b:f32[] = pjit[
      name=sum_logistic
      jaxpr={ lambda ; c:f32[]. let
          _:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[] j:f32[] k:f32[] = pjit[
            name=sum_logistic
            jaxpr={ lambda ; l:f32[]. let
                _:f32[] m:f32[] n:f32[] o:f32[] p:f32[] q:f32[] = pjit[
                  name=sum_logistic
                  jaxpr={ lambda ; r:f32[]. let
                      s:f32[] = neg r
                      t:f32[] = exp s
                      u:f32[] = add 1.0 t
                      v:f32[] = div 1.0 u
                      _:f32[] = integer_pow[y=-2] u
                      _:f32[] = integer_pow[y=-2] u
                      w:f32[] = integer_pow[y=-3] u
                      _:f32[] = mul -2.0 w
                      x:f32[] = integer_pow[y=-2] u
                      y:f32[] = integer_pow[y=-3] u
                      z:f32[] = mul -2.0 y
                      ba:f32[] = integer_pow[y=-3] u
                   

In [22]:
jax.make_jaxpr(one_jit)(1.0)

{ lambda ; a:f32[]. let
    b:f32[] = pjit[
      name=sum_logistic
      jaxpr={ lambda ; c:f32[]. let
          d:f32[] = neg c
          e:f32[] = exp d
          f:f32[] = add 1.0 e
          g:f32[] = div 1.0 f
          _:f32[] = integer_pow[y=-2] f
          _:f32[] = integer_pow[y=-2] f
          h:f32[] = integer_pow[y=-3] f
          _:f32[] = mul -2.0 h
          i:f32[] = integer_pow[y=-2] f
          j:f32[] = integer_pow[y=-3] f
          k:f32[] = mul -2.0 j
          l:f32[] = integer_pow[y=-3] f
          m:f32[] = integer_pow[y=-4] f
          n:f32[] = mul -3.0 m
          o:f32[] = mul -2.0 l
          p:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
          _:f32[] = reduce_sum[axes=()] p
          q:f32[] = mul 1.0 i
          r:f32[] = mul q 1.0
          s:f32[] = neg r
          t:f32[] = mul s e
          _:f32[] = neg t
          u:f32[] = neg 1.0
          v:f32[] = mul s u
          w:f32[] = mul u e
          x:f32[] = neg w
          