Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions datacompy/spark/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import logging

import pyspark.sql

from datacompy.spark.sql import SparkSQLCompare

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -98,10 +100,10 @@ def compare_by_row(
sorted_base_df, sorted_compare_df = sort_rows(base_dataframe, compare_dataframe)
column_to_join = ["row"]

LOG.info("Compared by column(s): ", column_to_join)
LOG.info("Compared by column(s): %s", column_to_join)
if string2double_cols:
LOG.info(
"String column(s) cast to doubles for numeric comparison: ",
"String column(s) cast to doubles for numeric comparison: %s",
string2double_cols,
)
return SparkSQLCompare(
Expand Down Expand Up @@ -222,25 +224,24 @@ def format_numeric_fields(df: "pyspark.sql.DataFrame") -> "pyspark.sql.DataFrame
-------
pyspark.sql.DataFrame
"""
fixed_cols = []
numeric_types = [
# Set-based lookup is faster than list-based for membership tests
numeric_types = {
"tinyint",
"smallint",
"int",
"bigint",
"float",
"double",
"decimal",
]

for c in df.dtypes:
# do not change non-numeric fields
if c[1] not in numeric_types:
fixed_cols.append(col(c[0]))
# round & truncate numeric fields
else:
new_val = format_number(col(c[0]), 5).alias(c[0])
fixed_cols.append(new_val)
}

# Precompute the column transform functions with a generator expression
fixed_cols = (
format_number(col(c[0]), 5).alias(c[0]) if c[1] in numeric_types else col(c[0])
for c in df.dtypes
)

# No need to unpack a list; select accepts * generator expression directly

formatted_df = df.select(*fixed_cols)
return formatted_df