Skip to content

Commit

Permalink
Modifies check_grad to not use the lrucache
Browse files Browse the repository at this point in the history
  • Loading branch information
Niru Maheswaranathan committed Mar 18, 2016
1 parent b8b311f commit da02ad1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
35 changes: 31 additions & 4 deletions descent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ def wrap(f_df, xref, size=1):
"""

memoized_f_df = lrucache(lambda x: f_df(restruct(x, xref)), size)
if size == 0:
memoized_f_df = lambda x: f_df(restruct(x, xref))
elif size > 0:
memoized_f_df = lrucache(lambda x: f_df(restruct(x, xref)), size)
else:
raise ValueError("size argument must be a positive integer")

objective = compose(first, memoized_f_df)
gradient = compose(destruct, second, memoized_f_df)
return objective, gradient
Expand Down Expand Up @@ -120,7 +126,7 @@ def check_grad(f_df, xref, stepsize=1e-6, n=50, tol=1e-6, out=sys.stdout):
CORRECT = u'\x1b[32mPass\x1b[0m'
INCORRECT = u'\x1b[31mFail\x1b[0m'

obj, grad = wrap(f_df, xref)
obj, grad = wrap(f_df, xref, size=0)
x0 = destruct(xref)
df = grad(x0)

Expand All @@ -132,6 +138,27 @@ def check_grad(f_df, xref, stepsize=1e-6, n=50, tol=1e-6, out=sys.stdout):
out.write(("{}".format("------------------------------------\n")))
out.flush()

# helper function to parse a number
def parse_error(number):

# colors
failure = "\033[91m"
passing = "\033[92m"
warning = "\033[93m"
end = "\033[0m"

# correct
if error < 0.1 * tol:
return "{:<e}".format(error)

# warning
elif error < tol:
return "{}{:<e}{}".format(warning, error, end)

# failure
else:
return "{}{:<e}{}".format(failure, error, end)

# check each dimension
for j in range(x0.size):

Expand All @@ -149,8 +176,8 @@ def check_grad(f_df, xref, stepsize=1e-6, n=50, tol=1e-6, out=sys.stdout):
if normsum > 0 else 0

errstr = CORRECT if error < tol else INCORRECT
out.write(("{:<10.4f} | {:<10.4f} | {:<5.6f} | {:^2}\n"
.format(df_approx, df_analytic, error, errstr)))
out.write(("{:<10.4f} | {:<10.4f} | {} | {:^2}\n"
.format(df_approx, df_analytic, parse_error(error), errstr)))
out.flush()


Expand Down
7 changes: 5 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from descent.utils import destruct, restruct, lrucache, check_grad
from io import StringIO
from time import sleep, time
import re


def test_lrucache():
Expand Down Expand Up @@ -51,12 +52,14 @@ def f_df_incorrect(x):
check_grad(f_df_correct, 5, out=output)

# helper functions
getvalues = lambda o: [float(s.strip()) for s in o.getvalue().split('\n')[3].split('|')[:-1]]
ansi_escape = re.compile(r'\x1b[^m]*m')
getvalues = lambda o: [float(ansi_escape.sub('', s.strip())) for s in o.getvalue().split('\n')[3].split('|')[:-1]]

# get the first row of data
values = getvalues(output)
print(values)
assert values[0] == values[1] == 10.0, "Correct gradient computation"
assert values[2] == 0.0, "Correct error computation"
assert values[2] <= 1e-10, "Correct error computation"

output = StringIO()
check_grad(f_df_incorrect, 5, out=output)
Expand Down

0 comments on commit da02ad1

Please sign in to comment.