Skip to content
Merged
40 changes: 37 additions & 3 deletions duckdb/experimental/spark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from duckdb import ColumnExpression, Expression, StarExpression

from ..errors import PySparkIndexError, PySparkTypeError, PySparkValueError
from ..exception import ContributionsAcceptedError
from .column import Column
from .readwriter import DataFrameWriter
from .type_utils import duckdb_to_spark_schema
Expand Down Expand Up @@ -569,6 +568,22 @@ def columns(self) -> list[str]:
"""
return [f.name for f in self.schema.fields]

@property
def dtypes(self) -> list[tuple[str, str]]:
"""Returns all column names and their data types as a list of tuples.

Returns:
-------
list of tuple
List of tuples, each tuple containing a column name and its data type as strings.

Examples:
--------
>>> df.dtypes
[('age', 'bigint'), ('name', 'string')]
"""
return [(f.name, f.dataType.simpleString()) for f in self.schema.fields]

def _ipython_key_completions_(self) -> list[str]:
# Provides tab-completion for column names in PySpark DataFrame
# when accessed in bracket notation, e.g. df['<TAB>]
Expand Down Expand Up @@ -982,8 +997,27 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
def write(self) -> DataFrameWriter: # noqa: D102
return DataFrameWriter(self)

def printSchema(self) -> None: # noqa: D102
raise ContributionsAcceptedError
def printSchema(self, level: Optional[int] = None) -> None:
"""Prints out the schema in the tree format.

Parameters
----------
level : int, optional
How many levels to print for nested schemas. Prints all levels by default.

Examples:
--------
>>> df.printSchema()
root
|-- age: bigint (nullable = true)
|-- name: string (nullable = true)
"""
if level is not None and level < 0:
raise PySparkValueError(
error_class="NEGATIVE_VALUE",
message_parameters={"arg_name": "level", "arg_value": str(level)},
)
print(self.schema.treeString(level))

def union(self, other: "DataFrame") -> "DataFrame":
"""Return a new :class:`DataFrame` containing union of rows in this and another
Expand Down
31 changes: 26 additions & 5 deletions duckdb/experimental/spark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,25 @@ def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column:
return _invoke_function(name, *cols)


def _nan_constant() -> Expression:
"""Create a NaN constant expression.

Note: ConstantExpression(float("nan")) returns NULL instead of NaN because
TransformPythonValue() in the C++ layer has nan_as_null=true by default.
This is intentional for data import scenarios (CSV, Pandas, etc.) where NaN
represents missing data.

For mathematical functions that need to return NaN (not NULL) for out-of-range
inputs per PySpark/IEEE 754 semantics, we use SQLExpression as a workaround.

Returns:
-------
Expression
An expression that evaluates to NaN (not NULL)
"""
return SQLExpression("'NaN'::DOUBLE")


def col(column: str) -> Column: # noqa: D103
return Column(ColumnExpression(column))

Expand Down Expand Up @@ -617,11 +636,9 @@ def asin(col: "ColumnOrName") -> Column:
+--------+
"""
col = _to_column_expr(col)
# TODO: ConstantExpression(float("nan")) gives NULL and not NaN # noqa: TD002, TD003
# asin domain is [-1, 1]; return NaN for out-of-range values per PySpark semantics
return Column(
CaseExpression((col < -1.0) | (col > 1.0), ConstantExpression(float("nan"))).otherwise(
FunctionExpression("asin", col)
)
CaseExpression((col < -1.0) | (col > 1.0), _nan_constant()).otherwise(FunctionExpression("asin", col))
)


Expand Down Expand Up @@ -4177,7 +4194,11 @@ def acos(col: "ColumnOrName") -> Column:
| NaN|
+--------+
"""
return _invoke_function_over_columns("acos", col)
col = _to_column_expr(col)
# acos domain is [-1, 1]; return NaN for out-of-range values per PySpark semantics
return Column(
CaseExpression((col < -1.0) | (col > 1.0), _nan_constant()).otherwise(FunctionExpression("acos", col))
)


def call_function(funcName: str, *cols: "ColumnOrName") -> Column:
Expand Down
2 changes: 1 addition & 1 deletion duckdb/experimental/spark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def load( # noqa: D102
types, names = schema.extract_types_and_names()
df = df._cast_types(types)
df = df.toDF(names)
raise NotImplementedError
return df

def csv( # noqa: D102
self,
Expand Down
8 changes: 7 additions & 1 deletion duckdb/experimental/spark/sql/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from duckdb.sqltypes import DuckDBPyType

from ..exception import ContributionsAcceptedError
from .types import (
ArrayType,
BinaryType,
Expand Down Expand Up @@ -79,7 +80,12 @@ def convert_nested_type(dtype: DuckDBPyType) -> DataType: # noqa: D103
if id == "list" or id == "array":
children = dtype.children
return ArrayType(convert_type(children[0][1]))
# TODO: add support for 'union' # noqa: TD002, TD003
if id == "union":
msg = (
"Union types are not supported in the PySpark interface. "
"DuckDB union types cannot be directly mapped to PySpark types."
)
raise ContributionsAcceptedError(msg)
if id == "struct":
children: list[tuple[str, DuckDBPyType]] = dtype.children
fields = [StructField(x[0], convert_type(x[1])) for x in children]
Expand Down
71 changes: 71 additions & 0 deletions duckdb/experimental/spark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,77 @@ def fieldNames(self) -> list[str]:
"""
return list(self.names)

def treeString(self, level: Optional[int] = None) -> str:
"""Returns a string representation of the schema in tree format.

Parameters
----------
level : int, optional
Maximum depth to print. If None, prints all levels.

Returns:
-------
str
Tree-formatted schema string

Examples:
--------
>>> schema = StructType([StructField("age", IntegerType(), True)])
>>> print(schema.treeString())
root
|-- age: integer (nullable = true)
"""

def _tree_string(schema: "StructType", depth: int = 0, max_depth: Optional[int] = None) -> list[str]:
"""Recursively build tree string lines."""
lines = []
if depth == 0:
lines.append("root")

if max_depth is not None and depth >= max_depth:
return lines

for field in schema.fields:
indent = " " * depth
prefix = " |-- "
nullable_str = "true" if field.nullable else "false"

# Handle nested StructType
if isinstance(field.dataType, StructType):
lines.append(f"{indent}{prefix}{field.name}: struct (nullable = {nullable_str})")
# Recursively handle nested struct - don't skip any lines, root only appears at depth 0
nested_lines = _tree_string(field.dataType, depth + 1, max_depth)
lines.extend(nested_lines)
# Handle ArrayType
elif isinstance(field.dataType, ArrayType):
element_type = field.dataType.elementType
if isinstance(element_type, StructType):
lines.append(f"{indent}{prefix}{field.name}: array (nullable = {nullable_str})")
lines.append(
f"{indent} | |-- element: struct (containsNull = {field.dataType.containsNull})"
)
nested_lines = _tree_string(element_type, depth + 2, max_depth)
lines.extend(nested_lines)
else:
type_str = element_type.simpleString()
lines.append(f"{indent}{prefix}{field.name}: array<{type_str}> (nullable = {nullable_str})")
# Handle MapType
elif isinstance(field.dataType, MapType):
key_type = field.dataType.keyType.simpleString()
value_type = field.dataType.valueType.simpleString()
lines.append(
f"{indent}{prefix}{field.name}: map<{key_type},{value_type}> (nullable = {nullable_str})"
)
# Handle simple types
else:
type_str = field.dataType.simpleString()
lines.append(f"{indent}{prefix}{field.name}: {type_str} (nullable = {nullable_str})")

return lines

lines = _tree_string(self, 0, level)
return "\n".join(lines)

def needConversion(self) -> bool: # noqa: D102
# We need convert Row()/namedtuple into tuple()
return True
Expand Down
170 changes: 170 additions & 0 deletions tests/fast/spark/test_spark_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,173 @@ def test_cache(self, spark):
assert df is not cached
assert cached.collect() == df.collect()
assert cached.collect() == [Row(one=1, two=2, three=3, four=4)]

def test_dtypes(self, spark):
data = [("Alice", 25, 5000.0), ("Bob", 30, 6000.0)]
df = spark.createDataFrame(data, ["name", "age", "salary"])
dtypes = df.dtypes

assert isinstance(dtypes, list)
assert len(dtypes) == 3
for col_name, col_type in dtypes:
assert isinstance(col_name, str)
assert isinstance(col_type, str)

col_names = [name for name, _ in dtypes]
assert col_names == ["name", "age", "salary"]
for _, col_type in dtypes:
assert len(col_type) > 0

def test_dtypes_complex_types(self, spark):
from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType

schema = StructType(
[
StructField("name", StringType(), True),
StructField("scores", ArrayType(IntegerType()), True),
StructField(
"address",
StructType([StructField("city", StringType(), True), StructField("zip", StringType(), True)]),
True,
),
]
)
data = [
("Alice", [90, 85, 88], {"city": "NYC", "zip": "10001"}),
("Bob", [75, 80, 82], {"city": "LA", "zip": "90001"}),
]
df = spark.createDataFrame(data, schema)
dtypes = df.dtypes

assert len(dtypes) == 3
col_names = [name for name, _ in dtypes]
assert col_names == ["name", "scores", "address"]

def test_printSchema(self, spark, capsys):
data = [("Alice", 25, 5000), ("Bob", 30, 6000)]
df = spark.createDataFrame(data, ["name", "age", "salary"])
df.printSchema()
captured = capsys.readouterr()
output = captured.out

assert "root" in output
assert "name" in output
assert "age" in output
assert "salary" in output
assert "string" in output or "varchar" in output.lower()
assert "int" in output.lower() or "bigint" in output.lower()

def test_printSchema_nested(self, spark, capsys):
from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType

schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"person",
StructType([StructField("name", StringType(), True), StructField("age", IntegerType(), True)]),
True,
),
StructField("hobbies", ArrayType(StringType()), True),
]
)
data = [
(1, {"name": "Alice", "age": 25}, ["reading", "coding"]),
(2, {"name": "Bob", "age": 30}, ["gaming", "music"]),
]
df = spark.createDataFrame(data, schema)
df.printSchema()
captured = capsys.readouterr()
output = captured.out

assert "root" in output
assert "person" in output
assert "hobbies" in output

def test_printSchema_negative_level(self, spark):
data = [("Alice", 25)]
df = spark.createDataFrame(data, ["name", "age"])

with pytest.raises(PySparkValueError):
df.printSchema(level=-1)

def test_treeString_basic(self, spark):
data = [("Alice", 25, 5000)]
df = spark.createDataFrame(data, ["name", "age", "salary"])
tree = df.schema.treeString()

assert tree.startswith("root\n")
assert " |-- name:" in tree
assert " |-- age:" in tree
assert " |-- salary:" in tree
assert "(nullable = true)" in tree
assert tree.count(" |-- ") == 3

def test_treeString_nested_struct(self, spark):
from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType

schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"person",
StructType([StructField("name", StringType(), True), StructField("age", IntegerType(), True)]),
True,
),
]
)
data = [(1, {"name": "Alice", "age": 25})]
df = spark.createDataFrame(data, schema)
tree = df.schema.treeString()

assert "root\n" in tree
assert " |-- id:" in tree
assert " |-- person: struct (nullable = true)" in tree
assert "name:" in tree
assert "age:" in tree

def test_treeString_with_level(self, spark):
from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType

schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"person",
StructType(
[
StructField("name", StringType(), True),
StructField("details", StructType([StructField("address", StringType(), True)]), True),
]
),
True,
),
]
)

data = [(1, {"name": "Alice", "details": {"address": "123 Main St"}})]
df = spark.createDataFrame(data, schema)

# Level 1 should only show top-level fields
tree_level_1 = df.schema.treeString(level=1)
assert " |-- id:" in tree_level_1
assert " |-- person: struct" in tree_level_1
# Should not show nested field names at level 1
lines = tree_level_1.split("\n")
assert len([line for line in lines if line.strip()]) <= 3

def test_treeString_array_type(self, spark):
from spark_namespace.sql.types import ArrayType, StringType, StructField, StructType

schema = StructType(
[StructField("name", StringType(), True), StructField("hobbies", ArrayType(StringType()), True)]
)

data = [("Alice", ["reading", "coding"])]
df = spark.createDataFrame(data, schema)
tree = df.schema.treeString()

assert "root\n" in tree
assert " |-- name:" in tree
assert " |-- hobbies: array<" in tree
assert "(nullable = true)" in tree
Loading
Loading