New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add csv & parquet write functions and toPandas to experimental PySpark API #9672
Changes from 15 commits
18b40dc
06889e7
9d66a73
681655f
a9a3c16
847f6d7
469e521
ccd5b64
fc356e9
e56b291
6088d2b
4efc83b
98b36a8
3ba4cd1
98fcbe8
b849a19
fcea558
eea76cf
0f6cc01
ff1e650
4ed5362
c394de3
d6349ea
2c5e7b6
4f65e10
02acfdb
3530a09
d315c7a
b14ef99
58ff8f1
212e157
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,47 @@ def saveAsTable(self, table_name: str) -> None: | |
relation = self.dataframe.relation | ||
relation.create(table_name) | ||
|
||
def parquet(self, path: str, mode: Optional[str] = None, partitionBy: Union[str, List[str], None] = None, compression: Optional[str] = None) -> None: | ||
relation = self.dataframe.relation | ||
if mode: | ||
raise NotImplementedError | ||
if partitionBy: | ||
raise NotImplementedError | ||
|
||
relation.write_parquet(path, compression=compression) | ||
|
||
def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we please use the exact prototype as PySpark? This is missing type annotation for the parameters There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, | ||
timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, | ||
charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None, lineSep=None): | ||
if mode not in (None, "overwrite"): | ||
raise NotImplementedError | ||
if escapeQuotes: | ||
raise NotImplementedError | ||
if ignoreLeadingWhiteSpace: | ||
raise NotImplementedError | ||
if ignoreTrailingWhiteSpace: | ||
raise NotImplementedError | ||
if charToEscapeQuoteEscaping: | ||
raise NotImplementedError | ||
if emptyValue: | ||
raise NotImplementedError | ||
if lineSep: | ||
raise NotImplementedError | ||
relation = self.dataframe.relation | ||
relation.write_csv(path, | ||
sep=sep, | ||
na_rep=nullValue, | ||
quotechar=quote, | ||
compression=compression, | ||
escapechar=escape, | ||
header=header if isinstance(header, bool) else header == "True", | ||
encoding=encoding, | ||
quoting=quoteAll, # ~ check this | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems you left a TODO in there? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes - already checked it and was happy. Thanks for noticing, I have removed this comment in the latest of this branch. |
||
date_format=dateFormat, | ||
timestamp_format = timestampFormat | ||
) | ||
|
||
|
||
class DataFrameReader: | ||
def __init__(self, session: "SparkSession"): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import pytest | ||
import tempfile | ||
|
||
import os | ||
|
||
_ = pytest.importorskip("duckdb.experimental.spark") | ||
|
||
from duckdb.experimental.spark.sql import SparkSession as session | ||
from duckdb import connect, InvalidInputException, read_csv | ||
from conftest import NumpyPandas, ArrowPandas | ||
import pandas._testing as tm | ||
import datetime | ||
import csv | ||
|
||
|
||
@pytest.fixture | ||
def df(spark): | ||
simpleData = ( | ||
("Java", 4000, 5), | ||
("Python", 4600, 10), | ||
("Scala", 4100, 15), | ||
("Scala", 4500, 15), | ||
("PHP", 3000, 20), | ||
) | ||
columns = ["CourseName", "fee", "discount"] | ||
dataframe = spark.createDataFrame(data=simpleData, schema=columns) | ||
yield dataframe | ||
|
||
|
||
class TestSparkToCSV(object): | ||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_basic_to_csv(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pytest has a def test_read_csv_glob(self, tmp_path, create_temp_csv):
file1_path, file2_path = create_temp_csv
# Use the temporary file paths to read CSV files
con = duckdb.connect()
rel = con.read_csv(f'{tmp_path}/file*.csv')
res = con.sql("select * from rel order by all").fetchall()
assert res == [(1,), (2,), (3,), (4,), (5,), (6,)] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, fixed in |
||
|
||
pandas_df = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit of a nitpick but there's a good bit of repetition in the dataframes used to test, with a fixture we can avoid duplicating these and making it more explicit where these are repeated.
@pytest.fixture
def dataframe_a(spark):
emp = [
(1, "Smith", -1, "2018", "10", "M", 3000),
(2, "Rose", 1, "2010", "20", "M", 4000),
(3, "Williams", 1, "2010", "10", "M", 1000),
(4, "Jones", 2, "2005", "10", "F", 2000),
(5, "Brown", 2, "2010", "40", "", -1),
(6, "Brown", 2, "2010", "50", "", -1),
]
empColumns = ["emp_id", "name", "superior_emp_id", "year_joined", "emp_dept_id", "gender", "salary"]
dataframe = spark.createDataFrame(data=emp, schema=empColumns)
yield dataframe
@pytest.fixture
def dataframe_b(spark):
dept = [("Finance", 10), ("Marketing", 20), ("Sales", 30), ("IT", 40)]
deptColumns = ["dept_name", "dept_id"]
dataframe = spark.createDataFrame(data=dept, schema=deptColumns)
yield dataframe
class TestDataFrameJoin(object):
def test_inner_join(self, dataframe_a, dataframe_b):
df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, "inner")
... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion. Wasn't sure how to parametrize fixtures, but figured it out: |
||
df = spark.createDataFrame(pandas_df) | ||
|
||
df.write.csv(temp_file_name, header=False) | ||
|
||
csv_rel = spark.read.csv(temp_file_name, header=False) | ||
|
||
assert df.collect() == csv_rel.collect() | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_to_csv_sep(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
pandas_df = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) | ||
|
||
df = spark.createDataFrame(pandas_df) | ||
|
||
df.write.csv(temp_file_name, sep=',', header=False) | ||
|
||
csv_rel = spark.read.csv(temp_file_name, header=False, sep=',') | ||
assert df.collect() == csv_rel.collect() | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_to_csv_na_rep(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
pandas_df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) | ||
|
||
df = spark.createDataFrame(pandas_df) | ||
|
||
df.write.csv(temp_file_name, nullValue="test", header=False) | ||
|
||
csv_rel = spark.read.csv(temp_file_name, nullValue="test") | ||
assert df.collect() == csv_rel.collect() | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_to_csv_header(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
pandas_df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) | ||
|
||
df = spark.createDataFrame(pandas_df) | ||
|
||
df.write.csv(temp_file_name, header=True) | ||
|
||
csv_rel = spark.read.csv(temp_file_name, header=True) | ||
assert df.collect() == csv_rel.collect() | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_to_csv_quotechar(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
|
||
pandas_df = pandas.DataFrame({'a': ["\'a,b,c\'", None, "hello", "bye"], 'b': [45, 234, 234, 2]}) | ||
|
||
df = spark.createDataFrame(pandas_df) | ||
|
||
df.write.csv(temp_file_name, quote='\'', sep=',', header=False) | ||
|
||
csv_rel = spark.read.csv(temp_file_name, sep=',', quote='\'') | ||
assert df.collect() == csv_rel.collect() | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_to_csv_escapechar(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
pandas_df = pandas.DataFrame( | ||
{ | ||
"c_bool": [True, False], | ||
"c_float": [1.0, 3.2], | ||
"c_int": [42, None], | ||
"c_string": ["a", "b,c"], | ||
} | ||
) | ||
|
||
df = spark.createDataFrame(pandas_df) | ||
|
||
df.write.csv(temp_file_name, header=True, quote='"', escape='!') | ||
csv_rel = spark.read.csv(temp_file_name, quote='"', escape='!', header=True) | ||
assert df.collect() == csv_rel.collect() | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_to_csv_date_format(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
pandas_df = pandas.DataFrame(tm.getTimeSeriesData()) | ||
dt_index = pandas_df.index | ||
pandas_df = pandas.DataFrame({"A": dt_index, "B": dt_index.shift(1)}, index=dt_index) | ||
|
||
df = spark.createDataFrame(pandas_df) | ||
|
||
df.write.csv(temp_file_name, dateFormat="%Y%m%d", header=False) | ||
|
||
csv_rel = spark.read.csv(temp_file_name, dateFormat="%Y%m%d") | ||
|
||
assert df.collect() == csv_rel.collect() | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_to_csv_timestamp_format(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
data = [datetime.time(hour=23, minute=1, second=34, microsecond=234345)] | ||
pandas_df = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) | ||
|
||
df = spark.createDataFrame(pandas_df) | ||
|
||
df.write.csv(temp_file_name, timestampFormat='%m/%d/%Y', header=False) | ||
|
||
csv_rel = spark.read.csv(temp_file_name, timestampFormat='%m/%d/%Y') | ||
|
||
assert df.collect() == csv_rel.collect() | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_to_csv_quoting_off(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
pandas_df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) | ||
|
||
df = spark.createDataFrame(pandas_df) | ||
df.write.csv(temp_file_name, quoteAll=None, header=False) | ||
|
||
csv_rel = spark.read.csv(temp_file_name) | ||
assert df.collect() == csv_rel.collect() | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_to_csv_quoting_on(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
pandas_df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) | ||
df = spark.createDataFrame(pandas_df) | ||
df.write.csv(temp_file_name, quoteAll="force", header=False) | ||
|
||
csv_rel = spark.read.csv(temp_file_name) | ||
assert df.collect() == csv_rel.collect() | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_to_csv_quoting_quote_all(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
pandas_df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) | ||
df = spark.createDataFrame(pandas_df) | ||
df.write.csv(temp_file_name, quoteAll=csv.QUOTE_ALL, header=False) | ||
|
||
csv_rel = spark.read.csv(temp_file_name) | ||
assert df.collect() == csv_rel.collect() | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_to_csv_encoding_incorrect(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
pandas_df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) | ||
df = spark.createDataFrame(pandas_df) | ||
with pytest.raises( | ||
InvalidInputException, match="Invalid Input Error: The only supported encoding option is 'UTF8" | ||
): | ||
df.write.csv(temp_file_name, encoding="nope", header=False) | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_to_csv_encoding_correct(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
pandas_df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) | ||
df = spark.createDataFrame(pandas_df) | ||
df.write.csv(temp_file_name, encoding="UTF-8", header=False) | ||
csv_rel = spark.read.csv(temp_file_name) | ||
assert df.collect() == csv_rel.collect() | ||
|
||
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) | ||
def test_compression_gzip(self, pandas, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
pandas_df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) | ||
df = spark.createDataFrame(pandas_df) | ||
df.write.csv(temp_file_name, compression="gzip", header=False) | ||
|
||
# slightly convoluted - pyspark .read.csv does not take a compression argument | ||
csv_rel = spark.createDataFrame(read_csv(temp_file_name, compression="gzip").df()) | ||
assert df.collect() == csv_rel.collect() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import pytest | ||
import tempfile | ||
|
||
import os | ||
|
||
_ = pytest.importorskip("duckdb.experimental.spark") | ||
|
||
from duckdb.experimental.spark.sql import SparkSession as session | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unused, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes, there must be something overriding my normal VSCode settings as normally that would point this out: |
||
from duckdb import connect | ||
|
||
|
||
@pytest.fixture | ||
def df(spark): | ||
simpleData = ( | ||
("Java", 4000, 5), | ||
("Python", 4600, 10), | ||
("Scala", 4100, 15), | ||
("Scala", 4500, 15), | ||
("PHP", 3000, 20), | ||
) | ||
columns = ["CourseName", "fee", "discount"] | ||
dataframe = spark.createDataFrame(data=simpleData, schema=columns) | ||
yield dataframe | ||
|
||
|
||
class TestSparkToParquet(object): | ||
def test_basic_to_parquet(self, df, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
|
||
df.write.parquet(temp_file_name) | ||
|
||
csv_rel = spark.read.parquet(temp_file_name) | ||
|
||
assert df.collect() == csv_rel.collect() | ||
|
||
def test_compressed_to_parquet(self, df, spark): | ||
temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) | ||
|
||
df.write.parquet(temp_file_name, compression="ZSTD") | ||
|
||
csv_rel = spark.read.parquet(temp_file_name) | ||
|
||
assert df.collect() == csv_rel.collect() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't want to have a dependency on
pandas
, what we do in other places is set"PandasDataFrame"
as the return type annotation and then inside the method do the actual import - so it only fails to import whentoPandas
is used if it's not installed.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Fixed here:
b849a19fc299b6e87a0069e1aacb82630a6671a3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, though I think the type annotation needs to be in quotes still?
Otherwise this is referencing a type that doesn't exist, can't imagine the type checker to be happy with that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not fixed properly yet^
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I have tried to fix it in
2c5e7b690e61c22c4eec18db15706111fce858f7
.I have tried to make it mirror how it works in the session.py with the type checking.