In [1]:
import autograd
import autograd.numpy as np
from autograd import grad
from autograd import checkpoint
from autograd.extend import primitive

import numpy as onp
from time import time

%load_ext memory_profiler

from builtins import range, list as ag_list, tuple as ag_tuple
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import grad
from autograd.scipy.misc import logsumexp
from os.path import dirname, join
from autograd.misc.optimizers import adam

In [2]:
def create_rnn_params(input_size, state_size, output_size,
                      param_scale=0.01, rs=npr.RandomState(0)):
    return {'change': rs.randn(input_size + state_size + 1, state_size) * param_scale,
            'predict': rs.randn(state_size + 1, output_size) * param_scale,
            'init hiddens': rs.randn(1, state_size) * param_scale,}

def sigmoid(x):
    return 0.5*(np.tanh(x) + 1.0)   # Output ranges from 0 to 1.

def hiddens_to_output_probs(theta, hiddens):
    output = concat_and_multiply(theta['predict'], hiddens)
    return output - logsumexp(output, axis=1, keepdims=True)

def concat_and_multiply(weights, *args):
    cat_state = np.hstack(args + (np.ones((args[0].shape[0], 1)),))
    return np.dot(cat_state, weights)

input_size = 64
state_size = 64
output_size = 64

batch_size = 64
seq_len = 512
num_checkpoints = 512

theta = create_rnn_params(input_size, state_size, output_size)

np.random.seed(0)
inputs = [np.random.randn(batch_size, input_size) for _ in range(seq_len)]

In [3]:
from autograd.differential_operators import binomial_checkpoint
from autograd.builtins import list as ag_list, tuple as ag_tuple

unit = ag_tuple(())

def rnn(theta, state, x):  
    _, visible_state = state
    return ag_tuple((unit, np.tanh(concat_and_multiply(theta['change'], x, visible_state))))

In [4]:
def rnn_predict(params, inputs):
    num_sequences = inputs[0].shape[0]
    
    hidden_single = npr.RandomState(0).randn(1, state_size) * .01
    hidden = np.repeat(hidden_single, num_sequences, axis=0)  
    output = hiddens_to_output_probs(params, hidden)
    
    outputs = [output]

    for input in inputs:  # Iterate over time steps.
        _, hidden = rnn(params, (unit, hidden), input)
        output = hiddens_to_output_probs(params, hidden)
        
        outputs.append(output)
    
    return outputs

f = lambda theta: rnn_predict(theta, inputs)

In [5]:
f = lambda theta: rnn_predict(theta, inputs)
g = lambda theta: np.sum(sum(f(theta)))

x2 = g(theta)

start = time()
%memit g2 = grad(g)(theta)
end = time()

print("elapsed time: ", end - start)

peak memory: 223.22 MiB, increment: 119.25 MiB
elapsed time:  0.944720983505249


In [5]:
loop = binomial_checkpoint(rnn, seq_len, num_checkpoints)

def rnn_predict(params, inputs):
    num_sequences = inputs[0].shape[0]
    
    hidden_single = npr.RandomState(0).randn(1, state_size) * .01
    hidden = np.repeat(hidden_single, num_sequences, axis=0)
    
    hiddens = loop(params, ag_tuple((unit, hidden)), inputs)
    
    return [hiddens_to_output_probs(params, hidden) for hidden in hiddens]

f = lambda theta: rnn_predict(theta, inputs)
g = lambda theta: np.sum(sum(f(theta)))

x1 = g(theta)

start = time()
%memit g1 = grad(g)(theta)
end = time()

print("elapsed time: ", end - start)

called!
512 512 1
511 511 1
510 510 1
509 509 1
508 508 1
507 507 1
506 506 1
505 505 1
504 504 1
503 503 1
502 502 1
501 501 1
500 500 1
499 499 1
498 498 1
497 497 1
496 496 1
495 495 1
494 494 1
493 493 1
492 492 1
491 491 1
490 490 1
489 489 1
488 488 1
487 487 1
486 486 1
485 485 1
484 484 1
483 483 1
482 482 1
481 481 1
480 480 1
479 479 1
478 478 1
477 477 1
476 476 1
475 475 1
474 474 1
473 473 1
472 472 1
471 471 1
470 470 1
469 469 1
468 468 1
467 467 1
466 466 1
465 465 1
464 464 1
463 463 1
462 462 1
461 461 1
460 460 1
459 459 1
458 458 1
457 457 1
456 456 1
455 455 1
454 454 1
453 453 1
452 452 1
451 451 1
450 450 1
449 449 1
448 448 1
447 447 1
446 446 1
445 445 1
444 444 1
443 443 1
442 442 1
441 441 1
440 440 1
439 439 1
438 438 1
437 437 1
436 436 1
435 435 1
434 434 1
433 433 1
432 432 1
431 431 1
430 430 1
429 429 1
428 428 1
427 427 1
426 426 1
425 425 1
424 424 1
423 423 1
422 422 1
421 421 1
420 420 1
419 419 1
418 418 1
417 417 1
416 416 1
415 415 1
414 414 1
41

In [7]:
# check correctness of gradients

print(x1 - x2)

for key in g1:
    print(key, np.sum((g1[key] - g2[key])**2))

0.0
change 3.8644039759066313e-26
predict 0.0
init hiddens 0.0


In [6]:
print(autograd.core.bpc)

9227


In [9]:
3.92 / (18443 / 14341)

3.0481331670552514