# Setup & Import libraries

In [3]:
!pip install pyspark -q

[K     |████████████████████████████████| 281.3 MB 45 kB/s 
[K     |████████████████████████████████| 199 kB 56.6 MB/s 
[?25h  Building wheel for pyspark (setup.py) ... [?25l[?25hdone


In [4]:
import sys

from pyspark.sql import SparkSession
from pyspark.sql.functions import count

# Build a Spark Session using SparkSession APIs.

In [5]:
spark = (SparkSession.builder.appName("PythonMnMCount").getOrCreate())

# Load M&M dataset

In [10]:
dataset_path = "mncount.csv"

In [11]:
mnm_df = (spark.read.format("csv").option("header", "true").option("inferSchema", "true").load(dataset_path))

In [13]:
mnm_df.show(5)

+-----+------+-----+
|State| Color|Count|
+-----+------+-----+
|   TX|   Red|   20|
|   NV|  Blue|   66|
|   CO|  Blue|   79|
|   OR|  Blue|   71|
|   WA|Yellow|   93|
+-----+------+-----+
only showing top 5 rows



In [14]:
mnm_df.limit(5).toPandas()

Unnamed: 0,State,Color,Count
0,TX,Red,20
1,NV,Blue,66
2,CO,Blue,79
3,OR,Blue,71
4,WA,Yellow,93


# Group each state and M&M color count and Ordering in descending order

In [15]:
count_mnm_df = (mnm_df
            .select("State", "Color", "Count")
            .groupBy("State", "Color")
            .agg(count("Count").alias("Total"))
            .orderBy("Total", ascending=False))

In [17]:
count_mnm_df.show(n=60, truncate=False)

+-----+------+-----+
|State|Color |Total|
+-----+------+-----+
|CA   |Yellow|1807 |
|WA   |Green |1779 |
|OR   |Orange|1743 |
|TX   |Green |1737 |
|TX   |Red   |1725 |
|CA   |Green |1723 |
|CO   |Yellow|1721 |
|CA   |Brown |1718 |
|CO   |Green |1713 |
|NV   |Orange|1712 |
|TX   |Yellow|1703 |
|NV   |Green |1698 |
|AZ   |Brown |1698 |
|WY   |Green |1695 |
|CO   |Blue  |1695 |
|NM   |Red   |1690 |
|AZ   |Orange|1689 |
|NM   |Yellow|1688 |
|NM   |Brown |1687 |
|UT   |Orange|1684 |
|NM   |Green |1682 |
|UT   |Red   |1680 |
|AZ   |Green |1676 |
|NV   |Yellow|1675 |
|NV   |Blue  |1673 |
|WA   |Red   |1671 |
|WY   |Red   |1670 |
|WA   |Brown |1669 |
|NM   |Orange|1665 |
|WY   |Blue  |1664 |
|WA   |Yellow|1663 |
|WA   |Orange|1658 |
|CA   |Orange|1657 |
|NV   |Brown |1657 |
|CA   |Red   |1656 |
|CO   |Brown |1656 |
|UT   |Blue  |1655 |
|AZ   |Yellow|1654 |
|TX   |Orange|1652 |
|AZ   |Red   |1648 |
|OR   |Blue  |1646 |
|UT   |Yellow|1645 |
|OR   |Red   |1645 |
|CO   |Orange|1642 |
|TX   |Brown 

In [20]:
print(f"Total Rows (state): {count_mnm_df.count()}")

Total Rows (state): 60


# Aggregate for a particular state, e.g. CA

In [21]:
ca_count_mnm_df = (mnm_df
        .select("State", "Color", "Count")
        .where(mnm_df.State == "CA")
        .groupBy("State", "Color")
        .agg(count("Count").alias("Total"))
        .orderBy("Total", ascending=False))

In [23]:
ca_count_mnm_df.show(n=10, truncate=False)

+-----+------+-----+
|State|Color |Total|
+-----+------+-----+
|CA   |Yellow|1807 |
|CA   |Green |1723 |
|CA   |Brown |1718 |
|CA   |Orange|1657 |
|CA   |Red   |1656 |
|CA   |Blue  |1603 |
+-----+------+-----+



In [24]:
spark.stop()