In [1]:
import time
from jax import grad, jit, lax
import jax.config as config; config.update('jax_platform_name', 'cpu')

In [2]:
def dual_averaging(t0=10, kappa=0.75, gamma=0.05):
    def init_fn(prox_center=0.):
        x_t = 0.
        x_avg = 0.
        g_avg = 0.
        t = 0
        return x_t, x_avg, g_avg, t, prox_center

    def update_fn(g, state):
        x_t, x_avg, g_avg, t, prox_center = state
        t = t + 1
        g_avg = (1 - 1 / (t + t0)) * g_avg + g / (t + t0)
        x_t = prox_center - (t ** 0.5) / gamma * g_avg
        weight_t = t ** (-kappa)
        x_avg = (1 - weight_t) * x_avg + weight_t * x_t
        return x_t, x_avg, g_avg, t, prox_center

    return init_fn, update_fn


def optimize(f):
    da_init, da_update = dual_averaging(gamma=0.5)
    init_state = da_init()
    
    def body_fn(i, state):
        x = state[0]
        g = grad(f)(x)
        return da_update(g, state)

    last_state = lax.fori_loop(0, 1000, body_fn, init_state)
    x_avg = last_state[1]
    return x_avg


def optimize_v1(f):
    da_init, da_update = dual_averaging(gamma=0.5)
    init_state = da_init()
    
    @jit
    def body_fn(i, state):
        x = state[0]
        g = grad(f)(x)
        return da_update(g, state)

    last_state = lax.fori_loop(0, 1000, body_fn, init_state)
    x_avg = last_state[1]
    return x_avg


def optimize_v2(f):
    da_init, da_update = dual_averaging(gamma=0.5)
    init_state = da_init()
    
    @jit
    def body_fn(i, state):
        x = state[0]
        g = grad(f)(x)
        return da_update(g, state)

    last_state = init_state
    for i in range(1000):
        last_state = body_fn(i, last_state)
    x_avg = last_state[1]
    return x_avg


def optimize_v3(f):
    da_init, da_update = dual_averaging(gamma=0.5)
    init_state = da_init()
    
    def body_fn(i, state):
        x = state[0]
        g = grad(f)(x)
        return da_update(g, state)

    last_state = jit(lax.fori_loop, static_argnums=(2,))(0, 1000, body_fn, init_state)
    x_avg = last_state[1]
    return x_avg


f = lambda x: (x + 1) ** 2

In [3]:
tic = time.time()
print(optimize(f).copy())
print("time before compiling:", time.time() - tic)

fn = jit(optimize, static_argnums=(0,))
tic = time.time()
print(fn(f).copy())
print("time with compiling:", time.time() - tic)

tic = time.time()
print(fn(f).copy())
print("time after compiling:", time.time() - tic)

print("===v1===")

tic = time.time()
print(optimize_v1(f).copy())
print("time before compiling:", time.time() - tic)

fn = jit(optimize_v1, static_argnums=(0,))
tic = time.time()
print(fn(f).copy())
print("time with compiling:", time.time() - tic)

tic = time.time()
print(fn(f).copy())
print("time after compiling:", time.time() - tic)

print("===v2===")

tic = time.time()
print(optimize_v2(f).copy())
print("time before compiling:", time.time() - tic)

fn = jit(optimize_v2, static_argnums=(0,))
tic = time.time()
print(fn(f).copy())
print("time with compiling:", time.time() - tic)

tic = time.time()
print(fn(f).copy())
print("time after compiling:", time.time() - tic)

print("===v3===")

tic = time.time()
print(optimize_v3(f).copy())
print("time before compiling:", time.time() - tic)

fn = jit(optimize_v3, static_argnums=(0,))
tic = time.time()
print(fn(f).copy())
print("time with compiling:", time.time() - tic)

tic = time.time()
print(fn(f).copy())
print("time after compiling:", time.time() - tic)

-0.99569756
time before compiling: 1.3949103355407715
-0.99569756
time with compiling: 1.3248636722564697
-0.99569756
time after compiling: 0.0004830360412597656
===v1===
-0.99569726
time before compiling: 15.286142110824585
-0.99569726
time with compiling: 15.5829918384552
-0.99569726
time after compiling: 0.00047397613525390625
===v2===
-0.99569726
time before compiling: 0.29132914543151855
-0.99569726
time with compiling: 0.2943596839904785
-0.99569726
time after compiling: 0.00048613548278808594
===v3===
-0.99569726
time before compiling: 0.03526735305786133
-0.99569726
time with compiling: 0.039703369140625
-0.99569726
time after compiling: 0.0007233619689941406
