Skip to content

Commit

Permalink
bug fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Jan 5, 2018
1 parent 2fb723f commit f48545b
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def simple_transform(_):
@patch('autokeras.search.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=lambda: None)
def test_fit_predict(_, _1):
os.remove(os.path.join(constant.DEFAULT_SAVE_PATH, 'classifier'))
os.remove(os.path.join(constant.DEFAULT_SAVE_PATH, 'searcher'))
clear_dir()
constant.MAX_ITER_NUM = 2
constant.MAX_MODEL_NUM = 2
constant.EPOCHS_EACH = 1
Expand All @@ -47,6 +46,16 @@ def test_fit_predict(_, _1):
assert all(map(lambda result: result in np.array(['a', 'b']), results))


def clear_dir():
ensure_dir(constant.DEFAULT_SAVE_PATH)
directory = os.path.join(constant.DEFAULT_SAVE_PATH, 'classifier')
if os.path.exists(directory):
os.remove(directory)
directory = os.path.join(constant.DEFAULT_SAVE_PATH, 'searcher')
if os.path.exists(directory):
os.remove(directory)


def simple_transform2(_):
generator = RandomConvClassifierGenerator(input_shape=(25, 1), n_classes=5)
return [generator.generate(), generator.generate()]
Expand All @@ -55,8 +64,7 @@ def simple_transform2(_):
@patch('autokeras.search.transform', side_effect=simple_transform2)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=lambda: None)
def test_fit_predict2(_, _1):
os.remove(os.path.join(constant.DEFAULT_SAVE_PATH, 'classifier'))
os.remove(os.path.join(constant.DEFAULT_SAVE_PATH, 'searcher'))
clear_dir()
constant.MAX_ITER_NUM = 1
constant.MAX_MODEL_NUM = 1
constant.EPOCHS_EACH = 1
Expand Down

0 comments on commit f48545b

Please sign in to comment.