In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [10]:
def layer_params(m,n,key,scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n,m)), scale * random.normal(b_key, (n,))

def init_net_params(shape, key):
    keys = random.split(key, len(shape))
    return [layer_params(m,n,k) for m,n,k in zip(shape[:-1], shape[1:], keys)]

ls = [784, 512, 512, 10]
ss=0.01
epochs = 8
bs = 128
dim_out = 10
params = init_net_params(ls, random.PRNGKey(0))

In [11]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0,x)

def predict(params, image):
    act = image
    for w,b in params[:-1]:
        out = jnp.dot(w, act) + b
        act = relu(out)
    fw, fb = params[-1]
    logits = jnp.dot(fw, act) + fb
    return logits - logsumexp(logits)


In [12]:
rand_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, rand_image) 
print(preds.shape) 

(10,)
