In [1]:
from pyspark import SparkConf
from pyspark.sql import SparkSession

conf = (
    SparkConf()
    .setMaster("local[*]")
    .setAppName("testing")
    .set("spark.sql.shuffle.partitions", 1)
)
spark = SparkSession.builder.config(conf=conf).getOrCreate()
spark

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/07/15 08:36:07 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
from pyspark.sql import Row

df = spark.createDataFrame([
    Row(age=12, name="Damian"),
    Row(age=15, name="Jake"),
    Row(age=18, name="Dominic"),
    Row(age=20, name="John"),
    Row(age=27, name="Jerry"),
    Row(age=101, name="Jerry's Grandpa"),
])
df

DataFrame[age: bigint, name: string]

In [4]:
from pyspark.ml.feature import Bucketizer
from pyspark.sql import DataFrame as SparkDataFrame
from typing import Dict, Optional, Any
from pyspark.sql import functions as sf
from pyspark.sql.column import Column
from itertools import chain

class AgeToRange(Bucketizer):
    def __init__(self, *args, **kwargs):
        super(AgeToRange, self).__init__(*args, **kwargs)

        self.setSplits([-float("inf"), 0, 12, 18, 25, 70, float("inf")])

        self._mapping: Dict[int, Column] = {
            0: sf.lit("Not yet born"),
            1: sf.lit("Child"),
            2: sf.lit("Teenager"),
            3: sf.lit("Young adulthood"),
            4: sf.lit("Adult"),
            5: sf.lit("Adult"),
        }

        assert len(self._mapping) == len(self.getSplits()) - 1

    def transform(self, dataset: SparkDataFrame, params: Optional[Any] = None) -> SparkDataFrame:
        bucketed: SparkDataFrame = super().transform(dataset, params)
        buckets: Column = bucketed[self.getOutputCol()]

        # Map ranges
        range_map = chain(*self._mapping.items())
        range_mapper = sf.create_map([sf.lit(x) for x in range_map])
        with_ranges = bucketed.withColumn(self.getOutputCol(), range_mapper[buckets])

        return with_ranges

In [5]:
AgeToRange(inputCol="age", outputCol="phase").transform(df).show()

+---+---------------+---------------+
|age|           name|          phase|
+---+---------------+---------------+
| 12|         Damian|       Teenager|
| 15|           Jake|       Teenager|
| 18|        Dominic|Young adulthood|
| 20|           John|Young adulthood|
| 27|          Jerry|          Adult|
|101|Jerry's Grandpa|          Adult|
+---+---------------+---------------+



In [None]:
# Benchmark counterexample using: UDF's & JOINs