diff --git a/spark_auto_mapper/data_types/base64.py b/spark_auto_mapper/data_types/base64.py new file mode 100644 index 0000000..60c3f03 --- /dev/null +++ b/spark_auto_mapper/data_types/base64.py @@ -0,0 +1,42 @@ +from typing import List, Optional, Union + +from pyspark.sql import Column, DataFrame +from pyspark.sql.functions import base64 + +from spark_auto_mapper.data_types.data_type_base import AutoMapperDataTypeBase +from spark_auto_mapper.data_types.text_like_base import AutoMapperTextLikeBase +from spark_auto_mapper.type_definitions.wrapper_types import ( + AutoMapperColumnOrColumnLikeType, +) + + +class AutoMapperBase64DataType(AutoMapperTextLikeBase): + """ + Computes the BASE64 encoding and returns it as a string + """ + + def __init__(self, column: AutoMapperColumnOrColumnLikeType): + super().__init__() + + self.column: AutoMapperColumnOrColumnLikeType = column + + def get_column_spec( + self, + source_df: Optional[DataFrame], + current_column: Optional[Column], + parent_columns: Optional[List[Column]], + ) -> Column: + column_spec = base64( + self.column.get_column_spec( + source_df=source_df, + current_column=current_column, + parent_columns=parent_columns, + ) + ) + return column_spec + + @property + def children( + self, + ) -> Union[AutoMapperDataTypeBase, List[AutoMapperDataTypeBase]]: + return self.column diff --git a/spark_auto_mapper/data_types/data_type_base.py b/spark_auto_mapper/data_types/data_type_base.py index a957581..93a1964 100644 --- a/spark_auto_mapper/data_types/data_type_base.py +++ b/spark_auto_mapper/data_types/data_type_base.py @@ -483,6 +483,25 @@ def join_using_delimiter( AutoMapperJoinUsingDelimiterDataType(column=self, delimiter=delimiter), ) + # noinspection PyMethodMayBeStatic + def base64(self: _TAutoMapperDataType) -> "AutoMapperTextLikeBase": + """ + Computes the BASE64 encoding of the column + + + :param self: Set by Python. No need to pass. + :return: a base64 automapper type + :example: A.column("data").base64() + """ + from spark_auto_mapper.data_types.base64 import ( + AutoMapperBase64DataType, + ) + + return cast( + AutoMapperTextLikeBase, + AutoMapperBase64DataType(column=self), + ) + # override this if your inherited class has a defined schema # noinspection PyMethodMayBeStatic def get_schema( diff --git a/spark_auto_mapper/helpers/automapper_helpers.py b/spark_auto_mapper/helpers/automapper_helpers.py index 3ad331a..8bacc46 100644 --- a/spark_auto_mapper/helpers/automapper_helpers.py +++ b/spark_auto_mapper/helpers/automapper_helpers.py @@ -7,6 +7,7 @@ from pyspark.sql import Column from spark_auto_mapper.data_types.array_distinct import AutoMapperArrayDistinctDataType +from spark_auto_mapper.data_types.base64 import AutoMapperBase64DataType from spark_auto_mapper.data_types.exists import AutoMapperExistsDataType from spark_auto_mapper.data_types.nested_array_filter import ( AutoMapperNestedArrayFilterDataType, @@ -765,3 +766,13 @@ def unix_timestamp(value: AutoMapperNumberInputType) -> AutoMapperUnixTimestampT :return: a join automapper type """ return AutoMapperUnixTimestampType(value=value) + + @staticmethod + def base64(column: AutoMapperColumnOrColumnLikeType) -> AutoMapperBase64DataType: + """ + Computes the BASE64 encoding and returns it as a string + + :param column: column whose contents to use + :return: a base64 automapper type + """ + return AutoMapperBase64DataType(column=column) diff --git a/tests/base64/__init__.py b/tests/base64/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/base64/test_automapper_base64.py b/tests/base64/test_automapper_base64.py new file mode 100644 index 0000000..78256eb --- /dev/null +++ b/tests/base64/test_automapper_base64.py @@ -0,0 +1,55 @@ +from typing import Dict + +from pyspark.sql import SparkSession, Column, DataFrame +from pyspark.sql.functions import base64 + +# noinspection PyUnresolvedReferences +from pyspark.sql.functions import col +from spark_auto_mapper.helpers.expression_comparer import assert_compare_expressions + +from spark_auto_mapper.automappers.automapper import AutoMapper +from spark_auto_mapper.helpers.automapper_helpers import AutoMapperHelpers as A + + +def test_auto_mapper_base64(spark_session: SparkSession) -> None: + # Arrange + spark_session.createDataFrame( + [ + (1, "This is data 1"), + (2, "This is data 2"), + ], + ["id", "data"], + ).createOrReplaceTempView("responses") + + source_df: DataFrame = spark_session.table("responses") + + df = source_df.select("id") + df.createOrReplaceTempView("content") + + # Act + mapper = AutoMapper(view="content", source_view="responses", keys=["id"]).columns( + encoded_column=A.base64(A.column("data")) + ) + + assert isinstance(mapper, AutoMapper) + sql_expressions: Dict[str, Column] = mapper.get_column_specs(source_df=source_df) + for column_name, sql_expression in sql_expressions.items(): + print(f"{column_name}: {sql_expression}") + + assert_compare_expressions( + sql_expressions["encoded_column"], base64(col("b.data")).alias("encoded_column") + ) + + result_df: DataFrame = mapper.transform(df=df) + + # Assert + result_df.printSchema() + result_df.show() + assert ( + result_df.where("id == 1").select("encoded_column").collect()[0][0] + == "VGhpcyBpcyBkYXRhIDE=" + ) + assert ( + result_df.where("id == 2").select("encoded_column").collect()[0][0] + == "VGhpcyBpcyBkYXRhIDI=" + )