In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## The forward and backward passes

In [2]:
#export
from exp.nb_01 import *

def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(jax.device_put, (x_train, y_train, x_valid, y_valid))

def normalize(x, m, s): return (x-m)/s

In [5]:
x_train,y_train,x_valid,y_valid = get_data()



In [6]:
train_mean,train_std = np.mean(x_train), np.std(x_train)
train_mean,train_std

(DeviceArray(0.130466, dtype=float32), DeviceArray(0.306289, dtype=float32))

In [7]:
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)

In [8]:
train_mean,train_std = np.mean(x_train), np.std(x_train)
train_mean,train_std

(DeviceArray(-5.165295e-05, dtype=float32),
 DeviceArray(0.999688, dtype=float32))

In [9]:
#export
def test_near_zero(a,tol=1e-3): assert np.abs(a)<tol, f"Near zero: {a}"

In [10]:
test_near_zero(train_mean)
test_near_zero(1 - train_std)

In [11]:
n,m = x_train.shape
c = y_train.max()+1
n,m,c

(50000, 784, DeviceArray(10, dtype=int32))

## Foundations version

### Basic Architecture

In [12]:
nh = 50
key = jax.random.PRNGKey(3)

In [13]:
w1 = jax.random.normal(key=key, shape=(m, nh))/math.sqrt(m) 
b1 = np.zeros(nh)
w2 = jax.random.normal(key=key, shape=(nh, 1))/math.sqrt(nh)
b2 = np.zeros(1)

In [15]:
def relu(x): return np.clip(x, a_min=0.)

In [16]:
def lin(x, w, b): return np.dot(x, w) + b

In [17]:
%timeit -n 10 t = relu(lin(x_valid, w1, b1))

7.39 ms ± 1.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
w1 = jax.random.normal(key=key, shape=(m, nh))*math.sqrt(2/m) 
w2 = jax.random.normal(key=key, shape=(nh, 1))*math.sqrt(2/nh)

In [19]:
def relu(x): return  np.clip(x, a_min=0.) - 0.5

In [20]:
def model(xb):
    l1 = lin(xb, w1, b1)
    l2 = relu(l1)
    l3 = lin(l2, w2, b2)
    return l3

In [21]:
%timeit -n 10 _=model(x_valid)

8.45 ms ± 2.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [22]:
assert model(x_valid).shape==torch.Size([x_valid.shape[0],1])

### Loss Function: MSE

In [28]:
#export
def mse(output, targ): return np.mean(np.dot((output.squeeze(-1) - targ), (output.squeeze(-1) - targ)))

In [25]:
y_train,y_valid = y_train.astype(float), y_valid.astype(float)

In [29]:
preds = model(x_train)

In [30]:
mse(preds, y_train)

DeviceArray(1167455.4, dtype=float32)

### Gradients and backward pass

In [44]:
mse_grad = jax.grad(mse)
relu_grad = jax.grad(relu)
lin_grad = jax.grad(lin)

In [45]:
def forward_and_backward(inp, targ):
    # forward pass:
    l1 = lin(inp, w1, b1)
    l2 = relu(l1)
    out = lin(l2, w2, b2)
    # we don't actually need the loss in backward!
    loss = mse(out, targ)
    
    # backward pass:
    mse_grad(out, targ)
    lin_grad(l2, w2, b2)
    relu_grad(l1)
    lin_grad(inp, w1, b1)

In [46]:
forward_and_backward(x_train, y_train)

TypeError: Gradient only defined for scalar-output functions. Output had shape: (50000, 1).

In [89]:
@jax.jit
def update(params, batch):
    grads = grad(mse)(params, batch)
    return [(w - step_size * dw, b - step_size * db)
            for (w, b), (dw, db) in zip(params, grads)]

In [91]:
!python notebook2script.py 02_fully_connected_jax.ipynb

Converted 02_fully_connected_jax.ipynb to exp/nb_02.py
