In [114]:
from typing import *
from enum import Enum

from pydantic import BaseModel, conlist, StrictStr, StrictBool, StrictInt, StrictFloat, model_validator

In [115]:
class GenericOperator(str, Enum):
    EQUAL = "="
    DOUBLEEQUAL = "=="
    NOTEQUAL = "!="

    def __str__(self):
        return self._value_


class NumericalOperator(str, Enum):
    LESS = "<"
    LESSEQUAL = "<="
    MORE = ">"
    MOREEQUAL = ">="

    def __str__(self):
        return self._value_


class CategoricalOperator(str, Enum):
    IN = "in"
    NOTIN = "not in"

    def __str__(self):
        return self._value_


class Filter(BaseModel):
    col_name: StrictStr
    operator: Union[GenericOperator, NumericalOperator, CategoricalOperator]
    value: Union[
        list[Union[StrictStr, StrictBool, StrictInt]],
        StrictStr,
        StrictInt,
        StrictFloat,
        StrictBool,
    ]

    @model_validator(mode='after')
    def check_passwords_match(self) -> 'Filter':
        if type(self.operator) == NumericalOperator and type(self.value) not in [int, float]:
            raise ValueError(f'Value ({self.value}) not allowed for numerical operator ({self.operator})')
        elif type(self.operator) == CategoricalOperator and type(self.value) not in [list]:
            raise ValueError(f'Value ({self.value}) not allowed for categorical operator ({self.operator})')
        elif type(self.operator) == GenericOperator and type(self.value) in [list]:
            raise ValueError(f'Value ({self.value}) not allowed for generic operator ({self.operator})')

        return self

In [116]:
Filter(
    col_name="nome_colonna",
    operator="in",
    value=[4, 5],
)

Filter(col_name='nome_colonna', operator=<CategoricalOperator.IN: 'in'>, value=[4, 5])

In [117]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.getOrCreate()

df = spark.createDataFrame(
        data=[
            (348272371, "2023-01-01", 5.50, "shopping", True),
            (348272371, "2023-01-01", 6.10, "salute", False),
            (348272371, "2023-01-01", 8.20, "trasporti", False),
            (348272371, "2023-01-01", 1.50, "trasporti", True),
            (348272371, "2023-01-06", 20.20, "shopping", False),
            (348272371, "2023-01-06", 43.00, "shopping", True),
            (348272371, "2023-01-06", 72.20, "shopping", False),
            (234984832, "2023-01-01", 15.34, "salute", True),
            (234984832, "2023-01-01", 36.22, "salute", True),
            (234984832, "2023-01-01", 78.35, "salute", False),
            (234984832, "2023-01-02", 2.20, "trasporti", True),
        ],
        schema=[
            "ID_BIC_CLIENTE",
            "DATA_TRANSAZIONE",
            "IMPORTO",
            "CA_CATEGORY_LIV0",
            "IS_CARTA",
        ],
    )

df.show()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/08/10 16:54:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
[Stage 0:>                                                          (0 + 1) / 1]

+--------------+----------------+-------+----------------+--------+
|ID_BIC_CLIENTE|DATA_TRANSAZIONE|IMPORTO|CA_CATEGORY_LIV0|IS_CARTA|
+--------------+----------------+-------+----------------+--------+
|     348272371|      2023-01-01|    5.5|        shopping|    true|
|     348272371|      2023-01-01|    6.1|          salute|   false|
|     348272371|      2023-01-01|    8.2|       trasporti|   false|
|     348272371|      2023-01-01|    1.5|       trasporti|    true|
|     348272371|      2023-01-06|   20.2|        shopping|   false|
|     348272371|      2023-01-06|   43.0|        shopping|    true|
|     348272371|      2023-01-06|   72.2|        shopping|   false|
|     234984832|      2023-01-01|  15.34|          salute|    true|
|     234984832|      2023-01-01|  36.22|          salute|    true|
|     234984832|      2023-01-01|  78.35|          salute|   false|
|     234984832|      2023-01-02|    2.2|       trasporti|    true|
+--------------+----------------+-------+-------

                                                                                

In [118]:
df.filter(F.col("IMPORTO") == 5.50).show()

+--------------+----------------+-------+----------------+--------+
|ID_BIC_CLIENTE|DATA_TRANSAZIONE|IMPORTO|CA_CATEGORY_LIV0|IS_CARTA|
+--------------+----------------+-------+----------------+--------+
|     348272371|      2023-01-01|    5.5|        shopping|    true|
+--------------+----------------+-------+----------------+--------+



## Enums 

In [119]:
from pydantic import model_validator

class GenericOperator(str, Enum):
    EQUAL = "="
    DOUBLEEQUAL = "=="
    NOTEQUAL = "!="

    def __str__(self):
        return self._value_


class NumericalOperator(str, Enum):
    LESS = "<"
    LESSEQUAL = "<="
    MORE = ">"
    MOREEQUAL = ">="

    def __str__(self):
        return self._value_


class CategoricalOperator(str, Enum):
    IN = "in"
    NOTIN = "not in"

    def __str__(self):
        return self._value_



from typing_extensions import Annotated
from pydantic import BaseModel, BeforeValidator, AfterValidator


def parse_operator(operator_str: str) -> Union[GenericOperator, NumericalOperator, CategoricalOperator]:
    for operator in [*GenericOperator, *NumericalOperator, *CategoricalOperator]:
        if operator_str == operator.value:
            return operator
    raise ValueError("Operator not allowed")

@dataclass
class Filter:
    operator: Annotated[
        str,
        AfterValidator(lambda v: parse_operator(v)),
    ]
    

Filter("in")

Filter(operator=<CategoricalOperator.IN: 'in'>)

23/08/10 16:54:59 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
