In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from numpy import random

In [3]:
def update(w, b, x_a, y_a, x_v, y_v, mu = 0.01):
    dw_a, db_a = vdfdwb(w, b, x_a, y_a)
    dw, db = jnp.mean(dw_a), jnp.mean(db_a)
    w -= mu * dw
    b -= mu * db
    J = jnp.sum(vf(w, b, x_v, y_v))
    return w, b, J    
update_jit = jit(update)

def get_fn():
    pred = lambda w, b, x: w * x + b
    f = lambda w, b, x, y: jnp.sum(jnp.power(y - pred(w,b,x),2))
    dfdwb = grad(f, argnums=[0,1])
    vdfdwb = vmap(dfdwb, (None, None, 0, 0))
    vf = vmap(f, (None, None, 0, 0))
    return vdfdwb, vf
vdfdwb, vf = get_fn()

def train(w, b, x_a, y_a, x_v, y_v, N_epoch=1000, update=update_jit):
    J_list = []
    for epoch in range(N_epoch):
        #J, w, b = update_jit(w, b, x_a, y_a)
        w, b, J = update(w, b, x_a, y_a, x_v, y_v)
        J_list.append(J)
    return w, b, J_list

In [4]:
model = lambda x: x * 2.0 + 1.0
x_train = jnp.array([0, 1])
y_train = model(x_train)
x_test = jnp.array([2,3,4])
y_test = model(x_test)



In [5]:
w = random.uniform(low=-0.01, high=0.01, size=(1,))
b = random.uniform(low=-0.01, high=0.01, size=(1,))

N_epoch = 1000
w, b, J_list = train(w, b, x_train, y_train, x_test, y_test, N_epoch = N_epoch, update=update_jit) 
print(w, b)

[1.9783688] [1.0133686]


In [6]:
%time w, b, J_list = train(w, b, x_train, y_train, x_test, y_test, N_epoch = 1000, update=update) 

CPU times: user 6.48 s, sys: 0 ns, total: 6.48 s
Wall time: 6.46 s


In [7]:
%time w, b, J_list = train(w, b, x_train, y_train, x_test, y_test, N_epoch = 1000, update=update_jit) 

CPU times: user 14.5 ms, sys: 0 ns, total: 14.5 ms
Wall time: 11.6 ms


In [8]:
5830/11.8

494.06779661016947