Skip to content

Commit

Permalink
Improved coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
belltailjp committed Oct 11, 2018
1 parent 491eba4 commit f36d23e
Showing 1 changed file with 57 additions and 26 deletions.
83 changes: 57 additions & 26 deletions tests/test_computational_cost_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def calc_custom(func, in_data, **kwargs):


def test_custom_cost_calculator_invalid():
def calc_invalid_signature(a, b):
pass

def calc_no_type_annotation(func, in_data, **kwargs):
pass

Expand All @@ -198,14 +201,21 @@ def calc_insufficient_return(func, in_data, **kwargs):
def calc_wrong_type(func, in_data, **kwargs):
return (1, 1, 1, None)

calculators = [
calc_invalid_signature, calc_no_type_annotation, calc_not_tuple,
calc_insufficient_return, calc_wrong_type
]

x = np.random.randn(1, 3, 32, 32).astype(np.float32)
x = chainer.Variable(x)
for f in [calc_not_tuple, calc_not_tuple, calc_insufficient_return]:
with chainer.using_config('train', False):
with ComputationalCostHook() as cost:
with chainer.using_config('train', False):
with ComputationalCostHook() as cost:
for f in calculators:
with pytest.raises(TypeError), pytest.warns(UserWarning):
cost.add_custom_cost_calculator(AddConstant, f)
x = x + 1
with pytest.raises(TypeError):
cost.add_custom_cost_calculator(1, f)


def test_report_ignored_layer():
Expand Down Expand Up @@ -271,8 +281,8 @@ def show_report(**kwargs):
def show_summary_report(**kwargs):
return _report(ccost.show_summary_report, **kwargs)

def assert_table(rep, expected):
for col, val in expected.items():
def assert_table(rep, expect):
for col, val in expect.items():
assert rep[col] == val

col_flops, col_mr, col_mw, col_mrw = (1, 2, 3, 4)
Expand All @@ -286,34 +296,41 @@ def assert_table(rep, expected):
assert summary_cols == show_summary_report()[0]

# Case unit=None: raw values are shown
expected = {col_flops: '2415919104', col_mr: '16851200',
col_mw: '33554432', col_mrw: '50405632'}
assert_table(show_report(unit=None)[-1], expected) # default CSV
assert_table(show_report(unit=None, mode='md')[-1], expected)
assert_table(show_report(unit=None, mode='table')[-1], expected)
expect = {col_flops: '2415919104', col_mr: '16851200',
col_mw: '33554432', col_mrw: '50405632'}
assert_table(show_report(unit=None)[-1], expect) # default CSV
assert_table(show_report(unit=None, mode='md')[-1], expect)
assert_table(show_report(unit=None, mode='table')[-1], expect)

assert_table(show_report(unit=None)[-2], expected)
assert_table(show_report(unit=None)[-2], expect)

# Case unit=G: FLOPs/=1000^3, mem/=1024^3, 3 digits after the decimal point
expected = {col_flops: '2.416', col_mr: '0.016',
col_mw: '0.031', col_mrw: '0.047'}
assert_table(show_report(unit='G')[-1], expected)
assert_table(show_report(unit='G', mode='md')[-1], expected)
assert_table(show_report(unit='G', mode='table')[-1], expected)
expect = {col_flops: '2.416', col_mr: '0.016',
col_mw: '0.031', col_mrw: '0.047'}
assert_table(show_report(unit='G')[-1], expect)
assert_table(show_report(unit='G', mode='md')[-1], expect)
assert_table(show_report(unit='G', mode='table')[-1], expect)

# Case unit=G, n_digits=6: more digits will be shown
expected = {col_flops: '2.415919', col_mr: '0.015694',
col_mw: '0.03125', col_mrw: '0.046944'}
assert_table(show_report(unit='G', n_digits=6)[-1], expected)
assert_table(show_report(unit='G', n_digits=6, mode='md')[-1], expected)
assert_table(show_report(unit='G', n_digits=6, mode='table')[-1], expected)
expect = {col_flops: '2.415919', col_mr: '0.015694',
col_mw: '0.03125', col_mrw: '0.046944'}
assert_table(show_report(unit='G', n_digits=6)[-1], expect)
assert_table(show_report(unit='G', n_digits=6, mode='md')[-1], expect)
assert_table(show_report(unit='G', n_digits=6, mode='table')[-1], expect)

# Case unit=G, n_digits>10: truncated to 10 digits
expect = {col_flops: '2.415919104', col_mr: '0.015693903',
col_mw: '0.03125', col_mrw: '0.046943903'}
assert_table(show_report(unit='G', n_digits=11)[-1], expect)
assert_table(show_report(unit='G', n_digits=11, mode='md')[-1], expect)
assert_table(show_report(unit='G', n_digits=11, mode='table')[-1], expect)

# Case unit=M, n_digits=0: Values are rounded to integer
expected = {col_flops: '2416', col_mr: '16',
col_mw: '32', col_mrw: '48'}
assert_table(show_report(unit='M', n_digits=0)[-1], expected)
assert_table(show_report(unit='M', n_digits=0, mode='md')[-1], expected)
assert_table(show_report(unit='M', n_digits=0, mode='table')[-1], expected)
expect = {col_flops: '2416', col_mr: '16',
col_mw: '32', col_mrw: '48'}
assert_table(show_report(unit='M', n_digits=0)[-1], expect)
assert_table(show_report(unit='M', n_digits=0, mode='md')[-1], expect)
assert_table(show_report(unit='M', n_digits=0, mode='table')[-1], expect)

# Case only some columns are specified
rep = show_report(unit='G', columns=['name', 'mrw'])[-1]
Expand All @@ -331,6 +348,20 @@ def assert_table(rep, expected):
show_report(unit='G', columns=['name', 'wooooohoooooooo'])


def test_error_when_invalid_report_type_is_specified():
x = np.random.randn(1, 3, 32, 32).astype(np.float32)
net = SimpleConvNet()
with chainer.using_config('train', False):
with ComputationalCostHook() as cost:
net(x)
with pytest.raises(ValueError):
cost.show_report(mode='unknown')
with pytest.raises(ValueError):
cost.show_report(unit='unknown')
with pytest.raises(ValueError):
cost.show_report(n_digits=-1)


def test_nest():
x = chainer.Variable(np.zeros((1, 3, 32, 32)).astype(np.float32))
c = chainer.Variable(np.ones((1, 3, 32, 32)).astype(np.float32))
Expand Down

0 comments on commit f36d23e

Please sign in to comment.