In [1]:
import autograd
import autograd.numpy as np
from autograd import grad
from autograd import checkpoint
from autograd.extend import primitive

import numpy as onp
from time import time

%load_ext memory_profiler

from builtins import range, list as ag_list, tuple as ag_tuple
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import grad
from autograd.scipy.misc import logsumexp
from os.path import dirname, join
from autograd.misc.optimizers import adam

In [2]:
def create_rnn_params(input_size, state_size, output_size,
                      param_scale=0.01, rs=npr.RandomState(0)):
    return {'change': rs.randn(input_size + state_size + 1, state_size) * param_scale,
            'predict': rs.randn(state_size + 1, output_size) * param_scale,
            'init hiddens': rs.randn(1, state_size) * param_scale,}

def sigmoid(x):
    return 0.5*(np.tanh(x) + 1.0)   # Output ranges from 0 to 1.

def hiddens_to_output_probs(theta, hiddens):
    output = concat_and_multiply(theta['predict'], hiddens)
    return output - logsumexp(output, axis=1, keepdims=True)

def concat_and_multiply(weights, *args):
    cat_state = np.hstack(args + (np.ones((args[0].shape[0], 1)),))
    return np.dot(cat_state, weights)

input_size = 64
state_size = 16
output_size = 16

batch_size = 64
seq_len = 512
num_checkpoints = 64

theta = create_rnn_params(input_size, state_size, output_size)

np.random.seed(0)
inputs = [np.random.randn(batch_size, input_size) for _ in range(seq_len)]

In [3]:
from autograd.differential_operators import binomial_checkpoint
from autograd.builtins import list as ag_list, tuple as ag_tuple

def rnn(theta, init_ch, x):    
    def update_rnn(x, hiddens):
        return np.tanh(concat_and_multiply(theta['change'], x, hiddens))
    init_cell = init_ch[0]
    init_hidden = init_ch[1]
    
    cell = update_rnn(x, init_cell)
    hidden = hiddens_to_output_probs(theta, cell)
    
    return ag_tuple((cell, hidden))

In [4]:
def rnn_predict(params, inputs):
    num_sequences = inputs[0].shape[0]
    
    hidden_single = npr.RandomState(0).randn(1, state_size) * .01
    hidden = np.repeat(hidden_single, num_sequences, axis=0)  
    output = hiddens_to_output_probs(params, hidden)
    
    outputs = [output]

    for input in inputs:  # Iterate over time steps.
        hidden, output = rnn(params, (hidden, output), input)
        outputs.append(output)
    
    return outputs

f = lambda theta: rnn_predict(theta, inputs)

In [5]:
f = lambda theta: rnn_predict(theta, inputs)
g = lambda theta: np.sum(sum(f(theta)))

x2 = g(theta)

start = time()
%memit g2 = grad(g)(theta)
end = time()

print("elapsed time: ", end - start)

peak memory: 137.24 MiB, increment: 45.96 MiB
elapsed time:  0.6979401111602783


In [6]:
loop = binomial_checkpoint(rnn, seq_len, num_checkpoints)

def rnn_predict(params, inputs):
    num_sequences = inputs[0].shape[0]
    
    hidden_single = npr.RandomState(0).randn(1, state_size) * .01
    hidden = np.repeat(hidden_single, num_sequences, axis=0)
    output = hiddens_to_output_probs(params, hidden)
        
    outputs = loop(params, ag_tuple((hidden, output)), inputs)
    return outputs

In [7]:
f = lambda theta: rnn_predict(theta, inputs)
g = lambda theta: np.sum(sum(h for c, h in f(theta)))

x1 = g(theta)

start = time()
%memit g1 = grad(g)(theta)
end = time()

print("elapsed time: ", end - start)

peak memory: 142.50 MiB, increment: 5.51 MiB
elapsed time:  3.488611936569214


In [8]:
# check correctness of gradients

print(x1 - x2)

for key in g1:
    print(key, np.sum((g1[key] - g2[key])**2))

0.0
change 5.968321299952532e-28
predict 5.811825116546847e-25
init hiddens 0.0
