Skip to content

Commit

Permalink
Merge pull request #445 from mv1388/test-util-copy-fn
Browse files Browse the repository at this point in the history
Test util copy_fn
  • Loading branch information
mv1388 committed Apr 10, 2020
2 parents ad71294 + 1264c43 commit 15a43c0
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
Binary file modified dist/aitoolbox-1.0.tar.gz
Binary file not shown.
77 changes: 77 additions & 0 deletions tests/test_utils/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,45 @@ def test_function_exists(self):
self.assertFalse(util.function_exists(DummyOptimizer(), 'zero_grad_ctr'))
self.assertFalse(util.function_exists(DummyOptimizer(), 'step_ctr'))

def test_copy_function(self):
src_fn_obj = SourceFn()
my_fn_copy = util.copy_function(src_fn_obj.my_fn)
my_fn_input_copy = util.copy_function(src_fn_obj.my_fn_input)

self.assertEqual(my_fn_copy(None), 'my_fn_return_value')
self.assertEqual(my_fn_input_copy(None, 'Value 1'), 'my_fn_return_value: Value 1')

def test_copy_function_another_object(self):
src_fn_obj = SourceFn()
my_fn_copy = util.copy_function(src_fn_obj.my_fn)
my_fn_input_copy = util.copy_function(src_fn_obj.my_fn_input)

target_fn_obj = TargetFnCopy(my_fn_copy, my_fn_input_copy)

self.assertEqual(target_fn_obj.copy_my_fn(), 'my_fn_return_value')
self.assertEqual(target_fn_obj.copy_my_fn_input('Value 2'), 'my_fn_return_value: Value 2')

def test_copy_function_another_object_fn_call_another_fn(self):
src_fn_obj = SourceFnCallAnotherFn()
my_fn_copy = util.copy_function(src_fn_obj.my_fn)
my_fn_input_copy = util.copy_function(src_fn_obj.my_fn_input)

target_fn_obj = TargetFnCopy(my_fn_copy, my_fn_input_copy)

self.assertEqual(target_fn_obj.copy_my_fn(), 'my_fn_return_value: MyValue_another_fn_call')
self.assertEqual(target_fn_obj.copy_my_fn_input('Value 2'), 'my_fn_return_value: Value 2')

def test_copy_function_another_object_access_attribute(self):
src_fn_obj = SourceFnAccessAttrVal()
my_fn_copy = util.copy_function(src_fn_obj.my_fn)
my_fn_input_copy = util.copy_function(src_fn_obj.my_fn_input)

target_fn_obj = TargetFnCopy(my_fn_copy, my_fn_input_copy, attribute_val='my_attribute_value')

self.assertEqual(target_fn_obj.copy_my_fn(), 'my_fn attribute value: my_attribute_value')
self.assertEqual(target_fn_obj.copy_my_fn_input('Value 2'),
'my_fn_input attribute value: my_attribute_value; fn input value Value 2')

def test_is_empty_function(self):
def empty_fn():
pass
Expand Down Expand Up @@ -89,3 +128,41 @@ def full_fn_arg(a):
def full_fn_arg_sum(self, a, b):
c = a + b + self.a
return c


class SourceFn:
def my_fn(self):
return 'my_fn_return_value'

def my_fn_input(self, value):
return f'my_fn_return_value: {value}'


class SourceFnCallAnotherFn:
def my_fn(self):
return self.copy_my_fn_input('MyValue_another_fn_call')

def my_fn_input(self, value):
return f'my_fn_return_value: {value}'


class SourceFnAccessAttrVal:
def my_fn(self):
return f'my_fn attribute value: {self.attribute_val}'

def my_fn_input(self, value):
return f'my_fn_input attribute value: {self.attribute_val}; fn input value {value}'


class TargetFnCopy:
def __init__(self, source_my_fn, source_my_fn_input, attribute_val='my_attribute_value'):
self.source_my_fn = source_my_fn
self.source_my_fn_input = source_my_fn_input

self.attribute_val = attribute_val

def copy_my_fn(self):
return self.source_my_fn(self)

def copy_my_fn_input(self, value):
return self.source_my_fn_input(self, value)

0 comments on commit 15a43c0

Please sign in to comment.