From 9cc023403081c2c411d4c4c9fd7fd654ceb2651d Mon Sep 17 00:00:00 2001 From: Haifeng Jin Date: Thu, 19 Apr 2018 15:55:26 -0500 Subject: [PATCH] limit memory --- autokeras/classifier.py | 9 +++++++++ autokeras/utils.py | 8 -------- tests/test_classifier.py | 1 + tests/test_utils.py | 1 - 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/autokeras/classifier.py b/autokeras/classifier.py index 60ad5004e..5fb3339f1 100644 --- a/autokeras/classifier.py +++ b/autokeras/classifier.py @@ -4,10 +4,12 @@ import csv import errno import time +import tensorflow as tf import scipy.ndimage as ndimage import numpy as np +from keras import backend from sklearn.metrics import accuracy_score from sklearn.model_selection import train_test_split, StratifiedKFold @@ -32,6 +34,13 @@ def _validate(x_train, y_train): def run_searcher_once(x_train, y_train, x_test, y_test, path): + if constant.LIMIT_MEMORY: + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + sess = tf.Session(config=config) + init = tf.global_variables_initializer() + sess.run(init) + backend.set_session(sess) searcher = pickle_from_file(os.path.join(path, 'searcher')) searcher.search(x_train, y_train, x_test, y_test) diff --git a/autokeras/utils.py b/autokeras/utils.py index 8b174b708..c52cdb238 100644 --- a/autokeras/utils.py +++ b/autokeras/utils.py @@ -2,7 +2,6 @@ import pickle import numpy as np -import tensorflow as tf from keras import backend from keras.callbacks import Callback, LearningRateScheduler, ReduceLROnPlateau from keras.losses import categorical_crossentropy @@ -164,13 +163,6 @@ def train_model(self): min_lr=0.5e-6) callbacks = [terminator, lr_scheduler, lr_reducer] - if constant.LIMIT_MEMORY: - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - sess = tf.Session(config=config) - init = tf.global_variables_initializer() - sess.run(init) - backend.set_session(sess) try: if constant.DATA_AUGMENTATION: flow = self.datagen.flow(self.x_train, self.y_train, batch_size) diff --git a/tests/test_classifier.py b/tests/test_classifier.py index 2557b5944..65d285f4a 100644 --- a/tests/test_classifier.py +++ b/tests/test_classifier.py @@ -77,6 +77,7 @@ def test_timout(_): @patch('multiprocessing.Process', new=MockProcess) def test_final_fit(): + constant.LIMIT_MEMORY = True path = 'tests/resources/temp' clean_dir(path) clf = ImageClassifier(path=path, verbose=False) diff --git a/tests/test_utils.py b/tests/test_utils.py index 23d24cd72..3562504ce 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,7 +12,6 @@ def test_model_trainer(): def test_model_trainer_not_augmented(): constant.DATA_AUGMENTATION = False - constant.LIMIT_MEMORY = True model = RandomConvClassifierGenerator(3, (28, 28, 1)).generate() ModelTrainer(model, np.random.rand(2, 28, 28, 1), np.random.rand(2, 3), np.random.rand(1, 28, 28, 1), np.random.rand(1, 3), False).train_model()