Skip to content

Commit

Permalink
Merge pull request #152 from icanbwell/gc-athd-4133
Browse files Browse the repository at this point in the history
ATHD-4133 - Added support for pyspark base64 sql function
  • Loading branch information
gagan-chawla committed Dec 21, 2022
2 parents e177d4d + ecbe289 commit 6504217
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 0 deletions.
42 changes: 42 additions & 0 deletions spark_auto_mapper/data_types/base64.py
@@ -0,0 +1,42 @@
from typing import List, Optional, Union

from pyspark.sql import Column, DataFrame
from pyspark.sql.functions import base64

from spark_auto_mapper.data_types.data_type_base import AutoMapperDataTypeBase
from spark_auto_mapper.data_types.text_like_base import AutoMapperTextLikeBase
from spark_auto_mapper.type_definitions.wrapper_types import (
AutoMapperColumnOrColumnLikeType,
)


class AutoMapperBase64DataType(AutoMapperTextLikeBase):
"""
Computes the BASE64 encoding and returns it as a string
"""

def __init__(self, column: AutoMapperColumnOrColumnLikeType):
super().__init__()

self.column: AutoMapperColumnOrColumnLikeType = column

def get_column_spec(
self,
source_df: Optional[DataFrame],
current_column: Optional[Column],
parent_columns: Optional[List[Column]],
) -> Column:
column_spec = base64(
self.column.get_column_spec(
source_df=source_df,
current_column=current_column,
parent_columns=parent_columns,
)
)
return column_spec

@property
def children(
self,
) -> Union[AutoMapperDataTypeBase, List[AutoMapperDataTypeBase]]:
return self.column
19 changes: 19 additions & 0 deletions spark_auto_mapper/data_types/data_type_base.py
Expand Up @@ -483,6 +483,25 @@ def join_using_delimiter(
AutoMapperJoinUsingDelimiterDataType(column=self, delimiter=delimiter),
)

# noinspection PyMethodMayBeStatic
def base64(self: _TAutoMapperDataType) -> "AutoMapperTextLikeBase":
"""
Computes the BASE64 encoding of the column
:param self: Set by Python. No need to pass.
:return: a base64 automapper type
:example: A.column("data").base64()
"""
from spark_auto_mapper.data_types.base64 import (
AutoMapperBase64DataType,
)

return cast(
AutoMapperTextLikeBase,
AutoMapperBase64DataType(column=self),
)

# override this if your inherited class has a defined schema
# noinspection PyMethodMayBeStatic
def get_schema(
Expand Down
11 changes: 11 additions & 0 deletions spark_auto_mapper/helpers/automapper_helpers.py
Expand Up @@ -7,6 +7,7 @@
from pyspark.sql import Column

from spark_auto_mapper.data_types.array_distinct import AutoMapperArrayDistinctDataType
from spark_auto_mapper.data_types.base64 import AutoMapperBase64DataType
from spark_auto_mapper.data_types.exists import AutoMapperExistsDataType
from spark_auto_mapper.data_types.nested_array_filter import (
AutoMapperNestedArrayFilterDataType,
Expand Down Expand Up @@ -765,3 +766,13 @@ def unix_timestamp(value: AutoMapperNumberInputType) -> AutoMapperUnixTimestampT
:return: a join automapper type
"""
return AutoMapperUnixTimestampType(value=value)

@staticmethod
def base64(column: AutoMapperColumnOrColumnLikeType) -> AutoMapperBase64DataType:
"""
Computes the BASE64 encoding and returns it as a string
:param column: column whose contents to use
:return: a base64 automapper type
"""
return AutoMapperBase64DataType(column=column)
Empty file added tests/base64/__init__.py
Empty file.
55 changes: 55 additions & 0 deletions tests/base64/test_automapper_base64.py
@@ -0,0 +1,55 @@
from typing import Dict

from pyspark.sql import SparkSession, Column, DataFrame
from pyspark.sql.functions import base64

# noinspection PyUnresolvedReferences
from pyspark.sql.functions import col
from spark_auto_mapper.helpers.expression_comparer import assert_compare_expressions

from spark_auto_mapper.automappers.automapper import AutoMapper
from spark_auto_mapper.helpers.automapper_helpers import AutoMapperHelpers as A


def test_auto_mapper_base64(spark_session: SparkSession) -> None:
# Arrange
spark_session.createDataFrame(
[
(1, "This is data 1"),
(2, "This is data 2"),
],
["id", "data"],
).createOrReplaceTempView("responses")

source_df: DataFrame = spark_session.table("responses")

df = source_df.select("id")
df.createOrReplaceTempView("content")

# Act
mapper = AutoMapper(view="content", source_view="responses", keys=["id"]).columns(
encoded_column=A.base64(A.column("data"))
)

assert isinstance(mapper, AutoMapper)
sql_expressions: Dict[str, Column] = mapper.get_column_specs(source_df=source_df)
for column_name, sql_expression in sql_expressions.items():
print(f"{column_name}: {sql_expression}")

assert_compare_expressions(
sql_expressions["encoded_column"], base64(col("b.data")).alias("encoded_column")
)

result_df: DataFrame = mapper.transform(df=df)

# Assert
result_df.printSchema()
result_df.show()
assert (
result_df.where("id == 1").select("encoded_column").collect()[0][0]
== "VGhpcyBpcyBkYXRhIDE="
)
assert (
result_df.where("id == 2").select("encoded_column").collect()[0][0]
== "VGhpcyBpcyBkYXRhIDI="
)

0 comments on commit 6504217

Please sign in to comment.