diff --git a/spark_auto_mapper/data_types/hash_abs.py b/spark_auto_mapper/data_types/hash_abs.py new file mode 100644 index 0000000..c93a82d --- /dev/null +++ b/spark_auto_mapper/data_types/hash_abs.py @@ -0,0 +1,57 @@ +from typing import List, Optional, Union + +from pyspark.sql import Column, DataFrame +from pyspark.sql.functions import hash, abs as abs_ + +from spark_auto_mapper.data_types.data_type_base import AutoMapperDataTypeBase +from spark_auto_mapper.data_types.text_like_base import AutoMapperTextLikeBase +from spark_auto_mapper.helpers.value_parser import AutoMapperValueParser +from spark_auto_mapper.type_definitions.native_types import AutoMapperNativeTextType +from spark_auto_mapper.type_definitions.wrapper_types import AutoMapperWrapperType + + +class AutoMapperHashAbsDataType(AutoMapperTextLikeBase): + """ + Calculates the hash code of given columns, and returns the absolute value of the result as an int column. + """ + + def __init__( + self, + *args: Union[ + AutoMapperNativeTextType, AutoMapperWrapperType, AutoMapperTextLikeBase + ], + ): + super().__init__() + + self.value: List[AutoMapperDataTypeBase] = [ + value + if isinstance(value, AutoMapperDataTypeBase) + else AutoMapperValueParser.parse_value(value=value) + for value in args + ] + + def get_column_spec( + self, + source_df: Optional[DataFrame], + current_column: Optional[Column], + parent_columns: Optional[List[Column]], + ) -> Column: + column_spec = abs_( + hash( + *[ + col.get_column_spec( + source_df=source_df, + current_column=current_column, + parent_columns=parent_columns, + ) + for col in self.value + ] + ) + ) + return column_spec.cast("int") + + @property + def children( + self, + ) -> Union[AutoMapperDataTypeBase, List[AutoMapperDataTypeBase]]: + return self.value diff --git a/spark_auto_mapper/helpers/automapper_helpers.py b/spark_auto_mapper/helpers/automapper_helpers.py index 8a03b11..d46c803 100644 --- a/spark_auto_mapper/helpers/automapper_helpers.py +++ b/spark_auto_mapper/helpers/automapper_helpers.py @@ -9,6 +9,7 @@ from spark_auto_mapper.data_types.array_distinct import AutoMapperArrayDistinctDataType from spark_auto_mapper.data_types.base64 import AutoMapperBase64DataType from spark_auto_mapper.data_types.exists import AutoMapperExistsDataType +from spark_auto_mapper.data_types.hash_abs import AutoMapperHashAbsDataType from spark_auto_mapper.data_types.nested_array_filter import ( AutoMapperNestedArrayFilterDataType, ) @@ -507,6 +508,21 @@ def hash( """ return AutoMapperHashDataType(*args) + @staticmethod + def hash_abs( + *args: Union[ + AutoMapperNativeTextType, AutoMapperWrapperType, AutoMapperTextLikeBase + ] + ) -> AutoMapperHashAbsDataType: + """ + Calculates the hash code of given columns, and returns the absolute value of the result as an int column. + + + :param args: string or column + :return: a concat automapper type + """ + return AutoMapperHashAbsDataType(*args) + @staticmethod def coalesce(*args: _TAutoMapperDataType) -> _TAutoMapperDataType: """ diff --git a/tests/hash_abs/__init__.py b/tests/hash_abs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/hash_abs/test_automapper_hash_abs.py b/tests/hash_abs/test_automapper_hash_abs.py new file mode 100644 index 0000000..4307948 --- /dev/null +++ b/tests/hash_abs/test_automapper_hash_abs.py @@ -0,0 +1,64 @@ +from typing import Dict + +from pyspark.sql import SparkSession, Column, DataFrame + +# noinspection PyUnresolvedReferences +from pyspark.sql.functions import col, hash, abs as abs_, concat +from spark_auto_mapper.data_types.text_like_base import AutoMapperTextLikeBase + +from spark_auto_mapper.automappers.automapper import AutoMapper +from spark_auto_mapper.helpers.automapper_helpers import AutoMapperHelpers as A +from spark_auto_mapper.helpers.expression_comparer import assert_compare_expressions + + +def test_auto_mapper_hash_abs(spark_session: SparkSession) -> None: + # Arrange + spark_session.createDataFrame( + [ + (1, "Qureshi", "54"), + (2, "Vidal", "67"), + (3, "Vidal", None), + (4, None, None), + ], + ["member_id", "last_name", "my_age"], + ).createOrReplaceTempView("patients") + + source_df: DataFrame = spark_session.table("patients") + + source_df = source_df.withColumn("my_age", col("my_age").cast("int")) + + df = source_df.select("member_id") + df.createOrReplaceTempView("members") + + # create a function that returns the columns to hash + def get_columns_to_hash() -> AutoMapperTextLikeBase: + return A.concat(A.column("my_age"), A.column("last_name")) + + # Act + mapper = AutoMapper( + view="members", source_view="patients", keys=["member_id"] + ).columns(age=A.hash_abs(get_columns_to_hash())) + + assert isinstance(mapper, AutoMapper) + sql_expressions: Dict[str, Column] = mapper.get_column_specs(source_df=source_df) + for column_name, sql_expression in sql_expressions.items(): + print(f"{column_name}: {sql_expression}") + + assert_compare_expressions( + sql_expressions["age"], + abs_(hash(concat(col("b.my_age"), col("b.last_name")))) + .cast("string") + .alias("age"), + ) + + result_df: DataFrame = mapper.transform(df=df) + + # Assert + result_df.printSchema() + result_df.show() + + assert result_df.where("member_id == 1").select("age").collect()[0][0] == 181944084 + assert result_df.where("member_id == 2").select("age").collect()[0][0] == 1424244624 + assert result_df.where("member_id == 3").select("age").collect()[0][0] == 42 + + assert dict(result_df.dtypes)["age"] == "int"