diff --git a/autokeras/utils.py b/autokeras/utils.py index 2645158d1..8b174b708 100644 --- a/autokeras/utils.py +++ b/autokeras/utils.py @@ -168,6 +168,8 @@ def train_model(self): config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) + init = tf.global_variables_initializer() + sess.run(init) backend.set_session(sess) try: if constant.DATA_AUGMENTATION: diff --git a/tests/test_utils.py b/tests/test_utils.py index 686cd6803..23d24cd72 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,8 +10,7 @@ def test_model_trainer(): np.random.rand(1, 3), False).train_model() -@patch('autokeras.utils.backend') -def test_model_trainer_not_augmented(_): +def test_model_trainer_not_augmented(): constant.DATA_AUGMENTATION = False constant.LIMIT_MEMORY = True model = RandomConvClassifierGenerator(3, (28, 28, 1)).generate()