In [4]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## The forward and backward passes

In [9]:
#export
from exp.nb_01 import *

def get_data():
    x_train, y_train, x_valid, y_valid  = datasets.mnist()
    return map(jax.device_put, (x_train, y_train, x_valid, y_valid))

def normalize(x, m, s): return (x-m)/s

In [10]:
x_train,y_train,x_valid,y_valid = get_data()



In [12]:
train_mean,train_std = np.mean(x_train), np.std(x_train)
train_mean,train_std

(DeviceArray(0.13067462, dtype=float32),
 DeviceArray(0.30700648, dtype=float32))

In [13]:
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)

In [14]:
train_mean,train_std = np.mean(x_train), np.std(x_train)
train_mean,train_std

(DeviceArray(-2.3471399e-05, dtype=float32),
 DeviceArray(0.99988186, dtype=float32))

In [19]:
#export
def test_near_zero(a,tol=1e-3): assert np.abs(a)<tol, f"Near zero: {a}"

In [21]:
test_near_zero(train_mean)
test_near_zero(1 - train_std)

In [22]:
n,m = x_train.shape
c = y_train.max()+1
n,m,c

(60000, 784, DeviceArray(10, dtype=int32))

## Foundations version

### Basic Architecture

In [25]:
nh = 50
key = jax.random.PRNGKey(3)

In [26]:
w1 = jax.random.normal(key=key, shape=(m, nh))/math.sqrt(m) 
b1 = np.zeros(nh)
w2 = jax.random.normal(key=key, shape=(nh, 1))/math.sqrt(nh)
b2 = np.zeros(1)

In [67]:
@jax.jit
def relu(x): return np.where(x > 0., x, 0.)

In [68]:
def lin(x, w, b): return np.dot(x, w) + b

In [69]:
%timeit -n 10 t = relu(lin(x_valid, w1, b1))

11 ms ± 7.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [63]:
w1 = jax.random.normal(key=key, shape=(m, nh))*math.sqrt(2/m) 
w2 = jax.random.normal(key=key, shape=(nh, 1))*math.sqrt(2/nh)

In [70]:
@jax.jit
def relu(x): return np.where(x > 0., x, 0.) - 0.5

In [72]:
@jax.jit
def model(xb):
    l1 = lin(xb, w1, b1)
    l2 = relu(l1)
    l3 = lin(l2, w2, b2)
    return l3

In [73]:
%timeit -n 10 _=model(x_valid)

The slowest run took 135.67 times longer than the fastest. This could mean that an intermediate result is being cached.
5.31 ms ± 3.48 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Loss Function: MSE

In [76]:
#export
def mse(output, targ): return np.mean(np.dot((output.squeeze(-1) - targ), (output.squeeze(-1) - targ)))

In [80]:
y_train,y_valid = y_train.astype(float), y_valid.astype(float)

In [81]:
preds = model(x_train)

In [82]:
mse(preds, y_train)

DeviceArray(1403289., dtype=float32)

In [89]:
@jax.jit
def update(params, batch):
    grads = grad(mse)(params, batch)
    return [(w - step_size * dw, b - step_size * db)
            for (w, b), (dw, db) in zip(params, grads)]

In [91]:
!python notebook2script.py 02_fully_connected_jax.ipynb

Converted 02_fully_connected_jax.ipynb to exp/nb_02.py
