# PySpark Essentials

- Spark vs Pandas: Why and when to use Spark
- Introduction to Spark DataFrames: Schema, loading data, inspecting DataFrames
- Distributed computing basics in Spark: Partitions and transformations
- Advanced DataFrame operations:
    - Column and row manipulations
    - Complex filtering and conditional logic
    - GroupBy, aggregations
    - Joins and window functions
- Handling missing data and data types
- Integrating with SQL: Using Spark SQL queries with Python
  

## Setup
Create a Spark session

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    StructType,
    StructField,
    IntegerType,
    FloatType,
    StringType,
    DateType,
)
import pyspark.sql.functions as F

In [2]:
### Run locally, use four worker nodes. Use this for code development and small problems
spark = SparkSession.builder.master("local[4]").getOrCreate()

### Connect to a Spark cluster (e.g. start-connect-server.sh)
# SparkSession.builder.master("local[*]").getOrCreate().stop()
# spark = SparkSession.builder.remote("sc://localhost:15002").getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/09/17 19:09:15 WARN Utils: Your hostname, RCBM8368-DIII.local, resolves to a loopback address: 127.0.0.1; using 10.250.4.136 instead (on interface en0)
25/09/17 19:09:15 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/09/17 19:09:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
----------------------------------------
Exception occurred during processing of request from ('127.0.0.1', 54280)
Traceback (most recent call last):
  File "/Users/pmolnar/homebrew/Cellar/python@3.11/3.11.7/Frameworks/Python.framework/Versions/3.11/lib/python3.11/socketserver.py", line 317, in _handle_request_noblock
  

In [None]:
# spark.stop()

In [3]:
spark.getActiveSession()

In [4]:
from datetime import datetime, date
import pandas as pd
from pyspark.sql import Row

df = spark.createDataFrame([
    Row(a=1, b=2., c='string1', d=date(2000, 1, 1), e=datetime(2000, 1, 1, 12, 0)),
    Row(a=2, b=3., c='string2', d=date(2000, 2, 1), e=datetime(2000, 1, 2, 12, 0)),
    Row(a=4, b=5., c='string3', d=date(2000, 3, 1), e=datetime(2000, 1, 3, 12, 0))
])
df

DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]

In [5]:
df.count()

                                                                                

3

## Spark vs Pandas: Why and when to use Spark

## Distributed computing basics in Spark: Partitions and transformations

## Data

![](northwind-schema.jpeg)

In [6]:
! ls -l data/

total 648
-rw-r--r--@ 1 pmolnar  GSUAD\Domain Users     425 Sep 17 12:33 categories.csv
-rw-r--r--@ 1 pmolnar  GSUAD\Domain Users   11558 Sep 17 12:33 customers.csv
-rw-r--r--@ 1 pmolnar  GSUAD\Domain Users    4085 Sep 17 12:33 employees.csv
-rw-r--r--@ 1 pmolnar  GSUAD\Domain Users  295065 Sep 17 12:33 orders.csv
-rw-r--r--@ 1 pmolnar  GSUAD\Domain Users    4327 Sep 17 12:33 products.csv
-rw-r--r--@ 1 pmolnar  GSUAD\Domain Users    4007 Sep 17 12:33 suppliers.csv


In [7]:
data_path = "./data/"

# Method 1: Auto-inference (simple but slower)
print("=== Method 1: Auto-inference ===")
categories_auto = spark.read.csv(f"{data_path}categories.csv", header=True, inferSchema=True)
categories_auto.printSchema()
categories_auto.show(5)

=== Method 1: Auto-inference ===
root
 |-- CategoryID: integer (nullable = true)
 |-- CategoryName: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Picture: string (nullable = true)

+----------+--------------+--------------------+-------+
|CategoryID|  CategoryName|         Description|Picture|
+----------+--------------+--------------------+-------+
|         1|     Beverages|Soft drinks, coff...|     \x|
|         2|    Condiments|Sweet and savory ...|     \x|
|         3|   Confections|Desserts, candies...|     \x|
|         4|Dairy Products|             Cheeses|     \x|
|         5|Grains/Cereals|Breads, crackers,...|     \x|
+----------+--------------+--------------------+-------+
only showing top 5 rows


In [8]:
# Method 2: Define schema (faster and more control)
print("\n=== Method 2: Predefined Schema ===")

# Categories schema
categories_schema = StructType([
    StructField("CategoryID", IntegerType(), True),
    StructField("CategoryName", StringType(), True),
    StructField("Description", StringType(), True),
    StructField("Picture", StringType(), True)
])

categories_schema_df = spark.read.csv(f"{data_path}categories.csv", header=True, schema=categories_schema)
categories_schema_df.printSchema()
categories_schema_df.show(5)


=== Method 2: Predefined Schema ===
root
 |-- CategoryID: integer (nullable = true)
 |-- CategoryName: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Picture: string (nullable = true)

+----------+--------------+--------------------+-------+
|CategoryID|  CategoryName|         Description|Picture|
+----------+--------------+--------------------+-------+
|         1|     Beverages|Soft drinks, coff...|     \x|
|         2|    Condiments|Sweet and savory ...|     \x|
|         3|   Confections|Desserts, candies...|     \x|
|         4|Dairy Products|             Cheeses|     \x|
|         5|Grains/Cereals|Breads, crackers,...|     \x|
+----------+--------------+--------------------+-------+
only showing top 5 rows


In [9]:
# Method 3: Load as Pandas Dataframe 
print("\n=== Method 3: Load as Pandas Dataframe  ===")

cat_df = pd.read_csv(f"{data_path}categories.csv")
categories_pandas_df = spark.createDataFrame(cat_df)
categories_pandas_df.printSchema()
categories_pandas_df.show(5)


=== Method 3: Load as Pandas Dataframe  ===
root
 |-- CategoryID: long (nullable = true)
 |-- CategoryName: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Picture: string (nullable = true)

+----------+--------------+--------------------+-------+
|CategoryID|  CategoryName|         Description|Picture|
+----------+--------------+--------------------+-------+
|         1|     Beverages|Soft drinks, coff...|     \x|
|         2|    Condiments|Sweet and savory ...|     \x|
|         3|   Confections|Desserts, candies...|     \x|
|         4|Dairy Products|             Cheeses|     \x|
|         5|Grains/Cereals|Breads, crackers,...|     \x|
+----------+--------------+--------------------+-------+
only showing top 5 rows


### Define Schema

In [10]:
# Categories schema
categories_schema = StructType([
    StructField("CategoryID", IntegerType(), True),
    StructField("CategoryName", StringType(), True),
    StructField("Description", StringType(), True),
    StructField("Picture", StringType(), True)
])

# categories_schema_df = spark.read.csv(f"{data_path}categories.csv", header=True, schema=categories_schema)
# categories_schema_df.printSchema()
# categories_schema_df.show(5)

# Customers schema
customers_schema = StructType([
    StructField("CustomerID", StringType(), True),
    StructField("CompanyName", StringType(), True),
    StructField("ContactName", StringType(), True),
    StructField("ContactTitle", StringType(), True),
    StructField("Address", StringType(), True),
    StructField("City", StringType(), True),
    StructField("Region", StringType(), True),
    StructField("PostalCode", StringType(), True),
    StructField("Country", StringType(), True),
    StructField("Phone", StringType(), True),
    StructField("Fax", StringType(), True)
])

# Orders schema
orders_schema = StructType([
    StructField("OrderID", IntegerType(), True),
    StructField("CustomerID", StringType(), True),
    StructField("EmployeeID", IntegerType(), True),
    StructField("OrderDate", DateType(), True),
    StructField("RequiredDate", DateType(), True),
    StructField("ShippedDate", DateType(), True),
    StructField("ShipVia", IntegerType(), True),
    StructField("Freight", FloatType(), True),
    StructField("ShipName", StringType(), True),
    StructField("ShipAddress", StringType(), True),
    StructField("ShipCity", StringType(), True),
    StructField("ShipRegion", StringType(), True),
    StructField("ShipPostalCode", StringType(), True),
    StructField("ShipCountry", StringType(), True)
])

# Products schema
products_schema = StructType([
    StructField("ProductID", IntegerType(), True),
    StructField("ProductName", StringType(), True),
    StructField("SupplierID", IntegerType(), True),
    StructField("CategoryID", IntegerType(), True),
    StructField("QuantityPerUnit", StringType(), True),
    StructField("UnitPrice", FloatType(), True),
    StructField("UnitsInStock", IntegerType(), True),
    StructField("UnitsOnOrder", IntegerType(), True),
    StructField("ReorderLevel", IntegerType(), True),
    StructField("Discontinued", IntegerType(), True)
])


# Employees schema
employees_schema = StructType([
    StructField("EmployeeID", IntegerType(), True),
    StructField("LastName", StringType(), True),
    StructField("FirstName", StringType(), True),
    StructField("Title", StringType(), True),
    StructField("TitleOfCourtesy", StringType(), True),
    StructField("BirthDate", DateType(), True),
    StructField("HireDate", DateType(), True),
    StructField("Address", StringType(), True),
    StructField("City", StringType(), True),
    StructField("Region", StringType(), True),
    StructField("PostalCode", StringType(), True),
    StructField("Country", StringType(), True),
    StructField("HomePhone", StringType(), True),
    StructField("Extension", StringType(), True),
    StructField("Photo", StringType(), True),
    StructField("Notes", StringType(), True),
    StructField("ReportsTo", IntegerType(), True),
    StructField("PhotoPath", StringType(), True)
])


### Load Data

In [11]:
# Load all tables with schemas
print("\n=== Loading all tables with schemas ===")
categories = spark.read.csv(f"{data_path}categories.csv", header=True, schema=categories_schema)
customers = spark.read.csv(f"{data_path}customers.csv", header=True, schema=customers_schema)
orders = spark.read.csv(f"{data_path}orders.csv", header=True, schema=orders_schema)
products = spark.read.csv(f"{data_path}products.csv", header=True, schema=products_schema)
employees = spark.read.csv(f"{data_path}employees.csv", header=True, schema=employees_schema)


print("Categories:")
categories.printSchema()
print(f"Count: {categories.count()}")

print("Customers:")
customers.printSchema()
print(f"Count: {customers.count()}")

print("\nOrders:")
orders.printSchema()
print(f"Count: {orders.count()}")

print("\nProducts:")
products.printSchema()
print(f"Count: {products.count()}")

print("\Employees:")
products.printSchema()
print(f"Count: {employees.count()}")


=== Loading all tables with schemas ===
Categories:
root
 |-- CategoryID: integer (nullable = true)
 |-- CategoryName: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Picture: string (nullable = true)

Count: 8
Customers:
root
 |-- CustomerID: string (nullable = true)
 |-- CompanyName: string (nullable = true)
 |-- ContactName: string (nullable = true)
 |-- ContactTitle: string (nullable = true)
 |-- Address: string (nullable = true)
 |-- City: string (nullable = true)
 |-- Region: string (nullable = true)
 |-- PostalCode: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- Phone: string (nullable = true)
 |-- Fax: string (nullable = true)

Count: 91

Orders:
root
 |-- OrderID: integer (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- EmployeeID: integer (nullable = true)
 |-- OrderDate: date (nullable = true)
 |-- RequiredDate: date (nullable = true)
 |-- ShippedDate: date (nullable = true)
 |-- ShipVia: integer (nullable = tru

## Advanced DataFrame operations

### Column and row manipulations

In [None]:
from pyspark.sql.functions import col, when, upper, concat, lit, year, month, round, desc
# Alternatively...
import pyspark.sql.functions as F

Select Specific Columns

In [12]:
# Extract only customer ID, company name, and country
customers.select("CustomerID", "CompanyName", "Country").show(5)

+----------+--------------------+-------+
|CustomerID|         CompanyName|Country|
+----------+--------------------+-------+
|     ALFKI| Alfreds Futterkiste|Germany|
|     ANATR|Ana Trujillo Empa...| Mexico|
|     ANTON|Antonio Moreno Ta...| Mexico|
|     AROUT|     Around the Horn|     UK|
|     BERGS|  Berglunds snabbköp| Sweden|
+----------+--------------------+-------+
only showing top 5 rows


Add New Column with Literal Value

In [14]:
# Add a constant 'Active' status column to all customers
customers_with_status = customers.withColumn("Status", F.lit("Active"))
customers_with_status.select("CustomerID", "CompanyName", "Status").show(5)


+----------+--------------------+------+
|CustomerID|         CompanyName|Status|
+----------+--------------------+------+
|     ALFKI| Alfreds Futterkiste|Active|
|     ANATR|Ana Trujillo Empa...|Active|
|     ANTON|Antonio Moreno Ta...|Active|
|     AROUT|     Around the Horn|Active|
|     BERGS|  Berglunds snabbköp|Active|
+----------+--------------------+------+
only showing top 5 rows


Rename Columns

In [15]:
# Rename CustomerID to ID and CompanyName to Company
customers_renamed = customers.withColumnRenamed("CustomerID", "ID").withColumnRenamed("CompanyName", "Company")
customers_renamed.select("ID", "Company", "Country").show(5)


+-----+--------------------+-------+
|   ID|             Company|Country|
+-----+--------------------+-------+
|ALFKI| Alfreds Futterkiste|Germany|
|ANATR|Ana Trujillo Empa...| Mexico|
|ANTON|Antonio Moreno Ta...| Mexico|
|AROUT|     Around the Horn|     UK|
|BERGS|  Berglunds snabbköp| Sweden|
+-----+--------------------+-------+
only showing top 5 rows


Conditional Column Creation

In [19]:
# Categorize products as 'Expensive' if price > 20, else 'Affordable'
products_categorized = products.withColumn("PriceCategory", 
    F.when(F.col("UnitPrice") > 20, "Expensive").otherwise("Affordable"))
products_categorized.select("ProductName", "UnitPrice", "PriceCategory").show(5)


+--------------------+---------+-------------+
|         ProductName|UnitPrice|PriceCategory|
+--------------------+---------+-------------+
|                Chai|     18.0|   Affordable|
|               Chang|     19.0|   Affordable|
|       Aniseed Syrup|     10.0|   Affordable|
|Chef Anton's Caju...|     22.0|    Expensive|
|Chef Anton's Gumb...|    21.35|    Expensive|
+--------------------+---------+-------------+
only showing top 5 rows


String Manipulation

In [23]:
# Convert company names to uppercase and create full address
customers_formatted = customers.withColumn("CompanyUpper", F.upper(F.col("CompanyName"))) \
    .withColumn("FullAddress", F.concat(F.col("Address"), F.lit(", "), F.col("City"), F.lit(", "), F.col("Country")))
customers_formatted.select("CompanyUpper", "FullAddress").show(5, truncate=False)


+----------------------------------+--------------------------------------------------+
|CompanyUpper                      |FullAddress                                       |
+----------------------------------+--------------------------------------------------+
|ALFREDS FUTTERKISTE               |Obere Str. 57, Berlin, Germany                    |
|ANA TRUJILLO EMPAREDADOS Y HELADOS|Avda. de la Constitución 2222, México D.F., Mexico|
|ANTONIO MORENO TAQUERÍA           |Mataderos  2312, México D.F., Mexico              |
|AROUND THE HORN                   |120 Hanover Sq., London, UK                       |
|BERGLUNDS SNABBKÖP                |Berguvsvägen  8, Luleå, Sweden                    |
+----------------------------------+--------------------------------------------------+
only showing top 5 rows


Filter Rows Based on Conditions

In [None]:
# Show only customers from USA or Germany
filtered_customers = customers.filter((col("Country") == "USA") | (col("Country") == "Germany"))
filtered_customers.select("CompanyName", "Country").show(5)


Extract Date Components

In [None]:
# Extract year and month from order dates
orders_with_date_parts = orders.withColumn("OrderYear", year(col("OrderDate"))) \
    .withColumn("OrderMonth", month(col("OrderDate")))
orders_with_date_parts.select("OrderID", "OrderDate", "OrderYear", "OrderMonth").show(5)


Round Numeric Values

In [None]:
# Round freight costs to 2 decimal places and create freight categories
orders_rounded = orders.withColumn("FreightRounded", round(col("Freight"), 2)) \
    .withColumn("FreightCategory", 
        when(col("Freight") < 10, "Low")
        .when(col("Freight") < 50, "Medium")
        .otherwise("High"))
orders_rounded.select("OrderID", "Freight", "FreightRounded", "FreightCategory").show(5)


Sort and Limit Rows

In [None]:
# Show top 5 most expensive products, sorted by price descending
expensive_products = products.orderBy(desc("UnitPrice")).limit(5)
expensive_products.select("ProductName", "UnitPrice", "CategoryID").show()


### Complex filtering and conditional logic

**Multiple AND/OR Conditions with Null Handling**

This combines multiple conditions using AND (&) and OR (|) operators. It finds customers who either: (1) are from specific European countries AND have both phone and fax numbers, OR (2) are from the USA. The isNotNull() function handles missing data gracefully.

In [None]:
customer_segments = customers.withColumn("Segment",
    when(col("Country") == "USA", 
         when(col("Region").isNotNull(), "US-Regional")
         .otherwise("US-National"))
    .when(col("Country").isin("Germany", "France", "UK"), "EU-Premium")
    .when(col("Country").isin("Brazil", "Argentina", "Mexico"), "LATAM")
    .otherwise("Other")
)
customer_segments.limit(3).show()

**Nested Conditional Logic (CASE-WHEN)**

This creates nested conditional logic similar to SQL's CASE-WHEN. It segments customers into tiers based on country and region, with nested conditions for US customers (regional vs national) and different categories for other geographic regions.

In [24]:
from pyspark.sql.functions import col, when

customer_segments = customers.withColumn("Segment",
    when(col("Country") == "USA", 
         when(col("Region").isNotNull(), "US-Regional")
         .otherwise("US-National"))
    .when(col("Country").isin("Germany", "France", "UK"), "EU-Premium")
    .when(col("Country").isin("Brazil", "Argentina", "Mexico"), "LATAM")
    .otherwise("Other")
)
customer_segments.limit(3).show()

+----------+--------------------+--------------+--------------------+--------------------+-----------+------+----------+-------+------------+------------+----------+
|CustomerID|         CompanyName|   ContactName|        ContactTitle|             Address|       City|Region|PostalCode|Country|       Phone|         Fax|   Segment|
+----------+--------------------+--------------+--------------------+--------------------+-----------+------+----------+-------+------------+------------+----------+
|     ALFKI| Alfreds Futterkiste|  Maria Anders|Sales Representative|       Obere Str. 57|     Berlin|  NULL|     12209|Germany| 030-0074321| 030-0076545|EU-Premium|
|     ANATR|Ana Trujillo Empa...|  Ana Trujillo|               Owner|Avda. de la Const...|México D.F.|  NULL|     05021| Mexico|(5) 555-4729|(5) 555-3745|     LATAM|
|     ANTON|Antonio Moreno Ta...|Antonio Moreno|               Owner|     Mataderos  2312|México D.F.|  NULL|     05023| Mexico|(5) 555-3932|        NULL|     LATAM|
+---

**Date-Based Filtering with Calculations**

This identifies problematic orders using date calculations. It finds orders that are: (1) shipped more than 7 days late, (2) have high freight costs for international shipments, or (3) remain unshipped after 30 days. The datediff() function calculates differences between dates.

In [None]:
from pyspark.sql.functions import datediff, current_date

problematic_orders = orders.filter(
    (datediff(col("ShippedDate"), col("RequiredDate")) > 7) |
    ((col("Freight") > 100) & (col("ShipCountry") != "USA")) |
    (col("ShippedDate").isNull() & (datediff(current_date(), col("OrderDate")) > 30))
)


**String Pattern Matching with Conditional Transformations**

This uses regular expressions to categorize products based on name patterns. The `(?i)` flag makes the search case-insensitive, and the pipe `(|)` acts as OR within the regex. It then applies different pricing strategies based on the categorization and other business rules.

In [None]:


product_categories = products.withColumn("ProductCategory",
    when(col("ProductName").rlike("(?i).*(cheese|dairy).*"), "Dairy")
    .when(col("ProductName").rlike("(?i).*(wine|beer|ale).*"), "Alcoholic")
    .when(col("ProductName").rlike("(?i).*(sauce|syrup|spread).*"), "Condiments")
    .when(col("ProductName").rlike("(?i).*(tea|coffee).*"), "Beverages")
    .otherwise("Other")
)



**Complex Inventory Management Logic**

This implements complex business logic for inventory management. It creates alerts for different scenarios: out-of-stock active products, low stock below reorder levels, expensive overstocked items, and products that need reordering. The final filter removes "Normal" items to focus only on products requiring attention.

The .rlike() method performs regex matching on column values. The pattern (?i).*(cheese|dairy).* means:

- `(?i)` - Case insensitive flag
- `.*` - Match any characters before
- `(cheese|dairy)` - Match either "cheese" or "dairy"
- `.*` - Match any characters after

In [26]:
inventory_alerts = products.withColumn("AlertType",
    when((col("UnitsInStock") == 0) & (col("Discontinued") == 0), "Out of Stock")
    .when((col("UnitsInStock") < col("ReorderLevel")) & (col("Discontinued") == 0), "Low Stock")
    .when((col("UnitsInStock") > 100) & (col("UnitPrice") > 20), "Overstock Expensive")
    .when((col("UnitsOnOrder") == 0) & (col("UnitsInStock") < 20) & (col("Discontinued") == 0), "No Reorder")
    .otherwise("Normal")
).filter(col("AlertType") != "Normal")
inventory_alerts.printSchema()

root
 |-- ProductID: integer (nullable = true)
 |-- ProductName: string (nullable = true)
 |-- SupplierID: integer (nullable = true)
 |-- CategoryID: integer (nullable = true)
 |-- QuantityPerUnit: string (nullable = true)
 |-- UnitPrice: float (nullable = true)
 |-- UnitsInStock: integer (nullable = true)
 |-- UnitsOnOrder: integer (nullable = true)
 |-- ReorderLevel: integer (nullable = true)
 |-- Discontinued: integer (nullable = true)
 |-- AlertType: string (nullable = false)



In [27]:
inventory_alerts.show(3)

+---------+--------------------+----------+----------+-------------------+---------+------------+------------+------------+------------+-------------------+
|ProductID|         ProductName|SupplierID|CategoryID|    QuantityPerUnit|UnitPrice|UnitsInStock|UnitsOnOrder|ReorderLevel|Discontinued|          AlertType|
+---------+--------------------+----------+----------+-------------------+---------+------------+------------+------------+------------+-------------------+
|        3|       Aniseed Syrup|         1|         2|12 - 550 ml bottles|     10.0|          13|          70|          25|           0|          Low Stock|
|        6|Grandma's Boysenb...|         3|         2|     12 - 8 oz jars|     25.0|         120|           0|          25|           0|Overstock Expensive|
|        7|Uncle Bob's Organ...|         3|         7|    12 - 1 lb pkgs.|     30.0|          15|           0|          10|           0|         No Reorder|
+---------+--------------------+----------+----------+----

In [28]:
inventory_alerts.count()

28

In [29]:
inventory_alerts.cache()
inventory_alerts.count()

28

In [30]:
inventory_alerts.count()

28

### Joins


**Inner Join - Basic Relationship**

This performs a standard inner join between orders and customers tables on the CustomerID column. Only orders that have matching customers are returned. This is the most common join type and provides the best performance since it filters out non-matching records.

In [None]:
orders_customers = orders.join(customers, "CustomerID", "inner")


**Left Join - Preserve All Records**

A left join preserves all customers even if they have no orders. The aggregation counts orders per customer, showing 0 for customers without orders. This is useful for finding inactive customers or getting complete customer metrics.

In [31]:
from pyspark.sql.functions import count

customer_orders = customers.join(orders, "CustomerID", "left") \
    .groupBy("CustomerID", "CompanyName") \
    .agg(count("OrderID").alias("OrderCount"))


In [32]:
customer_orders.printSchema()

root
 |-- CustomerID: string (nullable = true)
 |-- CompanyName: string (nullable = true)
 |-- OrderCount: long (nullable = false)



In [33]:
customer_orders.show(10)

+----------+--------------------+----------+
|CustomerID|         CompanyName|OrderCount|
+----------+--------------------+----------+
|     CENTC|Centro comercial ...|         2|
|     COMMI|    Comércio Mineiro|        10|
|     OCEAN|Océano Atlántico ...|        11|
|     ANATR|Ana Trujillo Empa...|        10|
|     LACOR|La corne d'abondance|        11|
|     ERNSH|        Ernst Handel|       102|
|     FRANS|      Franchi S.p.A.|        10|
|     GROSR|GROSELLA-Restaurante|         4|
|     QUEDE|         Que Delícia|        24|
|     TOMSP|  Toms Spezialitäten|        14|
+----------+--------------------+----------+
only showing top 10 rows


**Broadcast Join - Small Table Optimization**

The `broadcast()` function tells Spark to send the small categories table to all worker nodes, avoiding expensive shuffle operations. This is highly efficient when one table is small (<200MB). Spark automatically broadcasts tables under 10MB, but explicit broadcasting ensures optimization for larger small tables.


In [None]:
from pyspark.sql.functions import broadcast

products_categories = products.join(broadcast(categories), "CategoryID", "inner")


**Multiple Table Join - Complex Business Query**

This chains multiple joins to create a comprehensive view. Note the explicit column reference `(orders.EmployeeID == employees.EmployeeID)` when column names might be ambiguous. This pattern is common in data warehousing for creating denormalized views.

In [None]:
complete_orders = orders \
    .join(customers, "CustomerID", "inner") \
    .join(employees, orders.EmployeeID == employees.EmployeeID, "inner")


**Inequality Join - Range-Based Matching**

Unlike equality joins, this uses range conditions to match products to price tiers. The join condition `(UnitPrice >= MinPrice) & (UnitPrice < MaxPrice)` assigns each product to its appropri

In [None]:
# Create price tiers DataFrame
price_tiers = spark.createDataFrame([
    (1, "Budget", 0.0, 10.0),
    (2, "Standard", 10.0, 25.0),
    (3, "Premium", 25.0, 50.0),
    (4, "Luxury", 50.0, 999.0)
], ["TierID", "TierName", "MinPrice", "MaxPrice"])

products_tiers = products.join(
    price_tiers,
    (col("UnitPrice") >= col("MinPrice")) & (col("UnitPrice") < col("MaxPrice")),
    "inner"
)


#### Attributes vs Col

This FAILS - no schema defined
    df = spark.createDataFrame([(1, "John"), (2, "Jane")])
    result = df.CustomerID  # AttributeError: 'DataFrame' object has no attribute 'CustomerID'


This WORKS - schema defined
    schema = StructType([
        StructField("CustomerID", IntegerType(), True),
        StructField("Name", StringType(), True)
    ])
    df = spark.createDataFrame([(1, "John"), (2, "Jane")], schema)
    result = df.CustomerID  # Works!

Why This Happens
- Python's Dynamic Nature: Python resolves attributes at runtime using __getattr__
- Unknown Column Names: Without schema, Spark doesn't know what columns exist
- Lazy Evaluation: Schema inference happens only when an action is triggered

Solutions
- Use `col()` function (always works):
        ```df.select(col("CustomerID"))  # Works with or without schema```
- Use bracket notation:
        ```df["CustomerID"]  # Works with or without schema```
- Define schema explicitly:
        ```df = spark.createDataFrame(data, schema)
        df.CustomerID  # Now works```

The attribute notation (df.ColumnName) is syntactic sugar that only works when Spark knows the column exists at DataFrame creation time

### GroupBy, aggregations and window functions

Key Concepts:

- Grouping: Collapses rows into groups for aggregation
- Window Functions: Perform calculations across related rows without collapsing
- Partitioning: Divides data into logical groups for separate processing
- Ordering: Determines sequence for ranking and frame-based operations
- Frames: Define which rows to include in window calculations

 **Basic Grouping and Aggregation**

Groups orders by shipping country and calculates basic statistics. The agg() function allows multiple aggregations in one operation, providing count, average, and sum of freight costs per country.

In [38]:
from pyspark.sql.functions import avg, sum # naming conflict when using other package
# from pyspark.sql.functions import sum as spark_sum ...or use F.sum()

country_stats = orders.groupBy("ShipCountry") \
    .agg(
        count("OrderID").alias("OrderCount"),
        avg("Freight").alias("AvgFreight"),
        sum("Freight").alias("TotalFreight")
    )
country_stats.printSchema()
number_or_rec = country_stats.count()
print(f"Number of records: {number_or_rec:,}")
country_stats.toPandas()

root
 |-- ShipCountry: string (nullable = true)
 |-- OrderCount: long (nullable = false)
 |-- AvgFreight: double (nullable = true)
 |-- TotalFreight: double (nullable = true)

Number of records: 21


Unnamed: 0,ShipCountry,OrderCount,AvgFreight,TotalFreight
0,Sweden,97,104.599175,10146.120006
1,Germany,328,116.374238,38170.750006
2,France,184,68.622718,12626.580033
3,Argentina,34,52.137353,1772.670002
4,Belgium,56,76.268392,4271.029926
5,Finland,54,53.223889,2874.09002
6,Italy,53,39.790377,2108.89
7,Norway,16,56.065625,897.049995
8,Spain,54,44.800926,2419.250022
9,Denmark,46,93.184347,4286.479981


In [39]:
country_stats.collect()

[Row(ShipCountry='Sweden', OrderCount=97, AvgFreight=104.59917531554233, TotalFreight=10146.120005607605),
 Row(ShipCountry='Germany', OrderCount=328, AvgFreight=116.374237821869, TotalFreight=38170.750005573034),
 Row(ShipCountry='France', OrderCount=184, AvgFreight=68.6227175715258, TotalFreight=12626.580033160746),
 Row(ShipCountry='Argentina', OrderCount=34, AvgFreight=52.137353001271975, TotalFreight=1772.6700020432472),
 Row(ShipCountry='Belgium', OrderCount=56, AvgFreight=76.26839154213667, TotalFreight=4271.0299263596535),
 Row(ShipCountry='Finland', OrderCount=54, AvgFreight=53.22388925927657, TotalFreight=2874.0900200009346),
 Row(ShipCountry='Italy', OrderCount=53, AvgFreight=39.79037735315988, TotalFreight=2108.889999717474),
 Row(ShipCountry='Norway', OrderCount=16, AvgFreight=56.065624713897705, TotalFreight=897.0499954223633),
 Row(ShipCountry='Spain', OrderCount=54, AvgFreight=44.80092633212054, TotalFreight=2419.2500219345093),
 Row(ShipCountry='Denmark', OrderCount=46

**Multiple Column Grouping**

Groups by multiple columns (category and supplier) to create a cross-tabulation analysis. This shows how products are distributed across category-supplier combinations with their pricing and inventory metrics.

In [None]:
category_supplier_stats = products.groupBy("CategoryID", "SupplierID") \
    .agg(
        count("ProductID").alias("ProductCount"),
        avg("UnitPrice").alias("AvgPrice"),
        sum("UnitsInStock").alias("TotalStock")
    )


**Window Function - Ranking**

Creates cumulative sums using window frames. `rowsBetween(unboundedPreceding, currentRow)` includes all rows from the start of the partition to the current row, creating a running total of freight costs per customer.

In [40]:
from pyspark.sql import Window

customer_window = Window.partitionBy("CustomerID").orderBy("OrderDate")
running_totals = orders.withColumn("RunningFreight", 
    sum("Freight").over(customer_window.rowsBetween(Window.unboundedPreceding, Window.currentRow))
)


In [42]:
running_totals.printSchema()

root
 |-- OrderID: integer (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- EmployeeID: integer (nullable = true)
 |-- OrderDate: date (nullable = true)
 |-- RequiredDate: date (nullable = true)
 |-- ShippedDate: date (nullable = true)
 |-- ShipVia: integer (nullable = true)
 |-- Freight: float (nullable = true)
 |-- ShipName: string (nullable = true)
 |-- ShipAddress: string (nullable = true)
 |-- ShipCity: string (nullable = true)
 |-- ShipRegion: string (nullable = true)
 |-- ShipPostalCode: string (nullable = true)
 |-- ShipCountry: string (nullable = true)
 |-- RunningFreight: double (nullable = true)



In [43]:
running_totals.toPandas()

25/09/17 20:10:48 WARN CSVHeaderChecker: Number of column in CSV header is not equal to number of fields in the schema:
 Header length: 19, schema size: 14
CSV file: file:///Users/pmolnar/Classes-Workshops/MSA8395_Special_Topics_in_Analytics/SpecialTopicsInAnalytics/apache_spark/notebooks/data/orders.csv


Unnamed: 0,OrderID,CustomerID,EmployeeID,OrderDate,RequiredDate,ShippedDate,ShipVia,Freight,ShipName,ShipAddress,ShipCity,ShipRegion,ShipPostalCode,ShipCountry,RunningFreight
0,10643,ALFKI,6,1997-08-25,1997-09-22,1997-09-02,1,29.459999,Alfreds Futterkiste,Obere Str. 57,Berlin,,12209,Germany,29.459999
1,10643,ALFKI,6,1997-08-25,1997-09-22,1997-09-02,1,29.459999,Alfreds Futterkiste,Obere Str. 57,Berlin,,12209,Germany,58.919998
2,10643,ALFKI,6,1997-08-25,1997-09-22,1997-09-02,1,29.459999,Alfreds Futterkiste,Obere Str. 57,Berlin,,12209,Germany,88.379997
3,10692,ALFKI,4,1997-10-03,1997-10-31,1997-10-13,2,61.020000,Alfred's Futterkiste,Obere Str. 57,Berlin,,12209,Germany,149.399998
4,10702,ALFKI,4,1997-10-13,1997-11-24,1997-10-21,1,23.940001,Alfred's Futterkiste,Obere Str. 57,Berlin,,12209,Germany,173.339998
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2150,10998,WOLZA,8,1998-04-03,1998-04-17,1998-04-17,2,20.309999,Wolski Zajazd,ul. Filtrowa 68,Warszawa,,01-012,Poland,391.880008
2151,10998,WOLZA,8,1998-04-03,1998-04-17,1998-04-17,2,20.309999,Wolski Zajazd,ul. Filtrowa 68,Warszawa,,01-012,Poland,412.190007
2152,10998,WOLZA,8,1998-04-03,1998-04-17,1998-04-17,2,20.309999,Wolski Zajazd,ul. Filtrowa 68,Warszawa,,01-012,Poland,432.500007
2153,10998,WOLZA,8,1998-04-03,1998-04-17,1998-04-17,2,20.309999,Wolski Zajazd,ul. Filtrowa 68,Warszawa,,01-012,Poland,452.810006


**Window Function - Lag and Lead**

`lag()` and `lead()` access previous and next rows within a partition. This enables comparison between consecutive orders for the same customer, useful for trend analysis and change detection.



In [44]:
from pyspark.sql.functions import lead, lag

comparison_window = Window.partitionBy("CustomerID").orderBy("OrderDate")
order_comparison = orders.withColumn("PrevFreight", lag("Freight", 1).over(comparison_window))


**Advanced Aggregation with Conditional Logic**

Combines aggregation with conditional logic using `when()`. This counts total products and conditionally counts discontinued products, enabling calculation of percentages and business metrics within the same aggregation.

In [None]:
product_performance = products.groupBy("CategoryID") \
    .agg(
        count("ProductID").alias("TotalProducts"),
        count(when(col("Discontinued") == 1, col("ProductID"))).alias("DiscontinuedCount")
    )

In [None]:
# Alternative
product_performance = products.groupBy("CategoryID") \
    .agg(
        count("ProductID").alias("TotalProducts"),
        sum(when(col("Discontinued") == 1, 1).otherwise(0)).alias("DiscontinuedCount")
    )


**Window Function - Statistical Measures**

 Applies statistical functions across window partitions without collapsing rows (unlike `groupBy`). Each product retains its row while gaining category-level statistics, enabling comparison of individual items against group averages.

In [None]:
stats_window = Window.partitionBy("CategoryID")
product_stats = products.withColumn("AvgCategoryPrice", avg("UnitPrice").over(stats_window))
product_stats.limit(10).show()

## Data Types ad Handling Missing Data 

Spark DataFrames support a rich type system that mirrors SQL data types while providing strong schema enforcement and optimization capabilities. The core data types include
- primitive types (IntegerType, StringType, FloatType, DoubleType, BooleanType, DateType, TimestampType),
- decimal types (DecimalType for precise numeric calculations), and
- complex types (ArrayType, MapType, StructType for nested data).

Each column in a DataFrame has a defined data type that determines how Spark stores, processes, and optimizes operations on that data. Unlike Python's dynamic typing, Spark's static typing enables the Catalyst optimizer to generate efficient execution plans and catch type-related errors at compile time rather than runtime.

The nullable property in Spark schemas plays a crucial role in data integrity and query optimization. When a column is marked as `nullable=False`, Spark guarantees that the column cannot contain null values, enabling aggressive optimizations like eliminating null checks in generated code.

Conversely, `nullable=True` columns require null-safe operations and additional runtime checks. The `NullType` represents columns that contain only null values, which can occur during data loading or as intermediate results in transformations. 

Proper null handling is essential because null values propagate through most operations (e.g., `null + 5 = null`) and require explicit handling using functions like `isNull()`, `isNotNull()`, `coalesce()`, or `when().otherwise()` constructs.


In [None]:
# Schema definition with nullable control
StructField("CustomerID", StringType(), False)     # Cannot be null
StructField("Region", StringType(), True)          # Can be null

# Null value analysis
customers.select([count(when(col(c).isNull(), c)).alias(f"{c}_nulls") for c in customers.columns])

# Null-safe operations
when(col("Region").isNull(), "No Region").otherwise(col("Region"))

# Type conversions
col("UnitPrice").cast(IntegerType())


PySpark provides comprehensive tools for handling missing data through its DataFrame API and DataFrameNaFunctions class. Missing data in Spark is primarily represented as NULL values, which are distinct from empty strings, zeros, or NaN (Not a Number) values.

Spark's approach to missing data is SQL-compliant, meaning NULL values propagate through most operations (e.g., NULL + 5 = NULL) and are excluded from aggregations by default. The framework offers both automatic handling (like ignoring NULLs in aggregations) and explicit control through functions like `isNull()`, `isNotNull()`, `na.drop()`, and `na.fill()`.

The DataFrameNaFunctions class, accessed via df.na, provides the primary interface for missing data operations. Key strategies include
- dropping rows with na.drop() (with options for thresholds and specific columns),
- filling values with na.fill() (supporting different values per column), and
- replacing values with na.replace().

Advanced techniques involve conditional imputation using
- `when().otherwise()` constructs,
-  forward/backward filling with window functions, and
-   statistical imputation using calculated means or medians.

Spark also provides null-safe operations like `eqNullSafe()` for comparisons and coalesce() for selecting the first non-null value from multiple columns.

## User Defined Functions (UDFs) and User Defined Aggregate Functions (UDAFs)

**User Defined Functions (UDFs)** in PySpark allow you to extend Spark's built-in function library with custom Python logic that can be applied to DataFrame columns. UDFs are essential when built-in functions cannot handle specific business logic, complex string processing, or domain-specific calculations. 

There are two main types:
1. Standard UDFs that process one row at a time, and
2. Pandas UDFs (vectorized UDFs) that leverage Apache Arrow for better performance by processing entire columns as pandas Series.

UDFs are registered using the `udf()` function with a specified return type, and they can handle complex operations like regex processing, mathematical calculations, or external API calls.

**User Defined Aggregate Functions (UDAFs)** enable custom aggregation logic across groups of rows, though PySpark doesn't have direct UDAF support like Scala Spark. Instead, you can achieve similar functionality by combining `collect_list()` with UDFs, or using `pandas_udf` with `groupby().apply()` for more complex aggregations. These are useful for implementing custom statistical measures, weighted calculations, or business-specific metrics that aren't available in Spark's standard aggregation functions.

**Standard UDF - Phone Number Standardization**

In [None]:
from pyspark.sql.functions import udf

def standardize_phone(phone_str):
    if not phone_str:
        return None
    digits = re.sub(r'\D', '', phone_str)
    if len(digits) == 10:
        return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
    # ... complex formatting logic
    
standardize_phone_udf = udf(standardize_phone, StringType())
customers.withColumn("StandardizedPhone", standardize_phone_udf(col("Phone")))


**Custom Aggregation - Weighted Average**

Performance Considerations: Pandas UDFs are significantly faster than standard UDFs due to vectorization and reduced serialization overhead. However, UDFs should be used judiciously as they break Spark's optimization capabilities and require data movement between JVM and Python processes.

In [None]:
from pyspark.sql.functions import collect_list

def weighted_average_freight(freight_values, order_counts):
    total_weighted = sum(f * c for f, c in zip(freight_values, order_counts) if f and c)
    total_weights = sum(c for c in order_counts if c)
    return total_weighted / total_weights if total_weights > 0 else 0.0

orders.groupBy("ShipCountry").agg(collect_list("Freight").alias("freight_list"))


**Pandas UDF - Vectorized Z-Score Calculation**



In [None]:
# from pyspark.sql.functions import pandas_udf, col
# from pyspark.sql.types import StructType, StructField, FloatType
# import pandas as pd

# def calculate_zscore_group(pdf):
#     """Calculate Z-score within each group"""
#     mean_price = pdf['UnitPrice'].mean()
#     std_price = pdf['UnitPrice'].std()
    
#     if std_price > 0:
#         pdf['PriceZScore'] = (pdf['UnitPrice'] - mean_price) / std_price
#     else:
#         pdf['PriceZScore'] = 0.0
    
#     return pdf2

# # Create the output schema by adding the new column
# # output_schema = StructType(products.schema.fields + [StructField("PriceZScore", FloatType(), True)])
# schema_str = "ProductID int, ProductName string, SupplierID int, CategoryID int, QuantityPerUnit string, UnitPrice float, UnitsInStock int, UnitsOnOrder int, ReorderLevel int, Discontinued int, PriceZScore float"

# products_zscore = products.groupBy("CategoryID").applyInPandas(
#     calculate_zscore_group, 
#     schema=schema_str
# )

# # # Apply to grouped data
# # products_zscore = products.groupBy("CategoryID").applyInPandas(
# #     calculate_zscore_group, 
# #     schema=output_schema
# # )
# products_zscore.printSchema()

# #products_zscore.select("CategoryID", "ProductName", "UnitPrice", "PriceZScore").show(10)


Alternatively

In [None]:
from pyspark.sql.functions import avg, stddev
from pyspark.sql.window import Window

# Define window
window = Window.partitionBy("CategoryID")

# Calculate Z-score using window functions
products_zscore = products \
    .withColumn("AvgPrice", avg("UnitPrice").over(window)) \
    .withColumn("StdPrice", stddev("UnitPrice").over(window)) \
    .withColumn("PriceZScore", 
        (col("UnitPrice") - col("AvgPrice")) / col("StdPrice")
    )

# Now you can select the column
products_zscore.select("CategoryID", "ProductName", "UnitPrice", "PriceZScore").show(10)


## Integrating with SQL: Using Spark SQL queries with Python

Spark SQL provides a powerful way to integrate SQL queries directly into PySpark applications, allowing you to leverage familiar SQL syntax while maintaining the performance benefits of Spark's distributed computing. The integration works through temporary views that register DataFrames as SQL tables, enabling seamless switching between DataFrame API and SQL syntax within the same application. This approach is particularly valuable for teams with strong SQL backgrounds or when working with complex analytical queries that are more naturally expressed in SQL.

The core integration mechanism involves registering DataFrames as temporary views using `createOrReplaceTempView()`, then executing SQL queries with `spark.sql()` which returns DataFrames that can be further processed using the DataFrame API. This bidirectional integration allows you to use SQL for complex joins, window functions, and analytical queries while leveraging Python's DataFrame API for data manipulation, machine learning pipelines, and custom transformations. Spark SQL supports the full SQL standard including CTEs, subqueries, window functions, and advanced analytical functions, all optimized by the same Catalyst optimizer that powers the DataFrame API.

**Register DataFrames as Views**

In [45]:
customers.createOrReplaceTempView("customers")
orders.createOrReplaceTempView("orders")


**Execute SQL Queries**

In [47]:
result = spark.sql("""
    SELECT c.CompanyName, COUNT(o.OrderID) as OrderCount
    FROM customers c
    LEFT JOIN orders o ON c.CustomerID = o.CustomerID
    GROUP BY c.CompanyName
    ORDER BY OrderCount DESC
""")
result.printSchema()
result.toPandas()

root
 |-- CompanyName: string (nullable = true)
 |-- OrderCount: long (nullable = false)



Unnamed: 0,CompanyName,OrderCount
0,Save-a-lot Markets,116
1,Ernst Handel,102
2,QUICK-Stop,86
3,Rattlesnake Canyon Grocery,71
4,Hungry Owl All-Night Grocers,55
...,...,...
86,GROSELLA-Restaurante,4
87,Lazy K Kountry Store,2
88,Centro comercial Moctezuma,2
89,Paris spécialités,0


**Mix SQL and DataFrame API**

In [None]:
# Start with SQL
sql_result = spark.sql("SELECT * FROM customers WHERE Country = 'USA'")
# Continue with DataFrame API
final_result = sql_result.filter(col("Region").isNotNull()).orderBy("CompanyName")


**Complex Analytics with CTEs**

In [None]:
spark.sql("""
    WITH monthly_sales AS (
        SELECT YEAR(OrderDate) as year, MONTH(OrderDate) as month, SUM(Freight) as total
        FROM orders GROUP BY YEAR(OrderDate), MONTH(OrderDate)
    )
    SELECT *, LAG(total) OVER (ORDER BY year, month) as prev_month
    FROM monthly_sales
""")


## End

In [None]:
spark.stop()