Skip to content

Commit

Permalink
refactor(pyspark): remove custom implementation of cursors (#9161)
Browse files Browse the repository at this point in the history
Co-authored-by: Chloe He <chloe@chloe-mac.lan>
  • Loading branch information
chloeh13q and Chloe He committed May 10, 2024
1 parent 1258575 commit 9caa552
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 110 deletions.
31 changes: 15 additions & 16 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from ibis import util
from ibis.backends import CanCreateDatabase, CanCreateSchema
from ibis.backends.bigquery.client import (
BigQueryCursor,
bigquery_param,
parse_project_and_dataset,
rename_partitioned_column,
Expand Down Expand Up @@ -628,7 +627,7 @@ def _execute(self, stmt, query_parameters=None):
stmt, job_config=job_config, project=self.billing_project
)
query.result() # blocks until finished
return BigQueryCursor(query)
return query

def _to_sqlglot(
self,
Expand Down Expand Up @@ -738,9 +737,9 @@ def execute(self, expr, params=None, limit="default", **kwargs):

sql = self.compile(expr, limit=limit, params=params, **kwargs)
self._log(sql)
cursor = self.raw_sql(sql, params=params, **kwargs)
query = self.raw_sql(sql, params=params, **kwargs)

result = self.fetch_from_cursor(cursor, expr.as_table().schema())
result = self.fetch_from_query(query, expr.as_table().schema())

return expr.__pandas_result__(result)

Expand Down Expand Up @@ -782,27 +781,27 @@ def insert(
overwrite=overwrite,
)

def fetch_from_cursor(self, cursor, schema):
def fetch_from_query(self, query, schema):
from ibis.backends.bigquery.converter import BigQueryPandasData

arrow_t = self._cursor_to_arrow(cursor)
arrow_t = self._query_to_arrow(query)
df = arrow_t.to_pandas(timestamp_as_object=True)
return BigQueryPandasData.convert_table(df, schema)

def _cursor_to_arrow(
def _query_to_arrow(
self,
cursor,
query,
*,
method: Callable[[RowIterator], pa.Table | Iterable[pa.RecordBatch]]
| None = None,
method: (
Callable[[RowIterator], pa.Table | Iterable[pa.RecordBatch]] | None
) = None,
chunk_size: int | None = None,
):
if method is None:
method = lambda result: result.to_arrow(
progress_bar_type=None,
bqstorage_client=self.storage_client,
)
query = cursor.query
query_result = query.result(page_size=chunk_size)
# workaround potentially not having the ability to create read sessions
# in the dataset project
Expand All @@ -826,8 +825,8 @@ def to_pyarrow(
self._register_in_memory_tables(expr)
sql = self.compile(expr, limit=limit, params=params, **kwargs)
self._log(sql)
cursor = self.raw_sql(sql, params=params, **kwargs)
table = self._cursor_to_arrow(cursor)
query = self.raw_sql(sql, params=params, **kwargs)
table = self._query_to_arrow(query)
return expr.__pyarrow_result__(table)

def to_pyarrow_batches(
Expand All @@ -846,9 +845,9 @@ def to_pyarrow_batches(
self._register_in_memory_tables(expr)
sql = self.compile(expr, limit=limit, params=params, **kwargs)
self._log(sql)
cursor = self.raw_sql(sql, params=params, **kwargs)
batch_iter = self._cursor_to_arrow(
cursor,
query = self.raw_sql(sql, params=params, **kwargs)
batch_iter = self._query_to_arrow(
query,
method=lambda result: result.to_arrow_iterable(
bqstorage_client=self.storage_client
),
Expand Down
36 changes: 0 additions & 36 deletions ibis/backends/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,6 @@ def schema_from_bigquery_table(table):
return schema


class BigQueryCursor:
"""BigQuery cursor.
This allows the BigQuery client to reuse machinery in
:file:`ibis/client.py`.
"""

def __init__(self, query):
"""Construct a BigQueryCursor with query `query`."""
self.query = query

def fetchall(self):
"""Fetch all rows."""
result = self.query.result()
return [row.values() for row in result]

@property
def columns(self):
"""Return the columns of the result set."""
result = self.query.result()
return [field.name for field in result.schema]

@property
def description(self):
"""Get the fields of the result set's schema."""
result = self.query.result()
return list(result.schema)

def __enter__(self):
"""No-op for compatibility."""
return self

def __exit__(self, *_):
"""No-op for compatibility."""


@functools.singledispatch
def bigquery_param(dtype, value, name):
raise NotADirectoryError(dtype)
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/bigquery/tests/system/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def test_repr_struct_of_array_of_struct():


def test_raw_sql(con):
assert con.raw_sql("SELECT 1").fetchall() == [(1,)]
result = con.raw_sql("SELECT 1").result()
assert [row.values() for row in result] == [(1,)]


def test_parted_column_rename(parted_alltypes):
Expand Down
85 changes: 28 additions & 57 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sqlglot.expressions as sge
from packaging.version import parse as vparse
from pyspark import SparkConf
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import SparkSession
from pyspark.sql.functions import PandasUDFType, pandas_udf
from pyspark.sql.types import BooleanType, DoubleType, LongType, StringType

Expand Down Expand Up @@ -85,51 +85,6 @@ def nullify_type_mismatched_value(raw):
return unwrap


class _PySparkCursor:
"""Spark cursor.
This allows the Spark client to reuse machinery in
`ibis/backends/base/sql/client.py`.
"""

def __init__(self, query: DataFrame) -> None:
"""Construct a cursor with query `query`.
Parameters
----------
query
PySpark query
"""
self.query = query

def fetchall(self):
"""Fetch all rows."""
result = self.query.collect() # blocks until finished
return result

def fetchmany(self, nrows: int):
raise NotImplementedError()

@property
def columns(self):
"""Return the columns of the result set."""
return self.query.columns

@property
def description(self):
"""Get the fields of the result set's schema."""
return self.query.schema

def __enter__(self):
# For compatibility when constructed from Query.execute()
"""No-op for compatibility."""
return self

def __exit__(self, exc_type, exc_value, traceback):
"""No-op for compatibility."""


class Backend(SQLBackend, CanListCatalog, CanCreateDatabase):
name = "pyspark"
compiler = PySparkCompiler()
Expand Down Expand Up @@ -210,8 +165,8 @@ def disconnect(self) -> None:
self._session.stop()

def _get_schema_using_query(self, query: str) -> sch.Schema:
cursor = self.raw_sql(query)
struct_dtype = PySparkType.to_ibis(cursor.query.schema)
df = self.raw_sql(query)
struct_dtype = PySparkType.to_ibis(df.schema)
return sch.Schema(struct_dtype)

@property
Expand Down Expand Up @@ -354,18 +309,34 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
df = self._session.createDataFrame(data=op.data.to_frame(), schema=schema)
df.createOrReplaceTempView(op.name)

def _fetch_from_cursor(self, cursor, schema):
df = cursor.query.toPandas() # blocks until finished
return PySparkPandasData.convert_table(df, schema)

def _safe_raw_sql(self, query: str) -> _PySparkCursor:
return self.raw_sql(query)
@contextlib.contextmanager
def _safe_raw_sql(self, query: str) -> Any:
yield self.raw_sql(query)

def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> _PySparkCursor:
def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
with contextlib.suppress(AttributeError):
query = query.sql(dialect=self.dialect)
query = self._session.sql(query)
return _PySparkCursor(query)
return self._session.sql(query, **kwargs)

def execute(
self,
expr: ir.Expr,
params: Mapping | None = None,
limit: str | None = "default",
**kwargs: Any,
) -> Any:
"""Execute an expression."""

self._run_pre_execute_hooks(expr)
table = expr.as_table()
sql = self.compile(table, params=params, limit=limit, **kwargs)

schema = table.schema()

with self._safe_raw_sql(sql) as query:
df = query.toPandas() # blocks until finished
result = PySparkPandasData.convert_table(df, schema)
return expr.__pandas_result__(result)

def create_database(
self,
Expand Down

0 comments on commit 9caa552

Please sign in to comment.