Skip to content
This repository has been archived by the owner on Jan 9, 2024. It is now read-only.

Commit

Permalink
fix the command line option parsing for the multiprocess flag. (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhang-gp committed Oct 11, 2019
1 parent ccd2450 commit e3dbb42
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 21 deletions.
34 changes: 24 additions & 10 deletions foreshadow/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,26 @@
# flake8: noqa
# isort: noqa
import argparse
import json
import sys
import warnings

import pandas as pd
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.model_selection import train_test_split

from foreshadow.config import config
from foreshadow.estimators import AutoEstimator
from foreshadow.foreshadow import Foreshadow


def generate_model(args): # noqa: C901
"""Process command line args and generate a Foreshadow model to fit.
def process_argument(args): # noqa: C901
"""Process command line args.
Args:
args (list): A list of string arguments to process
Returns:
tuple: A tuple of `fs, X_train, y_train, X_test, y_test` which \
represents the foreshadow model along with the split data.
Raises:
ValueError: if invalid file or invalid y.
cargs: processed arguments from the parser
"""
parser = argparse.ArgumentParser(
Expand All @@ -49,7 +44,7 @@ def generate_model(args): # noqa: C901
parser.add_argument(
"--multiprocess",
default=False,
type=bool,
action="store_true",
help="Whether to enable multiprocessing on the dataset, useful for "
"large datasets and/or computational heavy transformations.",
)
Expand Down Expand Up @@ -90,6 +85,25 @@ def generate_model(args): # noqa: C901
)
cargs = parser.parse_args(args)

return cargs


def generate_model(args): # noqa: C901
"""Process command line args and generate a Foreshadow model to fit.
Args:
args (list): A list of string arguments to process
Returns:
tuple: A tuple of `fs, X_train, y_train, X_test, y_test` which \
represents the foreshadow model along with the split data.
Raises:
ValueError: if invalid file or invalid y.
"""
cargs = process_argument(args)

if cargs.level == 3 and cargs.method is not None:
warnings.warn(
"WARNING: Level 3 model search enabled. Method will be ignored."
Expand Down
14 changes: 14 additions & 0 deletions foreshadow/tests/test_console.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,17 @@ def test_console_get_method_error():
get_method("InvalidRegression", None)

assert "Invalid method." in str(e.value)


def test_console_parse_args_multiprocess():
from foreshadow.console import process_argument

data_path = get_file_path("data", "boston_housing.csv")

args = ["--level", "1", data_path, "medv", "regression"]
cargs = process_argument(args)
assert cargs.multiprocess is False

args = ["--level", "1", "--multiprocess", data_path, "medv", "regression"]
cargs = process_argument(args)
assert cargs.multiprocess is True
23 changes: 12 additions & 11 deletions foreshadow/tests/test_foreshadow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
from foreshadow.utils.testing import get_file_path


def check_slow():
import os

return os.environ.get("FORESHADOW_TESTS") != "ALL"


slow = pytest.mark.skipif(
check_slow(), reason="Skipping long-runnning integration tests"
)


def test_foreshadow_defaults():
from foreshadow.foreshadow import Foreshadow
from foreshadow.preparer import DataPreparer
Expand Down Expand Up @@ -736,6 +747,7 @@ def test_foreshadow_serialization_adults_small_classification():
assertions.assertAlmostEqual(score1, score2, places=7)


@slow
def test_foreshadow_serialization_adults_classification():
from foreshadow.foreshadow import Foreshadow
import pandas as pd
Expand Down Expand Up @@ -811,17 +823,6 @@ def test_foreshadow_serialization_boston_housing_regression():
assertions.assertAlmostEqual(score1, score2, places=7)


def check_slow():
import os

return os.environ.get("FORESHADOW_TESTS") != "ALL"


slow = pytest.mark.skipif(
check_slow(), reason="Skipping long-runnning integration tests"
)


@slow
def test_foreshadow_serialization_tpot():
from foreshadow.foreshadow import Foreshadow
Expand Down

0 comments on commit e3dbb42

Please sign in to comment.