<a href="https://colab.research.google.com/github/ninikvn/hackathon-project/blob/main/pytorch_template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Workshop: Pytorch and Jax

### Overview

In this workshop we are going to explore two other libraries that are commonly used for Deep Learning, Pytorch and Jax.

We will start by looking at Pytorch which is on the whole very similar to Tensorflow. We will highlight some of the differences and advantages of each then talk about some of the extras that come with Pytorch.

Jax on the other hand is not explicitly a deep learning library, but rather, is a *scientific computing library with auto-differientation.* The key to Jax is that it is fast, **really fast.** Jax is more similar to a super numpy than a Tensorflow or PyTorch

___
## Tensorflow v. Pytorch

Let's start out by taking a look at some familar Tensorflow code, then its PyTorch equivalent

In [None]:
# Start by importing all the libraries we need
import torch
import torch.nn as nn
import torch.nn.functional as F
# DataLoaders and Datasets allow us to easily preprocess data, batch, etc.
from torch.utils.data import DataLoader, Dataset

# torchvision has some preloaded datasets and image transformations
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# Standard imports we will use
import numpy as np
import tensorflow as tf
import numpy as np
from tqdm import tqdm


First, let's load and preprocess our data using Tensorflow

In [None]:
# Loading in dataset; split into training (X0,Y0) and testing (X1,Y1)
mnist = tf.keras.datasets.mnist
(train_inputs, train_labels_clean), (test_inputs, test_labels_clean) = mnist.load_data()
assert train_inputs.shape == (60000, 28, 28) #width and height of image
assert test_inputs.shape == (10000, 28, 28)
assert train_labels_clean.shape == (60000,) #just labels, digits 0-9
assert test_labels_clean.shape == (10000,)

# Normalize inputs to [0, 1] and make sure they're float32
X0 = (train_inputs / 255.).astype(np.float32)
X1 = (test_inputs / 255.).astype(np.float32)

# Make labels one hot vectors
def one_hot(arr, num_classes):
  """Convert array to a one-hot matrix
      Hint: We can use np.eye and index by our array"""
  return np.eye(num_classes)[arr]

train_labels = one_hot(train_labels_clean, 10).astype(np.float32)
test_labels = one_hot(test_labels_clean, 10).astype(np.float32)


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step


The following code creates and trains our Tensorflow model, it should look familiar to you if you've completed the mini-project

In [None]:
class Model(tf.keras.Model):

  def __init__(self, **kwargs):
    """
    The model class inherits from tf.keras.Model.
    It stores the trainable weights as attributes.
      Hint: Using Dense layers with output size 256, 128, 10 should work well
    """
    super(Model, self).__init__(**kwargs)

    self.layer1 = tf.keras.layers.Dense(256, activation="relu")
    self.layer2 = tf.keras.layers.Dense(128, activation="relu")
    self.layer3 = tf.keras.layers.Dense(10, activation="softmax")

  def call(self, inputs):
    """
    Forward pass, predicts labels given an input image using fully connected layers
    :return: the probabilites of each label
    """

    layer1Output = self.layer1(inputs)
    layer2Output = self.layer2(layer1Output)
    prbs = self.layer3(layer2Output)
    return prbs

  def loss_fn(self, predictions, labels):
    """
    Calculates the model loss
    :return: the loss of the model as a tensor
    """
    nll_comps = -labels * tf.math.log(tf.clip_by_value(predictions,1e-10,1.0))
    return tf.reduce_mean(tf.reduce_sum(nll_comps, axis=[1]))

  def accuracy(self, predictions, labels):
    """
    Calculates the model accuracy
    :return: the accuracy of the model as a tensor
    """
    pred_classes = tf.argmax(predictions, 1)
    true_classes = tf.argmax(labels, 1)
    correct_prediction = tf.equal(pred_classes, true_classes)
    return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

## END TODO
################################################################################

# Instantiate our model
model = Model()


# Choosing our optimizer
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)

# Loop through training steps
epochs = 10
batch_size = 1024
train_steps = len(train_inputs) // batch_size
for i in range(epochs):
  for j in tqdm(range(train_steps)):
    image = tf.reshape(train_inputs[j*batch_size:(j+1)*batch_size], (batch_size,-1))
    label = tf.reshape(train_labels[j*batch_size:(j+1)*batch_size], (batch_size,-1))
    # Implement backprop:
    with tf.GradientTape() as tape:
      y_pred = model(image) # this calls the call function conveniently
      label = tf.cast(label, tf.float32)
      loss = model.loss_fn(y_pred, label)

    # The keras Model class has the computed property trainable_variables to conveniently
    # return all the trainable variables you'd want to adjust based on the gradients

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  test_acc = model.accuracy(model(tf.reshape(test_inputs, (-1, 28*28))), test_labels)
  print(f"Accuracy on testing set after epoch {i}: {test_acc}")

print()
model.summary()

100%|██████████| 58/58 [00:05<00:00, 10.50it/s]


Accuracy on testing set after epoch 0: 0.5875999927520752


100%|██████████| 58/58 [00:00<00:00, 69.96it/s]


Accuracy on testing set after epoch 1: 0.6908000111579895


100%|██████████| 58/58 [00:00<00:00, 71.66it/s]


Accuracy on testing set after epoch 2: 0.7135000228881836


100%|██████████| 58/58 [00:00<00:00, 71.02it/s]


Accuracy on testing set after epoch 3: 0.7265999913215637


100%|██████████| 58/58 [00:00<00:00, 69.02it/s]


Accuracy on testing set after epoch 4: 0.7336999773979187


100%|██████████| 58/58 [00:00<00:00, 69.28it/s]


Accuracy on testing set after epoch 5: 0.7631000280380249


100%|██████████| 58/58 [00:00<00:00, 70.61it/s]


Accuracy on testing set after epoch 6: 0.8090000152587891


100%|██████████| 58/58 [00:00<00:00, 68.13it/s]


Accuracy on testing set after epoch 7: 0.817300021648407


100%|██████████| 58/58 [00:00<00:00, 66.22it/s]


Accuracy on testing set after epoch 8: 0.8228999972343445


100%|██████████| 58/58 [00:00<00:00, 68.54it/s]

Accuracy on testing set after epoch 9: 0.9045000076293945






Now we are going to look at the same code but using PyTorch instead of Tensforflow.

First, let's take a look at loading and preprocessing data with PyTorch

In [None]:

# pytorch has it's own methods of loading and batching data, which you can see below
""" Ordinarily the commented code below is all you need to load the MNIST data.
      For some reason, Yann Lecun.com is timing out so we instead make our own
      dataset to save time.
train_dataset = datasets.MNIST(root='./data', train=True, download=True,
                               transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True,
                              transform=transforms.ToTensor())
"""

class MNISTDataset(Dataset):
  """
    Every Pytorch Dataset needs an __init__, __len__, and __getitem__
    These methods are used to get and batch the data using a DataLoader later
  """
  def __init__(self, images, labels):
    self.images = torch.Tensor(images)
    self.labels = torch.Tensor(labels)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    return self.images[idx], self.labels[idx]


train_dataset = MNISTDataset(train_inputs, train_labels)
test_dataset = MNISTDataset(test_inputs, test_labels)


# dataloaders are an easy way to batch and shuffle datasets
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, shuffle=False)

Now that we have our data, let's create and train a PyTorch Model! It looks very similar to Tensorflow, with some small differences in naming conventions, etc.

In [None]:
class Model(torch.nn.Module):

  def __init__(self, **kwargs):
    """
    The model class inherits from tf.keras.Model.
    It stores the trainable weights as attributes.
    """
    super(Model, self).__init__(**kwargs)

    # Initialize our torch.nn.Linear layers again we use 256, 128, 10
    self.layer1 = torch.nn.Linear(784, 256)
    self.layer2 = torch.nn.Linear(256, 128)
    self.layer3 = torch.nn.Linear(128, 10)

    # PyTorch Linear Layers don't let you nicely initialize an activation function
    #   line TF does so we need to create these explicitly
    self.relu = torch.nn.ReLU()
    self.softmax = torch.nn.Softmax(dim=1)
  def forward(self, inputs):
    """
    Forward pass, predicts labels given an input image using fully connected layers
    :return: the probabilites of each label
    """

    out1 = self.layer1(inputs)
    out1 = self.relu(out1)
    out2 = self.layer2(out1)
    out2 = self.relu(out2)
    out3 = self.layer3(out2)
    prbs = self.softmax(out3)
    return prbs

  def loss(self, predictions, labels):
    """
    Calculates the model loss
    :return: the loss of the model as a tensor
    """
    nll_comps = -labels * torch.log(torch.clip(predictions,1e-10,1.0))
    return torch.mean(torch.sum(nll_comps, axis=[1]))

  def accuracy(self, predictions, labels):
    """
    Calculates the model accuracy
    :return: the accuracy of the model as a tensor
    """
    pred_classes = torch.argmax(predictions, 1)
    true_classes = torch.argmax(labels, 1)
    correct_prediction = torch.eq(pred_classes, true_classes)
    return torch.mean(torch.Tensor(correct_prediction).to(torch.float32))

## END TODO
################################################################################

# Instantiate our model
model = Model()

# Create our optimizer, notice that the parameters are passed into the init.
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Loop through training steps
epochs = 10

for j in range(epochs):
  for batch_idx, (input, label) in tqdm(enumerate(train_loader)):
    # There isn't a "GradientTape" context manager for torch
    #   Instead, torch Tensors have a backward method which backpropagates
    #   automatically. We will talk a little about some of these differences later

    input = torch.reshape(input, (len(input),-1))
    y_pred = model(input) # this calls the call function conveniently
    loss = model.loss(y_pred, label) # compute the loss
    loss.backward() # compute and assign the gradients via backprop
    optimizer.step() # update the parameters
    optimizer.zero_grad() # reset the stored gradients for each of the parameters (can also move this above the line that starts with input)

  test_acc = 0
  for batch_idx, (input, label) in enumerate(test_loader):
    input = torch.reshape(input, (len(input),-1))
    test_acc += model.accuracy(model(input), label)
  print(f"Accuracy on testing set after epoch {j}: {test_acc/len(test_loader)}")
print()
print(model)

# Different optimizer used here; tf basic but here Adam

59it [00:01, 40.03it/s]


Accuracy on testing set after epoch 0: 0.9364855885505676


59it [00:01, 53.38it/s]


Accuracy on testing set after epoch 1: 0.9549247026443481


59it [00:01, 50.60it/s]


Accuracy on testing set after epoch 2: 0.9593949317932129


59it [00:01, 52.18it/s]


Accuracy on testing set after epoch 3: 0.9654256701469421


59it [00:01, 43.21it/s]


Accuracy on testing set after epoch 4: 0.9686263799667358


59it [00:01, 52.84it/s]


Accuracy on testing set after epoch 5: 0.9700155258178711


59it [00:01, 51.78it/s]


Accuracy on testing set after epoch 6: 0.9710897207260132


59it [00:01, 49.31it/s]


Accuracy on testing set after epoch 7: 0.9715701341629028


59it [00:02, 27.35it/s]


Accuracy on testing set after epoch 8: 0.9713448286056519


59it [00:01, 39.87it/s]


Accuracy on testing set after epoch 9: 0.9723971486091614

Model(
  (layer1): Linear(in_features=784, out_features=256, bias=True)
  (layer2): Linear(in_features=256, out_features=128, bias=True)
  (layer3): Linear(in_features=128, out_features=10, bias=True)
  (relu): ReLU()
  (softmax): Softmax(dim=1)
)


Let's zoom into some of the differences between these two Model and training methods.

For one, in Tensorflow we subclass

```
class Model(tf.keras.Model)
```
and in PyTorch we have
```
class Model(torch.nn.Module)
```

Also in PyTorch, Tensorflow's ```call``` method is known as the ```forward``` method, but they both complete one forward pass.

Dense/Linear layers are initialized with ```torch.nn.Linear(in_dimension, out_dimension)```
instead of ```tf.keras.layers.Dense(out_dimension)```. In torch you have to specify the in_dimension, whereas tensorflow figures this out after you make your first call to the model. This small difference makes PyTorch typically faster for the first handful of iterations. That said, Tensorflow's backpropagation is generally faster so it catches up as the number of iterations grows.

Some of the other methods have small name changes, like ```tf.clip_by_value -> torch.clip.```

Perhaps the biggest difference between these two frameworks is how backpropagation is done. Notice that in Tensorflow we use
```

with tf.GradientTape() as tape:
  y_pred = model(image) # this calls the call function conveniently
  loss = model.loss(y_pred, label)

gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
```

But in PyTorch we have

```
# Reset the gradient of the trainable params to 0
optimizer.zero_grad()
y_pred = model(image) # this calls the call function conveniently
loss = model.loss(y_pred, label)

# Compute the gradients
loss.backward()
# Apply the gradients to the trainable parameters
optimizer.step()
```
Additionally, in PyTorch we initialize our optimizer *with the trainable parameters* instead of passing the trainable parameters into every backward pass like we do in Tensorflow.

## Pytorch Add-Ons and Niceties

Pytorch has a handful of additional libraries that contain useful content like **torchvision and torchaudio.** These libraries contain tons of helpful content to streamline data preprocessing, transformations and postprocessing. For more information, you can refer to [torchvision](https://pytorch.org/vision/0.20/) and [torchaudio](https://pytorch.org/audio/stable/index.html).

##Computational Graph
Static Graph (e.g., TensorFlow 1.x): The computation graph is defined before execution. Think of it as writing down a complete recipe first, and then following it to cook. You need to compile and run the graph using a session. Good for optimization, but harder to debug and less flexible for dynamic tasks.

Dynamic Graph (e.g., PyTorch): The computation graph is built on the fly during execution. Imagine cooking step by step, deciding what to do at each step. You can directly use Python constructs like loops and print statements. More intuitive and flexible, great for debugging and dynamic tasks.
### Dynamic Graph




In [None]:
# Importing torch
import torch

# Initializing input tensors
a = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(2.0, requires_grad=True)

# Computing the output
c = a * b
print('value of c: ', c)

# Displaying the outputs
print(f'c_out = {c}')

value of c:  tensor(2., grad_fn=<MulBackward0>)
c_out = 2.0


### Static Graph


In [None]:
# Importing tensorflow version 1
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

# Initializing placeholder variables of
# the graph
a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)

with tf.Graph().as_default():
  # Defining the operation
  c = tf.multiply(a, b)

with tf.Session() as sess:
  # Running the graph
  c_out = sess.run(c, feed_dict={a: 1.0, b: 2.0})

print(c_out)

Instructions for updating:
non-resource variables are not supported in the long term


2.0


In a static computation graph, you cannot use Python's print function to retrieve or display values within the graph during execution. This is because the graph defines operations symbolically, and values are only computed during a session run.

If you try to use print to inspect tensors directly in the graph, it will display the symbolic representation of the tensor (e.g., Tensor("add:0", shape=(), dtype=float32)) instead of the actual values.

In [None]:
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()  # Use static graph mode

# Define placeholders
x = tf.placeholder(tf.float32, shape=(None,))
y = tf.placeholder(tf.float32, shape=(None,))

# Example computation
z = x + y

# Add a print operation (does not modify the value of z)
print("Information from print: ", z)
z_print = tf.print('tf.Print: ', z)

# Create a session to execute the graph
with tf.Session() as sess:
    result, z_print = sess.run([z, z_print], feed_dict={x: [1.0, 2.0], y: [3.0, 4.0]})
    print("Result:", result)


Information from print:  Tensor("add:0", shape=(?,), dtype=float32)
tf.Print:  [4 6]
Result: [4. 6.]


In a static computation graph, **for loops** work differently compared to dynamic graphs. You cannot directly use Python's native for loops or if conditions when constructing the graph. Doing so may lead to unexpected behavior or errors because Python constructs are not part of the computation graph and won't execute as intended during graph execution.

Instead, if you're using TensorFlow 1.x, you need to use TensorFlow's built-in control flow operations, such as `tf.while_loop` for loops and tf.cond for conditions. These operations ensure that the control flow is properly represented within the graph.

In [None]:
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

# Example loop to compute the sum of numbers from 0 to 9
n = tf.constant(10)

# Loop variables
i = tf.constant(0)  # Initial index
total = tf.constant(0)  # Accumulator

# Define the loop condition and body
def condition(i, total):
    return i < n

def body(i, total):
    total += i
    i += 1
    return i, total

# Use tf.while_loop
final_i, final_total = tf.while_loop(condition, body, [i, total])

# Run the graph
with tf.Session() as sess:
    result = sess.run(final_total)
    print("Sum of numbers from 0 to 9:", result)


Sum of numbers from 0 to 9: 45


___

# Jax

JAX is a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.

By default, you can just use jax.numpy with all your favorite numpy operations as typical. Let's take a look at some examples of this.

In [None]:
import jax
import jax.numpy as jnp

In [None]:
def leaky_relu(x, alpha=0.01):
  return jnp.where(x > 0, x, alpha * x)

# We want an array of numbers [-2, -1, 0, 1, 2]. In numpy we have arange,
#    how about in jnp?
x = jnp.arange(-2, 3)
print(leaky_relu(x, alpha=0.1))

[-0.2 -0.1  0.   1.   2. ]


From the above, we can use standard numpy methods like where and arange in jax exactly as we would have in numpy.

One of Jax's most powerful features is "just-in-time" (jit) compilation. This feature allows jax to precompile the operations for python methods to compute the results very quickly. Let's take a look at how numpy's speed compares to jax and jax+jit.

In [None]:
import numpy as np

x = np.random.normal(0,1,1_000_000)

def numpy_leaky_relu(x, alpha=0.01):
  return np.where(x > 0, x, alpha * x)

%timeit numpy_leaky_relu(x, alpha=0.1)

9.43 ms ± 1.11 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
from jax import random

# initialize a pseudo-random number generation key.
key = random.key(1470)

# create a million standard normal random numbers
x = random.normal(key, (1_000_000))

# time how long it takes to run this function on our array
%timeit leaky_relu(x, alpha=0.1).block_until_ready()

562 µs ± 18.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [None]:
from jax import jit

jit_leaky_relu = jit(leaky_relu)

# we have to use the function once to let it compile with jit
_ = jit_leaky_relu(x, alpha=0.1)

%timeit jit_leaky_relu(x, alpha=0.1).block_until_ready()

243 µs ± 10.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Notice that in this case, base jax is about 10x faster than numpy and using jit cuts standard jax's time in half again! This is a pretty remarkable speed-up but it's only a slice of what jax can do.

Jax isn't always faster than numpy, but will generally outperform numpy in situations where using a GPU is faster than CPU (like deep learning). You can read more about the comparison [here](https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy).

Now let's see how we can differentiate functions in jax.

In [None]:
from jax import grad #computes gradients

def sum_of_squares(x):
  # in numpy we can use np.sum and np.square
  #   How about in jnp?
  return jnp.sum(jnp.square(x))

input = jnp.arange(5.)

# Here we are going to create a new function that computes the gradient
#    of the sum_of_squares function
grad_sum_squares = grad(sum_of_squares)

# output the value of the function
print(sum_of_squares(input))

# output the gradient at these points
print(grad_sum_squares(input))

30.0
[0. 2. 4. 6. 8.]


Just like that, we can differentiate functions with Jax!

Now we have everything we need to train a model in Jax like the one we had using PyTorch and Tensorflow. We should note here that there is a pretty significant difference in how Jax goes about training a network as compared to Tensorflow and PyTorch. It's a good exercise to think about all the steps that go into training a model, ignorant of what framework you use.

Let's start by intializing the model weights and hyperparamters.

In [None]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(in_size, out_size, key, scale=1e-2):
  w_key, b_key = random.split(key)

  # notice our weights are out_size x in_size, that's just for matmul shaping.
  #   This initialization is called the He/Kaiming normal initialization
  scale = (jnp.sqrt(2/jnp.sum(jnp.array([in_size, out_size]))))
  return scale*random.normal(w_key, (out_size, in_size)), scale * random.normal(b_key, (out_size,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 256, 128, 10]
learning_rate = 0.01
num_epochs = 10
batch_size = 1024
n_classes = 10
params = init_network_params(layer_sizes, random.key(0))

In [None]:
from jax.scipy.special import logsumexp

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b # we need the dot product and the bias addition
    activations = jit_leaky_relu(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b # what goes here?
  return logits - logsumexp(logits) # Nothing more than log(softmax) here

In [None]:
# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

# But when we try with batch it breaks
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')

(10,)
Invalid shapes!


The above example illustrates an important principle of Jax. You write the code to handle one example, then use Jax operations like vmap to generalize to the batched case. It takes some getting used to, but in many cases becomes a more natural coding experience.

Let's see how we can generalize our pipeline to handle batched inputs using vmap.

In [None]:
from jax import vmap

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


Now we can handle batches to predict our outputs. We have almost everything we need to train our network! Let's finish up with some utility functions and speed up the process using jit

In [None]:
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  # In numpy we used eye + indexing, what should we do in jnp?
  return jnp.eye(k, dtype=dtype)[x]

def accuracy(params, images, targets):
  """Compute the accuracy for a set of images and targets."""
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  """Compute the multi-class cross-entropy loss."""

  # here we want to get our predictions then compute
  preds = batched_predict(params, images)

  # Now we want to compute Categorical Cross Entropy then take the mean
  #    any guesses how we can compute that easily?
  #    hint: remember that we output log(prob) for each prediction
  #    hint: Categorical Cross Entropy = -sum(log(prob)*true_value)
  loss_value = -jnp.sum(preds*targets) # here
  return loss_value

@jit
def update(params, x, y):
  # This is standard SGD, nothing fancy here
  grads = grad(loss)(params, x, y) # Compute the gradient
  return [(w - learning_rate * dw, b - learning_rate * db)
          for (w, b), (dw, db) in zip(params, grads)] # update the weights

In [None]:
import time
# Ensure TF does not see GPU and grab all GPU memory.
#tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds
data_dir = '/tmp/tfds'


# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
jax_train_images, jax_train_labels = train_data['image'], train_data['label']
jax_train_images = jnp.reshape(jax_train_images, (len(jax_train_labels), num_pixels))
jax_train_labels = one_hot(jax_train_labels, num_labels)

# Full test set
jax_test_images, jax_test_labels = test_data['image'], test_data['label']
jax_test_images = jnp.reshape(jax_test_images, (len(jax_test_images), num_pixels))
jax_test_labels = one_hot(jax_test_labels, num_labels)


Now let's train out model and see how it does!

In [None]:
num_epochs=10
from tqdm import tqdm

for epoch in range(num_epochs):
  num_batches = len(jax_train_images)//batch_size
  start_time = time.time()
  for i in tqdm(range(num_batches)):
    # Get our batches
    x = jax_train_images[i*batch_size:(i+1)*batch_size]
    y = jax_train_labels[i*batch_size:(i+1)*batch_size]
    # update the parameters!
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, jax_train_images, jax_train_labels)
  test_acc = accuracy(params, jax_test_images, jax_test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

100%|██████████| 58/58 [00:00<00:00, 1009.92it/s]


Epoch 0 in 0.06 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027


100%|██████████| 58/58 [00:00<00:00, 1000.40it/s]


Epoch 1 in 0.06 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027


100%|██████████| 58/58 [00:00<00:00, 1179.73it/s]


Epoch 2 in 0.06 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027


100%|██████████| 58/58 [00:00<00:00, 1179.82it/s]


Epoch 3 in 0.05 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027


100%|██████████| 58/58 [00:00<00:00, 1039.99it/s]


Epoch 4 in 0.06 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027


100%|██████████| 58/58 [00:00<00:00, 1118.77it/s]


Epoch 5 in 0.06 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027


100%|██████████| 58/58 [00:00<00:00, 924.74it/s]


Epoch 6 in 0.07 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027


100%|██████████| 58/58 [00:00<00:00, 1142.28it/s]


Epoch 7 in 0.06 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027


100%|██████████| 58/58 [00:00<00:00, 949.29it/s]


Epoch 8 in 0.07 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027


100%|██████████| 58/58 [00:00<00:00, 1126.96it/s]


Epoch 9 in 0.06 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
