<a href="https://colab.research.google.com/github/igorvere/toys/blob/master/rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip install --upgrade -q jax jaxlib

# Colab runtime set to TPU accel
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# TPU driver as backend for JAX
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)


grpc://10.108.67.42:8470


In [7]:
!pip install --upgrade -q jax jaxlib

[K     |████████████████████████████████| 491kB 12.4MB/s 
[K     |████████████████████████████████| 33.3MB 95kB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone


In [8]:
import torch

class RNN(torch.nn.Module):
  def __init__(self, input_size):
    super(RNN, self).__init__()
    self.rnn = torch.nn.GRU(input_size, input_size, 1)
    self.decoder = torch.nn.Linear(input_size, 1)

  def forward(self, x, hidden):
    h0  = hidden
    fit, hidden = self.rnn(data, hidden)
    fit = self.decoder(fit)
    return fit, hidden

In [9]:
hidden_size, batch_size = 64, 128
seqlen = int(1e3)
h0 =  torch.zeros(1, batch_size, hidden_size)
data = torch.normal(0, 1, size=(seqlen, batch_size, hidden_size))
target = torch.normal(0, 1, size=(seqlen, batch_size, 1))
net = RNN(hidden_size)

if torch.cuda.device_count() > 0:
  net = net.cuda()
  h0 = h0.cuda()
  data = data.cuda()
  target = target.cuda()

In [10]:
%%time
for _ in range(10):
  net.zero_grad()
  fit, ht = net(data, h0)
  loss = torch.sum((fit - target)**2)
  loss.backward()

CPU times: user 412 ms, sys: 51.6 ms, total: 463 ms
Wall time: 466 ms


In [11]:
from jax.nn import sigmoid
from jax.nn.initializers import glorot_normal, normal

from functools import partial
from jax import lax
from jax import device_put, devices

def GRU(out_dim, W_init=glorot_normal(), b_init=normal()):
    def init_fun(rng, input_shape):
        """ Initialize the GRU layer for stax """
        #print('input_shape', input_shape)
        hidden = b_init(rng, (input_shape[1], out_dim))

        k1, k2, k3 = random.split(rng, num=3)
        update_W, update_U, update_b = (
            W_init(k1, (input_shape[2], out_dim)),
            W_init(k2, (out_dim, out_dim)),
            b_init(k3, (out_dim,)),)

        k1, k2, k3 = random.split(rng, num=3)
        reset_W, reset_U, reset_b = (
            W_init(k1, (input_shape[2], out_dim)),
            W_init(k2, (out_dim, out_dim)),
            b_init(k3, (out_dim,)),)

        k1, k2, k3 = random.split(rng, num=3)
        out_W, out_U, out_b = (
            W_init(k1, (input_shape[2], out_dim)),
            W_init(k2, (out_dim, out_dim)),
            b_init(k3, (out_dim,)),)
        # Input dim 0 represents the batch dimension
        # Input dim 1 represents the time dimension (before scan moveaxis)
        output_shape = (input_shape[0], input_shape[1], out_dim)
        return (output_shape,
            (hidden,
             (update_W, update_U, update_b),
             (reset_W, reset_U, reset_b),
             (out_W, out_U, out_b),),)

    def apply_fun(params, inputs, **kwargs):
        """ Loop over the time steps of the input sequence """
        h = params[0]

        def apply_fun_scan(params, hidden, inp):
            """ Perform single step update of the network """
            _, (update_W, update_U, update_b), (reset_W, reset_U, reset_b), (
                out_W, out_U, out_b) = params

            update_gate = sigmoid(np.dot(inp, update_W) +
                                  np.dot(hidden, update_U) + update_b)
            reset_gate = sigmoid(np.dot(inp, reset_W) +
                                 np.dot(hidden, reset_U) + reset_b)
            output_gate = np.tanh(np.dot(inp, out_W)
                                  + np.dot(np.multiply(reset_gate, hidden), out_U)
                                  + out_b)
            output = np.multiply(update_gate, hidden) + np.multiply(1-update_gate, output_gate)
            hidden = output
            return hidden, hidden

        # Move the time dimension to position 0
        #nputs = np.moveaxis(inputs, 1, 0)
        f = partial(apply_fun_scan, params)
        _, h_new = lax.scan(f, h, inputs)
        return h_new

    return init_fun, apply_fun

import jax.numpy as np
from jax.experimental import stax
from jax.experimental.stax import ( Dense, )

from jax import random
from jax import grad, jit, vmap, value_and_grad


# Generate key which is used to generate random numbers
key = random.PRNGKey(1)

num_dims = hidden_size              # Number of OU timesteps
num_hidden_units = hidden_size      # GRU cells in the RNN layer

# Initialize the network and perform a forward pass
init_fun, gru_rnn = stax.serial(GRU(num_hidden_units), Dense(1))

_, params = init_fun(key, (1, batch_size, num_dims))

def mse_loss(params, inputs, targets):
    """ Calculate the Mean Squared Error Prediction Loss. """
    preds = gru_rnn(params, inputs)
    return np.sum((preds - targets)**2)

#@jit()
def update(params, x, y):
    """ Perform a forward pass, calculate the MSE & perform a SGD step. """
    loss, grads = value_and_grad(mse_loss)(params, x, y)
    #opt_state = opt_update(0, grads, opt_state)
    return loss, grads
updatejit = jit(update, device=devices()[0])
datajax = device_put(data.cpu().numpy(), device=devices()[0])
targetjax = device_put(target.cpu().numpy(), device=devices()[0])
import numpy as onp
from jax import device_put

#datacpu = onp.moveaxis(datacpu, 1, 0)
#targetcpu = onp.moveaxis(targetcpu, 1, 0)


%time loss, grads = updatejit(params, datajax, targetjax)

CPU times: user 723 ms, sys: 33.9 ms, total: 757 ms
Wall time: 752 ms


In [12]:
%%time 
for i in range(10):
  loss, grads = updatejit(params, datajax, targetjax)
  loss.block_until_ready()  
  grads[0][0].block_until_ready()  

CPU times: user 2.36 s, sys: 13 ms, total: 2.38 s
Wall time: 2.37 s


In [17]:
import jax
jax.devices()

[CpuDevice(id=0)]

In [18]:
jax.devices()[0]

CpuDevice(id=0)