Skip to content

Commit

Permalink
autopep8
Browse files Browse the repository at this point in the history
  • Loading branch information
bshillingford committed Feb 19, 2017
1 parent 2c4ba7a commit 963865b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 15 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ script:
- nosetests --with-coverage --cover-package torchfile
after_success: coveralls

notifications:
email: false

matrix:
fast_finish: true
include:
Expand Down
34 changes: 19 additions & 15 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ def make_filename(fn):
TEST_FILE_DIRECTORY = 'testfiles_x86_64'
return os.path.join(TEST_FILE_DIRECTORY, fn)


def load(fn, **kwargs):
return torchfile.load(make_filename(fn), **kwargs)


class TestBasics(unittest.TestCase):

def test_dict(self):
obj = load('hello=123.t7')
self.assertEqual(dict(obj), {b'hello': 123})
Expand All @@ -34,8 +36,8 @@ def test_classnames_never_decoded(self):

def test_basic_tensors(self):
f64 = load('doubletensor.t7')
self.assertTrue((f64 == np.array([[1,2,3,], [4,5,6.9]],
dtype=np.float64)).all())
self.assertTrue((f64 == np.array([[1, 2, 3, ], [4, 5, 6.9]],
dtype=np.float64)).all())

f32 = load('floattensor.t7')
self.assertAlmostEqual(f32.sum(), 12.97241666913, delta=1e-5)
Expand All @@ -46,19 +48,20 @@ def test_function(self):

def test_dict_accessors(self):
obj = load('hello=123.t7',
use_int_heuristic=True,
utf8_decode_strings=True)
use_int_heuristic=True,
utf8_decode_strings=True)
self.assertIsInstance(obj['hello'], int)
self.assertIsInstance(obj.hello, int)

obj = load('hello=123.t7',
use_int_heuristic=True,
utf8_decode_strings=False)
use_int_heuristic=True,
utf8_decode_strings=False)
self.assertIsInstance(obj[b'hello'], int)
self.assertIsInstance(obj.hello, int)


class TestRecursiveObjects(unittest.TestCase):

def test_recursive_class(self):
obj = load('recursive_class.t7')
self.assertEqual(obj.a, obj)
Expand All @@ -72,6 +75,7 @@ def test_recursive_table(self):


class TestTDS(unittest.TestCase):

def test_hash(self):
obj = load('tds_hash.t7')
self.assertEqual(len(obj), 3)
Expand All @@ -85,16 +89,17 @@ def test_vec(self):


class TestHeuristics(unittest.TestCase):

def test_list_heuristic(self):
obj = load('list_table.t7', use_list_heuristic=True)
self.assertEqual(obj, [b'hello', b'world', b'third item', 123])

obj = load('list_table.t7',
use_list_heuristic=False,
use_int_heuristic=True)
use_list_heuristic=False,
use_int_heuristic=True)
self.assertEqual(
dict(obj),
{1: b'hello', 2: b'world', 3: b'third item', 4: 123})
dict(obj),
{1: b'hello', 2: b'world', 3: b'third item', 4: 123})

def test_int_heuristic(self):
obj = load('hello=123.t7', use_int_heuristic=True)
Expand All @@ -104,14 +109,13 @@ def test_int_heuristic(self):
self.assertNotIsInstance(obj[b'hello'], int)

obj = load('list_table.t7',
use_list_heuristic=False,
use_int_heuristic=False)
use_list_heuristic=False,
use_int_heuristic=False)
self.assertEqual(
dict(obj),
{1: b'hello', 2: b'world', 3: b'third item', 4: 123})
dict(obj),
{1: b'hello', 2: b'world', 3: b'third item', 4: 123})
self.assertNotIsInstance(list(obj.keys())[0], int)


if __name__ == '__main__':
unittest.main()

1 change: 1 addition & 0 deletions torchfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __dir__(self):

type_handlers = {}


def register_handler(typename):
def do_register(handler):
type_handlers[typename] = handler
Expand Down

0 comments on commit 963865b

Please sign in to comment.