Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .cursorrules
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ Additional for integration tests:
# Run local tests
./bin/test-local

# Run a specific test file
./bin/test-local tests/unit/test_file.py

# ... or specific test from file
./bin/test-local tests/unit/test_file.py::TestClass::test_method

# Run specific test type
export TEST_TYPE="unit|integration"
export TOOLKIT_VERSION="local-build"
Expand Down
48 changes: 18 additions & 30 deletions deepnote_toolkit/ocelots/pandas/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import pandas as pd

from deepnote_toolkit.ocelots.constants import DEEPNOTE_INDEX_COLUMN
from deepnote_toolkit.ocelots.pandas.utils import (
is_numeric_or_temporal,
is_type_datetime_or_timedelta,
safe_convert_to_string,
)
from deepnote_toolkit.ocelots.types import ColumnsStatsRecord, ColumnStats


Expand All @@ -24,7 +29,10 @@ def _get_categories(np_array):
# special treatment for empty values
num_nans = pandas_series.isna().sum().item()

counter = Counter(pandas_series.dropna().astype(str))
try:
counter = Counter(pandas_series.dropna().astype(str))
except (TypeError, UnicodeDecodeError, AttributeError):
counter = Counter(pandas_series.dropna().apply(safe_convert_to_string))

max_items = 3
if num_nans > 0:
Expand All @@ -46,33 +54,9 @@ def _get_categories(np_array):
return [{"name": name, "count": count} for name, count in categories]


def _is_type_numeric(dtype):
"""
Returns True if dtype is numeric, False otherwise

Numeric means either a number (int, float, complex) or a datetime or timedelta.
It means e.g. that a range of these values can be plotted on a histogram.
"""

# datetime doesn't play nice with np.issubdtype, so we need to check explicitly
if pd.api.types.is_datetime64_any_dtype(dtype) or pd.api.types.is_timedelta64_dtype(
dtype
):
return True

try:
return np.issubdtype(dtype, np.number)
except TypeError:
# np.issubdtype crashes on categorical column dtype, and also on others, e.g. geopandas types
return False


def _get_histogram(pd_series):
try:
if pd.api.types.is_datetime64_any_dtype(
pd_series
) or pd.api.types.is_timedelta64_dtype(pd_series):
# convert datetime or timedelta to an integer so that a histogram can be created
if is_type_datetime_or_timedelta(pd_series):
np_array = np.array(pd_series.dropna().astype(int))
else:
# let's drop infinite values because they break histograms
Expand Down Expand Up @@ -104,11 +88,15 @@ def _calculate_min_max(column):
"""
Calculate min and max values for a given column.
"""
if _is_type_numeric(column.dtype):
if not is_numeric_or_temporal(column.dtype):
return None, None

try:
min_value = str(min(column.dropna())) if len(column.dropna()) > 0 else None
max_value = str(max(column.dropna())) if len(column.dropna()) > 0 else None
return min_value, max_value
return None, None
except (TypeError, ValueError):
return None, None


def analyze_columns(
Expand Down Expand Up @@ -167,7 +155,7 @@ def analyze_columns(
unique_count=_count_unique(column), nan_count=column.isnull().sum().item()
)

if _is_type_numeric(column.dtype):
if is_numeric_or_temporal(column.dtype):
min_value, max_value = _calculate_min_max(column)
columns[i].stats.min = min_value
columns[i].stats.max = max_value
Expand All @@ -187,7 +175,7 @@ def analyze_columns(
for i in range(max_columns_to_analyze, len(df.columns)):
# Ignore columns that are not numeric
column = df.iloc[:, i]
if not _is_type_numeric(column.dtype):
if not is_numeric_or_temporal(column.dtype):
continue

column_name = columns[i].name
Expand Down
58 changes: 49 additions & 9 deletions deepnote_toolkit/ocelots/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@
from deepnote_toolkit.ocelots.constants import MAX_STRING_CELL_LENGTH


def safe_convert_to_string(value):
"""
Safely convert a value to string, handling cases where str() might fail.

Note: For bytes, this returns Python's standard string representation (e.g., b'hello')
rather than base64 encoding, which is more human-readable.
"""
try:
return str(value)
except Exception:
return "<unconvertible>"


# like fillna, but only fills NaT (not a time) values in datetime columns with the specified value
def fill_nat(df, value):
df_datetime_columns = df.select_dtypes(
Expand Down Expand Up @@ -76,36 +89,63 @@ def deduplicate_columns(df):
# Cast dataframe contents to strings and trim them to avoid sending too much data
def cast_objects_to_string(df):
def to_string_truncated(elem):
elem_string = str(elem)
elem_string = safe_convert_to_string(elem)
return (
(elem_string[: MAX_STRING_CELL_LENGTH - 1] + "…")
if len(elem_string) > MAX_STRING_CELL_LENGTH
else elem_string
)

for column in df:
if not _is_type_number(df[column].dtype):
if not is_pure_numeric(df[column].dtype):
# if the dtype is not a number, we want to convert it to string and truncate
df[column] = df[column].apply(to_string_truncated)

return df


def _is_type_number(dtype):
def is_type_datetime_or_timedelta(series_or_dtype):
"""
Returns True if dtype is a number, False otherwise. Datetime and timedelta will return False.
Returns True if the series or dtype is datetime or timedelta, False otherwise.
"""
return pd.api.types.is_datetime64_any_dtype(
series_or_dtype
) or pd.api.types.is_timedelta64_dtype(series_or_dtype)


The primary intent of this is to recognize a value that will converted to a JSON number during serialization.
def is_numeric_or_temporal(dtype):
"""
Returns True if dtype is numeric or temporal (datetime/timedelta), False otherwise.

if pd.api.types.is_datetime64_any_dtype(dtype) or pd.api.types.is_timedelta64_dtype(
dtype
):
This includes numbers (int, float), datetime, and timedelta types.
Use this to determine if values can be plotted on a histogram or have min/max calculated.
"""
if is_type_datetime_or_timedelta(dtype):
return True

try:
return np.issubdtype(dtype, np.number) and not np.issubdtype(
dtype, np.complexfloating
)
except TypeError:
# np.issubdtype crashes on categorical column dtype, and also on others, e.g. geopandas types
return False


def is_pure_numeric(dtype):
"""
Returns True if dtype is a pure number (int, float), False otherwise.

Use this to determine if a value will be serialized as a JSON number.
"""
if is_type_datetime_or_timedelta(dtype):
# np.issubdtype(dtype, np.number) returns True for timedelta, which we don't want
return False

try:
return np.issubdtype(dtype, np.number)
return np.issubdtype(dtype, np.number) and not np.issubdtype(
dtype, np.complexfloating
)
except TypeError:
# np.issubdtype crashes on categorical column dtype, and also on others, e.g. geopandas types
return False
27 changes: 21 additions & 6 deletions deepnote_toolkit/ocelots/pyspark/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,22 @@ def to_records(self, mode: Literal["json", "python"]) -> List[Dict[str, Any]]:
StructField,
)

def binary_to_string_repr(
binary_data: Optional[Union[bytes, bytearray]]
) -> Optional[str]:
"""Convert binary data to Python string representation (e.g., b'hello').

Args:
binary_data: Binary data as bytes or bytearray. PySpark passes BinaryType
as bytearray by default, but Spark 4.1+ with
spark.sql.execution.pyspark.binaryAsBytes=true passes bytes instead.
"""
if binary_data is None:
return None
return str(bytes(binary_data))

binary_udf = F.udf(binary_to_string_repr, StringType())

def select_column(field: StructField) -> Column:
col = F.col(field.name)
# Numbers are already JSON-serialise, except Decimal
Expand All @@ -240,11 +256,12 @@ def select_column(field: StructField) -> Column:
):
return col

# We slice binary field before encoding to avoid encoding potentially big blob. Round slicing to
# 4 bytes to avoid breaking multi-byte sequences
# We slice binary field before converting to string representation
if isinstance(field.dataType, BinaryType):
sliced = F.substring(field, 1, keep_bytes)
return F.base64(sliced)
# Each byte becomes up to 4 chars (\xNN) in string repr, plus b'' overhead
max_binary_bytes = (MAX_STRING_CELL_LENGTH - 3) // 4
sliced = F.substring(F.col(field.name), 1, max_binary_bytes)
return binary_udf(sliced)

# String just needs to be trimmed
if isinstance(field.dataType, StringType):
Expand All @@ -253,8 +270,6 @@ def select_column(field: StructField) -> Column:
# Everything else gets stringified (Decimal, Date, Timestamp, Struct, …)
return F.substring(col.cast("string"), 1, MAX_STRING_CELL_LENGTH)

keep_bytes = (MAX_STRING_CELL_LENGTH // 4) * 3

if mode == "python":
return [row.asDict() for row in self._df.collect()]
elif mode == "json":
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/helpers/testing_dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,14 @@ def create_dataframe_with_duplicate_column_names():
datetime.datetime(2023, 1, 1, 12, 0, 0),
datetime.datetime(2023, 1, 2, 12, 0, 0),
],
"binary": [b"hello", b"world"],
}
),
"pyspark_schema": pst.StructType(
[
pst.StructField("list", pst.ArrayType(pst.IntegerType()), True),
pst.StructField("datetime", pst.TimestampType(), True),
pst.StructField("binary", pst.BinaryType(), True),
]
),
},
Expand Down
Loading