In [None]:
import numpy as np
import polars as pl
from typing import Callable, Dict, Type, Union, List
from abc import ABC, abstractmethod


class BasePolluter(ABC):
    def __init__(self, transformation: Callable):
        """
        Base class for data polluters.

        Args:
            transformation: A callable that applies a transformation to a value.
        """
        self.transformation = transformation
        self.type_mapping = self._get_type_mapping()

    @abstractmethod
    def _get_type_mapping(self) -> Dict[Type[pl.DataType], Type[pl.DataType]]:
        """
        Define the mapping of input types to output types for this polluter.
        Must be implemented by subclasses.

        Returns:
            Dict mapping input Polars datatypes to output Polars datatypes
        """
        pass

    def _validate_column_type(self, input_type: pl.DataType) -> pl.DataType:
        """
        Validate that the input type is supported and get its corresponding output type.

        Args:
            input_type: The Polars datatype of the input column

        Returns:
            The expected output Polars datatype

        Raises:
            ValueError: If the input type is not supported by this polluter
        """
        for supported_input_type, output_type in self.type_mapping.items():
            if isinstance(input_type, supported_input_type):
                return output_type
        raise ValueError(
            f"Input type {input_type} is not supported by this polluter. "
            f"Supported types: {list(self.type_mapping.keys())}"
        )

    def apply(
        self,
        df: pl.DataFrame,
        target_columns: Union[str, List[str]],
        level: str = "column",
    ) -> pl.DataFrame:
        """
        Applies the pollution to the data.

        Args:
            df: Input DataFrame
            target_columns: Column(s) to apply transformation to
            level: Level of application ('column', 'cell', or 'row')

        Returns:
            Modified DataFrame
        """
        df = df.clone()

        # Convert single column to list
        if isinstance(target_columns, str):
            target_columns = [target_columns]

        # Validate all column types before proceeding
        for column in target_columns:
            input_type = df[column].dtype
            self._validate_column_type(input_type)

        if level == "column":
            for column in target_columns:
                df = self._apply_column(df, column)
        elif level == "cell":
            pass
        elif level == "row":
            pass
        else:
            raise ValueError(f"Invalid level: {level}. Must be 'column', 'cell', or 'row'")

        return df

    def _apply_column(self, df: pl.DataFrame, column: str) -> pl.DataFrame:
        """Apply transformation to the entire column."""
        input_type = df[column].dtype
        output_type = self._validate_column_type(input_type)

        return df.with_columns(
            pl.col(column)
            .map_elements(self.transformation, return_dtype=output_type)
            .alias(column)
        )

    def _apply_row(self, df: pl.DataFrame, column: str) -> pl.DataFrame:
        """Apply transformation to specific rows."""
        # Implementation will depend on specific needs
        pass

    def _apply_cell(self, df: pl.DataFrame, column: str, ratio: float = 0.1) -> pl.DataFrame:
        """Apply transformation to random cells in the column."""
        # Implementation will depend on specific needs
        pass


# Example subclasses with different type mappings:


class NumericPolluter(BasePolluter):
    """Polluter that handles numeric types and preserves their type."""

    def _get_type_mapping(self) -> Dict[Type[pl.DataType], Type[pl.DataType]]:
        return {pl.Int64: pl.Int64, pl.Float64: pl.Float64}


class StringifyPolluter(BasePolluter):
    """Polluter that converts any supported input to strings."""

    def _get_type_mapping(self) -> Dict[Type[pl.DataType], Type[pl.DataType]]:
        return {pl.Int64: pl.Utf8, pl.Float64: pl.Utf8, pl.Utf8: pl.Utf8, pl.Boolean: pl.Utf8}


class FloatPolluter(BasePolluter):
    """Polluter that converts numeric inputs to float."""

    def _get_type_mapping(self) -> Dict[Type[pl.DataType], Type[pl.DataType]]:
        return {pl.Int64: pl.Float64, pl.Float64: pl.Float64}


# Example usage:
if __name__ == "__main__":
    # Create sample data
    df = pl.DataFrame(
        {
            "int_col": [1, 2, 3, 4, 5],
            "float_col": [1.1, 2.2, 3.3, 4.4, 5.5],
            "str_col": ["a", "b", "c", "d", "e"],
        }
    )

    # Example numeric polluter that doubles values
    numeric_polluter = NumericPolluter(lambda x: x * 2)
    result_numeric = numeric_polluter.apply(df, ["int_col", "float_col"])
    print("Numeric pollution result:")
    print(result_numeric)

    # Example stringify polluter that adds prefix
    stringify_polluter = StringifyPolluter(lambda x: f"prefix_{x}")
    result_string = stringify_polluter.apply(df, ["int_col", "str_col"])
    print("\nStringify pollution result:")
    print(result_string)

    # Example float polluter that adds 0.5
    float_polluter = FloatPolluter(lambda x: float(x) + 0.5)
    result_float = float_polluter.apply(df, ["int_col", "float_col"])
    print("\nFloat pollution result:")
    print(result_float)

Numeric pollution result:
shape: (5, 3)
┌─────────┬───────────┬─────────┐
│ int_col ┆ float_col ┆ str_col │
│ ---     ┆ ---       ┆ ---     │
│ i64     ┆ f64       ┆ str     │
╞═════════╪═══════════╪═════════╡
│ 2       ┆ 2.2       ┆ a       │
│ 4       ┆ 4.4       ┆ b       │
│ 6       ┆ 6.6       ┆ c       │
│ 8       ┆ 8.8       ┆ d       │
│ 10      ┆ 11.0      ┆ e       │
└─────────┴───────────┴─────────┘

Stringify pollution result:
shape: (5, 3)
┌──────────┬───────────┬──────────┐
│ int_col  ┆ float_col ┆ str_col  │
│ ---      ┆ ---       ┆ ---      │
│ str      ┆ f64       ┆ str      │
╞══════════╪═══════════╪══════════╡
│ prefix_1 ┆ 1.1       ┆ prefix_a │
│ prefix_2 ┆ 2.2       ┆ prefix_b │
│ prefix_3 ┆ 3.3       ┆ prefix_c │
│ prefix_4 ┆ 4.4       ┆ prefix_d │
│ prefix_5 ┆ 5.5       ┆ prefix_e │
└──────────┴───────────┴──────────┘

Float pollution result:
shape: (5, 3)
┌─────────┬───────────┬─────────┐
│ int_col ┆ float_col ┆ str_col │
│ ---     ┆ ---       ┆ ---     │
│ f64 

Expr.map_elements is significantly slower than the native expressions API.
Only use if you absolutely CANNOT implement your logic otherwise.
Replace this expression...
  - pl.col("int_col").map_elements(lambda x: ...)
with this one instead:
  + pl.col("int_col") * 2

  pl.col(column)
Expr.map_elements is significantly slower than the native expressions API.
Only use if you absolutely CANNOT implement your logic otherwise.
Replace this expression...
  - pl.col("float_col").map_elements(lambda x: ...)
with this one instead:
  + pl.col("float_col") * 2

  pl.col(column)
Expr.map_elements is significantly slower than the native expressions API.
Only use if you absolutely CANNOT implement your logic otherwise.
Replace this expression...
  - pl.col("int_col").map_elements(lambda x: ...)
with this one instead:
  + pl.col("int_col").cast(pl.Float64) + 0.5

  pl.col(column)
Expr.map_elements is significantly slower than the native expressions API.
Only use if you absolutely CANNOT implement you

In [2]:
class ScientificNotationPolluter(BasePolluter):
    def __init__(self):
        super().__init__(self.scientific_notation_transform)

    @staticmethod
    def scientific_notation_transform(value):
        if isinstance(value, (int, float)):
            return f"{value:.2e}"
        return value


class GaussianNoisePolluter(BasePolluter):
    def __init__(self, mean: float = 0, std_dev: float = 1):
        self.mean = mean
        self.std_dev = std_dev
        super().__init__(self.add_gaussian_noise)

    def add_gaussian_noise(self, value):
        if isinstance(value, (int, float)):
            return value + np.random.normal(self.mean, self.std_dev)
        return value


class RoundingPolluter(BasePolluter):
    def __init__(self, decimal_places: int = 0):
        self.decimal_places = decimal_places
        super().__init__(self.round_transform)

    def round_transform(self, value):
        if isinstance(value, float):
            return round(value, self.decimal_places)
        return value

In [3]:
# Create a sample Polars DataFrame
df = pl.DataFrame(
    {
        "id": [1, 2, 3, 4],
        "value": [1234.56, 2345.67, 3456.78, 4567.89],
        "description": ["a", "b", "c", "d"],
    }
)

In [4]:
# Apply scientific notation pollution to the entire 'value' column
scientific_polluter = ScientificNotationPolluter()
polluted_df = scientific_polluter.apply(df, target_column="value", level="column")
print(polluted_df)

gaussian_polluter = GaussianNoisePolluter(mean=0, std_dev=0.1)
polluted_df = gaussian_polluter.apply(df, target_column="value", level="column")
print(polluted_df)

rounding_polluter = RoundingPolluter(decimal_places=1)
polluted_df = rounding_polluter.apply(df, target_column="value", level="column")
print(polluted_df)

shape: (4, 3)
┌─────┬──────────┬─────────────┐
│ id  ┆ value    ┆ description │
│ --- ┆ ---      ┆ ---         │
│ i64 ┆ str      ┆ str         │
╞═════╪══════════╪═════════════╡
│ 1   ┆ 1.23e+03 ┆ a           │
│ 2   ┆ 2.35e+03 ┆ b           │
│ 3   ┆ 3.46e+03 ┆ c           │
│ 4   ┆ 4.57e+03 ┆ d           │
└─────┴──────────┴─────────────┘
shape: (4, 3)
┌─────┬───────┬─────────────┐
│ id  ┆ value ┆ description │
│ --- ┆ ---   ┆ ---         │
│ i64 ┆ str   ┆ str         │
╞═════╪═══════╪═════════════╡
│ 1   ┆ null  ┆ a           │
│ 2   ┆ null  ┆ b           │
│ 3   ┆ null  ┆ c           │
│ 4   ┆ null  ┆ d           │
└─────┴───────┴─────────────┘
shape: (4, 3)
┌─────┬───────┬─────────────┐
│ id  ┆ value ┆ description │
│ --- ┆ ---   ┆ ---         │
│ i64 ┆ str   ┆ str         │
╞═════╪═══════╪═════════════╡
│ 1   ┆ null  ┆ a           │
│ 2   ┆ null  ┆ b           │
│ 3   ┆ null  ┆ c           │
│ 4   ┆ null  ┆ d           │
└─────┴───────┴─────────────┘


In [5]:
import polars as pl

# Create sample data
df = pl.DataFrame({"A": [1, 2, 3, 4], "B": [10, 20, 30, 40], "C": [100, 200, 300, 400]})


# Define a sample function
def multiply_by_two(x):
    return x * 2


# 1. Apply to a specific column
def apply_to_column():
    # Using with_columns
    result = df.with_columns(pl.col("A").map_elements(multiply_by_two).alias("A_doubled"))

    # Alternative using map_elements
    result2 = df.select([pl.col("A").map_elements(multiply_by_two).alias("A_doubled")])

    return result, result2


# 2. Apply to all columns
def apply_to_all_columns():
    # Using with_columns and wildcard selector
    result = df.with_columns(pl.col("A").map_elements(multiply_by_two).alias("A_doubled"))

    return result


# 3. Apply to a specific row
def apply_to_row(row_idx):
    # Using rows() to get specific row as series
    row = df.row(row_idx)
    modified_row = pl.Series([multiply_by_two(val) for val in row])

    # To modify a specific row in the dataframe
    mask = pl.arange(0, df.height) == row_idx
    result = df.with_columns(
        [pl.when(mask).then(pl.col("*").map_elements(multiply_by_two)).otherwise(pl.col("*"))]
    )

    return modified_row, result


# Example usage
print("Original DataFrame:")
print(df)

print("\nApply to column 'A':")
col_result, col_result2 = apply_to_column()
print(col_result)

print("\nApply to row 1:")
_, row_result = apply_to_row(1)
print(row_result)


Original DataFrame:
shape: (4, 3)
┌─────┬─────┬─────┐
│ A   ┆ B   ┆ C   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1   ┆ 10  ┆ 100 │
│ 2   ┆ 20  ┆ 200 │
│ 3   ┆ 30  ┆ 300 │
│ 4   ┆ 40  ┆ 400 │
└─────┴─────┴─────┘

Apply to column 'A':
shape: (4, 4)
┌─────┬─────┬─────┬───────────┐
│ A   ┆ B   ┆ C   ┆ A_doubled │
│ --- ┆ --- ┆ --- ┆ ---       │
│ i64 ┆ i64 ┆ i64 ┆ i64       │
╞═════╪═════╪═════╪═══════════╡
│ 1   ┆ 10  ┆ 100 ┆ 2         │
│ 2   ┆ 20  ┆ 200 ┆ 4         │
│ 3   ┆ 30  ┆ 300 ┆ 6         │
│ 4   ┆ 40  ┆ 400 ┆ 8         │
└─────┴─────┴─────┴───────────┘

Apply to row 1:
shape: (4, 3)
┌─────┬─────┬─────┐
│ A   ┆ B   ┆ C   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1   ┆ 10  ┆ 100 │
│ 4   ┆ 40  ┆ 400 │
│ 3   ┆ 30  ┆ 300 │
│ 4   ┆ 40  ┆ 400 │
└─────┴─────┴─────┘


Expr.map_elements is significantly slower than the native expressions API.
Only use if you absolutely CANNOT implement your logic otherwise.
Replace this expression...
  - pl.col("A").map_elements(multiply_by_two)
with this one instead:
  + pl.col("A") * 2

  result = df.with_columns(pl.col("A").map_elements(multiply_by_two).alias("A_doubled"))
  result = df.with_columns(pl.col("A").map_elements(multiply_by_two).alias("A_doubled"))
Expr.map_elements is significantly slower than the native expressions API.
Only use if you absolutely CANNOT implement your logic otherwise.
Replace this expression...
  - pl.col("A").map_elements(multiply_by_two)
with this one instead:
  + pl.col("A") * 2

  result2 = df.select([pl.col("A").map_elements(multiply_by_two).alias("A_doubled")])
  result2 = df.select([pl.col("A").map_elements(multiply_by_two).alias("A_doubled")])
