In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [51]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 1
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [26]:
from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

In [27]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [28]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')

Invalid shapes!


In [29]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


In [34]:
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

# # @jit
# def update(params, x, y):
#   grads = grad(loss)(params, x, y)
#   return [(w - step_size * dw, b - step_size * db)
#           for (w, b), (dw, db) in zip(params, grads)]

In [35]:
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
  return tree_map(np.asarray, data.default_collate(batch))

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

In [22]:
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz


0.3%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz


100.0%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz



2.0%

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz


100.0%

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw






In [36]:
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)

In [87]:
my_ls = [1,2,3,4]

[x for x in my_ls[:-1]]

[1, 2, 3]

In [93]:
def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    print('w = ', w.shape, 'b = ', b.shape)
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)


batched_predict = vmap(predict, in_axes=(None, 0))

def loss(params, images, targets):
  preds = batched_predict(params, images)
  print('preds.shape = ', preds.shape, 'targets.shape = ', targets.shape)
  return -jnp.mean(preds * targets)

def update(params, x, y):
  # for param in params:
    # print(param[0].shape, end=' ')
  # print('\n-----------------------------------------')
  print('y.shape = ', y.shape)
  grads = grad(loss)(params, x, y)
  
  for idx, param_tuple in enumerate(params, start=1):
    for param in param_tuple:
      print(f'({idx}) param_shape:', param.shape, end=' ')
    print(' ')
  
  for idx, grad_tuple in enumerate(grads, start=1):
    for grad_ in grad_tuple:
      print(f'idx: ({idx}) gradient_shape:', grad_.shape, end=' ')
    print(' ')

  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [94]:
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)

w =  (512, 784) b =  (512,)
w =  (512, 512) b =  (512,)


In [95]:
import time

print('num_layers = ', len(params))

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    params = update(params, x, y)
    # Jums
    break
  break 
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

num_layers =  3
y.shape =  (128, 10)
w =  (512, 784) b =  (512,)
w =  (512, 512) b =  (512,)
preds.shape =  (128, 10) targets.shape =  (128, 10)
(1) param_shape: (512, 784) (1) param_shape: (512,)  
(2) param_shape: (512, 512) (2) param_shape: (512,)  
(3) param_shape: (10, 512) (3) param_shape: (10,)  
idx: (1) gradient_shape: (512, 784) idx: (1) gradient_shape: (512,)  
idx: (2) gradient_shape: (512, 512) idx: (2) gradient_shape: (512,)  
idx: (3) gradient_shape: (10, 512) idx: (3) gradient_shape: (10,)  


In [68]:
import math
num_hidden = 10
input_size = 1
output_size = 1
key = random.PRNGKey(420)
num_instance = 1000

# Weights for Shallow Net
W1 = random.normal(key, (input_size, num_hidden))
W2 = random.normal(key, (num_hidden, output_size))

# print(W1.shape)

# print(type(random.split(key, 3)))
# Input Data
x_tr = jnp.arange(0, stop=math.pi / 2, step= math.pi / (2 * num_instance ), dtype=jnp.float32)
x_tr = x_tr.reshape((-1, 1))
# Expected Output Labels
y_tr = jnp.sin(x_tr)


print(x_tr.shape)
print(y_tr.shape)
# @jit

print(x_tr[5])
print(y_tr[5])
(x_tr * y_tr)[3]


(1000, 1)
(1000, 1)
[0.00785398]
[0.0078539]


Array([2.2206528e-05], dtype=float32)