Skip to content

Commit

Permalink
Merge pull request #158 from icanbwell/sg-HDE-3445
Browse files Browse the repository at this point in the history
HDE-3445: Add support for column as list data type in skip_if_column_null
  • Loading branch information
shubhamgoeljtg committed Oct 4, 2023
2 parents b8e6ca8 + 87a7351 commit 7dad80e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 13 deletions.
40 changes: 28 additions & 12 deletions spark_auto_mapper/automappers/with_column_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pyspark.sql import Column, DataFrame

# noinspection PyUnresolvedReferences
from pyspark.sql.functions import col, when, lit
from pyspark.sql.functions import col, when, lit, size
from pyspark.sql.types import DataType, StructField
from pyspark.sql.utils import AnalysisException
from spark_data_frame_comparer.schema_comparer import SchemaComparer
Expand Down Expand Up @@ -62,23 +62,39 @@ def get_column_spec(self, source_df: Optional[DataFrame]) -> Column:
source_df=source_df, current_column=None, parent_columns=None
)
if self.skip_if_columns_null_or_empty:
column_type_dict = dict(source_df.dtypes) # type: ignore
is_first_when_case = True
for column in self.skip_if_columns_null_or_empty:
column_to_check = f"b.{column}"
# wrap column spec in when
column_spec = (
when(
col(column_to_check).isNull()
| col(column_to_check).eqNullSafe(""),
lit(None),
if column_type_dict[column].startswith("array"):
column_spec = (
when(
col(column_to_check).isNull()
| (size(col(column_to_check)) == 0),
lit(None),
)
if is_first_when_case
else column_spec.when(
col(column_to_check).isNull()
| (size(col(column_to_check)) == 0),
lit(None),
)
)
if is_first_when_case
else column_spec.when(
col(column_to_check).isNull()
| col(column_to_check).eqNullSafe(""),
lit(None),
else:
column_spec = (
when(
col(column_to_check).isNull()
| col(column_to_check).eqNullSafe(""),
lit(None),
)
if is_first_when_case
else column_spec.when(
col(column_to_check).isNull()
| col(column_to_check).eqNullSafe(""),
lit(None),
)
)
)
is_first_when_case = False

column_spec = column_spec.otherwise(
Expand Down
68 changes: 67 additions & 1 deletion tests/complex/test_automapper_complex_with_skip_if_null.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pyspark.sql import SparkSession, Column, DataFrame

# noinspection PyUnresolvedReferences
from pyspark.sql.functions import col, when, lit
from pyspark.sql.functions import col, when, lit, size
from pyspark.sql.types import StructType, StructField, StringType, LongType, DataType

from spark_auto_mapper.automappers.automapper import AutoMapper
Expand Down Expand Up @@ -160,3 +160,69 @@ def test_automapper_complex_with_skip_if_null(spark_session: SparkSession) -> No
assert result_df.where("id == 1").select("name").collect()[0][0] == "Qureshi"

assert dict(result_df.dtypes)["age"] in ("int", "long", "bigint")

# Case when list column type is given in skip_if_column_null_or_empty field
# Arrange
spark_session.createDataFrame(
[
(1, "Qureshi", "Imran", 45, ["123", "456"]),
(2, "Goel", "Shubham", 35, []),
],
["member_id", "last_name", "first_name", "my_age", "list_of_ids"],
).createOrReplaceTempView("patients")

source_df = spark_session.table("patients")

df = source_df.select("member_id")
df.createOrReplaceTempView("members")

# Map
mapper = AutoMapper(
view="members",
source_view="patients",
keys=["member_id"],
drop_key_columns=True,
skip_if_columns_null_or_empty=["first_name", "list_of_ids"],
).complex(
MyClass(
id_=A.column("member_id"),
name=A.column("last_name"),
age=A.number(A.column("my_age")),
)
)

assert isinstance(mapper, AutoMapper)
sql_expressions = mapper.get_column_specs(source_df=source_df)
for column_name, sql_expression in sql_expressions.items():
print(f"{column_name}: {sql_expression}")

result_df = mapper.transform(df=df)

# Assert
assert_compare_expressions(
sql_expressions["name"],
when(
col("b.first_name").isNull() | col("b.first_name").eqNullSafe(""), lit(None)
)
.when(col("b.list_of_ids").isNull() | (size("b.list_of_ids") == 0), lit(None))
.otherwise(col("b.last_name"))
.cast(StringType())
.alias("name"),
)
assert_compare_expressions(
sql_expressions["age"],
when(
col("b.first_name").isNull() | col("b.first_name").eqNullSafe(""), lit(None)
)
.when(col("b.list_of_ids").isNull() | (size("b.list_of_ids") == 0), lit(None))
.otherwise(col("b.my_age"))
.cast(LongType())
.alias("age"),
)

result_df.printSchema()

result_df.show()

assert result_df.count() == 1
assert result_df.where("id == 1").select("name").collect()[0][0] == "Qureshi"

0 comments on commit 7dad80e

Please sign in to comment.