Skip to content

Commit

Permalink
Link existing tests with PySpark backend (ibis-project#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss committed Jul 23, 2019
1 parent 56b421a commit 4afec31
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 24 deletions.
33 changes: 23 additions & 10 deletions conftest.py
Expand Up @@ -23,14 +23,15 @@ def data_directory():


@pytest.fixture(scope='session')
def spark_client_testing(data_directory):
def spark_session(data_directory):
pytest.importorskip('pyspark')

import pyspark.sql.types as pt
from pyspark.sql import SparkSession

client = ibis.spark.connect()
spark = SparkSession.builder.getOrCreate()

df_functional_alltypes = client._session.read.csv(
df_functional_alltypes = spark.read.csv(
path=str(data_directory / 'functional_alltypes.csv'),
schema=pt.StructType([
pt.StructField('index', pt.IntegerType(), True),
Expand All @@ -57,7 +58,7 @@ def spark_client_testing(data_directory):
"bool_col", df_functional_alltypes["bool_col"].cast("boolean"))
df_functional_alltypes.createOrReplaceTempView('functional_alltypes')

df_batting = client._session.read.csv(
df_batting = spark.read.csv(
path=str(data_directory / 'batting.csv'),
schema=pt.StructType([
pt.StructField('playerID', pt.StringType(), True),
Expand Down Expand Up @@ -87,7 +88,7 @@ def spark_client_testing(data_directory):
)
df_batting.createOrReplaceTempView('batting')

df_awards_players = client._session.read.csv(
df_awards_players = spark.read.csv(
path=str(data_directory / 'awards_players.csv'),
schema=pt.StructType([
pt.StructField('playerID', pt.StringType(), True),
Expand All @@ -101,16 +102,16 @@ def spark_client_testing(data_directory):
)
df_awards_players.createOrReplaceTempView('awards_players')

df_simple = client._session.createDataFrame([(1, 'a')], ['foo', 'bar'])
df_simple = spark.createDataFrame([(1, 'a')], ['foo', 'bar'])
df_simple.createOrReplaceTempView('simple')

df_struct = client._session.createDataFrame(
df_struct = spark.createDataFrame(
[((1, 2, 'a'),)],
['struct_col']
)
df_struct.createOrReplaceTempView('struct')

df_nested_types = client._session.createDataFrame(
df_nested_types = spark.createDataFrame(
[
(
[1, 2],
Expand All @@ -126,10 +127,22 @@ def spark_client_testing(data_directory):
)
df_nested_types.createOrReplaceTempView('nested_types')

df_complicated = client._session.createDataFrame(
df_complicated = spark.createDataFrame(
[({(1, 3) : [[2, 4], [3, 5]]},)],
['map_tuple_list_of_list_of_ints']
)
df_complicated.createOrReplaceTempView('complicated')

return client
return spark


@pytest.fixture(scope='session')
def spark_client_testing(spark_session):
pytest.importorskip('pyspark')
return ibis.spark.connect(spark_session)


@pytest.fixture(scope='session')
def pyspark_client_testing(spark_session):
pytest.importorskip('pyspark')
return ibis.pyspark.connect(spark_session)
10 changes: 8 additions & 2 deletions ibis/pyspark/api.py
@@ -1,12 +1,18 @@
from ibis.pyspark.client import PysparkClient


def connect(**kwargs):
def connect(session):
"""
Create a `SparkClient` for use with Ibis. Pipes **kwargs into SparkClient,
which pipes them into SparkContext. See documentation for SparkContext:
https://spark.apache.org/docs/latest/api/python/_modules/pyspark/context.html#SparkContext
"""
client = PysparkClient(**kwargs)
client = PysparkClient(session)

# Spark internally stores timestamps as UTC values, and timestamp data that
# is brought in without a specified time zone is converted as local time to
# UTC with microsecond resolution.
# https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics
client._session.conf.set('spark.sql.session.timeZone', 'UTC')

return client
9 changes: 5 additions & 4 deletions ibis/pyspark/client.py
@@ -1,6 +1,7 @@
from ibis.spark.client import SparkClient
from ibis.pyspark.operations import PysparkTable
from ibis.pyspark.compiler import translate
from ibis.pyspark.operations import PysparkTable
from ibis.spark.client import SparkClient


class PysparkClient(SparkClient):
"""
Expand All @@ -15,5 +16,5 @@ def compile(self, expr, *args, **kwargs):
"""
return translate(expr)

def execute(self, df, params=None, limit='default', **kwargs):
return df.toPandas()
def execute(self, expr, params=None, limit='default', **kwargs):
return self.compile(expr).toPandas()
4 changes: 2 additions & 2 deletions ibis/spark/api.py
Expand Up @@ -2,13 +2,13 @@
from ibis.spark.compiler import dialect # noqa: F401


def connect(**kwargs):
def connect(spark_session):
"""
Create a `SparkClient` for use with Ibis. Pipes **kwargs into SparkClient,
which pipes them into SparkContext. See documentation for SparkContext:
https://spark.apache.org/docs/latest/api/python/_modules/pyspark/context.html#SparkContext
"""
client = SparkClient(**kwargs)
client = SparkClient(spark_session)

# Spark internally stores timestamps as UTC values, and timestamp data that
# is brought in without a specified time zone is converted as local time to
Expand Down
8 changes: 4 additions & 4 deletions ibis/spark/client.py
Expand Up @@ -165,10 +165,10 @@ class SparkClient(SQLClient):
query_class = SparkQuery
table_class = SparkTable

def __init__(self, **kwargs):
self._context = ps.SparkContext(**kwargs)
self._session = ps.sql.SparkSession(self._context)
self._catalog = self._session.catalog
def __init__(self, session):
self._context = session.sparkContext
self._session = session
self._catalog = session.catalog

def close(self):
"""
Expand Down
11 changes: 9 additions & 2 deletions ibis/tests/all/conftest.py
Expand Up @@ -3,7 +3,7 @@
import pytest

import ibis.common as com
from ibis.tests.backends import Backend, Spark
from ibis.tests.backends import Backend, PySpark, Spark


def subclasses(cls):
Expand Down Expand Up @@ -120,9 +120,16 @@ def pytest_pyfunc_call(pyfuncitem):


@pytest.fixture(params=params_backend, scope='session')
def backend(request, data_directory, spark_client_testing):
def backend(
request,
data_directory,
spark_client_testing,
pyspark_client_testing
):
if request.param is Spark:
Spark.client_testing = spark_client_testing
elif request.param is PySpark:
PySpark.client_testing = pyspark_client_testing
return request.param(data_directory)


Expand Down
10 changes: 10 additions & 0 deletions ibis/tests/backends.py
Expand Up @@ -533,3 +533,13 @@ def skip_if_missing_dependencies() -> None:
@staticmethod
def connect(data_directory):
return Spark.client_testing


class PySpark(Backend, RoundHalfToEven):
@staticmethod
def skip_if_missing_dependencies() -> None:
pass

@staticmethod
def connect(data_directory):
return PySpark.client_testing

0 comments on commit 4afec31

Please sign in to comment.