In [1]:
import functools
import os
import psutil
import autograd.numpy as np
from autograd import make_vjp, make_jvp, grad

import time
from memory_profiler import memory_usage

import plotting_standards.colors as colors
import plotting_standards.utils as plotting_utils

from plotly.subplots import make_subplots
import plotly.offline as plotly
import plotly.graph_objs as go

%matplotlib inline
plotly.init_notebook_mode(connected=True)

In [2]:
def get_end_memory(fn, args):
    process = psutil.Process(os.getpid())
    start_vms = process.memory_info().vms
    fn(*args)
    return process.memory_info().vms - start_vms

In [3]:
def get_fwd_grad(f, x):
    in_dim = x.shape[0]
    out_dim = f(x).shape[0]
    
    jac = make_jvp(f)(x)
    output_vec = np.zeros((out_dim, in_dim))
    id_mat = np.identity(in_dim)
    for i in range(in_dim):
        output_vec[:,i] = jac(id_mat[i])[1]
    return output_vec

In [4]:
def get_rev_grad(f, x):
    in_dim = x.shape[0]
    out_dim = f(x).shape[0]
    
    jac = make_vjp(f)(x)
    output_vec = np.zeros((out_dim, in_dim))
    id_mat = np.identity(out_dim)
    for i in range(out_dim):
        output_vec[i] = jac[0](id_mat[i])
    return output_vec    

In [5]:
def get_fn_to_diff_no_mem_cost(n_start, n_end, n_middle, n_layers):
    def fn(x):
        x = np.tanh(np.matmul(np.zeros((n_middle, n_start)), x))
        for i in range(n_layers):
            x = np.tanh(np.matmul(np.zeros((n_middle, n_middle)), x))
        return np.tanh(np.matmul(np.zeros((n_end, n_middle)), x))
    
    return fn

In [6]:
def get_fn_to_diff(n_start, n_end, n_middle, n_layers):
    start_weights = np.random.rand(n_middle, n_start)
    end_weights = np.random.rand(n_end, n_middle)
    middle_weights = [np.random.rand(n_middle, n_middle) for _ in range(n_layers)]
    
    def fn(x):
        x = np.tanh(np.matmul(start_weights, x))
        for i in range(n_layers):
            x = np.tanh(np.matmul(middle_weights[i], x))
        return np.tanh(np.matmul(end_weights, x))
    
    return fn

# Memory

In [7]:
fwd_mems = []
fwd_ns = []

for n in range(5, 10):
    n_layers = int(2.5 ** n)
    fwd_ns.append(n_layers)
    fn = get_fn_to_diff_no_mem_cost(100, 100, 100, n_layers)
    x = np.random.rand(100)
    fwd_mems.append(get_end_memory(get_fwd_grad, (fn, x)))

In [8]:
rev_mems = []
rev_ns = []

for n in range(5, 10):
    n_layers = int(2.5 ** n)
    rev_ns.append(n_layers)
    fn = get_fn_to_diff_no_mem_cost(100, 100, 100, n_layers)
    x = np.random.rand(100)
    rev_mems.append(get_end_memory(get_rev_grad, (fn, x)))

# Time

In [None]:
fwd_n_start_times = []
fwd_n_middle_times = []
fwd_n_end_times = []

for n in range(10, 1000, 10):
    start_avg_total = 0.0
    end_avg_total = 0.0
    middle_avg_total = 0.0
    
    for _ in range(30):
        fn = get_fn_to_diff(n, 10, 10, 10)
        x = np.random.rand(n)
        start = time.time()
        get_fwd_grad(fn, x)
        start_avg_total += time.time() - start

        fn = get_fn_to_diff(10, 10, n, 10)
        x = np.random.rand(10)
        start = time.time()
        get_fwd_grad(fn, x)
        middle_avg_total += time.time() - start

        fn = get_fn_to_diff(10, n, 10, 10)
        x = np.random.rand(10)
        start = time.time()
        get_fwd_grad(fn, x)
        end_avg_total += time.time() - start
        
    fwd_n_start_times.append(start_avg_total / 30.0)
    fwd_n_middle_times.append(middle_avg_total / 30.0)
    fwd_n_end_times.append(end_avg_total / 30.0)

In [None]:
rev_n_start_times = []
rev_n_middle_times = []
rev_n_end_times = []

for n in range(10, 1000, 10):
    start_avg_total = 0.0
    end_avg_total = 0.0
    middle_avg_total = 0.0
    
    for _ in range(30):
        fn = get_fn_to_diff(n, 10, 10, 10)
        x = np.random.rand(n)
        start = time.time()
        get_rev_grad(fn, x)
        start_avg_total += time.time() - start

        fn = get_fn_to_diff(10, 10, n, 10)
        x = np.random.rand(10)
        start = time.time()
        get_rev_grad(fn, x)
        middle_avg_total += time.time() - start

        fn = get_fn_to_diff(10, n, 10, 10)
        x = np.random.rand(10)
        start = time.time()
        get_rev_grad(fn, x)
        end_avg_total += time.time() - start
        
    rev_n_start_times.append(start_avg_total / 30.0)
    rev_n_middle_times.append(middle_avg_total / 30.0)
    rev_n_end_times.append(end_avg_total / 30.0)

# Discussion And Results

The folowing plots show how forward-mode and reverse-mode time performance vary as a function of 

In [18]:
fig = make_subplots(rows=1, cols=2, shared_yaxes=True, subplot_titles=("Forward-Mode Performance", "Reverse-Mode Performance"))

plot_1 = go.Scatter(
    x=np.arange(1, len(fwd_n_start_times) + 1), y=fwd_n_start_times, mode="lines", name="start size",
    line={"smoothing": 0.5, 'color': colors.core_1})
plot_2 = go.Scatter(
    x=np.arange(1, len(fwd_n_end_times) + 1), y=fwd_n_end_times, mode="lines", name="end size",
    line={"smoothing": 0.5, 'color': colors.core_2})
plot_3 = go.Scatter(
    x=np.arange(1, len(fwd_n_middle_times) + 1), y=fwd_n_middle_times, mode="lines", name="middle size",
    line={"smoothing": 0.5, 'color': colors.core_3})
fig.add_trace(plot_1, row=1, col=1)
fig.add_trace(plot_2, row=1, col=1)
fig.add_trace(plot_3, row=1, col=1)

plot_1 = go.Scatter(
    x=np.arange(1, len(rev_n_start_times) + 1), y=rev_n_start_times, mode="lines", showlegend=False,
    line={"smoothing": 0.5, 'color': colors.core_1})
plot_2 = go.Scatter(
    x=np.arange(1, len(rev_n_end_times) + 1), y=rev_n_end_times, mode="lines", showlegend=False,
    line={"smoothing": 0.5, 'color': colors.core_2})
plot_3 = go.Scatter(
    x=np.arange(1, len(rev_n_middle_times) + 1), y=rev_n_middle_times, mode="lines", showlegend=False,
    line={"smoothing": 0.5, 'color': colors.core_3})
fig.add_trace(plot_1, row=1, col=2)
fig.add_trace(plot_2, row=1, col=2)
fig.add_trace(plot_3, row=1, col=2)

fig.update_layout(height=300)
plotting_utils.set_layout(fig['layout'])
fig['layout']['xaxis']['title']['text'] = "Layer Size"
fig['layout']['yaxis']['title']['text'] = "Time"
fig['layout']['xaxis2']['title']['text'] = "Layer Size"
fig.show()

In [19]:
with open("/home/fkwang/Downloads/fwd_vs_rev_time.svg", 'wb') as f:
    f.write(fig.to_image(format="svg"))

In [24]:
fig = go.Figure()
plot_1 = go.Scatter(
    x=fwd_ns, y=fwd_mems, mode="lines", line={"smoothing": 0.5, 'color': colors.core_1}, name='Forward Mode')
plot_2 = go.Scatter(
    x=rev_ns, y=rev_mems, mode="lines", line={"smoothing": 0.5, 'color': colors.core_2}, name='Reverse Mode')
fig.add_trace(plot_1)
fig.add_trace(plot_2)

plotting_utils.set_layout(fig['layout'])
fig.update_layout(height=500)
fig['layout']['xaxis']['title']['text'] = "Number of Layers"
fig['layout']['yaxis']['title']['text'] = "Bytes"
fig['layout']['yaxis']['range'] = [-50000000, max(rev_mems)]

fig.show()

In [25]:
with open("/home/fkwang/Downloads/fwd_vs_rev_mem.svg", 'wb') as f:
    f.write(fig.to_image(format="svg"))