Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add csv & parquet write functions and toPandas to experimental PySpark API #9672

Merged
merged 31 commits into from Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
18b40dc
feat - pyspark write.parquet and write.csv
TomBurdge Nov 10, 2023
06889e7
feat - pyspark toPandas
TomBurdge Nov 10, 2023
9d66a73
feat - PySpark API complete write csv
TomBurdge Nov 13, 2023
681655f
feat - comprehensive tests for PySpark write csv
TomBurdge Nov 13, 2023
a9a3c16
fix - very slightly fix imports
TomBurdge Nov 13, 2023
847f6d7
add parquet writer
TomBurdge Nov 14, 2023
469e521
reformat - spark to csv tests
TomBurdge Nov 14, 2023
ccd5b64
feat - spark to parquet tests
TomBurdge Nov 14, 2023
fc356e9
add spark dataframe toPandas test
TomBurdge Nov 14, 2023
e56b291
fix - fewer spark sessions in spark to csv tests
TomBurdge Nov 14, 2023
6088d2b
fix - fewer spark session in to parquet tests
TomBurdge Nov 14, 2023
4efc83b
Merge remote-tracking branch 'upstream/main' into extend-pyspark
TomBurdge Nov 14, 2023
98b36a8
remove question comment for spark test
TomBurdge Nov 14, 2023
3ba4cd1
reformat - spark tests
TomBurdge Nov 14, 2023
98fcbe8
Merge branch 'main' into extend-pyspark
TomBurdge Nov 14, 2023
b849a19
fix - move pandas import inside toPandas
TomBurdge Nov 14, 2023
fcea558
fix - unused imports in test
TomBurdge Nov 14, 2023
eea76cf
fix - type annotations to mirror PySpark API
TomBurdge Nov 14, 2023
0f6cc01
fix - use tmp_path feature for test csv
TomBurdge Nov 14, 2023
ff1e650
fix - use pytest tmp_path fixture for
TomBurdge Nov 14, 2023
4ed5362
fix - amend parquet temp files to .parquet
TomBurdge Nov 14, 2023
c394de3
fix - add DataFrame fixtures to spark write cv
TomBurdge Nov 14, 2023
d6349ea
reformatting from makefile
TomBurdge Nov 14, 2023
2c5e7b6
fix - PandasDataFrame typehint
TomBurdge Nov 14, 2023
4f65e10
fix - remove future import for type hints
TomBurdge Nov 14, 2023
02acfdb
Merge branch 'main' into extend-pyspark
TomBurdge Nov 15, 2023
3530a09
fix - put type hint in quotes
TomBurdge Nov 15, 2023
d315c7a
fix - mirror type hints for DataFrameWriter.csv
TomBurdge Nov 21, 2023
b14ef99
fix - remove note to self
TomBurdge Nov 21, 2023
58ff8f1
Merge branch 'main' into extend-pyspark
TomBurdge Nov 21, 2023
212e157
Merge branch 'main' into extend-pyspark
Mytherin Dec 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions tools/pythonpkg/duckdb/experimental/spark/sql/dataframe.py
Expand Up @@ -10,6 +10,7 @@
from .column import Column
import duckdb
from functools import reduce
from pandas import DataFrame as PandasDataFrame

if TYPE_CHECKING:
from .session import SparkSession
Expand All @@ -27,6 +28,9 @@ def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession"):

def show(self, **kwargs) -> None:
self.relation.show()

def toPandas(self) -> PandasDataFrame:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't want to have a dependency on pandas, what we do in other places is set "PandasDataFrame" as the return type annotation and then inside the method do the actual import - so it only fails to import when toPandas is used if it's not installed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Fixed here: b849a19fc299b6e87a0069e1aacb82630a6671a3

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, though I think the type annotation needs to be in quotes still?
Otherwise this is referencing a type that doesn't exist, can't imagine the type checker to be happy with that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fixed properly yet^

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I have tried to fix it in 2c5e7b690e61c22c4eec18db15706111fce858f7.
I have tried to make it mirror how it works in the session.py with the type checking.

return self.relation.df()

def createOrReplaceTempView(self, name: str) -> None:
"""Creates or replaces a local temporary view with this :class:`DataFrame`.
Expand Down
41 changes: 41 additions & 0 deletions tools/pythonpkg/duckdb/experimental/spark/sql/readwriter.py
Expand Up @@ -20,6 +20,47 @@ def saveAsTable(self, table_name: str) -> None:
relation = self.dataframe.relation
relation.create(table_name)

def parquet(self, path: str, mode: Optional[str] = None, partitionBy: Union[str, List[str], None] = None, compression: Optional[str] = None) -> None:
relation = self.dataframe.relation
if mode:
raise NotImplementedError
if partitionBy:
raise NotImplementedError

relation.write_parquet(path, compression=compression)

def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please use the exact prototype as PySpark?
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriter.csv.html

This is missing type annotation for the parameters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eea76cf5c174b09d795a4e22e9e791698b8eaaba

header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None,
timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None,
charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None, lineSep=None):
if mode not in (None, "overwrite"):
raise NotImplementedError
if escapeQuotes:
raise NotImplementedError
if ignoreLeadingWhiteSpace:
raise NotImplementedError
if ignoreTrailingWhiteSpace:
raise NotImplementedError
if charToEscapeQuoteEscaping:
raise NotImplementedError
if emptyValue:
raise NotImplementedError
if lineSep:
raise NotImplementedError
relation = self.dataframe.relation
relation.write_csv(path,
sep=sep,
na_rep=nullValue,
quotechar=quote,
compression=compression,
escapechar=escape,
header=header if isinstance(header, bool) else header == "True",
encoding=encoding,
quoting=quoteAll, # ~ check this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems you left a TODO in there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes - already checked it and was happy. Thanks for noticing, I have removed this comment in the latest of this branch.

date_format=dateFormat,
timestamp_format = timestampFormat
)


class DataFrameReader:
def __init__(self, session: "SparkSession"):
Expand Down
Expand Up @@ -18,6 +18,7 @@
from duckdb.experimental.spark.sql.functions import col, struct, when
import duckdb
import re
from pandas.testing import assert_frame_equal


@pytest.fixture
Expand Down Expand Up @@ -48,3 +49,8 @@ def test_pd_conversion_schema(self, spark, pandasDF):
res = sparkDF.collect()
expected = "[Row(First Name='Scott', Age=50), Row(First Name='Jeff', Age=45), Row(First Name='Thomas', Age=54), Row(First Name='Ann', Age=34)]"
assert str(res) == expected

def test_spark_to_pandas_dataframe(self, spark, pandasDF):
sparkDF = spark.createDataFrame(pandasDF)
res = sparkDF.toPandas()
assert_frame_equal(res, pandasDF)
199 changes: 199 additions & 0 deletions tools/pythonpkg/tests/fast/spark/test_spark_to_csv.py
@@ -0,0 +1,199 @@
import pytest
import tempfile

import os

_ = pytest.importorskip("duckdb.experimental.spark")

from duckdb.experimental.spark.sql import SparkSession as session
from duckdb import connect, InvalidInputException, read_csv
from conftest import NumpyPandas, ArrowPandas
import pandas._testing as tm
import datetime
import csv


@pytest.fixture
def df(spark):
simpleData = (
("Java", 4000, 5),
("Python", 4600, 10),
("Scala", 4100, 15),
("Scala", 4500, 15),
("PHP", 3000, 20),
)
columns = ["CourseName", "fee", "discount"]
dataframe = spark.createDataFrame(data=simpleData, schema=columns)
yield dataframe


class TestSparkToCSV(object):
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_basic_to_csv(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest has a tmp_path fixture, which should be used instead of manually using tempfile
(example from another test)

    def test_read_csv_glob(self, tmp_path, create_temp_csv):
        file1_path, file2_path = create_temp_csv

        # Use the temporary file paths to read CSV files
        con = duckdb.connect()
        rel = con.read_csv(f'{tmp_path}/file*.csv')
        res = con.sql("select * from rel order by all").fetchall()
        assert res == [(1,), (2,), (3,), (4,), (5,), (6,)]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, fixed in ff1e65087754405fc15f632c20d2c5bb696e4fcb and 0f6cc01f0b398f236eaf50bb6be2912efde70f89


pandas_df = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]})
Copy link
Contributor

@Tishj Tishj Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a nitpick but there's a good bit of repetition in the dataframes used to test, with a fixture we can avoid duplicating these and making it more explicit where these are repeated.
3+ repetitions (and the pieces of code are standalone) are usually a good sign it should be a fixture 馃憤

test_spark_join.py has a good example of this

@pytest.fixture
def dataframe_a(spark):
    emp = [
        (1, "Smith", -1, "2018", "10", "M", 3000),
        (2, "Rose", 1, "2010", "20", "M", 4000),
        (3, "Williams", 1, "2010", "10", "M", 1000),
        (4, "Jones", 2, "2005", "10", "F", 2000),
        (5, "Brown", 2, "2010", "40", "", -1),
        (6, "Brown", 2, "2010", "50", "", -1),
    ]
    empColumns = ["emp_id", "name", "superior_emp_id", "year_joined", "emp_dept_id", "gender", "salary"]
    dataframe = spark.createDataFrame(data=emp, schema=empColumns)
    yield dataframe


@pytest.fixture
def dataframe_b(spark):
    dept = [("Finance", 10), ("Marketing", 20), ("Sales", 30), ("IT", 40)]
    deptColumns = ["dept_name", "dept_id"]
    dataframe = spark.createDataFrame(data=dept, schema=deptColumns)
    yield dataframe


class TestDataFrameJoin(object):
    def test_inner_join(self, dataframe_a, dataframe_b):
        df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, "inner")
        ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. Wasn't sure how to parametrize fixtures, but figured it out: c394de3de3ed82470f75bd7dc98a392024223995

df = spark.createDataFrame(pandas_df)

df.write.csv(temp_file_name, header=False)

csv_rel = spark.read.csv(temp_file_name, header=False)

assert df.collect() == csv_rel.collect()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_to_csv_sep(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
pandas_df = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]})

df = spark.createDataFrame(pandas_df)

df.write.csv(temp_file_name, sep=',', header=False)

csv_rel = spark.read.csv(temp_file_name, header=False, sep=',')
assert df.collect() == csv_rel.collect()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_to_csv_na_rep(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
pandas_df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]})

df = spark.createDataFrame(pandas_df)

df.write.csv(temp_file_name, nullValue="test", header=False)

csv_rel = spark.read.csv(temp_file_name, nullValue="test")
assert df.collect() == csv_rel.collect()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_to_csv_header(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
pandas_df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]})

df = spark.createDataFrame(pandas_df)

df.write.csv(temp_file_name, header=True)

csv_rel = spark.read.csv(temp_file_name, header=True)
assert df.collect() == csv_rel.collect()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_to_csv_quotechar(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))

pandas_df = pandas.DataFrame({'a': ["\'a,b,c\'", None, "hello", "bye"], 'b': [45, 234, 234, 2]})

df = spark.createDataFrame(pandas_df)

df.write.csv(temp_file_name, quote='\'', sep=',', header=False)

csv_rel = spark.read.csv(temp_file_name, sep=',', quote='\'')
assert df.collect() == csv_rel.collect()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_to_csv_escapechar(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
pandas_df = pandas.DataFrame(
{
"c_bool": [True, False],
"c_float": [1.0, 3.2],
"c_int": [42, None],
"c_string": ["a", "b,c"],
}
)

df = spark.createDataFrame(pandas_df)

df.write.csv(temp_file_name, header=True, quote='"', escape='!')
csv_rel = spark.read.csv(temp_file_name, quote='"', escape='!', header=True)
assert df.collect() == csv_rel.collect()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_to_csv_date_format(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
pandas_df = pandas.DataFrame(tm.getTimeSeriesData())
dt_index = pandas_df.index
pandas_df = pandas.DataFrame({"A": dt_index, "B": dt_index.shift(1)}, index=dt_index)

df = spark.createDataFrame(pandas_df)

df.write.csv(temp_file_name, dateFormat="%Y%m%d", header=False)

csv_rel = spark.read.csv(temp_file_name, dateFormat="%Y%m%d")

assert df.collect() == csv_rel.collect()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_to_csv_timestamp_format(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
data = [datetime.time(hour=23, minute=1, second=34, microsecond=234345)]
pandas_df = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')})

df = spark.createDataFrame(pandas_df)

df.write.csv(temp_file_name, timestampFormat='%m/%d/%Y', header=False)

csv_rel = spark.read.csv(temp_file_name, timestampFormat='%m/%d/%Y')

assert df.collect() == csv_rel.collect()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_to_csv_quoting_off(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
pandas_df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']})

df = spark.createDataFrame(pandas_df)
df.write.csv(temp_file_name, quoteAll=None, header=False)

csv_rel = spark.read.csv(temp_file_name)
assert df.collect() == csv_rel.collect()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_to_csv_quoting_on(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
pandas_df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']})
df = spark.createDataFrame(pandas_df)
df.write.csv(temp_file_name, quoteAll="force", header=False)

csv_rel = spark.read.csv(temp_file_name)
assert df.collect() == csv_rel.collect()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_to_csv_quoting_quote_all(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
pandas_df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']})
df = spark.createDataFrame(pandas_df)
df.write.csv(temp_file_name, quoteAll=csv.QUOTE_ALL, header=False)

csv_rel = spark.read.csv(temp_file_name)
assert df.collect() == csv_rel.collect()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_to_csv_encoding_incorrect(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
pandas_df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']})
df = spark.createDataFrame(pandas_df)
with pytest.raises(
InvalidInputException, match="Invalid Input Error: The only supported encoding option is 'UTF8"
):
df.write.csv(temp_file_name, encoding="nope", header=False)

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_to_csv_encoding_correct(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
pandas_df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']})
df = spark.createDataFrame(pandas_df)
df.write.csv(temp_file_name, encoding="UTF-8", header=False)
csv_rel = spark.read.csv(temp_file_name)
assert df.collect() == csv_rel.collect()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_compression_gzip(self, pandas, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
pandas_df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']})
df = spark.createDataFrame(pandas_df)
df.write.csv(temp_file_name, compression="gzip", header=False)

# slightly convoluted - pyspark .read.csv does not take a compression argument
csv_rel = spark.createDataFrame(read_csv(temp_file_name, compression="gzip").df())
assert df.collect() == csv_rel.collect()
43 changes: 43 additions & 0 deletions tools/pythonpkg/tests/fast/spark/test_spark_to_parquet.py
@@ -0,0 +1,43 @@
import pytest
import tempfile

import os

_ = pytest.importorskip("duckdb.experimental.spark")

from duckdb.experimental.spark.sql import SparkSession as session
Copy link
Contributor

@Tishj Tishj Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unused, no?
Same goes for connect

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, there must be something overriding my normal VSCode settings as normally that would point this out: fcea55801ce6e13e7962be7f424236f9e7fe1154

from duckdb import connect


@pytest.fixture
def df(spark):
simpleData = (
("Java", 4000, 5),
("Python", 4600, 10),
("Scala", 4100, 15),
("Scala", 4500, 15),
("PHP", 3000, 20),
)
columns = ["CourseName", "fee", "discount"]
dataframe = spark.createDataFrame(data=simpleData, schema=columns)
yield dataframe


class TestSparkToParquet(object):
def test_basic_to_parquet(self, df, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))

df.write.parquet(temp_file_name)

csv_rel = spark.read.parquet(temp_file_name)

assert df.collect() == csv_rel.collect()

def test_compressed_to_parquet(self, df, spark):
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))

df.write.parquet(temp_file_name, compression="ZSTD")

csv_rel = spark.read.parquet(temp_file_name)

assert df.collect() == csv_rel.collect()