In [49]:
import jax.numpy as np
import numpy as onp
import data
from jax import random, jit, vmap, grad
from jax.scipy.special import logsumexp

In [2]:
key = random.PRNGKey(42)



In [3]:
size = 300
x = random.normal(key, (size, size), dtype=np.float32)
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def xxt_and_selu(x):
    return selu(np.dot(x, x.T))

In [4]:
%timeit xxt_and_selu(x).block_until_ready()

967 µs ± 113 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
%timeit jit(xxt_and_selu)(x).block_until_ready()

306 µs ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [6]:
def dot(x, y):
    return np.dot(x, y)
def naive_mv(A, v):
    return np.vstack([dot(a.T, v) for a in A])
%timeit naive_mv(x, x[0])

56.6 ms ± 1.19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
vmapped_mv = vmap(dot, (None, 0), 0)
%timeit vmapped_mv(x, x[0])

31.6 ms ± 2.27 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
jit_vmap = jit(vmapped_mv)
%timeit jit_vmap(x, x[0])

11.8 ms ± 444 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
def trustfall(x):
    try:
        if x < 0:
            raise ValueError
        else:
            return x**2
    except ValueError:
        if x > -1:
            return np.pi * x
        else:
            ct = 0
            ret = 0
            while ct < 5:
                ret += x
                ct += 1
            return ret
grad_fun = grad(trustfall)
jitted = jit(grad_fun)
print(grad_fun(-0.5))
print(grad_fun(-3.))
print(grad_fun(2.))

3.1415927
5.0
4.0


In [70]:
def tanhsin(x):
    return np.tanh(x) + np.sin(x)
grad1 = grad(tanhsin)
grad2 = grad(grad1)
grad3 = grad(grad2)
print(tanhsin(1.))
print(grad1(1.))
print(grad2(1.))
print(grad3(1.))
print(jit(grad(jit(grad(grad2))))(1.))

1.6030651
0.9602766
-1.4811709
0.0813244
1.5065618


In [38]:
def get_normal(key, shape): # Helper function for grabbing a normal and a new key
    key, subkey = random.split(key)
    out = random.normal(key, shape, dtype=np.float32)
    return subkey, out
key, x = get_normal(key, (100, 10))

In [39]:
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    tripzip = zip(sizes[:-1], sizes[1:], keys)
    return [random_layer_params(m, n, k) for m, n, k in tripzip]

In [50]:
def lrelu(x, leak=0.1):
    return np.maximum(-x, x)

def softmax(logits):
    expsum = np.sum(np.exp(logits))
    return np.exp(logits)/expsum

def one_hot(x, k, dtype=np.float32):
    return np.array(x[:, None] == np.arange(k), dtype)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -np.sum(preds * targets)

def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        outputs = np.dot(w, activations) + b
        activations = selu(outputs)
    final_w, final_b = params[-1]
    logits = np.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)
@jit
def update(params, x, y, step_size):
    grads = grad(loss)(params, x, y) # by default, jax only grads to first parameter
    return [(w - step_size * dw, b - step_size * db) \
            for (w, b), (dw, db) in zip(params, grads)]
    
batched_predict = vmap(predict, in_axes=(None, 0))

def accuracy(params, images, targets):
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(batched_predict(params, images,), axis=1)
    return np.mean(predicted_class == target_class)

In [53]:
def train(layer_sizes = [784, 512, 256, 10],
          param_scale = 0.1,
          step_size = 0.0001,
          num_epochs = 8,
          batch_size = 128,
          n_targets = 10,):
    params = init_network_params(layer_sizes, random.PRNGKey(0))
    # Dataset loading nonsense
    mnist_dataset = data.get_mnist_dataset(train=True)
    mnist_dataset_test = data.get_mnist_dataset(train=False)
    training_generator = data.NumpyLoader(mnist_dataset, batch_size=128, num_workers=0)
    train_images = onp.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
    train_labels = one_hot(onp.array(mnist_dataset.train_labels), n_targets)
    test_images = np.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=np.float32)
    test_labels = one_hot(onp.array(mnist_dataset_test.test_labels), n_targets)
    import time
    for epoch in range(num_epochs):
        start_time = time.time()
        for x, y in training_generator:
            y = one_hot(y, n_targets)
            params = update(params, x, y, step_size)
        epoch_time = time.time() - start_time
    
        train_acc = accuracy(params, train_images, train_labels)
        test_acc = accuracy(params, test_images, test_labels)
        
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set accuracy {}".format(train_acc))
        print("Test set accuracy {}".format(test_acc))
    return params

In [55]:
final_params = train(num_epochs=1)

Epoch 0 in 3.41 sec
Training set accuracy 0.9606500267982483
Test set accuracy 0.9571999907493591


In [None]:
f\