## hessQuik Timing Tests

In this notebook, we provide the timing tests used in our paper.  To run the tests on a GPU, you must change the runtime type.  See [Making the Most of your Colab Subscription](https://colab.research.google.com/?utm_source=scs-index) for further details.

## Colab Computing Resources Info
Optional printouts

CPU info

In [None]:
# disk information
!df -h

# CPU specs
!cat /proc/cpuinfo

# CPU memory
!cat /proc/meminfo

GPU info

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)
  !nvidia-smi -L

Virtual memory

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

## Clone and Install hessQuik 

To run and save the timing results, clone the hessQuik repository first.  The repository only needs to be cloned once per runtime.

In [None]:
%cd ~
%cd ../content/

!git clone https://github.com/elizabethnewman/hessQuik.git

Install hessQuik for current runtime.

In [None]:
%cd ~
%cd ../content/hessQuik/

!git pull
!python -m pip install git+https://github.com/elizabethnewman/hessQuik.git

# run code from examples folder
%cd hessQuik/examples/

## Setup Parameters for Experiments

There are many available parameters for the timing experients.  We provide the options used in the paper here.  To find all options, see the run_timing_test.py script in the examples folder.

In [None]:
num_input = 11            # powers of 2 from 2^0 to 2^(num_input - 1)
num_examples = 10         # number of examples/samples
num_trials = 10           # number of trials
num_threads = 1           # number of computational threads
network_type = 'resnet'   # network architecture


# store flags
flags = ' --num-input ' + str(num_input)
flags += ' --num-examples ' + str(num_examples) 
flags += ' --num-trials ' + str(num_trials)
flags += ' --num-threads ' + str(num_threads)
flags += ' --network-type ' + network_type

# store scalar flags
num_output = 1            # powers of 2 from 2^0 to 2^(num_output - 1)
flags_scalar = flags + ' --num-output ' + str(num_output) 

# store vector flags
num_output = 4            # powers of 2 from 2^0 to 2^(num_output - 1)
flags_vector = flags + ' --num-output ' + str(num_output) 


## Tests

We present three different tests: hessQuik, PytorchAD, and PytorchHessian.

### hessQuik

This is our AD-free code to compute gradients and Hessians of a feed forward network with respect to the inputs.

In [None]:
# hessQuik - scalar output
!python run_timing_test.py $flags_scalar --network-wrapper hessQuik --verbose --save

In [None]:
# hessQuik - vector output
!python run_timing_test.py $flags_vector --network-wrapper hessQuik --verbose --save

### PytorchAD
This uses PyTorch's AD to compute gradients and Hessians.  Our implementation follows that of [CP-Flow](https://github.com/CW-Huang/CP-Flow).

In [None]:
# PytorchAD - scalar output
!python run_timing_test.py $flags_scalar --network-wrapper PytorchAD --verbose --save

In [None]:
# PytorchAD - vector output
!python run_timing_test.py $flags_vector --network-wrapper PytorchAD --verbose --save

### PytorchHessian
This test uses PyTorch's built-in [Hessian function](https://pytorch.org/docs/stable/generated/torch.autograd.functional.hessian.html) to compute the Hessian of the network with respect to the inputs.  Currently, our implementation is only available for scalar outputs.

In [None]:
# PytorchHessian (only for scalar outputs (num_output = 1))
!python run_timing_test.py $flags_scalar --network-wrapper PytorchHessian --verbose --save

## Create Plots

In [None]:
%cd ~
%cd ../content/

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from datetime import datetime
now = datetime.now()
date = now.strftime("%m-%d-%Y--")

# save results
filename = date + device + '.zip'
!zip -r $filename hessQuik/hessQuik/examples/results/

In [None]:
%cd ~
%cd ../content/
%cd hessQuik/hessQuik/examples/results/


import matplotlib.pyplot as plt
import pickle
import numpy as np
import torch
from datetime import datetime

# plot parameters
plt.rcParams.update({'font.size': 16})
plt.rcParams.update({'image.interpolation': None})
plt.rcParams['figure.figsize'] = [7, 6]
plt.rcParams['figure.dpi'] = 100


names = ['hessQuik', 'PytorchAD', 'PytorchHessian']
markers = ['o', '^', 's']
linewidth = 3
markersize = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'
now = datetime.now()
date = now.strftime("%m-%d-%Y--")

plt.figure()
for i, name in enumerate(names):

    output = pickle.load(open(date + name + '-resnet-' + device + '-w16-d4-out1.p', 'rb'))
    results = output['results']

    x = results['in_feature_range']
    y = results['timing_trials_mean'].squeeze()

    if device == 'cpu':
      linestyle = '-'
    else:
      linestyle = '--'

    plt.loglog(x, y, linestyle + markers[i], basex=2, linewidth=linewidth, markersize=markersize, label=name + ': cuda')

plt.xlabel('in features')
plt.ylabel('time (seconds)')
plt.grid()
plt.ylim(1e-3, 2e1)
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
import torch
from datetime import datetime

names = ['hessQuik', 'PytorchAD']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
now = datetime.now()
date = now.strftime("%m-%d-%Y--")

fig, axes = plt.subplots(nrows=1, ncols=2);
for i, name in enumerate(names):

    output = pickle.load(open(date + name + '-resnet-' + device + '-w16-d4-out4.p', 'rb'))
    results = output['results']
    timing_trials_mean = results['timing_trials_mean']
    in_feature_range = results['in_feature_range']
    out_feature_range = results['out_feature_range']

    im = axes[i].imshow(torch.flipud(timing_trials_mean), norm=colors.LogNorm(vmin=1e-3, vmax=1e2))

    # local subplot info
    plt.sca(axes[i])
    plt.xticks(list(torch.arange(len(out_feature_range)).numpy()), out_feature_range)

    if i == 0:
        plt.yticks(list(torch.arange(len(in_feature_range)).numpy()), list(np.flip(in_feature_range)))
    else:
        plt.tick_params(axis='y', left=False, right=False, labelleft=False)

    plt.title(name + ': ' + device)


fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.27, 0.02, 0.45])
fig.colorbar(im, cax=cbar_ax)
fig.text(0.02, 0.5, 'input features', va='center', rotation='vertical')
fig.text(0.5, 0.02, 'output features', ha='center')
plt.show()

