# Introduction to Colab, Jax, haiku
Authors: Viorica Patraucean, David Szepesvari

Contact: vpatrauc@gmail.com

Thanks to Carl Doersch and Stanislaw Jastrzebski for proofreading and advice.

## What is Colab?

[Colaboratory](https://colab.sandbox.google.com/notebooks/welcome.ipynb) is a [Jupyter](http://jupyter.org/) notebook environment that requires no setup to use. It allows you to create and share documents that contain

* Live, runnable code
* Visualizations
* Explanatory text

It's also a great tool for prototyping and quick development. Let's give it a try. 

Run the following so-called *(Code) Cell* by moving the cursor into it, and either

* Pressing the "play" icon on the left of the cell, or
* Hitting **`Shift + Enter`**.

In [None]:
print('Hello, onlineEEML2020!')

You should see the `Hello, onlineEEML2020!` printed under the code.

The code is executed on a virtual machine dedicated to your account, with the results sent back to your browser. This has some positive and negative consequences.

### Using a GPU

You can connect to a virtual machine with a GPU. To select the hardware you want to use, follow either

* **Edit > Notebook settings**, or
* **Runtime > Change runtime type**

and choose an accelerator.

### Losing Connection

You may lose connection to your virtual machine. The two most common causes are

* Virtual machines are recycled when idle for a while, and have a maximum lifetime enforced by the system.
* Long-running background computations, particularly on GPUs, may be stopped.

**If you lose connection**, the state of your notebook will also be lost. You will need to **rerun all cells** up to the one you are currently working on. To do so

1. Select (place the cursor into) the cell you are working on. 
2. Follow **Runtime > Run before**.

### Pretty Printing by colab
1) If the **last operation** of a given cell returns a value, it will be pretty printed by colab.


In [None]:
6 * 7

In [None]:
my_dict = {'one': 1, 'some set': {4, 2, 2}, 'a regular list': range(5)}

There is no output from the second cell, as assignment does not return anything.

2) You can explicitly **print** anything before the last operation, or **supress** the output of the last operation by adding a semicolon.

In [None]:
print(my_dict)
my_dict['one'] * 10 + 1;

### Scoping and Execution Model

Notice that in the previous code cell we worked with `my_dict`, while it was defined in an even earlier cell.

1) In colabs, variables defined at cell root have **global** scope.

Modify `my_dict`:

In [None]:
my_dict['I\'ve been changed!'] = True

2) Cells can be **run** in any **arbitrary order**, and global state is maintained between them.

Try re-running the cell where we printed `my_dict`. You should see now  see the additional item `"I've been changed!": True`.

3) Unintentionally reusing a global variable can lead to bugs. If all else fails, you can uncomment and run the following line to **clear all global variables** and run again all the cells.

In [None]:
# %reset -f

### Autocomplete / Documentation

* Press *`<TAB>`* after typing a prefix will show the available variables / commands.
* Press *`<TAB>`* on a function parameter list will show the function documentation.

Note: this only works for variables that are already been defined (not while you are writing your code).

### Setup and Imports

Python packages can and need to be imported into your colab notebook, the same way you would import them in a python script. For example, to use `numpy`, you would do

In [None]:
# import numpy as np

While many packages can just be imported, some (e.g. `haiku`, a neural network library from DeepMind) may not be prepackaged in the runtime. With Colab, you can install any python package from `pip` for the duration of your connection.

In [None]:
# we will use haiku on top of jax 
# !pip install -q dm-haiku
# import haiku as hk

### Forms

With colab it is easy to take input from the user in code cells through so called forms. A simplest example is shown below.

In [None]:
#@title This text shows up as a title.

a = 2  #@param {type: 'integer'}
b = 3  #@param

print('a+b =', str(a+b))

You can change parameters on the right hand side, then rerun the cell to use these values. **Try setting the value of a=5 and rerun the cell above.**

In order to expose a variable as parameter you just add `#@param` after it. There are various knds of params, if you're interested you can read more about this on the official starting colab.

Cells with forms allow you to toggle whether

* the code,
* the form,
* or both

are visible.

**Try switching between these 3 options for the above cell.** This is how you do this:

1. Click anywhere over the area of the cell with the form to highlight it.
2. Click on the "three vertically arranged dots" icong in the top right of the cell.
3. Go to "Form >", select your desired action.

## JAX
[JAX](https://jax.readthedocs.io/en/latest/jax.html) allows NumPy-like code to execute on CPU, or accelerators like GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

- Jax automatically differentiates python code and NumPy code (with [Autograd](https://github.com/hips/autograd))
- uses [XLA](https://www.tensorflow.org/xla) to compile and run NumPy code efficiently on accelerators

### JAX and random number generators (WIP)
Unlike many ML frameworks, JAX does not hide the pseudo-random number generator state. You need to generate explicitely a random key, and pass it to the operations that work with random numbers (e.g. initialising a model, dropout etc). A call to a random function with the same key does not change the state of the generator. This has to be done explicitely with `split()` or `next_rng_key()` in `haiku`.

In [None]:
import numpy as np
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(0)
x1 = random.normal(key, (3,))
print(x1)
x2 = random.normal(key, (3,))
print(x2)

In [None]:
# Let's split the key to be able to generate different random values
key, new_key = random.split(key)
x1 = random.normal(key, (3,))
print (x1)
x2 = random.normal(new_key, (3,))
print (x2)

### JAX program transformations with examples 
* `jit` (just-in-time compilation) -- speeds up your code by running all the ops inside the jit-ed function as a *fused* op; it compiles the function when it's called the first time and uses the compiled (optimised) version from the second call onwards.
* `grad` -- returns derivatives of function with respect to the model weights passed as parameters
* `vmap` -- automatic batching; returns a new function that can apply the original (per-sample) function to a batch.



In [None]:
from jax import grad, jit
# Let's use jit to speed up a function
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

# execute the function without jit
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()   # block_until_ready is needed as jax, by default, runs operations asyncronously

# Execute the function with jit and compare timing with above -- it should be much faster
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

In [None]:
# Let's use grad to compute gradient of a simple function
def simple_fun(x):
  return jnp.sin(x) / x

# Get the gradient of simple_fun with respect to x
grad_simple_fun = grad(simple_fun)

# We can also get higher order derivatives, e.g. Hessian
grad_grad_simple_fun = grad(grad(simple_fun))

In [None]:
# Let's plot the result
import matplotlib.pyplot as plt
x_range = jnp.arange(-8, 8, .1)
plt.plot(x_range, simple_fun(x_range), 'b')
plt.plot(x_range, [grad_simple_fun(xi) for xi in x_range], 'r')
plt.plot(x_range, [grad_grad_simple_fun(xi) for xi in x_range], '--g')
plt.show()

In [None]:
from jax import vmap
# Let's see how vmap can be used to vectorize computations efficiently
# In the example above, we can use vmap instead of loop to compute gradients

grad_vect_simple_fun = vmap(grad_simple_fun)(x_range)

# plot again and check that the gradients are identical 
plt.plot(x_range, simple_fun(x_range), 'b')
plt.plot(x_range, [grad_simple_fun(xi) for xi in x_range], 'r')
plt.plot(x_range, grad_vect_simple_fun, 'oc', mfc='none')
plt.show()


In [None]:
# Let's time them!

# naive batching
def naively_batched(x):
  return jnp.stack([grad_simple_fun(xi) for xi in x])

# manual batching with jit
@jit
def manual_batched(x):
  return jnp.stack([grad_simple_fun(xi) for xi in x])

# Batching using vmap and jit
@jit
def vmap_batched(x):
  return vmap(grad_simple_fun)(x)

print ('Naively batched')
%timeit naively_batched(x_range).block_until_ready()
print ('jit batched')
%timeit manual_batched(x_range).block_until_ready()
print ('With jit vmap')
%timeit vmap_batched(x_range).block_until_ready()

### Read the doc for [common gotchas](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in JAX!

## Haiku -- object-oriented neural network library on top of JAX
Notable functions / entities
* `hk.Module` base class: implement your own modules by deriving from it
* `hk.transform`: convert non-pure (objects) functions into pure functions; returns a pair of pure functions `init` and `apply`.
* `hk.next_rng_key()`: returns a unique random key


### Example: Train MLP classifier on MNIST

In [None]:
import contextlib
from typing import Any, Mapping, Generator, Tuple 

# we will use haiku on top of jax 
!pip install -q dm-haiku
import haiku as hk

import jax
from jax.experimental import optix  # package for optimisers
import jax.numpy as jnp
import numpy as np
import enum

# Dataset library
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

# Plotting library.
from matplotlib import pyplot as plt
import pylab as pl
from IPython import display

# Don't forget to select GPU runtime environment in Runtime -> Change runtime type
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

# define some useful types
OptState = Any
Batch = Mapping[str, np.ndarray]

### Define the dataset: MNIST

In [None]:
# We use TF datasets; JAX does not support data loading or preprocessing.
NUM_CLASSES = 10  # MNIST has 10 classes, corresponding to the different digits.
def load_dataset(
    split: str,
    *,
    is_training: bool,
    batch_size: int,
) -> Generator[Batch, None, None]:
  """Loads the dataset as a generator of batches."""
  ds = tfds.load('mnist:3.*.*', split=split).cache().repeat()
  if is_training:
    ds = ds.shuffle(10 * batch_size, seed=0)
  ds = ds.batch(batch_size)
  return tfds.as_numpy(ds)

In [None]:
# Function to display images
MAX_IMAGES = 10
def gallery(images, label, title='Input images'):  
  class_dict = [u'zero', u'one', u'two', u'three', u'four', u'five', u'six', u'seven', u'eight', u'nine']
  num_frames, h, w, num_channels = images.shape
  num_frames = min(num_frames, MAX_IMAGES)
  ff, axes = plt.subplots(1, num_frames,
                          figsize=(30, 30),
                          subplot_kw={'xticks': [], 'yticks': []})
  if images.min() < 0:
    images = (images + 1.) / 2.
  for i in range(0, num_frames):
    if num_channels == 3:
      axes[i].imshow(np.squeeze(images[i]))
    else:
      axes[i].imshow(np.squeeze(images[i]), cmap='gray')
    axes[i].set_title(class_dict[label[i]], fontsize=28)
    plt.setp(axes[i].get_xticklabels(), visible=False)
    plt.setp(axes[i].get_yticklabels(), visible=False)
  ff.subplots_adjust(wspace=0.1)
  plt.show()

In [None]:
# Make datasets for train and test
train_dataset = load_dataset('train', is_training=True, batch_size=1000)
train_eval_dataset = load_dataset('train', is_training=False, batch_size=10000)
test_eval_dataset = load_dataset('test', is_training=False, batch_size=10000)

### Define classifier: a simple MLP

In [None]:
def net_fn(batch: Batch) -> jnp.ndarray:
  """Standard LeNet-300-100 MLP network."""
  # The images are in [0, 255], uint8; we need to convert to float and normalise 
  x = batch['image'].astype(jnp.float32) / 255.
  # We use hk.Sequential to chain the modules in the network
  mlp = hk.Sequential([
  # The input images are 28x28, so we first flatten them to apply linear (fully-connected) layers                     
      hk.Flatten(),  
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  return mlp(x)

### Retrieve pure functions for our model (`init`, `apply`) using `hk.transform`



In [None]:
# Since we don't store additional state statistics, e.g. needed in batch norm,
# we use `hk.transform`. When we use batch_norm, we will use `hk.transform_with_state`
net = hk.transform(net_fn)

### Define the optimiser

In [None]:
# We use Adam optimizer here. Others are possible, e.g. sgd with momentum.
lr = 1e-3
opt = optix.adam(lr)

### Define the optimisation objective

In [None]:
# Training loss: cross-entropy plus regularization weight decay
def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:
  """Compute the loss of the network, including L2 for regularization."""
  
  # Get network predictions
  logits = net.apply(params, batch)

  # Generate one_hot labels from index classes
  labels = jax.nn.one_hot(batch['label'], NUM_CLASSES)

  # Compute mean softmax cross entropy over the batch
  softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
  softmax_xent /= labels.shape[0]

  # Compute the weight decay loss by penalising the norm of parameters
  l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params))
  
  return softmax_xent + 1e-4 * l2_loss

### Evaluation metric

In [None]:
# Classification accuracy
@jax.jit
def accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray:
  # Get network predictions
  predictions = net.apply(params, batch)
  # Return accuracy = how many predictions match the ground truth
  return jnp.mean(jnp.argmax(predictions, axis=-1) == batch['label'])

### Define training step (parameters update)

In [None]:
@jax.jit
def update(
    params: hk.Params,
    opt_state: OptState,
    batch: Batch,
) -> Tuple[hk.Params, OptState]:
  """Learning rule (stochastic gradient descent)."""
  # Use jax transformation `grad` to compute gradients; 
  # it expects the prameters of the model and the input batch
  grads = jax.grad(loss)(params, batch)

  # Compute parameters updates based on gradients and optimiser state
  updates, opt_state = opt.update(grads, opt_state)

  # Apply updates to parameters
  new_params = optix.apply_updates(params, updates)
  return new_params, opt_state

### Initialise the model and the optimiser

In [None]:
# Initialize model and optimiser; note that a sample input is needed to compute
# shapes of parameters

# Draw a data batch
batch = next(train_dataset)
# Initialize model
params = net.init(jax.random.PRNGKey(42), batch)
#Initialize optimiser
opt_state = opt.init(params)

### Visualise data and parameter shapes

In [None]:
# Display shapes and images
print(batch['image'].shape)
print(batch['label'].shape)
gallery(batch['image'], batch['label'])

In [None]:
# Let's see how many parameters in our network and their shapes
def get_num_params(params: hk.Params):
  num_params = 0
  for p in jax.tree_leaves(params): 
    print(p.shape)
    num_params = num_params + jnp.prod(p.shape)
  return num_params
print('Total number of parameters %d' % get_num_params(params))

### Accuracy of the untrained model (should be ~10%)

In [None]:
# Run accuracy on the test dataset
test_accuracy = accuracy(params, next(test_eval_dataset))
print('Test accuracy %f '% test_accuracy)

In [None]:
# Let's visualise some network predictions before training; if some are correct,
# they are correct by chance.
predictions = net.apply(params, batch)
pred_labels = jnp.argmax(predictions, axis=-1)
gallery(batch['image'], pred_labels)

### Run one training step

In [None]:
# First, let's do one step and check if the updates lead to decrease in error
loss_before_train = loss(params, batch) 
print('Loss before train %f' % loss_before_train)
params, opt_state = update(params, opt_state, batch)
new_loss = loss(params, next(train_dataset))
new_loss_same_batch = loss(params, batch)
print('Loss after one step of training, same batch %f, different batch %f' % (new_loss_same_batch, new_loss))

### Run training steps in a loop. We also run evaluation periodically.

In [None]:
# Train/eval loop.
for step in range(5001):
  if step % 1000 == 0:
    # Periodically evaluate classification accuracy on train & test sets.
    train_accuracy = accuracy(params, next(train_eval_dataset))
    test_accuracy = accuracy(params, next(test_eval_dataset))
    train_accuracy, test_accuracy = jax.device_get(
        (train_accuracy, test_accuracy))
    print('Step %d Train / Test accuracy: %f / %f' % (step, train_accuracy, test_accuracy))

  # Do SGD on a batch of training examples.
  params, opt_state = update(params, opt_state, next(train_dataset))

### Visualise network predictions after training; most of the predictions should be correct.

In [None]:
# Get predictions for the same batch
predictions = net.apply(params, batch)
pred_labels = jnp.argmax(predictions, axis=-1)
gallery(batch['image'], pred_labels)