Skip to content

Commit

Permalink
Merge d010419 into db061c1
Browse files Browse the repository at this point in the history
  • Loading branch information
yufei-12 committed Sep 20, 2019
2 parents db061c1 + d010419 commit 4114e4b
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 9 deletions.
1 change: 0 additions & 1 deletion autokeras/hypermodel/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,6 @@ class FeatureEngineering(Preprocessor):
"""

def __init__(self, max_columns=1000, **kwargs):

super().__init__(**kwargs)
self.input_node = None
self.max_columns = max_columns
Expand Down
4 changes: 4 additions & 0 deletions autokeras/meta_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def assemble(self, input_node):
self.infer_column_types()
if input_node.column_types is None:
input_node.column_types = self.column_types
# partial column_types is provided.
for key, value in self.column_types.items():
if key not in input_node.column_types:
input_node.column_types[key] = value
return hyperblock.StructuredDataBlock()(input_node)


Expand Down
11 changes: 8 additions & 3 deletions autokeras/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ def __init__(self, outputs, column_names, column_types, **kwargs):
inputs = node.StructuredDataInput()
inputs.column_types = column_types
inputs.column_names = column_names
if column_types:
if not all([column_type in ['categorical', 'numerical']
for column_type in column_types.values()]):
raise ValueError(
'Column_types should be either "categorical" or "numerical".')
if column_names and column_types:
if not all([column_name in column_names
for column_name in column_types]):
Expand All @@ -185,10 +190,10 @@ def __init__(self, outputs, column_names, column_types, **kwargs):
**kwargs)

def fit(self,
x=None, # file path of training data
y=None, # label name
x=None,
y=None,
validation_split=0,
validation_data=None, # file path of validataion data
validation_data=None,
**kwargs):
"""
# Arguments
Expand Down
16 changes: 16 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@
'deck_': 'categorical',
'embark_town_': 'categorical',
'alone_': 'categorical'}
false_column_types_from_csv = {
'sex_': 'cat',
'age_': 'num',
'n_siblings_spouses_': 'cat',
'parch_': 'categorical',
'fare_': 'numerical',
'class_': 'categorical',
'deck_': 'categorical',
'embark_town_': 'categorical',
'alone_': 'categorical'}
partial_column_types_from_csv = {
'fare': 'numerical',
'class': 'categorical',
'deck': 'categorical',
'embark_town': 'categorical',
'alone': 'categorical'}


def structured_data(num_data=500):
Expand Down
32 changes: 27 additions & 5 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from tests.common import column_names_from_numpy
from tests.common import column_types_from_csv
from tests.common import column_types_from_numpy
from tests.common import false_column_types_from_csv
from tests.common import less_column_names_from_csv
from tests.common import partial_column_types_from_csv
from tests.common import structured_data


train_file_path = r'tests/resources/titanic/train.csv'
test_file_path = r'tests/resources/titanic/eval.csv'

Expand Down Expand Up @@ -88,7 +89,7 @@ def test_structured_data_from_numpy_regressor(tmp_dir):
x_train = data
y = np.random.rand(num_data, 1)
y_train = y
clf = ak.StructuredDataRegressor(directory=tmp_dir, max_trials=2)
clf = ak.StructuredDataRegressor(directory=tmp_dir, max_trials=1)
clf.fit(x_train, y_train, epochs=2, validation_data=(x_train, y_train))


Expand Down Expand Up @@ -125,7 +126,7 @@ def test_structured_data_from_numpy_col_type_classifier(tmp_dir):
clf = ak.StructuredDataClassifier(
column_types=column_types_from_numpy,
directory=tmp_dir,
max_trials=2)
max_trials=1)
clf.fit(x_train, y_train, epochs=2, validation_data=(x_train, y_train))
assert str(info.value) == 'Column names must be specified.'

Expand All @@ -140,7 +141,7 @@ def test_structured_data_from_numpy_col_name_type_classifier(tmp_dir):
column_names=column_names_from_numpy,
column_types=column_types_from_numpy,
directory=tmp_dir,
max_trials=2)
max_trials=1)
clf.fit(x_train, y_train, epochs=2, validation_data=(x_train, y_train))


Expand All @@ -151,7 +152,7 @@ def test_structured_data_classifier_transform_new_data(tmp_dir):
x_train, x_test = data[:num_train], data[num_train:]
y = np.random.randint(0, 3, num_data)
y_train, y_test = y[:num_train], y[num_train:]
clf = ak.StructuredDataClassifier(directory=tmp_dir, max_trials=2)
clf = ak.StructuredDataClassifier(directory=tmp_dir, max_trials=1)
clf.fit(x_train, y_train, epochs=2, validation_data=(x_test, y_test))


Expand Down Expand Up @@ -205,3 +206,24 @@ def test_structured_data_from_csv_col_type_mismatch_classifier(tmp_dir):
clf.fit(x=train_file_path, y='survived', epochs=2,
validation_data=test_file_path)
assert str(info.value) == 'Column_names and column_types are mismatched.'


def test_structured_data_from_csv_false_col_type_classifier(tmp_dir):
with pytest.raises(ValueError) as info:
clf = ak.StructuredDataClassifier(
column_types=false_column_types_from_csv,
directory=tmp_dir,
max_trials=1)
clf.fit(x=train_file_path, y='survived', epochs=2,
validation_data=test_file_path)
assert str(info.value) == \
'Column_types should be either "categorical" or "numerical".'


def test_structured_data_from_csv_partial_col_type_classifier(tmp_dir):
clf = ak.StructuredDataClassifier(
column_types=partial_column_types_from_csv,
directory=tmp_dir,
max_trials=1)
clf.fit(x=train_file_path, y='survived', epochs=2,
validation_data=test_file_path)

0 comments on commit 4114e4b

Please sign in to comment.