In [0]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, lit, unbase64, base64, aes_encrypt, aes_decrypt, when, length
from pyspark.sql.types import StringType
import base64 as py_base64

class EncryptionHelper:
    def __init__(self, dbutils):
        self.dbutils = dbutils
        secret_key = dbutils.secrets.get(scope='secret-scope', key='encryption_key')
        self.encryption_key = py_base64.b64decode(secret_key)
        self.aes_mode = 'ECB'
        self.aes_padding = 'PKCS'

    def encrypt_dataframe(self, df: DataFrame, columns_to_encrypt: tuple) -> DataFrame:
        df_columns = df.columns
        found_columns = []
        skipped_columns = []

        for column in columns_to_encrypt:
            if column in df_columns:
                try:
                    df = df.withColumn(
                        column,
                        base64(aes_encrypt(col(column).cast(StringType()), lit(self.encryption_key), lit(self.aes_mode), lit(self.aes_padding)))
                    )
                    found_columns.append(column)
                except Exception as e:
                    print(f"Exception thrown while encrypting column '{column}': {e}")
            else:
                skipped_columns.append(column)
        
        return df

    def decrypt_dataframe(self, df: DataFrame, columns_to_decrypt: list) -> DataFrame:
        b64 = "^[A-Za-z0-9+/]*={0,2}$"
        out = df
        for c in columns_to_decrypt:
            looks_b64 = (
                col(c).isNotNull() &
                (length(col(c)) > 0) &
                ((length(col(c)) % 4) == 0) &
                col(c).rlike(b64)
            )
            out = out.withColumn(
                c,
                when(
                    looks_b64,
                    aes_decrypt(
                        unbase64(col(c)),
                        lit(self.encryption_key),
                        lit(self.aes_mode),
                        lit(self.aes_padding)
                    ).cast('string')
                ).otherwise(col(c))
            )
        return out