Skip to content

Commit

Permalink
[MRG]solve text preprocessor error when no gpu (#388)
Browse files Browse the repository at this point in the history
* solve text preprocessor error when no gpu

* change the test error from RuntimeError to Exception

* directly delete the device select

* remove the gputils import
  • Loading branch information
boyuangong authored and haifeng-jin committed Jan 4, 2019
1 parent 5ed6bb4 commit dd97602
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 34 deletions.
51 changes: 21 additions & 30 deletions autokeras/text/text_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import re

import GPUtil
import numpy as np

from autokeras.constant import Constant
Expand Down Expand Up @@ -79,7 +78,7 @@ def read_embedding_index(extract_path):
embedding_index: Dictionary contains word with pre trained index.
"""
embedding_index = {}
f = open(os.path.join(extract_path, Constant.PRE_TRAIN_FILE_NAME))
f = open(os.path.join(extract_path, Constant.PRE_TRAIN_FILE_NAME), encoding="utf-8")
for line in f:
values = line.split()
word = values[0]
Expand Down Expand Up @@ -138,34 +137,26 @@ def processing(path, word_index, input_length, x_train):

embedding_matrix = load_pretrain(path=path, word_index=word_index)

# Get the first available GPU
device_id_list = GPUtil.getFirstAvailable()
device_id = device_id_list[0] # grab first element from list

# Set CUDA_VISIBLE_DEVICES to mask out all other GPUs than the first available device id
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
device = '/gpu:0'
with tf.device(device):
from keras import Input, Model
from keras import backend
from keras.layers import Embedding
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
backend.set_session(sess)
print("generating preprocessing model...")
embedding_layer = Embedding(len(word_index) + 1,
Constant.EMBEDDING_DIM,
weights=[embedding_matrix],
input_length=input_length,
trainable=False)

sequence_input = Input(shape=(input_length,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
model = Model(sequence_input, embedded_sequences)
print("converting text to vector...")
x_train = model.predict(x_train)
del model
from keras import Input, Model
from keras import backend
from keras.layers import Embedding
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
backend.set_session(sess)
print("generating preprocessing model...")
embedding_layer = Embedding(len(word_index) + 1,
Constant.EMBEDDING_DIM,
weights=[embedding_matrix],
input_length=input_length,
trainable=False)

sequence_input = Input(shape=(input_length,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
model = Model(sequence_input, embedded_sequences)
print("converting text to vector...")
x_train = model.predict(x_train)
del model

return x_train

Expand Down
2 changes: 1 addition & 1 deletion examples/text_cnn/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ def read_csv(file_path):
file_path = "labeledTrainData.tsv"
x_train, y_train = read_csv(file_path=file_path)
clf = TextClassifier(verbose=True)
clf.fit(x=x_train, y=y_train, batch_size=10, time_limit=12 * 60 * 60)
clf.fit(x=x_train, y=y_train, time_limit=12 * 60 * 60)
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
'tensorflow==1.10.0',
'imageio==2.4.1',
'requests==2.20.1',
'GPUtil==1.3.0',
'lightgbm==2.2.2',
'pandas==0.23.4',
'opencv-python==3.4.4.19'],
Expand Down
3 changes: 1 addition & 2 deletions tests/text/test_text_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def test_load_pretrain(_, _1):
assert (embedding_matrix[1] == embedding_index.get("bar")).all()


@patch('autokeras.text.text_preprocessor.GPUtil.getFirstAvailable', return_value=[0])
@patch('autokeras.text.text_preprocessor.load_pretrain', side_effect=mock_load_pretrain)
def test_processing(_, _1):
def test_processing(_):
train_x = np.full((1, 2), 1)
train_x = processing(TEST_TEMP_DIR, word_index, 2, train_x)
assert np.allclose(train_x[0][0], embedding_matrix[1])
Expand Down

0 comments on commit dd97602

Please sign in to comment.