Skip to content

Commit

Permalink
fix(pyspark): default to inferring the schema of CSV files and assumi…
Browse files Browse the repository at this point in the history
…ng they have a header with `header=True`
  • Loading branch information
cpcloud committed Aug 11, 2023
1 parent 3323156 commit 0ffda75
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
6 changes: 5 additions & 1 deletion ibis/backends/pyspark/__init__.py
Expand Up @@ -671,8 +671,12 @@ def read_csv(
ir.Table
The just-registered table
"""
inferSchema = kwargs.pop("inferSchema", True)
header = kwargs.pop("header", True)
source_list = normalize_filenames(source_list)
spark_df = self._session.read.csv(source_list, **kwargs)
spark_df = self._session.read.csv(
source_list, inferSchema=inferSchema, header=header, **kwargs
)
table_name = table_name or util.gen_name("read_csv")

spark_df.createOrReplaceTempView(table_name)
Expand Down
34 changes: 34 additions & 0 deletions ibis/backends/tests/test_register.py
Expand Up @@ -10,6 +10,7 @@
import pytest
from pytest import param

import ibis
from ibis.backends.conftest import TEST_TABLES

if TYPE_CHECKING:
Expand Down Expand Up @@ -444,6 +445,21 @@ def num_diamonds(data_dir):
return sum(1 for _ in f) - 1


DIAMONDS_COLUMN_TYPES = {
# snowflake's `INFER_SCHEMA` returns this for the diamonds CSV `price`
# column type
"snowflake": {
"carat": "decimal(3, 2)",
"depth": "decimal(3, 1)",
"table": "decimal(3, 1)",
"x": "decimal(4, 2)",
"y": "decimal(4, 2)",
"z": "decimal(4, 2)",
},
"pyspark": {"price": "int32"},
}


@pytest.mark.parametrize(
("in_table_name", "out_table_name"),
[
Expand Down Expand Up @@ -474,4 +490,22 @@ def test_read_csv(con, data_dir, in_table_name, out_table_name, num_diamonds):
table = con.read_csv(fname, table_name=in_table_name)

assert any(out_table_name in t for t in con.list_tables())

special_types = DIAMONDS_COLUMN_TYPES.get(con.name, {})

assert table.schema() == ibis.schema(
{
"carat": "float64",
"cut": "string",
"color": "string",
"clarity": "string",
"depth": "float64",
"table": "float64",
"price": "int64",
"x": "float64",
"y": "float64",
"z": "float64",
**special_types,
}
)
assert table.count().execute() == num_diamonds

0 comments on commit 0ffda75

Please sign in to comment.