In [1]:
import autograd
import autograd.numpy as np
from autograd import grad
from autograd import checkpoint
from autograd.extend import primitive

import numpy as onp
from time import time

%load_ext memory_profiler

from builtins import range, list as ag_list, tuple as ag_tuple
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import grad
from autograd.scipy.misc import logsumexp
from os.path import dirname, join
from autograd.misc.optimizers import adam

In [128]:
def create_rnn_params(input_size, state_size, output_size,
                      param_scale=0.01, rs=npr.RandomState(0)):
    return {'change': rs.randn(input_size + state_size + 1, state_size) * param_scale,
            'predict': rs.randn(state_size + 1, output_size) * param_scale,
            'init hiddens': rs.randn(1, state_size) * param_scale,}

def sigmoid(x):
    return 0.5*(np.tanh(x) + 1.0)   # Output ranges from 0 to 1.

def hiddens_to_output_probs(theta, hiddens):
    output = concat_and_multiply(theta['predict'], hiddens)
    return output - logsumexp(output, axis=1, keepdims=True)

def concat_and_multiply(weights, *args):
    cat_state = np.hstack(args + (np.ones((args[0].shape[0], 1)),))
    return np.dot(cat_state, weights)

input_size = 64
state_size = 64
output_size = 64

batch_size = 64
seq_len = 512
num_checkpoints = 512

theta = create_rnn_params(input_size, state_size, output_size)

np.random.seed(0)
inputs = [np.random.randn(batch_size, input_size) for _ in range(seq_len)]

In [129]:
from autograd.differential_operators import binomial_checkpoint
from autograd.builtins import list as ag_list, tuple as ag_tuple

def rnn(theta, state, x):  
    return np.tanh(concat_and_multiply(theta['change'], x, state))

In [130]:
def rnn_predict(params, inputs):
    num_sequences = inputs[0].shape[0]
    hidden_single = npr.RandomState(0).randn(1, state_size) * .01
    hidden = np.repeat(hidden_single, num_sequences, axis=0)  
    
    outputs = [hiddens_to_output_probs(params, hidden)]
    
    for input in inputs:
        hidden = rnn(params, hidden, input)
        outputs.append(hiddens_to_output_probs(params, hidden))
    return outputs

loop = binomial_checkpoint(rnn, seq_len, num_checkpoints, hiddens_to_output_probs)

def rnn_predict_checkpointed(params, inputs):
    num_sequences = inputs[0].shape[0]
    
    hidden_single = npr.RandomState(0).randn(1, state_size) * .01
    hidden = np.repeat(hidden_single, num_sequences, axis=0)
    
    return loop.fun(params, hidden, inputs)

In [133]:
# change rnn_predict to rnn_predict_checkpointed and restart the notebook for comparison
f = lambda theta: rnn_predict(theta, inputs)
g = lambda theta: np.sum(sum(rnn_predict(theta, inputs)))

# the first time grad is called introduces memory overhead, so we do it here and ignore it

start = time()
g1 = make_vjp(rnn_predict_checkpointed, 0)(theta, inputs)[0](output_grads)
end = time()

print(end - start)
# print(x)

0.7677478790283203


In [134]:
hidden_single = npr.RandomState(0).randn(1, state_size) * .01
initial_state = np.repeat(hidden_single, batch_size, axis=0)

output_grads = vspace(f(theta)).ones()

start = time()
g2 = vjp_general(theta, initial_state, inputs, output_grads, None, num_checkpoints, True)[0]
end = time()

end - start

1.4373729228973389

In [123]:
g

{'change': array([[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        ...,
        [-8.62700490e-02,  1.25206451e-01,  9.03192207e-01, ...,
          2.69645134e-01,  5.19298640e-01,  3.88437962e-01],
        [-1.55154406e-02,  2.25149011e-02,  1.62348787e-01, ...,
          4.84923640e-02,  9.33537326e-02,  6.98254223e-02],
        [-6.77322536e+00,  9.83053247e+00,  7.09205017e+01, ...,
          2.11706358e+01,  4.07753732e+01,  3.05005622e+01]]),
 'init hiddens': array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [88]:
start = time()
g1 = grad(g)(theta)
end = time()

print("elapsed time: ", end - start)

elapsed time:  0.5895800590515137


In [15]:
for key in g1:
    print(np.sum(g1[key]))

156.55160449204888
-8.249401162174763e-12
0.0


In [10]:
from autograd.differential_operators import vspace, forward_loop_no_saving, checkpoint_policy, make_vjp

function = rnn
postprocess=lambda params, state: hiddens_to_output_probs(params, state)

curried_function = lambda param_and_state, input: function(param_and_state[0], param_and_state[1], input)
curried_function_vjp = make_vjp(curried_function, 0)

curried_postprocess = lambda param_and_state: postprocess(param_and_state[0], param_and_state[1])
curried_postprocess_vjp = make_vjp(curried_postprocess, 0)

def vjp_one_checkpoint(parameters, state_0, inputs, postprocess_grads, state_grad_wrt_next_state, fst):
    assert(len(inputs) > 0)
    assert(len(postprocess_grads) > 0)
    assert(len(inputs) + 1 == len(postprocess_grads))
    
    state_grad_vspace = vspace(state_0)
    parameter_vspace = vspace(parameters)

    parameter_grad = parameter_vspace.zeros()

    if state_grad_wrt_next_state is None:
        state_grad_wrt_next_state = state_grad_vspace.zeros()

    for y in range(len(inputs) - 1, -1, -1):
        state_y = forward_loop_no_saving(function, parameters, state_0, inputs[:y])
        
        state_vjp, state_yplusone = curried_function_vjp(ag_tuple((parameters, state_y)), inputs[y])
        postprocess_vjp = curried_postprocess_vjp((parameters, state_yplusone))[0]
        
        parameter_grad_wrt_output, state_grad_wrt_output = postprocess_vjp(postprocess_grads[y + 1])
        parameter_grad_wrt_next_state, state_grad_wrt_next_state = state_vjp(
            state_grad_vspace.add(state_grad_wrt_output, state_grad_wrt_next_state)
        )
        
        parameter_grad = parameter_vspace.add(parameter_grad, parameter_vspace.add(parameter_grad_wrt_output, parameter_grad_wrt_next_state))
        
    if fst:
        postprocess_vjp = curried_postprocess_vjp((parameters, state_0))[0]
        parameter_grad_wrt_output, state_grad_wrt_output = postprocess_vjp(postprocess_grads[0])
        parameter_grad = parameter_vspace.add(parameter_grad, parameter_grad_wrt_output)
        state_grad_wrt_next_state = state_grad_vspace.add(state_grad_wrt_output, state_grad_wrt_next_state)        
    return parameter_grad, state_grad_wrt_next_state

def vjp_general(parameters, state_0, inputs, postprocess_grads, state_grad_wrt_final_state, num_checkpoints, fst):
    assert(len(inputs) > 0)
    assert(len(postprocess_grads) > 0)
    assert(len(inputs) + 1 == len(postprocess_grads))
    
    if num_checkpoints == 1 or len(inputs) == 1:
        return vjp_one_checkpoint(parameters, state_0, inputs, postprocess_grads, state_grad_wrt_final_state, fst)
    else:
        print
        
        y = checkpoint_policy(len(inputs), num_checkpoints)
        state_y = forward_loop_no_saving(function, parameters, state_0, inputs[:y])        
        parameter_grad_wrt_case2, state_grad_wrt_case2 = vjp_general(
            parameters, state_y, inputs[y:], postprocess_grads[y:], state_grad_wrt_final_state, num_checkpoints - 1, False
        )        
        parameter_grad_wrt_case1, state_grad_wrt_case1 = vjp_general(
            parameters, state_0, inputs[:y], postprocess_grads[:y + 1], state_grad_wrt_case2, num_checkpoints, True and fst
        )
        return vspace(parameters).add(parameter_grad_wrt_case1, parameter_grad_wrt_case2), state_grad_wrt_case1