Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
create hash_abs function to return absolute value of a hash value (#161)
* create hash_abs function to return absolute value of a hash value * rename file --------- Co-authored-by: kyle paul <kyle@icanbwell.com>
- Loading branch information
Showing
4 changed files
with
137 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import List, Optional, Union | ||
|
||
from pyspark.sql import Column, DataFrame | ||
from pyspark.sql.functions import hash, abs as abs_ | ||
|
||
from spark_auto_mapper.data_types.data_type_base import AutoMapperDataTypeBase | ||
from spark_auto_mapper.data_types.text_like_base import AutoMapperTextLikeBase | ||
from spark_auto_mapper.helpers.value_parser import AutoMapperValueParser | ||
from spark_auto_mapper.type_definitions.native_types import AutoMapperNativeTextType | ||
from spark_auto_mapper.type_definitions.wrapper_types import AutoMapperWrapperType | ||
|
||
|
||
class AutoMapperHashAbsDataType(AutoMapperTextLikeBase): | ||
""" | ||
Calculates the hash code of given columns, and returns the absolute value of the result as an int column. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*args: Union[ | ||
AutoMapperNativeTextType, AutoMapperWrapperType, AutoMapperTextLikeBase | ||
], | ||
): | ||
super().__init__() | ||
|
||
self.value: List[AutoMapperDataTypeBase] = [ | ||
value | ||
if isinstance(value, AutoMapperDataTypeBase) | ||
else AutoMapperValueParser.parse_value(value=value) | ||
for value in args | ||
] | ||
|
||
def get_column_spec( | ||
self, | ||
source_df: Optional[DataFrame], | ||
current_column: Optional[Column], | ||
parent_columns: Optional[List[Column]], | ||
) -> Column: | ||
column_spec = abs_( | ||
hash( | ||
*[ | ||
col.get_column_spec( | ||
source_df=source_df, | ||
current_column=current_column, | ||
parent_columns=parent_columns, | ||
) | ||
for col in self.value | ||
] | ||
) | ||
) | ||
return column_spec.cast("int") | ||
|
||
@property | ||
def children( | ||
self, | ||
) -> Union[AutoMapperDataTypeBase, List[AutoMapperDataTypeBase]]: | ||
return self.value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from typing import Dict | ||
|
||
from pyspark.sql import SparkSession, Column, DataFrame | ||
|
||
# noinspection PyUnresolvedReferences | ||
from pyspark.sql.functions import col, hash, abs as abs_, concat | ||
from spark_auto_mapper.data_types.text_like_base import AutoMapperTextLikeBase | ||
|
||
from spark_auto_mapper.automappers.automapper import AutoMapper | ||
from spark_auto_mapper.helpers.automapper_helpers import AutoMapperHelpers as A | ||
from spark_auto_mapper.helpers.expression_comparer import assert_compare_expressions | ||
|
||
|
||
def test_auto_mapper_hash_abs(spark_session: SparkSession) -> None: | ||
# Arrange | ||
spark_session.createDataFrame( | ||
[ | ||
(1, "Qureshi", "54"), | ||
(2, "Vidal", "67"), | ||
(3, "Vidal", None), | ||
(4, None, None), | ||
], | ||
["member_id", "last_name", "my_age"], | ||
).createOrReplaceTempView("patients") | ||
|
||
source_df: DataFrame = spark_session.table("patients") | ||
|
||
source_df = source_df.withColumn("my_age", col("my_age").cast("int")) | ||
|
||
df = source_df.select("member_id") | ||
df.createOrReplaceTempView("members") | ||
|
||
# create a function that returns the columns to hash | ||
def get_columns_to_hash() -> AutoMapperTextLikeBase: | ||
return A.concat(A.column("my_age"), A.column("last_name")) | ||
|
||
# Act | ||
mapper = AutoMapper( | ||
view="members", source_view="patients", keys=["member_id"] | ||
).columns(age=A.hash_abs(get_columns_to_hash())) | ||
|
||
assert isinstance(mapper, AutoMapper) | ||
sql_expressions: Dict[str, Column] = mapper.get_column_specs(source_df=source_df) | ||
for column_name, sql_expression in sql_expressions.items(): | ||
print(f"{column_name}: {sql_expression}") | ||
|
||
assert_compare_expressions( | ||
sql_expressions["age"], | ||
abs_(hash(concat(col("b.my_age"), col("b.last_name")))) | ||
.cast("string") | ||
.alias("age"), | ||
) | ||
|
||
result_df: DataFrame = mapper.transform(df=df) | ||
|
||
# Assert | ||
result_df.printSchema() | ||
result_df.show() | ||
|
||
assert result_df.where("member_id == 1").select("age").collect()[0][0] == 181944084 | ||
assert result_df.where("member_id == 2").select("age").collect()[0][0] == 1424244624 | ||
assert result_df.where("member_id == 3").select("age").collect()[0][0] == 42 | ||
|
||
assert dict(result_df.dtypes)["age"] == "int" |