Skip to content

Commit

Permalink
removes six dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
Niru Maheswaranathan committed Mar 10, 2016
1 parent 1fb06ed commit 3aa43fb
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import (absolute_import, division, print_function, unicode_literals)
import numpy as np
from descent.utils import destruct, restruct, lrucache, check_grad
from six import StringIO
from io import StringIO
from time import sleep, time


Expand Down Expand Up @@ -42,17 +42,16 @@ def test_check_grad():
"""Tests the check_grad() function"""

def f_df_correct(x):
return x**2, 2*x
return x**2, 2 * x

def f_df_incorrect(x):
return x**3, 0.5*x**2
return x**3, 0.5 * x**2

output = StringIO()
check_grad(f_df_correct, 5, out=output)

# helper functions
getvalues = lambda o: [float(s.strip()) for s in \
o.getvalue().split('\n')[3].split('|')[:-1]]
getvalues = lambda o: [float(s.strip()) for s in o.getvalue().split('\n')[3].split('|')[:-1]]

# get the first row of data
values = getvalues(output)
Expand All @@ -63,8 +62,7 @@ def f_df_incorrect(x):
check_grad(f_df_incorrect, 5, out=output)
values = getvalues(output)
printed_error = values[2]
correct_error = np.abs(values[0] - values[1]) \
/ (np.abs(values[0]) + np.abs(values[1]))
correct_error = np.abs(values[0] - values[1]) / (np.abs(values[0]) + np.abs(values[1]))
assert np.isclose(printed_error, correct_error), "Correct relative error"


Expand Down Expand Up @@ -97,18 +95,19 @@ def test_destruct():
assert np.allclose(larray, destruct(tuple(lref))), "Tuple destruct"

# Composed / mixed types
lref = [np.eye(2), {'a': np.array([-1,1]), 'b': np.arange(3)}, 7]
lref = [np.eye(2), {'a': np.array([-1, 1]), 'b': np.arange(3)}, 7]
larray = np.array([1, 0, 0, 1, -1, 1, 0, 1, 2, 7])
assert np.allclose(larray, destruct(lref)), "List of mixed types"

dref = {
'a': [3.5, 1.2, -33],
'b': np.arange(6).reshape(2,3),
'b': np.arange(6).reshape(2, 3),
'c': 7
}
darray = np.array([3.5, 1.2, -33, 0, 1, 2, 3, 4, 5, 7])
assert np.allclose(darray, destruct(dref)), "Dictionary of mixed types"


def test_restruct():
"""Tests the destruct utility function"""

Expand Down Expand Up @@ -148,19 +147,19 @@ def test_restruct():
assert np.allclose(lref[idx], val), "Tuple restruct"

# Composed / mixed types
lref = [np.eye(2), {'a': np.array([-1,1]), 'b': np.arange(3)}, 7]
lzeros = [np.zeros((2,2)), {'a': np.zeros(2), 'b': np.zeros(3)}, 0]
lref = [np.eye(2), {'a': np.array([-1, 1]), 'b': np.arange(3)}, 7]
lzeros = [np.zeros((2, 2)), {'a': np.zeros(2), 'b': np.zeros(3)}, 0]
larray = np.array([1, 0, 0, 1, -1, 1, 0, 1, 2, 7])
assert np.allclose(destruct(restruct(larray, lzeros)), destruct(lref)), "List of mixed types"

dref = {
'a': [3.5, 1.2, -33],
'b': np.arange(6).reshape(2,3),
'b': np.arange(6).reshape(2, 3),
'c': 7
}
dzeros = {
'a': [0, 0, 0],
'b': np.zeros((2,3)),
'b': np.zeros((2, 3)),
'c': 0
}
darray = np.array([3.5, 1.2, -33, 0, 1, 2, 3, 4, 5, 7])
Expand Down

0 comments on commit 3aa43fb

Please sign in to comment.