Skip to content

Commit

Permalink
Add explicit spark session termination after tests finish. Rename `sp…
Browse files Browse the repository at this point in the history
…ark_session` to `spark`. Add docs.
  • Loading branch information
mansenfranzen committed Feb 21, 2019
1 parent e6cd630 commit ce15357
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def pytest_collection_modifyitems(config, items):


@pytest.fixture(scope="session")
def spark_session():
def spark(request):
"""Provide session wide Spark Session to avoid expensive recreation for
each test.
Expand All @@ -66,6 +66,10 @@ def spark_session():

try:
from pyspark.sql import SparkSession
return SparkSession.builder.getOrCreate()
spark = SparkSession.builder.getOrCreate()

request.addfinalizer(lambda: spark.stop())
return spark

except ImportError:
pytest.skip("Pyspark not available.")
10 changes: 6 additions & 4 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def test_java_environment():
"""Pyspark requires Java to be available. It uses Py4J to start and
communicate with the JVM. Py4J looks for JAVA_HOME or falls back calling
java directly. This test explicitly checks for the java prerequisites for
pyspark to work correctly.
pyspark to work correctly. If errors occur regarding the instantiation of
a spark session, this test helps to rule out potential java related causes.
"""

Expand All @@ -30,7 +31,8 @@ def test_java_environment():

@pytest.mark.pyspark
def test_pyspark_import():
"""Fail if pyspark can't be imported.
"""Fail if pyspark can't be imported. This test is mandatory because other
spark tests will be skipped if the spark session fixture fails.
"""

Expand All @@ -42,7 +44,7 @@ def test_pyspark_import():


@pytest.mark.pyspark
def test_pyspark_pandas_interaction(spark_session):
def test_pyspark_pandas_interaction(spark):
"""Check simple interaction between pyspark and pandes.
"""
Expand All @@ -51,7 +53,7 @@ def test_pyspark_pandas_interaction(spark_session):
import numpy as np

df_pandas = pd.DataFrame(np.random.rand(10, 2), columns=["a", "b"])
df_spark = spark_session.createDataFrame(df_pandas)
df_spark = spark.createDataFrame(df_pandas)
df_converted = df_spark.toPandas()

pd.testing.assert_frame_equal(df_pandas, df_converted)

0 comments on commit ce15357

Please sign in to comment.