In [63]:
# imports
import numpy as np
import matplotlib.pyplot as plt
import descent
from ipywidgets import interact
%matplotlib inline

# Helper utilities

Demonstrations of some useful helper functions and utilities in `descent`:

## Numerical gradient checks

Checks the given (analytic) objective and gradient function with a numerical comparison:

$$ \frac{f(x + \delta x) - f(x)}{\delta x} \approx \frac{\partial f}{\partial x} $$

In [34]:
def f_df(x):
    objective = 0.5 * np.linalg.norm(x)**2
    gradient = x.copy()
    gradient[4] = 5.     # Error! incorrect gradient here
    return objective, gradient

In [35]:
x0 = np.random.randn(10)
descent.check_grad(f_df, x0)

------------------------------------
Numerical  | Analytic   | Error          
------------------------------------
0.4725     | 0.4725     | 0.000000 | [32m✔[0m
0.2603     | 0.2603     | 0.000000 | [32m✔[0m
0.2569     | 0.2569     | 0.000000 | [32m✔[0m
-0.8082    | -0.8082    | 0.000000 | [32m✔[0m
-2.2036    | 5.0000     | 1.000000 | [31m✗[0m
0.6512     | 0.6512     | 0.000000 | [32m✔[0m
-0.1794    | -0.1794    | 0.000000 | [32m✔[0m
1.2999     | 1.2999     | 0.000000 | [32m✔[0m
0.3329     | 0.3329     | 0.000000 | [32m✔[0m
0.1381     | 0.1381     | 0.000000 | [32m✔[0m


## Function wrapping

In [46]:
A = np.random.randn(10,5)
def f_df(theta):
    objective = 0.5 * np.linalg.norm(A.dot(theta['w']) - theta['b']) ** 2
    gradient = dict()
    gradient['w'] = A.T.dot(A.dot(theta['w']) - theta['b'])
    gradient['b'] = theta['b'] - A.dot(theta['w'])
    return objective, gradient

In [47]:
theta_init = {'w': np.random.randn(5,), 'b': np.random.randn(10,)}

In [48]:
descent.check_grad(f_df, theta_init)

------------------------------------
Numerical  | Analytic   | Error          
------------------------------------
4.1944     | 4.1944     | 0.000000 | [32m✔[0m
-3.7402    | -3.7402    | 0.000000 | [32m✔[0m
3.3267     | 3.3267     | 0.000000 | [32m✔[0m
0.8700     | 0.8700     | 0.000000 | [32m✔[0m
-0.1878    | -0.1878    | 0.000000 | [32m✔[0m
2.5599     | 2.5599     | 0.000000 | [32m✔[0m
0.1329     | 0.1329     | 0.000000 | [32m✔[0m
3.6567     | 3.6567     | 0.000000 | [32m✔[0m
-0.8752    | -0.8752    | 0.000000 | [32m✔[0m
0.7195     | 0.7195     | 0.000000 | [32m✔[0m
-8.0961    | -8.0961    | 0.000000 | [32m✔[0m
-2.7421    | -2.7421    | 0.000000 | [32m✔[0m
20.8833    | 20.8833    | 0.000000 | [32m✔[0m
1.2850     | 1.2850     | 0.000000 | [32m✔[0m
5.5346     | 5.5346     | 0.000000 | [32m✔[0m


In [59]:
opt = descent.sgd(f_df, theta_init)
opt.display.every = 1000
opt.run(maxiter=1e4)

+----------------+-----------------+----------------+
|Iteration       | Objective       | Runtime        |
+----------------+-----------------+----------------+
|              0 |          31.732 |             0 s|
|           1000 |         0.52727 |      285.864 μs|
|           2000 |        0.071144 |      324.965 μs|
|           3000 |       0.0096186 |      272.989 μs|
|           4000 |       0.0013004 |      326.872 μs|
|           5000 |      0.00017582 |      274.897 μs|
|           6000 |      2.3771e-05 |      275.135 μs|
|           7000 |      3.2138e-06 |      272.989 μs|
|           8000 |      4.3451e-07 |       275.85 μs|
|           9000 |      5.8745e-08 |      285.864 μs|
+----------------+-----------------+----------------+
➛ Final objective: 7.958262477606792e-09
➛ Total runtime: 3.20679 s
➛ All done!



## Interrupts

In [60]:
opt = descent.sgd(f_df, theta_init)
opt.display.every = 1000
opt.run(maxiter=1e5)

+----------------+-----------------+----------------+
|Iteration       | Objective       | Runtime        |
+----------------+-----------------+----------------+
|              0 |          31.732 |             0 s|
|           1000 |         0.52727 |      271.082 μs|
|           2000 |        0.071144 |      326.157 μs|
|           3000 |       0.0096186 |      338.078 μs|
|           4000 |       0.0013004 |      426.054 μs|
|           5000 |      0.00017582 |      405.073 μs|
|           6000 |      2.3771e-05 |      306.129 μs|
|           7000 |      3.2138e-06 |      385.046 μs|
|           8000 |      4.3451e-07 |      386.953 μs|
|           9000 |      5.8745e-08 |      381.947 μs|
+----------------+-----------------+----------------+
➛ Final objective: 3.059645471663519e-08
➛ Total runtime: 3.29281 s
➛ All done!

