In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from numpy import random
from math import sqrt
import numpy as pnp

In [2]:
# load dataset, MNIST
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
assert x_train.shape == (60000, 28, 28)
assert x_test.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_test.shape == (10000,)
print(y_train[:10]) # not onehot yet

[5 0 4 1 9 2 1 3 1 4]


In [3]:
N_classes = 10
X_tr, X_te = x_train.reshape(-1,28*28), x_test.reshape(-1,28*28)
N_features = X_tr.shape[1] 
N_features, N_classes

(784, 10)

In [68]:
def onehot_enc(values, n_classes):
    # values = [1, 0, 3], n_classes=4
    return pnp.eye(n_classes)[values]
Y_tr = onehot_enc(y_train, N_classes)
Y_te = onehot_enc(y_test, N_classes)
Y_tr.shape, Y_te.shape, Y_tr[0]

((60000, 10), (10000, 10), array([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]))

In [89]:
def cross_entropy_each(y_hat, y):
    if y == 1:
        return -jnp.log(y_hat)
    else:
        return -jnp.log(1 - y_hat)

def cross_entropy(y_hat_vec, y_vec):
    s = 0
    for y_hat, y in zip(y_hat_vec, y_vec):
        s += cross_entropy_each(y_hat, y)
    return s / len(y_vec)
    
def sigmoid(z):
    return 1 / (1 + jnp.exp(-z))

def softmax(x): # x should be a vector 
    ex = jnp.exp(x)
    return ex/jnp.sum(ex)

def cl_predict(Wb_list, x_vec):
    # W: matrix, b: vector
    for W, b in Wb_list[:-1]:
        o_vec = jnp.dot(x_vec, W) + b
        x_vec = sigmoid(o_vec) # for multiple layers 
    W, b = Wb_list[-1]
    o_vec =  jnp.dot(x_vec, W) + b
    return softmax(o_vec)

def cost(Wb_list, x_vec, y_vec):
    y_hat_vec = cl_predict(Wb_list, x_vec)
    return cross_entropy(y_hat_vec, y_vec)

def update_inplace(Wb_list, delta_Wb_list, mu =0.01):
    for Wb, dWb in zip(Wb_list, delta_Wb_list):
        W, b = Wb
        dW, db = dWb 
        W += mu * dW
        b += mu * db

In [90]:
## Prepare weights
N_hidden = int(sqrt(N_features))
W1 = random.uniform(low=-0.01, high=0.01, size=(N_features,N_hidden))
b1 = random.uniform(low=-0.01, high=0.01, size=(N_hidden,))
W2 = random.uniform(low=-0.01, high=0.01, size=(N_hidden,N_classes))
b2 = random.uniform(low=-0.01, high=0.01, size=(N_classes,))
Wb_list = [[W1,b1],[W2,b2]]

In [91]:
dcost = grad(cost, argnums=0)

In [92]:
N = 2
for x, y in zip(X_tr[:N], Y_tr[:N]):
    J = cost(Wb_list, x, y)
    delta_Wb_list = dcost(Wb_list, x, y)
    update_inplace(Wb_list, delta_Wb_list, mu=0.1)
    print(delta_Wb_list[0][0][0])
    print(J)

[-0. -0. -0. -0.  0. -0.  0. -0. -0. -0. -0.  0. -0.  0. -0. -0. -0. -0.
 -0. -0. -0. -0.  0. -0.  0.  0.  0. -0.]
0.32573888
[-0.  0.  0.  0.  0.  0.  0. -0.  0.  0. -0.  0.  0. -0. -0.  0. -0. -0.
 -0. -0.  0.  0.  0.  0. -0.  0. -0.  0.]
0.32725117


In [93]:
x_vec = X_tr[0]
y_vec = Y_tr[0]

def f(Wb_list, x_vec, y_vec):
    y_hat_vec = cl_predict(Wb_list, x_vec)
    return cross_entropy(y_hat_vec, y_vec)

In [94]:
df = grad(f, argnums=0)
dWb = df(Wb_list, x_vec, y_vec)
f(Wb_list, x_vec, y_vec)

DeviceArray(0.32573888, dtype=float32)

In [95]:
dWb[0][0][0]

DeviceArray([-0., -0., -0., -0.,  0., -0.,  0., -0., -0., -0., -0.,  0.,
             -0.,  0., -0., -0., -0., -0., -0., -0., -0., -0.,  0., -0.,
              0.,  0.,  0., -0.], dtype=float32)