# PySpark Playground

A hands-on reference notebook covering the most common PySpark patterns.

| Section | Topics |
|---------|--------|
| 0 | Setup & SparkSession |
| 1 | DataFrame Basics |
| 2 | Selecting & Filtering |
| 3 | Aggregations & GroupBy |
| 4 | Joins |
| 5 | Window Functions |
| 6 | Spark SQL |
| 7 | User-Defined Functions (UDFs) |
| 8 | Null Handling & Data Quality |
| 9 | String & Date Operations |
| 10 | Reading & Writing Data |

## 0. Setup & SparkSession

In [None]:
%pip install pyspark==3.5.8 python-dotenv --quiet

In [None]:
import os
from pyspark.sql import SparkSession
from dotenv import load_dotenv

load_dotenv()

os.environ["JAVA_HOME"]   = r"C:\devhome\tools\Java\jdk-17.0.2"
os.environ["SPARK_HOME"]  = r"C:\devhome\tools\spark-3.5.8-bin-hadoop3"
os.environ["HADOOP_HOME"] = r"C:\devhome\tools\hadoop-3.3.6"

spark = (
    SparkSession.builder
    .master("local[*]")
    .appName("PySpark Playground")
    .config("spark.pyspark.python", os.environ.get("PYSPARK_PYTHON", ""))
    .config("spark.pyspark.driver.python", os.environ.get("PYSPARK_DRIVER_PYTHON", ""))
    .config("spark.sql.shuffle.partitions", "4")  # keeps local runs fast
    .getOrCreate()
)

spark.sparkContext.setLogLevel("WARN")
print(f"Spark version: {spark.version}")
print(f"Spark UI: http://localhost:4040")

---
## 1. DataFrame Basics

Creating DataFrames from Python collections and inspecting schema/data.

In [None]:
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, DateType
from datetime import date

# --- Option A: infer schema from list of tuples ---
employees_data = [
    (1, "Alice",   "Engineering", 95000.0, date(2019, 3, 15)),
    (2, "Bob",     "Marketing",   72000.0, date(2020, 7, 1)),
    (3, "Carol",   "Engineering", 110000.0, date(2017, 11, 20)),
    (4, "Dave",    "Marketing",   68000.0, date(2021, 1, 10)),
    (5, "Eve",     "Engineering", 105000.0, date(2018, 5, 5)),
    (6, "Frank",   "HR",          60000.0, date(2022, 2, 28)),
    (7, "Grace",   "HR",          63000.0, date(2020, 9, 14)),
    (8, "Hank",    "Engineering", 98000.0, date(2019, 6, 1)),
]

schema = StructType([
    StructField("id",         IntegerType(), False),
    StructField("name",       StringType(),  False),
    StructField("department", StringType(),  True),
    StructField("salary",     DoubleType(),  True),
    StructField("hire_date",  DateType(),    True),
])

employees = spark.createDataFrame(employees_data, schema)

# --- Option B: from list of Row objects ---
departments = spark.createDataFrame([
    Row(dept_id="Engineering", budget=500000, location="San Francisco"),
    Row(dept_id="Marketing",   budget=200000, location="New York"),
    Row(dept_id="HR",          budget=150000, location="Chicago"),
])

print("=== employees schema ===")
employees.printSchema()

print("=== employees sample ===")
employees.show()

print(f"Row count: {employees.count()}")
print(f"Columns:   {employees.columns}")

In [None]:
# Descriptive statistics
employees.describe("salary").show()

---
## 2. Selecting & Filtering

In [None]:
from pyspark.sql.functions import col, lit, expr, when

# Select specific columns (three equivalent ways)
employees.select("name", "department", "salary").show()

# Add a derived column
employees.select(
    col("name"),
    col("salary"),
    (col("salary") * 1.1).alias("salary_with_raise"),
    when(col("salary") >= 90000, "senior").otherwise("junior").alias("level")
).show()

In [None]:
# Filter / where (interchangeable)
print("Engineering employees earning > 95k:")
employees.filter(
    (col("department") == "Engineering") & (col("salary") > 95000)
).show()

# isin
print("Marketing or HR employees:")
employees.where(col("department").isin("Marketing", "HR")).show()

# String pattern match
print("Names containing 'a' (case-insensitive):")
employees.filter(col("name").rlike("(?i)a")).show()

---
## 3. Aggregations & GroupBy

In [None]:
from pyspark.sql.functions import count, sum, avg, min, max, stddev, round as spark_round

dept_stats = (
    employees
    .groupBy("department")
    .agg(
        count("*").alias("headcount"),
        spark_round(avg("salary"), 0).alias("avg_salary"),
        min("salary").alias("min_salary"),
        max("salary").alias("max_salary"),
        spark_round(stddev("salary"), 0).alias("stddev_salary"),
    )
    .orderBy("department")
)

dept_stats.show()

In [None]:
# Pivot: department salaries as columns (requires small cardinality)
from pyspark.sql.functions import avg as avg_fn

employees.groupBy().pivot("department").agg(spark_round(avg_fn("salary"), 0)).show()

---
## 4. Joins

In [None]:
# Inner join â€“ enrich employees with department metadata
enriched = employees.join(departments, employees.department == departments.dept_id, how="inner")
enriched.select("name", "department", "salary", "budget", "location").show()

# Left join (employees without a matching department row still appear)
employees_left = employees.join(departments, employees.department == departments.dept_id, how="left")
employees_left.select("name", "department", "location").show()

In [None]:
# Self-join: pair employees in the same department
emp_a = employees.alias("a")
emp_b = employees.alias("b")

(
    emp_a.join(emp_b,
               (col("a.department") == col("b.department")) & (col("a.id") < col("b.id")),
               how="inner")
    .select(
        col("a.name").alias("employee_1"),
        col("b.name").alias("employee_2"),
        col("a.department")
    )
    .orderBy("department", "employee_1")
    .show()
)

---
## 5. Window Functions

Window functions compute a value for each row based on a group of related rows (the *window*).

In [None]:
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, dense_rank, row_number, lag, lead, sum as sum_fn, avg as avg_fn2

# Partition by department, order by salary descending
dept_window = Window.partitionBy("department").orderBy(col("salary").desc())

employees.select(
    "name", "department", "salary",
    rank().over(dept_window).alias("rank"),
    dense_rank().over(dept_window).alias("dense_rank"),
    row_number().over(dept_window).alias("row_num"),
).orderBy("department", col("salary").desc()).show()

In [None]:
# Lag / lead: compare salary to next/previous colleague in same dept
employees.select(
    "name", "department", "salary",
    lag("salary", 1).over(dept_window).alias("prev_salary"),
    lead("salary", 1).over(dept_window).alias("next_salary"),
).orderBy("department", col("salary").desc()).show()

In [None]:
# Running total of salary within department (order matters)
running_window = Window.partitionBy("department").orderBy("hire_date").rowsBetween(Window.unboundedPreceding, Window.currentRow)

employees.select(
    "name", "department", "hire_date", "salary",
    sum_fn("salary").over(running_window).alias("cumulative_salary")
).orderBy("department", "hire_date").show()

---
## 6. Spark SQL

Register DataFrames as temp views and query with SQL.

In [None]:
employees.createOrReplaceTempView("employees")
departments.createOrReplaceTempView("departments")

# Top earner per department
spark.sql("""
    SELECT department, name, salary
    FROM (
        SELECT department, name, salary,
               RANK() OVER (PARTITION BY department ORDER BY salary DESC) AS rnk
        FROM employees
    )
    WHERE rnk = 1
    ORDER BY department
""").show()

# Department budget vs total payroll
spark.sql("""
    SELECT d.dept_id, d.budget, d.location,
           COUNT(e.id)              AS headcount,
           ROUND(SUM(e.salary), 0)  AS total_payroll,
           ROUND(d.budget - SUM(e.salary), 0) AS remaining_budget
    FROM departments d
    LEFT JOIN employees e ON e.department = d.dept_id
    GROUP BY d.dept_id, d.budget, d.location
    ORDER BY d.dept_id
""").show()

---
## 7. User-Defined Functions (UDFs)

In [None]:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

# --- Row-level Python UDF (slower, but flexible) ---
def salary_band(salary: float) -> str:
    if salary is None:
        return "unknown"
    if salary < 70000:
        return "band-A"
    if salary < 90000:
        return "band-B"
    return "band-C"

salary_band_udf = udf(salary_band, StringType())

employees.select(
    "name", "salary",
    salary_band_udf(col("salary")).alias("band")
).show()

In [None]:
# --- Vectorised Pandas UDF (much faster for large data) ---
from pyspark.sql.functions import pandas_udf
import pandas as pd

@pandas_udf(DoubleType())
def normalise_salary(series: pd.Series) -> pd.Series:
    """Normalise salaries to 0-1 range within each batch."""
    return (series - series.min()) / (series.max() - series.min())

employees.select(
    "name", "salary",
    spark_round(normalise_salary(col("salary")), 4).alias("norm_salary")
).orderBy("salary").show()

---
## 8. Null Handling & Data Quality

In [None]:
from pyspark.sql.functions import coalesce, isnull, isnan, count as count_fn

# Introduce some nulls for demonstration
dirty_data = [
    (1, "Alice",  95000.0),
    (2, "Bob",    None),
    (3, None,     72000.0),
    (4, "Dave",   float("nan")),
    (5, "Eve",    105000.0),
]
dirty_df = spark.createDataFrame(dirty_data, ["id", "name", "salary"])

print("=== Raw data ===")
dirty_df.show()

# Count nulls per column
print("=== Null counts ===")
dirty_df.select([
    count_fn(when(isnull(c) | isnan(c), c)).alias(c)
    for c in dirty_df.columns
]).show()

# Fill / drop
print("=== After fillna ===")
dirty_df.fillna({"name": "Unknown", "salary": 0.0}).show()

print("=== After dropna ===")
dirty_df.dropna().show()

# coalesce: first non-null wins
print("=== coalesce salary with default ===")
dirty_df.select(
    "name",
    coalesce(col("salary"), lit(50000.0)).alias("salary")
).show()

---
## 9. String & Date Operations

In [None]:
from pyspark.sql.functions import (
    upper, lower, trim, length, concat, concat_ws,
    regexp_replace, split, substring, initcap
)

employees.select(
    upper(col("name")).alias("name_upper"),
    length(col("name")).alias("name_len"),
    concat_ws(" @ ", col("name"), col("department")).alias("name_dept"),
    regexp_replace(col("department"), "Engineering", "Eng").alias("dept_short"),
    substring(col("name"), 1, 3).alias("initials"),
).show()

In [None]:
from pyspark.sql.functions import (
    year, month, dayofweek, datediff, months_between,
    current_date, date_add, date_format, to_date
)

employees.select(
    "name",
    "hire_date",
    year("hire_date").alias("hire_year"),
    month("hire_date").alias("hire_month"),
    dayofweek("hire_date").alias("day_of_week"),          # 1=Sunday
    datediff(current_date(), col("hire_date")).alias("days_employed"),
    spark_round(months_between(current_date(), col("hire_date")), 1).alias("months_employed"),
    date_format(col("hire_date"), "MMM dd, yyyy").alias("formatted"),
).show(truncate=False)

---
## 10. Reading & Writing Data

Write to common formats and read them back.

In [None]:
import tempfile, os

tmp = tempfile.mkdtemp()

# --- Parquet (columnar, compressed, schema-preserving) ---
parquet_path = os.path.join(tmp, "employees.parquet")
employees.write.mode("overwrite").parquet(parquet_path)

df_from_parquet = spark.read.parquet(parquet_path)
print("Read back from Parquet:")
df_from_parquet.printSchema()
df_from_parquet.show(3)

In [None]:
# --- CSV ---
csv_path = os.path.join(tmp, "employees.csv")
employees.write.mode("overwrite").option("header", True).csv(csv_path)

df_from_csv = spark.read.option("header", True).option("inferSchema", True).csv(csv_path)
print("Read back from CSV:")
df_from_csv.show(3)

In [None]:
# --- Partitioned write (Hive-style) ---
partitioned_path = os.path.join(tmp, "employees_by_dept")
(
    employees
    .write
    .mode("overwrite")
    .partitionBy("department")
    .parquet(partitioned_path)
)

# Spark automatically prunes partitions when filtering on the partition column
df_eng = spark.read.parquet(partitioned_path).filter(col("department") == "Engineering")
print("Engineering employees (partition-pruned read):")
df_eng.show()

print(f"\nTemp files written to: {tmp}")

---
## Teardown

In [None]:
# Stop the Spark session when done to release resources
spark.stop()
print("Spark session stopped.")