Skip to content

Commit

Permalink
limit memory
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Apr 19, 2018
1 parent a969a6e commit 9cc0234
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 9 deletions.
9 changes: 9 additions & 0 deletions autokeras/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import csv
import errno
import time
import tensorflow as tf

import scipy.ndimage as ndimage

import numpy as np
from keras import backend
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split, StratifiedKFold

Expand All @@ -32,6 +34,13 @@ def _validate(x_train, y_train):


def run_searcher_once(x_train, y_train, x_test, y_test, path):
if constant.LIMIT_MEMORY:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
init = tf.global_variables_initializer()
sess.run(init)
backend.set_session(sess)
searcher = pickle_from_file(os.path.join(path, 'searcher'))
searcher.search(x_train, y_train, x_test, y_test)

Expand Down
8 changes: 0 additions & 8 deletions autokeras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pickle
import numpy as np

import tensorflow as tf
from keras import backend
from keras.callbacks import Callback, LearningRateScheduler, ReduceLROnPlateau
from keras.losses import categorical_crossentropy
Expand Down Expand Up @@ -164,13 +163,6 @@ def train_model(self):
min_lr=0.5e-6)

callbacks = [terminator, lr_scheduler, lr_reducer]
if constant.LIMIT_MEMORY:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
init = tf.global_variables_initializer()
sess.run(init)
backend.set_session(sess)
try:
if constant.DATA_AUGMENTATION:
flow = self.datagen.flow(self.x_train, self.y_train, batch_size)
Expand Down
1 change: 1 addition & 0 deletions tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_timout(_):

@patch('multiprocessing.Process', new=MockProcess)
def test_final_fit():
constant.LIMIT_MEMORY = True
path = 'tests/resources/temp'
clean_dir(path)
clf = ImageClassifier(path=path, verbose=False)
Expand Down
1 change: 0 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def test_model_trainer():

def test_model_trainer_not_augmented():
constant.DATA_AUGMENTATION = False
constant.LIMIT_MEMORY = True
model = RandomConvClassifierGenerator(3, (28, 28, 1)).generate()
ModelTrainer(model, np.random.rand(2, 28, 28, 1), np.random.rand(2, 3), np.random.rand(1, 28, 28, 1),
np.random.rand(1, 3), False).train_model()

0 comments on commit 9cc0234

Please sign in to comment.