<a href="https://colab.research.google.com/github/maxmatical/jax_projects/blob/master/02_jax_neural_networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import jax
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

defining hyperparams for NN

In [0]:
def random_layer_params(m, n, key, scale = 1e-2):
  # initialize the weights and bias for a hidden layer
  w, b = random.split(key)
  return scale*random.normal(w, (n,m)), scale*random.normal(b, (n,))

def init_nn_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

# definig parameters of the nn

szs = [28*28, 512, 256, 10] # flattened input will be 28*28
param_scale = 0.1
lr = 0.01
n_epochs = 10
bs = 128
n_classes = 10

params = init_nn_params(szs, random.PRNGKey(0))

defining NN

In [0]:
from jax.scipy.special import logsumexp

def relu(x):
  return np.maximum(0, x)

def net(params, input):
  # forward pass for a sinlge input
  out = input
  for w, b in params[:-1]: # go through hidden layers until classifier layer
    out = np.dot(w, out) + b
    out = relu(out)
  w_clas, b_clas = params[-1]
  logits = np.dot(w_clas, out)+b_clas
  # return logits
  return logits -logsumexp(logits)

checking to see if it works on a single sample

In [4]:
rand_input = random.normal(random.PRNGKey(1), (28*28,))
preds = net(params, rand_input)
print(preds.shape)

(10,)


using `vmap` to generate batch predictions

In [5]:
forward_batch = vmap(net, in_axes = (None, 0))
rand_input_batch = random.normal(random.PRNGKey(1), (128, 28*28,))
batch_preds = forward_batch(params, rand_input_batch)
print(batch_preds.shape)

(128, 10)


defining metrics and loss function

a cross_entropy loss would look something like

``` 
def cross_entropy_loss(logits, labels):
    log_softmax_logits = jax.nn.log_softmax(logits)
    num_classes = log_softmax_logits.shape[-1]
    one_hot_labels = common_utils.onehot(labels, num_classes)
    return -jnp.sum(one_hot_labels * log_softmax_logits) / labels.size
```

In [0]:
def one_hot(x, k, dtype = np.float32):
  # create 1-hot encoding of size k
  return np.array(x[:, None] == np.arange(k), dtype)


def accuracy(params, input, targets):
  target_class = np.argmax(targets, axis=1)
  pred_class = np.argmax(forward_batch(params, input), axis = 1)
  return np.mean(pred_class==target_class)
  
# def loss(params, x, targets):
#   preds = forward_batch(params, x)
#   return -np.mean(preds*targets)

"""

when the output is only logits

"""


def loss(params, input, targets): # cross entropy loss for logits
  logits = forward_batch(params, input)
  preds = jax.nn.log_softmax(logits)
  return -np.mean(targets*preds)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w-lr*grad_w, b-lr*grad_b) 
          for (w, b), (grad_w, grad_b) in zip(params, grads)]



Using TF Dataloader

In [0]:
import tensorflow_datasets as tfds
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
data_dir = '/tmp/tfds'

mnist, info = tfds.load(name="mnist", batch_size =-1, data_dir=data_dir, with_info = True)
mnist = tfds.as_numpy(mnist)
train, test = mnist['train'], mnist['test']
n_classes = info.features['label'].num_classes
h,w,c = info.features['image'].shape



In [0]:
# reshaping data into 1 contiguous vector
# and one hot encoding labels

train_img, train_label = train['image'], train['label']
train_img = np.reshape(train_img, (len(train_img), h*w*c))
train_label = one_hot(train_label, n_classes)

test_img, test_label = test['image'], test['label']
test_img = np.reshape(test_img, (len(test_img), h*w*c))
test_label = one_hot(test_label, n_classes)

In [9]:
# validate
print('Train:', train_img.shape, train_label.shape)
print('Test:', test_img.shape, test_label.shape)

Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)


training NN

In [10]:
import time

def get_train_batches():
  ds = tfds.load(name="mnist", split="train", as_supervised=True, data_dir=data_dir) # as_supervised = True returns tuple instead of a dict
  ds = ds.batch(bs).prefetch(1)
  return tfds.as_numpy(ds)

best_acc = 0

for epoch in range(n_epochs):
  start = time.time()
  for x, y, in get_train_batches():
    x = np.reshape(x, (len(x), h*w*c))
    y = one_hot(y, n_classes)
    params = update(params, x, y)
  epoch_time = time.time()-start
  train_loss = loss(params, train_img, train_label)
  test_loss = loss(params, test_img, test_label)
  train_acc = accuracy(params, train_img, train_label)
  test_acc = accuracy(params, test_img, test_label)

  # printing statistics
  print(f"epoch {epoch+1} completed in {epoch_time} seconds")
  print(f"train loss: {train_loss}; training accuracy: {train_acc}")
  print(f"val loss: {test_loss}; val accuracy: {test_acc}")

  # saving best params
  if test_acc > best_acc:
    best_params = params
    print(f"better model found at epoch {epoch+1}, saving model")


epoch 1 completed in 10.527390718460083 seconds
train loss: 0.029938776046037674; training accuracy: 0.9151999950408936
val loss: 0.028568075969815254; val accuracy: 0.9201000332832336
better model found at epoch 1, saving model
epoch 2 completed in 8.731337308883667 seconds
train loss: 0.023019123822450638; training accuracy: 0.9347833395004272
val loss: 0.022458553314208984; val accuracy: 0.9355000257492065
better model found at epoch 2, saving model
epoch 3 completed in 9.47747278213501 seconds
train loss: 0.019159628078341484; training accuracy: 0.9455500245094299
val loss: 0.019136376678943634; val accuracy: 0.9463000297546387
better model found at epoch 3, saving model
epoch 4 completed in 9.177500486373901 seconds
train loss: 0.016445571556687355; training accuracy: 0.9531166553497314
val loss: 0.01683495007455349; val accuracy: 0.9517000317573547
better model found at epoch 4, saving model
epoch 5 completed in 9.285526990890503 seconds
train loss: 0.01440518070012331; training 

Additional utilities

In [13]:
# saving params

best_params # list of length 3

3