- Author: Ben Du
- Date: 2020-06-17
- Title: The Case Statement and the when Function in Spark
- Slug: spark-dataframe-case-when
- Category: Computer Science
- Tags: programming, Scala, Spark, DataFrame, case, when, Spark SQL, functions

In [1]:
import pandas as pd
import findspark
findspark.init("/opt/spark")

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import *
from pyspark.sql.types import StructType
spark = SparkSession.builder.appName("Case/When") \
    .enableHiveSupport().getOrCreate()

In [5]:
df_p = pd.DataFrame({
    "age": [None, 30, 19],
    "name": ["Michael", "Andy", "Justin"]
})

In [7]:
df = spark.createDataFrame(df_p)
df.show()

+----+-------+
| age|   name|
+----+-------+
| NaN|Michael|
|30.0|   Andy|
|19.0| Justin|
+----+-------+



In [15]:
df.schema

StructType(List(StructField(age,DoubleType,true),StructField(name,StringType,true)))

In [12]:
df.filter(col("age").isNull()).show()

+---+----+
|age|name|
+---+----+
+---+----+



Column alias and postional columns can be used in group by in Spark SQL!!!

In [18]:
df.createOrReplaceTempView("df")

In [22]:
spark.sql("""
    select 
        case 
            when age > 20 then 1
            else 0
        end as age_group,
        count(*) as n
    from 
        df
    group by
        age_group
    """).show()

+---------+---+
|age_group|  n|
+---------+---+
|        1|  2|
|        0|  1|
+---------+---+



In [21]:
spark.sql("""
    select 
        case 
            when age > 20 then 1
            else 0
        end as age_group,
        count(*) as n
    from 
        df
    group by
        1
    """).show()

+---------+---+
|age_group|  n|
+---------+---+
|        1|  2|
|        0|  1|
+---------+---+



In [14]:
df.withColumn("null_gt",
    when(col("age") >= 0, 1).otherwise(None)
).show()

+----+-------+-------+
| age|   name|null_gt|
+----+-------+-------+
| NaN|Michael|      1|
|30.0|   Andy|      1|
|19.0| Justin|      1|
+----+-------+-------+



In [14]:
df.withColumn("null_lt",
    when($"age" <= 1000, 1).otherwise(null)
).show

+----+-------+-------+
| age|   name|null_lt|
+----+-------+-------+
|null|Michael|   null|
|  30|   Andy|      1|
|  19| Justin|      1|
+----+-------+-------+



In [4]:
df.select(when($"age".isNull, 0).when($"age" > 20 , 100).otherwise(10).alias("age")).show

+---+
|age|
+---+
|  0|
|100|
| 10|
+---+



In [5]:
df.select(when($"age".isNull, 0).alias("age")).show

+----+
| age|
+----+
|   0|
|null|
|null|
+----+



In [6]:
val df = Range(0, 10).toDF
df.show

+-----+
|value|
+-----+
|    0|
|    1|
|    2|
|    3|
|    4|
|    5|
|    6|
|    7|
|    8|
|    9|
+-----+



null

Notice the function `when` behaves like `if-else`.

In [7]:
df.withColumn("group",
    when($"value" <= 3, 0)
    .when($"value" <= 100, 1)
).show

+-----+-----+
|value|group|
+-----+-----+
|    0|    0|
|    1|    0|
|    2|    0|
|    3|    0|
|    4|    1|
|    5|    1|
|    6|    1|
|    7|    1|
|    8|    1|
|    9|    1|
+-----+-----+

