In [1]:
from jax import grad, jit, vmap
import jax.numpy as jnp
from numpy import random

## Data Preparation

In [2]:
model = lambda x: x * 2.0 + 1.0
x_train = jnp.array([0, 1])
y_train = model(x_train)
x_test = jnp.array([2,3,4])
y_test = model(x_test)



## Function Definitinos

In [8]:
def get_fn():
    pred = lambda w, b, x: w * x + b
    f = lambda w, b, x, y: jnp.mean(jnp.power(y - pred(w,b,x),2))
    dfdwb = grad(f, argnums=[0,1])
    vdfdwb = vmap(dfdwb, (None, None, 0, 0))
    vf = vmap(f, (None, None, 0, 0))
    return vdfdwb, vf
vdfdwb, vf = get_fn()

In [9]:
def update(w, b, x_a, y_a, x_v, y_v, mu = 0.01):
    dw_a, db_a = vdfdwb(w, b, x_a, y_a)
    dw, db = jnp.mean(dw_a), jnp.mean(db_a)
    w -= mu * dw
    b -= mu * db
    J = jnp.sum(vf(w, b, x_v, y_v))
    return w, b, J    
update_jit = jit(update)

In [10]:
def train(w, b, x_a, y_a, x_v, y_v, N_epoch=1000, update=update_jit):
    J_list = []
    for epoch in range(N_epoch):
        #J, w, b = update_jit(w, b, x_a, y_a)
        w, b, J = update(w, b, x_a, y_a, x_v, y_v)
        J_list.append(J)
    return w, b, J_list

In [11]:
w = random.uniform(low=-0.01, high=0.01, size=(1,))
b = random.uniform(low=-0.01, high=0.01, size=(1,))

In [12]:
N_epoch = 1000
w, b, J_list = train(w, b, x_train, y_train, x_test, y_test, N_epoch = N_epoch, update=update_jit) 
print(w, b)

[1.9781797] [1.0134858]


In [13]:
%time w, b, J_list = train(w, b, x_train, y_train, x_test, y_test, N_epoch = 1000, update=update) 
%time w, b, J_list = train(w, b, x_train, y_train, x_test, y_test, N_epoch = 1000, update=update_jit) 

CPU times: user 6.12 s, sys: 16.8 ms, total: 6.14 s
Wall time: 6.13 s
CPU times: user 8.7 ms, sys: 0 ns, total: 8.7 ms
Wall time: 8.48 ms


In [14]:
6190/8.63

717.2653534183081