In [1]:
import autograd
import autograd.numpy as np
from autograd import grad
from autograd import checkpoint
from autograd.extend import primitive

import numpy as onp
from time import time

%load_ext memory_profiler

from builtins import range, list as ag_list, tuple as ag_tuple
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import grad
from autograd.scipy.misc import logsumexp
from os.path import dirname, join
from autograd.misc.optimizers import adam

In [2]:
def create_rnn_params(input_size, state_size, output_size,
                      param_scale=0.01, rs=npr.RandomState(0)):
    return {'change': rs.randn(input_size + state_size + 1, state_size) * param_scale,
            'predict': rs.randn(state_size + 1, output_size) * param_scale,
            'init hiddens': rs.randn(1, state_size) * param_scale,}

def sigmoid(x):
    return 0.5*(np.tanh(x) + 1.0)   # Output ranges from 0 to 1.

def hiddens_to_output_probs(theta, hiddens):
    output = concat_and_multiply(theta['predict'], hiddens)
    return output - logsumexp(output, axis=1, keepdims=True)

def concat_and_multiply(weights, *args):
    cat_state = np.hstack(args + (np.ones((args[0].shape[0], 1)),))
    return np.dot(cat_state, weights)

input_size = 64
state_size = 64
output_size = 64

batch_size = 64
seq_len = 512
num_checkpoints = 64

theta = create_rnn_params(input_size, state_size, output_size)

np.random.seed(0)
inputs = [np.random.randn(batch_size, input_size) for _ in range(seq_len)]

In [3]:
from autograd.differential_operators import binomial_checkpoint
from autograd.builtins import list as ag_list, tuple as ag_tuple

def rnn(theta, state, x):  
    return np.tanh(concat_and_multiply(theta['change'], x, state))

In [4]:
def rnn_predict(params, inputs):
    num_sequences = inputs[0].shape[0]
    hidden_single = npr.RandomState(0).randn(1, state_size) * .01
    hidden = np.repeat(hidden_single, num_sequences, axis=0)  
    
    outputs = [hiddens_to_output_probs(params, hidden)]
    
    for input in inputs:  # Iterate over time steps.
        hidden = rnn(params, hidden, input)
        outputs.append(hiddens_to_output_probs(params, hidden))
    return outputs

loop = binomial_checkpoint(rnn, seq_len, num_checkpoints, hiddens_to_output_probs)

def rnn_predict_checkpointed(params, inputs):
    num_sequences = inputs[0].shape[0]
    
    hidden_single = npr.RandomState(0).randn(1, state_size) * .01
    hidden = np.repeat(hidden_single, num_sequences, axis=0)
    
    return loop(params, hidden, inputs)

In [5]:
# change rnn_predict to rnn_predict_checkpointed and restart the notebook for comparison
f = lambda theta: rnn_predict(theta, inputs)
g = lambda theta: np.sum(sum(f(theta)))

# the first time grad is called introduces memory overhead, so we do it here and ignore it
_ = grad(g)(theta)

In [6]:
start = time()
%memit g1 = grad(g)(theta)
end = time()

print("elapsed time: ", end - start)

peak memory: 218.01 MiB, increment: 60.02 MiB
elapsed time:  0.8502910137176514


In [7]:
for key in g1:
    print(np.sum(g1[key]))

156.55160449204902
-5.4498627832799684e-12
0.0
