In [None]:
import jax.numpy as np
import numpy as onp
import data
from jax import random, jit, vmap, grad
from jax.scipy.special import logsumexp

In [None]:
key = random.PRNGKey(42)

In [None]:
size = 300
x = random.normal(key, (size, size), dtype=np.float32)
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def xxt_and_selu(x):
    return selu(np.dot(x, x.T))

In [None]:
%timeit xxt_and_selu(x).block_until_ready()

In [None]:
#TODO: Time a JIT-Compiled version of selu(X@X.T)

In [None]:
def dot(x, y):
    return np.dot(x, y)
def naive_mv(A, v):
    # Naively multiplies matrix A by vector v
    return np.vstack([dot(a.T, v) for a in A])
%timeit naive_mv(x, x[0])

In [None]:
vmapped_mv = #TODO: use vmap to speed up naive_mv
%timeit vmapped_mv(x, x[0])

In [None]:
jit_vmap = #Try JITting the above function: how much do you expect this to help?
%timeit jit_vmap(x, x[0])

In [None]:
def trustfall(x):
    # no way this is differentiable
    try:
        if x < 0:
            raise ValueError
        else:
            return x**2
    except ValueError:
        if x > -1:
            return np.pi * x
        else:
            ct = 0
            ret = 0
            while ct < 5:
                ret += x
                ct += 1
            return ret
grad_fun = # TODO Take the gradient of the above function.
print(grad_fun(-0.5))
print(grad_fun(-3.))
print(grad_fun(2.))

In [None]:
def tanhsin(x):
    return np.tanh(x) + np.sin(x)
grad1 = # TODO first derivative
grad2 = # TODO second derivative
grad3 = # TODO third derivative
print(tanhsin(1.))
print(grad1(1.))
print(grad2(1.))
print(grad3(1.))

In [None]:
def get_normal(key, shape, scale=1e-2):
    # TODO Helper function for grabbing a normal scaled by scale and a fresh key
    pass
key, x = get_normal(key, (100, 10))

In [None]:
def random_layer_params(m, n, key, scale=1e-2):
    # initialize new randon mormal matrix and bias for an affine map
    # should take R^m -> R^n
    pass
def init_network_params(sizes, key):
    keys = # TODO get keys for each layer
    tripzip = zip(sizes[:-1], sizes[1:], keys)
    return [random_layer_params(m, n, k) for m, n, k in tripzip]

In [None]:
def one_hot(x, k, dtype=np.float32):
    return np.array(x[:, None] == np.arange(k), dtype)

def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        outputs = np.dot(w, activations) + b
        activations = selu(outputs)
    final_w, final_b = params[-1]
    logits = np.dot(final_w, activations) + final_b
    return logits - logsumexp(logits) #logsumexp for numerical stability


def batched_predict_no_vmap(params, images):
    return np.vstack([predict(params, images[i]) for i in range(len(images))])

batched_predict = None # use the appropriate vmap here

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -np.sum(preds * targets)

def update(params, x, y, step_size):
    grads = grad(loss)(params, x, y) # by default, jax only grads to first parameter
    # Return new parameters, updating by sgd with step size

def accuracy(params, images, targets):
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(batched_predict(params, images,), axis=1)
    return np.mean(predicted_class == target_class)

In [None]:
def train(layer_sizes = [784, 512, 256, 10],
          param_scale = 0.1,
          step_size = 0.0001,
          num_epochs = 8,
          batch_size = 128,
          n_targets = 10,
          jit_update = True,):
    params = init_network_params(layer_sizes, random.PRNGKey(0))
    # Dataset loading nonsense
    mnist_dataset = data.get_mnist_dataset(train=True)
    mnist_dataset_test = data.get_mnist_dataset(train=False)
    training_generator = data.NumpyLoader(mnist_dataset, batch_size=128, num_workers=0)
    train_images = onp.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
    train_labels = one_hot(onp.array(mnist_dataset.train_labels), n_targets)
    test_images = np.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=np.float32)
    test_labels = one_hot(onp.array(mnist_dataset_test.test_labels), n_targets)
    import time
    jitted_update = jit(update)
    for epoch in range(num_epochs):
        start_time = time.time()
        for x, y in training_generator:
            y = one_hot(y, n_targets)
            if jit_update:
                params = jitted_update(params, x, y, step_size)
            else:
                params = update(params, x, y, step_size)
        epoch_time = time.time() - start_time
    
        train_acc = accuracy(params, train_images, train_labels)
        test_acc = accuracy(params, test_images, test_labels)
        
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set accuracy {}".format(train_acc))
        print("Test set accuracy {}".format(test_acc))
    return params

In [None]:
final_params = train(num_epochs=1, jit_update=False)

In [None]:
final_params = train(num_epochs=1, jit_update=True)