## Count M&Ms

In [1]:
# Prerequisites
from pyspark.sql import SparkSession
from pyspark.sql.functions import * 

In [2]:
# Get SparkSession
spark = SparkSession.builder.master("local").getOrCreate()
print("Spark Version: ", spark.version)

Spark Version:  3.5.0


In [8]:
# Read CSV file and infer the schema
df_mnm = (spark
          .read
          .format("csv")
          .option("header","true")
          .option("inferSchema","true")
          .load("data/mnm_dataset.csv"))

display(df_mnm)

DataFrame[State: string, Color: string, Count: int]

In [9]:
df_mnm.show()

+-----+------+-----+
|State| Color|Count|
+-----+------+-----+
|   TX|   Red|   20|
|   NV|  Blue|   66|
|   CO|  Blue|   79|
|   OR|  Blue|   71|
|   WA|Yellow|   93|
|   WY|  Blue|   16|
|   CA|Yellow|   53|
|   WA| Green|   60|
|   OR| Green|   71|
|   TX| Green|   68|
|   NV| Green|   59|
|   AZ| Brown|   95|
|   WA|Yellow|   20|
|   AZ|  Blue|   75|
|   OR| Brown|   72|
|   NV|   Red|   98|
|   WY|Orange|   45|
|   CO|  Blue|   52|
|   TX| Brown|   94|
|   CO|   Red|   82|
+-----+------+-----+
only showing top 20 rows



In [10]:
print("Total number of rows: ", df_mnm.count())

Total number of rows:  99999


In [11]:
df_mnm.printSchema()

root
 |-- State: string (nullable = true)
 |-- Color: string (nullable = true)
 |-- Count: integer (nullable = true)



In [14]:
# Aggregate count of colors 
df_mnm_count = (df_mnm
                .select("State","Color","Count")
                .groupBy("Color")
                .agg(count('Count').alias("Total"))
                .orderBy("Total", ascending=False))

df_mnm_count.show()
print("# of different colors = ", df_mnm_count.count())

+------+-----+
| Color|Total|
+------+-----+
| Green|16928|
|Yellow|16796|
|Orange|16697|
|   Red|16619|
| Brown|16510|
|  Blue|16449|
+------+-----+

# of different colors =  6


In [16]:
# Aggregate counts for Washington State by filtering on State

df_mnm_wa = (df_mnm
            .select("State","Color","Count")
            .where(df_mnm.State == "WA")
            .groupBy("State", "Color")
            .agg(count('Count').alias("Total WA"))
            .orderBy("Total WA", ascending=False))

df_mnm_wa.show()

+-----+------+--------+
|State| Color|Total WA|
+-----+------+--------+
|   WA| Green|    1779|
|   WA|   Red|    1671|
|   WA| Brown|    1669|
|   WA|Yellow|    1663|
|   WA|Orange|    1658|
|   WA|  Blue|    1625|
+-----+------+--------+



In [18]:
df_mnm.head(5)

[Row(State='TX', Color='Red', Count=20),
 Row(State='NV', Color='Blue', Count=66),
 Row(State='CO', Color='Blue', Count=79),
 Row(State='OR', Color='Blue', Count=71),
 Row(State='WA', Color='Yellow', Count=93)]