Skip to content

Commit

Permalink
Merge pull request #724 from mv1388/add_tests_to_improve_coverage
Browse files Browse the repository at this point in the history
Add new unittests to improve coverage
  • Loading branch information
mv1388 committed Aug 7, 2022
2 parents 70becd2 + 71c0e36 commit 98c56c0
Show file tree
Hide file tree
Showing 13 changed files with 524 additions and 12 deletions.
3 changes: 3 additions & 0 deletions aitoolbox/experiment/local_save/local_results_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def save_experiment_results_separate_files(self, result_package, training_histor
The first file path should be pointing to the main experiment results file.
"""
if experiment_timestamp is None:
experiment_timestamp = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d_%H-%M-%S')

experiment_results_local_path = self.create_experiment_local_folder_structure(project_name, experiment_name,
experiment_timestamp)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ def __init__(self, pkg_name=None, strict_content_check=False, np_array=True, **k
"""Base Result package used to derive specific result packages from
Functions which the user should potentially override in a specific result package:
- prepare_results_dict()
- list_additional_results_dump_paths()
- set_experiment_dir_path_for_additional_results()
* :meth:`aitoolbox.experiment.result_package.abstract_result_packages.AbstractResultPackage.prepare_results_dict`
* :meth:`aitoolbox.experiment.result_package.abstract_result_packages.AbstractResultPackage.list_additional_results_dump_paths`
* :meth:`aitoolbox.experiment.result_package.abstract_result_packages.AbstractResultPackage.set_experiment_dir_path_for_additional_results`
Args:
pkg_name (str or None): result package name used just for clarity
Expand Down Expand Up @@ -141,7 +142,7 @@ def get_additional_results_dump_paths(self):
"""Return paths to the additional results which are stored to local drive when the package is evaluated
For example if package plots attention heatmaps and saves pictures to disk, this function will return
paths to these picture files. This is achieved via the call to the use implemented function
paths to these picture files. This is achieved via the call to the user-implemented function
list_additional_results_dump_paths().
Returns:
Expand Down
4 changes: 2 additions & 2 deletions aitoolbox/torchtrain/train_loop/components/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def validate_msg_handling_settings(msg_handling_settings):
raise ValueError(f'Provided two incompatible msg_handling_settings {msg_handling_settings}. '
'Only OVERRIDE setting can currently be combined with another available setting.')
elif type(msg_handling_settings) != MessageHandling:
raise ValueError(f'Provided msg_handling_settings {msg_handling_settings} type not of the supported '
'MessageHandling or list of MessageHandling.')
raise TypeError(f'Provided msg_handling_settings {msg_handling_settings} type not of the supported '
'MessageHandling or list of MessageHandling.')

return msg_handling_settings if type(msg_handling_settings) is list else [msg_handling_settings]
17 changes: 17 additions & 0 deletions tests/test_experiment/test_local_experiment_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@ def get_hyperparameters(self):
return self.hyper_params


class DummyNonAbstractLocalModelSaver:
pass


class TestBaseFullExperimentLocalSaver(unittest.TestCase):
def test_init(self):
project_dir_name = 'projectPyTorchLocalModelSaver'
exp_dir_name = 'experimentSubDirPT'

with self.assertRaises(TypeError):
BaseFullExperimentLocalSaver(
DummyNonAbstractLocalModelSaver(),
project_name=project_dir_name, experiment_name=exp_dir_name,
local_model_result_folder_path=THIS_DIR
)


class TestFullPyTorchExperimentLocalSaver(unittest.TestCase):
def test_init(self):
project_dir_name = 'projectPyTorchLocalModelSaver'
Expand Down
180 changes: 179 additions & 1 deletion tests/test_experiment/test_local_save/test_local_results_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ def _build_epoch_list(self):


class DummyFullResultPackage(AbstractResultPackage):
def __init__(self, result_dict, hyper_params):
def __init__(self, result_dict, hyper_params, additional_results=None):
AbstractResultPackage.__init__(self, 'dummyFullPkg')
self.result_dict = result_dict
self.hyper_params = hyper_params
self.y_true = [10.0] * 100
self.y_predicted = [123.4] * 100

self.additional_results = additional_results

def prepare_results_dict(self):
return self.result_dict

Expand All @@ -40,6 +42,9 @@ def get_results(self):
def get_hyperparameters(self):
return self.hyper_params

def list_additional_results_dump_paths(self):
return self.additional_results


class TestBaseLocalResultsSaver(unittest.TestCase):
def test_init(self):
Expand Down Expand Up @@ -155,6 +160,33 @@ def save_file_result(self, file_format, expected_extension):
if os.path.exists(project_dir_path_true):
shutil.rmtree(project_dir_path_true)

def test_forced_unsupported_file_format_error(self):
project_dir_name = 'projectDir'
exp_dir_name = 'experimentSubDir'
current_time = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d_%H-%M-%S')
file_name = 'test_dump'

project_dir_path_true = os.path.join(THIS_DIR, project_dir_name)
experiment_dir_path_true = os.path.join(project_dir_path_true, f'{exp_dir_name}_{current_time}')
experiment_results_dir_path_true = os.path.join(experiment_dir_path_true, 'results')

result_dict = {'acc': 10, 'loss': 101010.2, 'rogue': 4445.5}

saver = BaseLocalResultsSaver(local_model_result_folder_path=THIS_DIR, file_format='my_fancy_format')

self.assertEqual(saver.file_format, 'pickle')

saver.create_experiment_local_folder_structure(project_dir_name, exp_dir_name, current_time)

# Force format re-set to unsupported my_fancy_format
saver.file_format = 'my_fancy_format'

with self.assertRaises(ValueError):
saver.save_file(result_dict, file_name, f'{experiment_results_dir_path_true}/{file_name}')

if os.path.exists(project_dir_path_true):
shutil.rmtree(project_dir_path_true)


class TestLocalResultsSaverSingleFile(unittest.TestCase):
def test_save_experiment_results_pickle(self):
Expand Down Expand Up @@ -220,6 +252,75 @@ def save_experiment_results(self, file_format, expected_extension, save_true_pre
if os.path.exists(project_path):
shutil.rmtree(project_path)

def test_experiment_timestamp_not_provided(self):
project_dir_name = 'projectDir'
exp_dir_name = 'experimentSubDir'
current_time = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d_%H-%M-%S')
file_format = 'pickle'
expected_extension = '.p'
result_pkg = DummyFullResultPackage({'metric1': 33434, 'acc1': 223.43, 'loss': 4455.6},
{'epoch': 20, 'lr': 0.334})
training_history = DummyTrainingHistory().wrap_pre_prepared_history({})
result_file_name_true = f'results_hyperParams_hist_{exp_dir_name}_{current_time}{expected_extension}'

project_path = os.path.join(THIS_DIR, project_dir_name)
exp_path = os.path.join(project_path, f'{exp_dir_name}_{current_time}')
results_path = os.path.join(exp_path, 'results')
result_file_path_true = os.path.join(results_path, result_file_name_true)

saver = LocalResultsSaver(local_model_result_folder_path=THIS_DIR, file_format=file_format)
experiment_results_paths = saver.save_experiment_results(
result_pkg, training_history,
project_dir_name, exp_dir_name
)

self.assertEqual(result_file_path_true, experiment_results_paths[0][1])

if os.path.exists(project_path):
shutil.rmtree(project_path)

def test_additional_results_dump_paths(self):
project_dir_name = 'projectDir'
exp_dir_name = 'experimentSubDir'
current_time = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d_%H-%M-%S')
file_format = 'pickle'
expected_extension = '.p'

training_history = DummyTrainingHistory().wrap_pre_prepared_history({})
result_file_name_true = f'results_hyperParams_hist_{exp_dir_name}_{current_time}{expected_extension}'

project_path = os.path.join(THIS_DIR, project_dir_name)
exp_path = os.path.join(project_path, f'{exp_dir_name}_{current_time}')
results_path = os.path.join(exp_path, 'results')
result_file_path_true = os.path.join(results_path, result_file_name_true)

additional_results_paths = [
['BLAAAAA.txt', os.path.join(results_path, 'BLAAAAA.txt')],
['uuuuu.p', os.path.join(results_path, 'uuuuu.p')],
['aaaaaa.json', os.path.join(results_path, 'aaaaaa.json')]
]

result_pkg = DummyFullResultPackage(
{'metric1': 33434, 'acc1': 223.43, 'loss': 4455.6},
{'epoch': 20, 'lr': 0.334},
additional_results=additional_results_paths
)

saver = LocalResultsSaver(local_model_result_folder_path=THIS_DIR, file_format=file_format)
experiment_results_paths = saver.save_experiment_results(
result_pkg, training_history,
project_dir_name, exp_dir_name,
current_time
)

self.assertEqual(
[[result_file_name_true, result_file_path_true]] + additional_results_paths,
experiment_results_paths
)

if os.path.exists(project_path):
shutil.rmtree(project_path)


class TestLocalResultsSaverSeparateFiles(unittest.TestCase):
def test_save_experiment_results_pickle(self):
Expand Down Expand Up @@ -318,3 +419,80 @@ def read_result_file(f_path):

if os.path.exists(project_path):
shutil.rmtree(project_path)

def test_experiment_timestamp_not_provided(self):
project_dir_name = 'projectDir'
exp_dir_name = 'experimentSubDir'
current_time = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d_%H-%M-%S')
file_format = 'pickle'
expected_extension = '.p'
result_pkg = DummyFullResultPackage({'metric1': 33434, 'acc1': 223.43, 'loss': 4455.6},
{'epoch': 20, 'lr': 0.334})
training_history = DummyTrainingHistory().wrap_pre_prepared_history({})
result_file_name_true = f'results_{exp_dir_name}_{current_time}{expected_extension}'

project_path = os.path.join(THIS_DIR, project_dir_name)
exp_path = os.path.join(project_path, f'{exp_dir_name}_{current_time}')
results_path = os.path.join(exp_path, 'results')
result_file_path_true = os.path.join(results_path, result_file_name_true)

saver = LocalResultsSaver(local_model_result_folder_path=THIS_DIR, file_format=file_format)
experiment_results_paths = saver.save_experiment_results_separate_files(
result_pkg, training_history,
project_dir_name, exp_dir_name
)

self.assertEqual(result_file_path_true, experiment_results_paths[0][1])

if os.path.exists(project_path):
shutil.rmtree(project_path)

def test_additional_results_dump_paths(self):
project_dir_name = 'projectDir'
exp_dir_name = 'experimentSubDir'
file_format = 'pickle'
expected_extension = '.p'
current_time = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d_%H-%M-%S')

training_history = DummyTrainingHistory().wrap_pre_prepared_history({})
result_file_name_true = f'results_{exp_dir_name}_{current_time}{expected_extension}'
hyper_param_file_name_true = f'hyperparams_{exp_dir_name}_{current_time}{expected_extension}'
train_hist_file_name_true = f'train_history_{exp_dir_name}_{current_time}{expected_extension}'

project_path = os.path.join(THIS_DIR, project_dir_name)
exp_path = os.path.join(project_path, f'{exp_dir_name}_{current_time}')
results_path = os.path.join(exp_path, 'results')
result_file_path_true = os.path.join(results_path, result_file_name_true)
hyper_param_file_path_true = os.path.join(results_path, hyper_param_file_name_true)
train_hist_file_path_true = os.path.join(results_path, train_hist_file_name_true)

additional_results_paths = [
['BLAAAAA.txt', os.path.join(results_path, 'BLAAAAA.txt')],
['uuuuu.p', os.path.join(results_path, 'uuuuu.p')],
['aaaaaa.json', os.path.join(results_path, 'aaaaaa.json')]
]

result_pkg = DummyFullResultPackage(
{'metric1': 33434, 'acc1': 223.43, 'loss': 4455.6},
{'epoch': 20, 'lr': 0.334},
additional_results=additional_results_paths
)

saver = LocalResultsSaver(local_model_result_folder_path=THIS_DIR, file_format=file_format)
experiment_results_paths = saver.save_experiment_results_separate_files(
result_pkg, training_history,
project_dir_name, exp_dir_name,
current_time
)

self.assertEqual(
[
[result_file_name_true, result_file_path_true],
[hyper_param_file_name_true, hyper_param_file_path_true],
[train_hist_file_name_true, train_hist_file_path_true]
] + additional_results_paths,
experiment_results_paths
)

if os.path.exists(project_path):
shutil.rmtree(project_path)
18 changes: 18 additions & 0 deletions tests/test_experiment/test_training_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,24 @@ def test__add_methods(self):
'NEW_METRIC': [13323.4, 133323.4], 'ADDITIONAL_metric': [122.3], 'addi': [344]}
)

with self.assertRaises(TypeError):
th + [123.4, 1223.4, 13323.4, 13323.4]

with self.assertRaises(TypeError):
th + 123.4

with self.assertRaises(TypeError):
[123.4, 1223.4, 13323.4, 13323.4] + th

with self.assertRaises(TypeError):
123.4 + th

with self.assertRaises(TypeError):
th += [123.4, 1223.4, 13323.4, 13323.4]

with self.assertRaises(TypeError):
th += 123.4

@staticmethod
def _build_dummy_history():
th = TrainingHistory()
Expand Down
13 changes: 12 additions & 1 deletion tests/test_torchtrain/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@
from random import randint

from torch.utils.data.dataloader import DataLoader
from aitoolbox.torchtrain.data.dataset import ListDataset
from aitoolbox.torchtrain.data.dataset import BasicDataset, ListDataset


class TestBasicDataset(unittest.TestCase):
def test_len(self):
ds = BasicDataset(list(range(100)))
self.assertEqual(len(ds), 100)

def test_get_item(self):
ds = BasicDataset(list(range(100)))
for i in range(100):
self.assertEqual(ds[i], i)


class TestListDataset(unittest.TestCase):
Expand Down

0 comments on commit 98c56c0

Please sign in to comment.