Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions datacompy/spark/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import logging

import pyspark.sql

from datacompy.spark.sql import SparkSQLCompare

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -98,10 +100,10 @@ def compare_by_row(
sorted_base_df, sorted_compare_df = sort_rows(base_dataframe, compare_dataframe)
column_to_join = ["row"]

LOG.info("Compared by column(s): ", column_to_join)
LOG.info("Compared by column(s): %s", column_to_join)
if string2double_cols:
LOG.info(
"String column(s) cast to doubles for numeric comparison: ",
"String column(s) cast to doubles for numeric comparison: %s",
string2double_cols,
)
return SparkSQLCompare(
Expand Down Expand Up @@ -162,22 +164,28 @@ def sort_rows(
compare_cols = compare_df.columns

# Ensure both DataFrames have the same columns

# Ensure both DataFrames have the same columns (optimized: use set for faster membership tests)
compare_cols_set = set(compare_cols)
for x in base_cols:
if x not in compare_cols:
if x not in compare_cols_set:
raise Exception(
f"{x} is present in base_df but does not exist in compare_df"
)

if set(base_cols) != set(compare_cols):
if set(base_cols) != compare_cols_set:
LOG.warning(
"WARNING: There are columns present in Compare df that do not exist in Base df. "
"The Base df columns will be used for row-wise sorting and may produce unanticipated "
"report output if the extra fields are not null."
)

w = Window.orderBy(*base_cols)
sorted_base_df = base_df.select("*", row_number().over(w).alias("row"))
sorted_compare_df = compare_df.select("*", row_number().over(w).alias("row"))

# Use withColumn instead of select("*", ...) to reduce column metadata copying overhead
sorted_base_df = base_df.withColumn("row", row_number().over(w))
sorted_compare_df = compare_df.withColumn("row", row_number().over(w))

return sorted_base_df, sorted_compare_df


Expand Down