diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 9dba64e4..e7519e81 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -15,7 +15,6 @@ from duckdb import ColumnExpression, Expression, StarExpression from ..errors import PySparkIndexError, PySparkTypeError, PySparkValueError -from ..exception import ContributionsAcceptedError from .column import Column from .readwriter import DataFrameWriter from .type_utils import duckdb_to_spark_schema @@ -569,6 +568,22 @@ def columns(self) -> list[str]: """ return [f.name for f in self.schema.fields] + @property + def dtypes(self) -> list[tuple[str, str]]: + """Returns all column names and their data types as a list of tuples. + + Returns: + ------- + list of tuple + List of tuples, each tuple containing a column name and its data type as strings. + + Examples: + -------- + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + """ + return [(f.name, f.dataType.simpleString()) for f in self.schema.fields] + def _ipython_key_completions_(self) -> list[str]: # Provides tab-completion for column names in PySpark DataFrame # when accessed in bracket notation, e.g. df['] @@ -982,8 +997,27 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] def write(self) -> DataFrameWriter: # noqa: D102 return DataFrameWriter(self) - def printSchema(self) -> None: # noqa: D102 - raise ContributionsAcceptedError + def printSchema(self, level: Optional[int] = None) -> None: + """Prints out the schema in the tree format. + + Parameters + ---------- + level : int, optional + How many levels to print for nested schemas. Prints all levels by default. + + Examples: + -------- + >>> df.printSchema() + root + |-- age: bigint (nullable = true) + |-- name: string (nullable = true) + """ + if level is not None and level < 0: + raise PySparkValueError( + error_class="NEGATIVE_VALUE", + message_parameters={"arg_name": "level", "arg_value": str(level)}, + ) + print(self.schema.treeString(level)) def union(self, other: "DataFrame") -> "DataFrame": """Return a new :class:`DataFrame` containing union of rows in this and another diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 79a2a8e2..49c475a4 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -30,6 +30,25 @@ def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: return _invoke_function(name, *cols) +def _nan_constant() -> Expression: + """Create a NaN constant expression. + + Note: ConstantExpression(float("nan")) returns NULL instead of NaN because + TransformPythonValue() in the C++ layer has nan_as_null=true by default. + This is intentional for data import scenarios (CSV, Pandas, etc.) where NaN + represents missing data. + + For mathematical functions that need to return NaN (not NULL) for out-of-range + inputs per PySpark/IEEE 754 semantics, we use SQLExpression as a workaround. + + Returns: + ------- + Expression + An expression that evaluates to NaN (not NULL) + """ + return SQLExpression("'NaN'::DOUBLE") + + def col(column: str) -> Column: # noqa: D103 return Column(ColumnExpression(column)) @@ -617,11 +636,9 @@ def asin(col: "ColumnOrName") -> Column: +--------+ """ col = _to_column_expr(col) - # TODO: ConstantExpression(float("nan")) gives NULL and not NaN # noqa: TD002, TD003 + # asin domain is [-1, 1]; return NaN for out-of-range values per PySpark semantics return Column( - CaseExpression((col < -1.0) | (col > 1.0), ConstantExpression(float("nan"))).otherwise( - FunctionExpression("asin", col) - ) + CaseExpression((col < -1.0) | (col > 1.0), _nan_constant()).otherwise(FunctionExpression("asin", col)) ) @@ -4177,7 +4194,11 @@ def acos(col: "ColumnOrName") -> Column: | NaN| +--------+ """ - return _invoke_function_over_columns("acos", col) + col = _to_column_expr(col) + # acos domain is [-1, 1]; return NaN for out-of-range values per PySpark semantics + return Column( + CaseExpression((col < -1.0) | (col > 1.0), _nan_constant()).otherwise(FunctionExpression("acos", col)) + ) def call_function(funcName: str, *cols: "ColumnOrName") -> Column: diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index b3d08561..eef99043 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -125,7 +125,7 @@ def load( # noqa: D102 types, names = schema.extract_types_and_names() df = df._cast_types(types) df = df.toDF(names) - raise NotImplementedError + return df def csv( # noqa: D102 self, diff --git a/duckdb/experimental/spark/sql/type_utils.py b/duckdb/experimental/spark/sql/type_utils.py index 90dac658..0874f2da 100644 --- a/duckdb/experimental/spark/sql/type_utils.py +++ b/duckdb/experimental/spark/sql/type_utils.py @@ -2,6 +2,7 @@ from duckdb.sqltypes import DuckDBPyType +from ..exception import ContributionsAcceptedError from .types import ( ArrayType, BinaryType, @@ -79,7 +80,12 @@ def convert_nested_type(dtype: DuckDBPyType) -> DataType: # noqa: D103 if id == "list" or id == "array": children = dtype.children return ArrayType(convert_type(children[0][1])) - # TODO: add support for 'union' # noqa: TD002, TD003 + if id == "union": + msg = ( + "Union types are not supported in the PySpark interface. " + "DuckDB union types cannot be directly mapped to PySpark types." + ) + raise ContributionsAcceptedError(msg) if id == "struct": children: list[tuple[str, DuckDBPyType]] = dtype.children fields = [StructField(x[0], convert_type(x[1])) for x in children] diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 856885e9..3213169b 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -894,6 +894,77 @@ def fieldNames(self) -> list[str]: """ return list(self.names) + def treeString(self, level: Optional[int] = None) -> str: + """Returns a string representation of the schema in tree format. + + Parameters + ---------- + level : int, optional + Maximum depth to print. If None, prints all levels. + + Returns: + ------- + str + Tree-formatted schema string + + Examples: + -------- + >>> schema = StructType([StructField("age", IntegerType(), True)]) + >>> print(schema.treeString()) + root + |-- age: integer (nullable = true) + """ + + def _tree_string(schema: "StructType", depth: int = 0, max_depth: Optional[int] = None) -> list[str]: + """Recursively build tree string lines.""" + lines = [] + if depth == 0: + lines.append("root") + + if max_depth is not None and depth >= max_depth: + return lines + + for field in schema.fields: + indent = " " * depth + prefix = " |-- " + nullable_str = "true" if field.nullable else "false" + + # Handle nested StructType + if isinstance(field.dataType, StructType): + lines.append(f"{indent}{prefix}{field.name}: struct (nullable = {nullable_str})") + # Recursively handle nested struct - don't skip any lines, root only appears at depth 0 + nested_lines = _tree_string(field.dataType, depth + 1, max_depth) + lines.extend(nested_lines) + # Handle ArrayType + elif isinstance(field.dataType, ArrayType): + element_type = field.dataType.elementType + if isinstance(element_type, StructType): + lines.append(f"{indent}{prefix}{field.name}: array (nullable = {nullable_str})") + lines.append( + f"{indent} | |-- element: struct (containsNull = {field.dataType.containsNull})" + ) + nested_lines = _tree_string(element_type, depth + 2, max_depth) + lines.extend(nested_lines) + else: + type_str = element_type.simpleString() + lines.append(f"{indent}{prefix}{field.name}: array<{type_str}> (nullable = {nullable_str})") + # Handle MapType + elif isinstance(field.dataType, MapType): + key_type = field.dataType.keyType.simpleString() + value_type = field.dataType.valueType.simpleString() + lines.append( + f"{indent}{prefix}{field.name}: map<{key_type},{value_type}> (nullable = {nullable_str})" + ) + # Handle simple types + else: + type_str = field.dataType.simpleString() + lines.append(f"{indent}{prefix}{field.name}: {type_str} (nullable = {nullable_str})") + + return lines + + lines = _tree_string(self, 0, level) + return "\n".join(lines) + def needConversion(self) -> bool: # noqa: D102 # We need convert Row()/namedtuple into tuple() return True diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index 95a6b3a8..e242092e 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -427,3 +427,173 @@ def test_cache(self, spark): assert df is not cached assert cached.collect() == df.collect() assert cached.collect() == [Row(one=1, two=2, three=3, four=4)] + + def test_dtypes(self, spark): + data = [("Alice", 25, 5000.0), ("Bob", 30, 6000.0)] + df = spark.createDataFrame(data, ["name", "age", "salary"]) + dtypes = df.dtypes + + assert isinstance(dtypes, list) + assert len(dtypes) == 3 + for col_name, col_type in dtypes: + assert isinstance(col_name, str) + assert isinstance(col_type, str) + + col_names = [name for name, _ in dtypes] + assert col_names == ["name", "age", "salary"] + for _, col_type in dtypes: + assert len(col_type) > 0 + + def test_dtypes_complex_types(self, spark): + from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType + + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("scores", ArrayType(IntegerType()), True), + StructField( + "address", + StructType([StructField("city", StringType(), True), StructField("zip", StringType(), True)]), + True, + ), + ] + ) + data = [ + ("Alice", [90, 85, 88], {"city": "NYC", "zip": "10001"}), + ("Bob", [75, 80, 82], {"city": "LA", "zip": "90001"}), + ] + df = spark.createDataFrame(data, schema) + dtypes = df.dtypes + + assert len(dtypes) == 3 + col_names = [name for name, _ in dtypes] + assert col_names == ["name", "scores", "address"] + + def test_printSchema(self, spark, capsys): + data = [("Alice", 25, 5000), ("Bob", 30, 6000)] + df = spark.createDataFrame(data, ["name", "age", "salary"]) + df.printSchema() + captured = capsys.readouterr() + output = captured.out + + assert "root" in output + assert "name" in output + assert "age" in output + assert "salary" in output + assert "string" in output or "varchar" in output.lower() + assert "int" in output.lower() or "bigint" in output.lower() + + def test_printSchema_nested(self, spark, capsys): + from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType + + schema = StructType( + [ + StructField("id", IntegerType(), True), + StructField( + "person", + StructType([StructField("name", StringType(), True), StructField("age", IntegerType(), True)]), + True, + ), + StructField("hobbies", ArrayType(StringType()), True), + ] + ) + data = [ + (1, {"name": "Alice", "age": 25}, ["reading", "coding"]), + (2, {"name": "Bob", "age": 30}, ["gaming", "music"]), + ] + df = spark.createDataFrame(data, schema) + df.printSchema() + captured = capsys.readouterr() + output = captured.out + + assert "root" in output + assert "person" in output + assert "hobbies" in output + + def test_printSchema_negative_level(self, spark): + data = [("Alice", 25)] + df = spark.createDataFrame(data, ["name", "age"]) + + with pytest.raises(PySparkValueError): + df.printSchema(level=-1) + + def test_treeString_basic(self, spark): + data = [("Alice", 25, 5000)] + df = spark.createDataFrame(data, ["name", "age", "salary"]) + tree = df.schema.treeString() + + assert tree.startswith("root\n") + assert " |-- name:" in tree + assert " |-- age:" in tree + assert " |-- salary:" in tree + assert "(nullable = true)" in tree + assert tree.count(" |-- ") == 3 + + def test_treeString_nested_struct(self, spark): + from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType + + schema = StructType( + [ + StructField("id", IntegerType(), True), + StructField( + "person", + StructType([StructField("name", StringType(), True), StructField("age", IntegerType(), True)]), + True, + ), + ] + ) + data = [(1, {"name": "Alice", "age": 25})] + df = spark.createDataFrame(data, schema) + tree = df.schema.treeString() + + assert "root\n" in tree + assert " |-- id:" in tree + assert " |-- person: struct (nullable = true)" in tree + assert "name:" in tree + assert "age:" in tree + + def test_treeString_with_level(self, spark): + from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType + + schema = StructType( + [ + StructField("id", IntegerType(), True), + StructField( + "person", + StructType( + [ + StructField("name", StringType(), True), + StructField("details", StructType([StructField("address", StringType(), True)]), True), + ] + ), + True, + ), + ] + ) + + data = [(1, {"name": "Alice", "details": {"address": "123 Main St"}})] + df = spark.createDataFrame(data, schema) + + # Level 1 should only show top-level fields + tree_level_1 = df.schema.treeString(level=1) + assert " |-- id:" in tree_level_1 + assert " |-- person: struct" in tree_level_1 + # Should not show nested field names at level 1 + lines = tree_level_1.split("\n") + assert len([line for line in lines if line.strip()]) <= 3 + + def test_treeString_array_type(self, spark): + from spark_namespace.sql.types import ArrayType, StringType, StructField, StructType + + schema = StructType( + [StructField("name", StringType(), True), StructField("hobbies", ArrayType(StringType()), True)] + ) + + data = [("Alice", ["reading", "coding"])] + df = spark.createDataFrame(data, schema) + tree = df.schema.treeString() + + assert "root\n" in tree + assert " |-- name:" in tree + assert " |-- hobbies: array<" in tree + assert "(nullable = true)" in tree diff --git a/tests/fast/spark/test_spark_functions_numeric.py b/tests/fast/spark/test_spark_functions_numeric.py index 8378aafa..98966548 100644 --- a/tests/fast/spark/test_spark_functions_numeric.py +++ b/tests/fast/spark/test_spark_functions_numeric.py @@ -5,7 +5,6 @@ import math import numpy as np -from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql import functions as sf from spark_namespace.sql.types import Row @@ -288,12 +287,7 @@ def test_asin(self, spark): res = df.select("asin_value").collect() assert res[0].asin_value == 0 - if USE_ACTUAL_SPARK: - assert np.isnan(res[1].asin_value) - else: - # TODO: DuckDB should return NaN here. Reason is that # noqa: TD002, TD003 - # ConstantExpression(float("nan")) gives NULL and not NaN - assert res[1].asin_value is None + assert np.isnan(res[1].asin_value) def test_corr(self, spark): N = 20 diff --git a/tests/fast/test_insert.py b/tests/fast/test_insert.py index 455e6e48..c5de1589 100644 --- a/tests/fast/test_insert.py +++ b/tests/fast/test_insert.py @@ -27,4 +27,8 @@ def test_insert_with_schema(self, duckdb_cursor): res = duckdb_cursor.table("not_main.tbl").fetchall() assert len(res) == 10 - duckdb_cursor.table("not_main.tbl").insert((42,)) + # Insert into a schema-qualified table should work; table has a single column from range(10) + duckdb_cursor.table("not_main.tbl").insert([42]) + res2 = duckdb_cursor.table("not_main.tbl").fetchall() + assert len(res2) == 11 + assert (42,) in res2 diff --git a/tests/fast/test_relation.py b/tests/fast/test_relation.py index 4d6f6591..f386b091 100644 --- a/tests/fast/test_relation.py +++ b/tests/fast/test_relation.py @@ -280,8 +280,13 @@ def test_value_relation(self, duckdb_cursor): rel = duckdb_cursor.values((const(1), const(2), const(3)), const(4)) # Using Expressions that can't be resolved: + # Accept both historical and current Binder error message variants with pytest.raises( - duckdb.BinderException, match='Referenced column "a" was not found because the FROM clause is missing' + duckdb.BinderException, + match=( + r'Referenced column "a" not found in FROM clause!|' + r'Referenced column "a" was not found because the FROM clause is missing' + ), ): duckdb_cursor.values(duckdb.ColumnExpression("a"))