Skip to content

Commit

Permalink
Merge pull request #159 from icanbwell/sg-HDE-3445
Browse files Browse the repository at this point in the history
HDE-3445: Fixing skip_if_column_null functionality
  • Loading branch information
shubhamgoeljtg committed Oct 4, 2023
2 parents 7dad80e + c532703 commit 79ebe68
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 4 deletions.
10 changes: 7 additions & 3 deletions spark_auto_mapper/automappers/with_column_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# noinspection PyUnresolvedReferences
from pyspark.sql.functions import col, when, lit, size
from pyspark.sql.types import DataType, StructField
from pyspark.sql.types import DataType, StructField, ArrayType
from pyspark.sql.utils import AnalysisException
from spark_data_frame_comparer.schema_comparer import SchemaComparer

Expand Down Expand Up @@ -62,12 +62,16 @@ def get_column_spec(self, source_df: Optional[DataFrame]) -> Column:
source_df=source_df, current_column=None, parent_columns=None
)
if self.skip_if_columns_null_or_empty:
column_type_dict = dict(source_df.dtypes) # type: ignore
is_first_when_case = True
for column in self.skip_if_columns_null_or_empty:
column_type = (
source_df.select(column).schema[column.split(".")[-1]].dataType # type: ignore
if "." in column
else source_df.schema[column].dataType # type: ignore
)
column_to_check = f"b.{column}"
# wrap column spec in when
if column_type_dict[column].startswith("array"):
if isinstance(column_type, ArrayType):
column_spec = (
when(
col(column_to_check).isNull()
Expand Down
87 changes: 86 additions & 1 deletion tests/complex/test_automapper_complex_with_skip_if_null.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@

# noinspection PyUnresolvedReferences
from pyspark.sql.functions import col, when, lit, size
from pyspark.sql.types import StructType, StructField, StringType, LongType, DataType
from pyspark.sql.types import (
StructType,
StructField,
StringType,
LongType,
DataType,
IntegerType,
ArrayType,
)

from spark_auto_mapper.automappers.automapper import AutoMapper
from spark_auto_mapper.data_types.complex.complex_base import (
Expand Down Expand Up @@ -226,3 +234,80 @@ def test_automapper_complex_with_skip_if_null(spark_session: SparkSession) -> No

assert result_df.count() == 1
assert result_df.where("id == 1").select("name").collect()[0][0] == "Qureshi"

# Case when nested columns are present in skip_if_columns_null field
spark_session.createDataFrame(
[
(1, "Qureshi", "Imran", 45, {"nid": 123, "ssn": "", "lis": ["123"]}),
(2, "Goel", "Shubham", 35, {"nid": 456, "ssn": "456", "lis": ["123"]}),
(3, "Chawla", "Gagan", 12, {"nid": 456, "ssn": "456", "lis": []}),
],
StructType(
[
StructField("member_id", IntegerType()),
StructField("last_name", StringType()),
StructField("first_name", StringType()),
StructField("my_age", IntegerType()),
StructField(
"exploded",
StructType(
[
StructField("nid", StringType()),
StructField("ssn", StringType()),
StructField("lis", ArrayType(StringType())),
]
),
),
]
),
).createOrReplaceTempView("patients")

source_df = spark_session.table("patients")

df = source_df.select("member_id")
df.createOrReplaceTempView("members")

# Map
mapper = AutoMapper(
view="members",
source_view="patients",
keys=["member_id"],
drop_key_columns=True,
skip_if_columns_null_or_empty=["first_name", "exploded.ssn", "exploded.lis"],
).complex(
MyClass(
id_=A.column("member_id"),
name=A.column("last_name"),
age=A.number(A.column("my_age")),
)
)

assert isinstance(mapper, AutoMapper)
sql_expressions = mapper.get_column_specs(source_df=source_df)
for column_name, sql_expression in sql_expressions.items():
print(f"{column_name}: {sql_expression}")

result_df = mapper.transform(df=df)

# Assert
assert_compare_expressions(
sql_expressions["name"],
when(
col("b.first_name").isNull() | col("b.first_name").eqNullSafe(""), lit(None)
)
.when(
col("b.exploded.ssn").isNull() | col("b.exploded.ssn").eqNullSafe(""),
lit(None),
)
.when(col("b.exploded.lis").isNull() | (size("b.exploded.lis") == 0), lit(None))
.otherwise(col("b.last_name"))
.cast(StringType())
.alias("name"),
)

result_df.printSchema()

result_df.show()

assert result_df.count() == 1
assert result_df.where("id == 2").select("name").collect()[0][0] == "Goel"

0 comments on commit 79ebe68

Please sign in to comment.