In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from numpy import random

In [6]:
model = lambda x: x * 2.0 + 1.0
x_train = jnp.array([0, 1])
y_train = model(x_train)
x_test = jnp.array([2,3,4])
y_test = model(y_train)

In [7]:
w = random.uniform(low=-0.01, high=0.01, size=(1,))
b = random.uniform(low=-0.01, high=0.01, size=(1,))

In [9]:
pred = lambda w, b, x: w * x + b
f = lambda w, b, x, y: jnp.sum(jnp.power(y - pred(w,b,x),2))
dfdwb= grad(f, argnums=[0,1])
vdfdwb = vmap(dfdwb, (None, None, 0, 0))

In [10]:
def update(w, b, x_a, y_a, x_v, y_v, mu = 0.01):
    dw_a, db_a = vdfdwb(w, b, x_a, y_a)
    dw, db = jnp.mean(dw_a), jnp.mean(db_a)
    w -= mu * dw
    b -= mu * db
    vf = vmap(f,(None,None,0,0))
    J = jnp.sum(vf(w, b, x_v, y_v))
    return J, w, b
    
update_jit = jit(update)

In [4]:
def train_slow(w, b, x_a, y_a, x_v, y_v, N_epoch=1000):
    J_list = []
    for epoch in range(N_epoch):
        #J, w, b = update_jit(w, b, x_a, y_a)
        J, w, b = update(w, b, x_a, y_a)
        J_list.append(J)
    return w, b, J_list

def train_fast(w, b, x_a, y_a, N_epoch=1000):
    J_list = []
    for epoch in range(N_epoch):
        J, w, b = update_jit(w, b, x_a, y_a)
        J_list.append(J)
    return w, b, J_list

N_epoch = 1000
w, b, J_list = train_fast(w, b, x_a[:2], y_a[:2], N_epoch = 1000) 
print(w, b)

[1.9784147] [1.0133404]


In [5]:
%time w, b, J_list = train_slow(w, b, x_a[:2], y_a[:2], N_epoch = 1000) 
%time w, b, J_list = train_fast(w, b, x_a[:2], y_a[:2], N_epoch = 1000) 

CPU times: user 4.81 s, sys: 0 ns, total: 4.81 s
Wall time: 4.8 s
CPU times: user 7.4 ms, sys: 0 ns, total: 7.4 ms
Wall time: 7.15 ms


In [None]:
4670 / 8.21