- Author: Ben Du
- Date: 2020-06-18 08:41:46
- Title: The Case Statement and the when Function in Spark
- Slug: spark-dataframe-case-when
- Category: Computer Science
- Tags: programming, Scala, Spark, DataFrame, case, when, Spark SQL, functions

In [None]:
import pandas as pd
import findspark
findspark.init("/opt/spark")

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import *
from pyspark.sql.types import StructType
spark = SparkSession.builder.appName("PySpark Example") \
    .enableHiveSupport().getOrCreate()

In [2]:
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

val spark = SparkSession.builder()
    .master("local[2]")
    .appName("Spark Column Example")
    .config("spark.some.config.option", "some-value")
    .getOrCreate()

import spark.implicits._

org.apache.spark.sql.SparkSession$implicits$@561a187f

In [8]:
val df = spark.read.json("../../data/people.json")
df.show

+----+-------+
| age|   name|
+----+-------+
|null|Michael|
|  30|   Andy|
|  19| Justin|
+----+-------+



null

In [13]:
df.withColumn("null_gt",
    when($"age" >= 0, 1).otherwise(null)
).show

+----+-------+-------+
| age|   name|null_gt|
+----+-------+-------+
|null|Michael|   null|
|  30|   Andy|      1|
|  19| Justin|      1|
+----+-------+-------+



In [14]:
df.withColumn("null_lt",
    when($"age" <= 1000, 1).otherwise(null)
).show

+----+-------+-------+
| age|   name|null_lt|
+----+-------+-------+
|null|Michael|   null|
|  30|   Andy|      1|
|  19| Justin|      1|
+----+-------+-------+



In [4]:
df.select(when($"age".isNull, 0).when($"age" > 20 , 100).otherwise(10).alias("age")).show

+---+
|age|
+---+
|  0|
|100|
| 10|
+---+



In [5]:
df.select(when($"age".isNull, 0).alias("age")).show

+----+
| age|
+----+
|   0|
|null|
|null|
+----+



In [6]:
val df = Range(0, 10).toDF
df.show

+-----+
|value|
+-----+
|    0|
|    1|
|    2|
|    3|
|    4|
|    5|
|    6|
|    7|
|    8|
|    9|
+-----+



null

Notice the function `when` behaves like `if-else`.

In [7]:
df.withColumn("group",
    when($"value" <= 3, 0)
    .when($"value" <= 100, 1)
).show

+-----+-----+
|value|group|
+-----+-----+
|    0|    0|
|    1|    0|
|    2|    0|
|    3|    0|
|    4|    1|
|    5|    1|
|    6|    1|
|    7|    1|
|    8|    1|
|    9|    1|
+-----+-----+

