Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify spark session fixture usage #10915

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 2 additions & 6 deletions tests/evaluate/test_evaluation.py
Expand Up @@ -58,6 +58,8 @@
from mlflow.utils import insecure_hash
from mlflow.utils.file_utils import TempDir

from tests.utils.test_file_utils import spark_session # noqa: F401


def get_iris():
iris = sklearn.datasets.load_iris()
Expand Down Expand Up @@ -108,12 +110,6 @@ def get_local_artifact_path(run_id, artifact_path):
return get_artifact_uri(run_id, artifact_path).replace("file://", "")


@pytest.fixture(scope="module")
def spark_session():
with SparkSession.builder.master("local[*]").getOrCreate() as session:
yield session


@pytest.fixture(scope="module")
def iris_dataset():
X, y = get_iris()
Expand Down
8 changes: 1 addition & 7 deletions tests/spark/autologging/ml/test_pyspark_ml_autologging.py
Expand Up @@ -27,7 +27,6 @@
from pyspark.ml.linalg import Vectors
from pyspark.ml.regression import LinearRegression, LinearRegressionModel
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, TrainValidationSplit
from pyspark.sql import SparkSession

import mlflow
from mlflow import MlflowClient
Expand All @@ -52,17 +51,12 @@
)

from tests.helper_functions import AnyStringWith
from tests.utils.test_file_utils import spark_session # noqa: F401

MODEL_DIR = "model"
MLFLOW_PARENT_RUN_ID = "mlflow.parentRunId"


@pytest.fixture(scope="module")
def spark_session():
with SparkSession.builder.master("local[*]").getOrCreate() as session:
yield session


@pytest.fixture(scope="module")
def dataset_binomial(spark_session):
return spark_session.createDataFrame(
Expand Down
Expand Up @@ -30,6 +30,7 @@
expect_status_code,
pyfunc_serve_and_score_model,
)
from tests.utils.test_file_utils import spark_session # noqa: F401

IS_TENSORFLOW_AVAILABLE = _is_available_on_pypi("tensorflow")
EXTRA_PYFUNC_SERVING_TEST_ARGS = [] if IS_TENSORFLOW_AVAILABLE else ["--env-manager", "local"]
Expand All @@ -51,14 +52,6 @@ def data():
return x, y


@pytest.fixture(scope="module")
def spark_session():
from pyspark.sql import SparkSession

with SparkSession.builder.master("local[2]").getOrCreate() as session:
yield session


@pytest.fixture(scope="module")
def single_tensor_input_model(data):
x, y = data
Expand Down