In [1]:
// Find total number of people living in the US and India
// Display population in below format

// US   | India  | Total
// ---------------------------
//      |       |

val countryDF=Seq(
    (1, "India"),
    (2, "US"),
    (3, "AU"),
    (4, "UK")
).toDF("country_id", "country_name")

val populationDF=Seq(
    (1, 1, 230000),
    (1, 2, 470000),
    (1, 3, 650000),
    (2, 1, 247000),
    (3, 1, 153000),
    (2, 2, 212000),
    (3, 2, 517000),
    (4, 1, 820000),
).toDF("country_id", "city_id", "population")

countryDF.show()
populationDF.show()


Intitializing Scala interpreter ...

Spark Web UI available at http://192.168.0.100:4040
SparkContext available as 'sc' (version = 3.4.1, master = local[*], app id = local-1704953514742)
SparkSession available as 'spark'


+----------+------------+
|country_id|country_name|
+----------+------------+
|         1|       India|
|         2|          US|
|         3|          AU|
|         4|          UK|
+----------+------------+

+----------+-------+----------+
|country_id|city_id|population|
+----------+-------+----------+
|         1|      1|    230000|
|         1|      2|    470000|
|         1|      3|    650000|
|         2|      1|    247000|
|         3|      1|    153000|
|         2|      2|    212000|
|         3|      2|    517000|
|         4|      1|    820000|
+----------+-------+----------+



countryDF: org.apache.spark.sql.DataFrame = [country_id: int, country_name: string]
populationDF: org.apache.spark.sql.DataFrame = [country_id: int, city_id: int ... 1 more field]


In [26]:
// DF Approach

val resultDF = countryDF.as("c").join(
        populationDF.filter($"country_id".isin(1,2)).groupBy($"country_id").sum("population"
         ).as("p"), $"c.country_id" === $"p.country_id", "inner"
).withColumnRenamed("sum(population)", "total_population"
).withColumn("US", when($"country_name"===lit("US"), $"total_population")
).withColumn("India", when($"country_name"===lit("India"), $"total_population")
).drop("country_id","country_name","country_id","total_population"
).agg(sum("US").as("US"),sum("India").as("India")
).withColumn("Total", $"US"+$"India")


resultDF.show(false)

+------+-------+-------+
|US    |India  |Total  |
+------+-------+-------+
|459000|1350000|1809000|
+------+-------+-------+



resultDF: org.apache.spark.sql.DataFrame = [US: bigint, India: bigint ... 1 more field]


In [40]:
// Spark SQL Approach

countryDF.createOrReplaceTempView("country")
populationDF.createOrReplaceTempView("population")

spark.sql("""
WITH population_cte AS (
    SELECT 
        country_id, 
        SUM(population) total_population
    FROM population 
    WHERE country_id in (1,2)
    GROUP BY country_id
   ),
   joined_cte AS (
       SELECT 
            SUM(CASE WHEN country_name='India' THEN total_population ELSE 0 END) as India,
            SUM(CASE WHEN country_name='US' THEN total_population ELSE 0 END) as US,
            India + US as Total
       FROM population_cte p INNER JOIN country c
       ON p.country_id=c.country_id
   )
   SELECT * FROM joined_cte
""").show(false)

+-------+------+-------+
|India  |US    |Total  |
+-------+------+-------+
|1350000|459000|1809000|
+-------+------+-------+

