Skip to content

Commit

Permalink
create hash_abs function to return absolute value of a hash value (#161)
Browse files Browse the repository at this point in the history
* create hash_abs function to return absolute value of a hash value

* rename file

---------

Co-authored-by: kyle paul <kyle@icanbwell.com>
  • Loading branch information
kpaul and kyle paul committed Oct 6, 2023
1 parent 172e48a commit d2081a2
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 0 deletions.
57 changes: 57 additions & 0 deletions spark_auto_mapper/data_types/hash_abs.py
@@ -0,0 +1,57 @@
from typing import List, Optional, Union

from pyspark.sql import Column, DataFrame
from pyspark.sql.functions import hash, abs as abs_

from spark_auto_mapper.data_types.data_type_base import AutoMapperDataTypeBase
from spark_auto_mapper.data_types.text_like_base import AutoMapperTextLikeBase
from spark_auto_mapper.helpers.value_parser import AutoMapperValueParser
from spark_auto_mapper.type_definitions.native_types import AutoMapperNativeTextType
from spark_auto_mapper.type_definitions.wrapper_types import AutoMapperWrapperType


class AutoMapperHashAbsDataType(AutoMapperTextLikeBase):
"""
Calculates the hash code of given columns, and returns the absolute value of the result as an int column.
"""

def __init__(
self,
*args: Union[
AutoMapperNativeTextType, AutoMapperWrapperType, AutoMapperTextLikeBase
],
):
super().__init__()

self.value: List[AutoMapperDataTypeBase] = [
value
if isinstance(value, AutoMapperDataTypeBase)
else AutoMapperValueParser.parse_value(value=value)
for value in args
]

def get_column_spec(
self,
source_df: Optional[DataFrame],
current_column: Optional[Column],
parent_columns: Optional[List[Column]],
) -> Column:
column_spec = abs_(
hash(
*[
col.get_column_spec(
source_df=source_df,
current_column=current_column,
parent_columns=parent_columns,
)
for col in self.value
]
)
)
return column_spec.cast("int")

@property
def children(
self,
) -> Union[AutoMapperDataTypeBase, List[AutoMapperDataTypeBase]]:
return self.value
16 changes: 16 additions & 0 deletions spark_auto_mapper/helpers/automapper_helpers.py
Expand Up @@ -9,6 +9,7 @@
from spark_auto_mapper.data_types.array_distinct import AutoMapperArrayDistinctDataType
from spark_auto_mapper.data_types.base64 import AutoMapperBase64DataType
from spark_auto_mapper.data_types.exists import AutoMapperExistsDataType
from spark_auto_mapper.data_types.hash_abs import AutoMapperHashAbsDataType
from spark_auto_mapper.data_types.nested_array_filter import (
AutoMapperNestedArrayFilterDataType,
)
Expand Down Expand Up @@ -507,6 +508,21 @@ def hash(
"""
return AutoMapperHashDataType(*args)

@staticmethod
def hash_abs(
*args: Union[
AutoMapperNativeTextType, AutoMapperWrapperType, AutoMapperTextLikeBase
]
) -> AutoMapperHashAbsDataType:
"""
Calculates the hash code of given columns, and returns the absolute value of the result as an int column.
:param args: string or column
:return: a concat automapper type
"""
return AutoMapperHashAbsDataType(*args)

@staticmethod
def coalesce(*args: _TAutoMapperDataType) -> _TAutoMapperDataType:
"""
Expand Down
Empty file added tests/hash_abs/__init__.py
Empty file.
64 changes: 64 additions & 0 deletions tests/hash_abs/test_automapper_hash_abs.py
@@ -0,0 +1,64 @@
from typing import Dict

from pyspark.sql import SparkSession, Column, DataFrame

# noinspection PyUnresolvedReferences
from pyspark.sql.functions import col, hash, abs as abs_, concat
from spark_auto_mapper.data_types.text_like_base import AutoMapperTextLikeBase

from spark_auto_mapper.automappers.automapper import AutoMapper
from spark_auto_mapper.helpers.automapper_helpers import AutoMapperHelpers as A
from spark_auto_mapper.helpers.expression_comparer import assert_compare_expressions


def test_auto_mapper_hash_abs(spark_session: SparkSession) -> None:
# Arrange
spark_session.createDataFrame(
[
(1, "Qureshi", "54"),
(2, "Vidal", "67"),
(3, "Vidal", None),
(4, None, None),
],
["member_id", "last_name", "my_age"],
).createOrReplaceTempView("patients")

source_df: DataFrame = spark_session.table("patients")

source_df = source_df.withColumn("my_age", col("my_age").cast("int"))

df = source_df.select("member_id")
df.createOrReplaceTempView("members")

# create a function that returns the columns to hash
def get_columns_to_hash() -> AutoMapperTextLikeBase:
return A.concat(A.column("my_age"), A.column("last_name"))

# Act
mapper = AutoMapper(
view="members", source_view="patients", keys=["member_id"]
).columns(age=A.hash_abs(get_columns_to_hash()))

assert isinstance(mapper, AutoMapper)
sql_expressions: Dict[str, Column] = mapper.get_column_specs(source_df=source_df)
for column_name, sql_expression in sql_expressions.items():
print(f"{column_name}: {sql_expression}")

assert_compare_expressions(
sql_expressions["age"],
abs_(hash(concat(col("b.my_age"), col("b.last_name"))))
.cast("string")
.alias("age"),
)

result_df: DataFrame = mapper.transform(df=df)

# Assert
result_df.printSchema()
result_df.show()

assert result_df.where("member_id == 1").select("age").collect()[0][0] == 181944084
assert result_df.where("member_id == 2").select("age").collect()[0][0] == 1424244624
assert result_df.where("member_id == 3").select("age").collect()[0][0] == 42

assert dict(result_df.dtypes)["age"] == "int"

0 comments on commit d2081a2

Please sign in to comment.