Skip to content

Commit

Permalink
Moving back model_uri to predict
Browse files Browse the repository at this point in the history
Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com>
Use spark_session fixture from test_file_utils instead of duplication

Signed-off-by: Thomas Coquereau <thomas.coquereau@klm.com>
  • Loading branch information
Cokral committed Jan 26, 2024
1 parent 5d13d6e commit 900cf55
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 21 deletions.
7 changes: 1 addition & 6 deletions tests/evaluate/test_evaluation.py
Expand Up @@ -57,6 +57,7 @@
from mlflow.tracking.artifact_utils import get_artifact_uri
from mlflow.utils import insecure_hash
from mlflow.utils.file_utils import TempDir
from tests.utils.test_file_utils import spark_session # pylint: disable=unused-import


def get_iris():
Expand Down Expand Up @@ -108,12 +109,6 @@ def get_local_artifact_path(run_id, artifact_path):
return get_artifact_uri(run_id, artifact_path).replace("file://", "")


@pytest.fixture(scope="module")
def spark_session():
with SparkSession.builder.master("local[*]").getOrCreate() as session:
yield session


@pytest.fixture(scope="module")
def iris_dataset():
X, y = get_iris()
Expand Down
8 changes: 1 addition & 7 deletions tests/spark/autologging/ml/test_pyspark_ml_autologging.py
Expand Up @@ -27,7 +27,6 @@
from pyspark.ml.linalg import Vectors
from pyspark.ml.regression import LinearRegression, LinearRegressionModel
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, TrainValidationSplit
from pyspark.sql import SparkSession

import mlflow
from mlflow import MlflowClient
Expand All @@ -52,17 +51,12 @@
)

from tests.helper_functions import AnyStringWith
from tests.utils.test_file_utils import spark_session # pylint: disable=unused-import

MODEL_DIR = "model"
MLFLOW_PARENT_RUN_ID = "mlflow.parentRunId"


@pytest.fixture(scope="module")
def spark_session():
with SparkSession.builder.master("local[*]").getOrCreate() as session:
yield session


@pytest.fixture(scope="module")
def dataset_binomial(spark_session):
return spark_session.createDataFrame(
Expand Down
Expand Up @@ -30,6 +30,7 @@
expect_status_code,
pyfunc_serve_and_score_model,
)
from tests.utils.test_file_utils import spark_session # pylint: disable=unused-import

IS_TENSORFLOW_AVAILABLE = _is_available_on_pypi("tensorflow")
EXTRA_PYFUNC_SERVING_TEST_ARGS = [] if IS_TENSORFLOW_AVAILABLE else ["--env-manager", "local"]
Expand All @@ -51,14 +52,6 @@ def data():
return x, y


@pytest.fixture(scope="module")
def spark_session():
from pyspark.sql import SparkSession

with SparkSession.builder.master("local[2]").getOrCreate() as session:
yield session


@pytest.fixture(scope="module")
def single_tensor_input_model(data):
x, y = data
Expand Down

0 comments on commit 900cf55

Please sign in to comment.