Skip to content

Commit

Permalink
feat(pyspark): read/write delta tables
Browse files Browse the repository at this point in the history
  • Loading branch information
lostmygithubaccount authored and jcrist committed Jul 10, 2023
1 parent cb0abfc commit d403187
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 7 deletions.
56 changes: 56 additions & 0 deletions ibis/backends/pyspark/__init__.py
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -587,6 +588,37 @@ def _clean_up_cached_table(self, op):
t.unpersist()
assert not t.is_cached

def read_delta(
self,
source: str | Path,
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a Delta Lake table as a table in the current database.
Parameters
----------
source
The path to the Delta Lake table.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
kwargs
Additional keyword arguments passed to PySpark.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.load.html
Returns
-------
ir.Table
The just-registered table
"""
source = util.normalize_filename(source)
spark_df = self._session.read.format("delta").load(source, **kwargs)
table_name = table_name or util.gen_name("read_delta")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

def read_parquet(
self,
source: str | Path,
Expand Down Expand Up @@ -707,3 +739,27 @@ def _register_failure(self):

def _to_sql(self, expr: ir.Expr, **kwargs) -> str:
raise NotImplementedError(f"Backend '{self.name}' backend doesn't support SQL")

@util.experimental
def to_delta(
self,
expr: ir.Table,
path: str | Path,
**kwargs: Any,
) -> None:
"""Write the results of executing the given expression to a Delta Lake table.
This method is eager and will execute the associated expression
immediately.
Parameters
----------
expr
The ibis expression to execute and persist to a Delta Lake table.
path
The data source. A string or Path to the Delta Lake table.
**kwargs
PySpark Delta Lake table write arguments. https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrameWriter.save.html
"""
expr.compile().write.format("delta").save(os.fspath(path), **kwargs)
16 changes: 9 additions & 7 deletions ibis/backends/tests/test_export.py
Expand Up @@ -217,7 +217,7 @@ def test_table_to_parquet(tmp_path, backend, awards_players):

df = pd.read_parquet(outparquet)

backend.assert_frame_equal(awards_players.execute(), df)
backend.assert_frame_equal(awards_players.to_pandas(), df)


@pytest.mark.notimpl(
Expand Down Expand Up @@ -256,7 +256,7 @@ def test_roundtrip_partitioned_parquet(tmp_path, con, backend, awards_players):
# avoid type comparison to appease duckdb: as of 0.8.0 it returns large_string
assert reingest.schema().names == awards_players.schema().names

backend.assert_frame_equal(awards_players.execute(), awards_players.execute())
backend.assert_frame_equal(awards_players.to_pandas(), awards_players.to_pandas())


@pytest.mark.notimpl(["dask", "impala", "pyspark"])
Expand All @@ -270,7 +270,7 @@ def test_table_to_csv(tmp_path, backend, awards_players):

df = pd.read_csv(outcsv, dtype=awards_players.schema().to_pandas())

backend.assert_frame_equal(awards_players.execute(), df)
backend.assert_frame_equal(awards_players.to_pandas(), df)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -327,7 +327,6 @@ def test_to_pyarrow_decimal(backend, dtype, pyarrow_dtype):
"mysql",
"oracle",
"postgres",
"pyspark",
"snowflake",
"sqlite",
"trino",
Expand All @@ -343,10 +342,13 @@ def test_to_pyarrow_decimal(backend, dtype, pyarrow_dtype):
reason="arrow type conversion fails in `to_delta` call",
)
def test_roundtrip_delta(con, alltypes, tmp_path, monkeypatch):
pytest.importorskip("deltalake")
if con.name == "pyspark":
pytest.importorskip("delta")
else:
pytest.importorskip("deltalake")

t = alltypes.head()
expected = t.execute()
expected = t.to_pandas()
path = tmp_path / "test.delta"
t.to_delta(path)

Expand Down Expand Up @@ -402,7 +404,7 @@ def test_dataframe_protocol(alltypes):
pytest.importorskip("pyarrow", minversion="12")
output = alltypes.__dataframe__()
assert list(output.column_names()) == alltypes.columns
assert alltypes.count().execute() == output.num_rows()
assert alltypes.count().to_pandas() == output.num_rows()


@pytest.mark.notimpl(["dask", "druid"])
Expand Down

0 comments on commit d403187

Please sign in to comment.