<a href="https://colab.research.google.com/github/divya28jain/Training-a-neural-network-with-tensorflow/blob/main/Training_a_neural_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))

In [3]:
from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

In [4]:
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [5]:
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')

Invalid shapes!


In [6]:
batched_predict = vmap(predict, in_axes=(None, 0))
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


In [7]:
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [8]:
import tensorflow as tf
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)



Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /tmp/tfds/mnist/3.0.1...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /tmp/tfds/mnist/incomplete.AAZQQE_3.0.1/mnist-train.tfrecord*...:   0%|          | 0/60000 [00:00<?,…

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /tmp/tfds/mnist/incomplete.AAZQQE_3.0.1/mnist-test.tfrecord*...:   0%|          | 0/10000 [00:00<?, …

Dataset mnist downloaded and prepared to /tmp/tfds/mnist/3.0.1. Subsequent calls will reuse this data.


In [9]:
import time

def get_train_batches():
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  ds = ds.batch(batch_size).prefetch(1)
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 5.49 sec
Training set accuracy 0.9226000308990479
Test set accuracy 0.9249999523162842
Epoch 1 in 4.07 sec
Training set accuracy 0.9415667057037354
Test set accuracy 0.9399999976158142
Epoch 2 in 4.47 sec
Training set accuracy 0.9513500332832336
Test set accuracy 0.9487999677658081
Epoch 3 in 3.57 sec
Training set accuracy 0.9586166739463806
Test set accuracy 0.9550999999046326
Epoch 4 in 3.73 sec
Training set accuracy 0.9637666940689087
Test set accuracy 0.9580000042915344
Epoch 5 in 4.57 sec
Training set accuracy 0.968250036239624
Test set accuracy 0.9610999822616577
Epoch 6 in 5.21 sec
Training set accuracy 0.9716500043869019
Test set accuracy 0.9631999731063843
Epoch 7 in 5.21 sec
Training set accuracy 0.9742833375930786
Test set accuracy 0.9651999473571777
Epoch 8 in 3.98 sec
Training set accuracy 0.9764666557312012
Test set accuracy 0.9674999713897705
Epoch 9 in 3.67 sec
Training set accuracy 0.9790500402450562
Test set accuracy 0.9690999984741211
