# ðŸš€ PySpark Complete Tutorial: Basics â†’ Intermediate â†’ Advanced
## For Databricks Serverless Free Account

This notebook provides a comprehensive guide to PySpark on Databricks Serverless Compute. All examples use inline sample data â€” no external files required.

**Prerequisites:**
- Databricks Community Edition (free) or Databricks Serverless account
- `spark` session is pre-configured in Databricks notebooks

---

### Table of Contents

**PART 1 â€” BASICS**
1. Install & Import Libraries
2. Initialize Spark Session
3. Creating DataFrames
4. Basic DataFrame Operations
5. Column Manipulations & Type Casting
6. Aggregations & GroupBy
7. Working with Null Values & Data Cleaning

**PART 2 â€” INTERMEDIATE**
8. Joins
9. Window Functions
10. User Defined Functions (UDFs) & Pandas UDFs
11. Complex Data Types: Arrays, Maps, Structs
12. Spark SQL: Temp Views & SQL Queries

**PART 3 â€” ADVANCED**
13. Reading & Writing Data in Delta Lake
14. Caching, Persistence & Performance Tuning
15. Higher-Order Functions
16. Broadcast Variables & Accumulators
17. Partitioning, Bucketing & AQE
18. Structured Streaming Basics

---
# PART 1 â€” BASICS
---
## 1. Install & Import PySpark Libraries

In Databricks, PySpark is pre-installed. On local environments, install with `pip install pyspark`.
Below we import the essential modules.

In [None]:
# In Databricks, these are pre-installed. For local, run: pip install pyspark
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType,
    FloatType, DoubleType, DateType, TimestampType,
    BooleanType, ArrayType, MapType, LongType
)
from pyspark.sql.window import Window

print("âœ… All imports successful!")

## 2. Initialize Spark Session for Databricks Serverless

In **Databricks**, a `spark` session is automatically available â€” you don't need to create one.  
Databricks Serverless Compute manages cluster resources automatically (no cluster configuration needed).

The code below is for **local/non-Databricks** environments. In Databricks, you can skip this cell.

In [None]:
# In Databricks, `spark` is already available. This is for local environments only.
# spark = SparkSession.builder \
#     .appName("PySpark Tutorial") \
#     .master("local[*]") \
#     .getOrCreate()

# Verify the session
print(f"Spark Version: {spark.version}")
print(f"App Name: {spark.sparkContext.appName}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")
spark

## 3. Creating DataFrames from Various Sources

DataFrames are the primary data structure in PySpark. You can create them from:
- Python lists of tuples/rows
- Python dictionaries (via Pandas)
- Explicit schemas with `StructType`
- External files (CSV, JSON, Parquet, Delta)

In [None]:
# --- Method 1: From a list of tuples with column names ---
data = [
    ("Alice", "Engineering", 95000, 30, "2020-01-15"),
    ("Bob", "Marketing", 72000, 28, "2021-03-22"),
    ("Charlie", "Engineering", 110000, 35, "2018-07-10"),
    ("Diana", "HR", 68000, 26, "2022-06-01"),
    ("Eve", "Marketing", 85000, 32, "2019-11-30"),
    ("Frank", "Engineering", 102000, 40, "2017-04-18"),
    ("Grace", "HR", 71000, 29, "2021-09-05"),
    ("Hank", "Sales", 78000, 34, "2020-02-14"),
    ("Ivy", "Sales", 92000, 38, "2018-12-20"),
    ("Jack", "Engineering", 115000, 45, "2016-08-25"),
]
columns = ["name", "department", "salary", "age", "hire_date"]

df_employees = spark.createDataFrame(data, columns)
df_employees.show()
print(f"Row count: {df_employees.count()}")
print(f"Columns: {df_employees.columns}")

In [None]:
# --- Method 2: With an explicit schema using StructType ---
schema = StructType([
    StructField("name", StringType(), True),
    StructField("department", StringType(), True),
    StructField("salary", IntegerType(), True),
    StructField("age", IntegerType(), True),
    StructField("hire_date", StringType(), True),
])

df_typed = spark.createDataFrame(data, schema=schema)
df_typed.printSchema()
df_typed.show(5)

In [None]:
# --- Method 3: From a Pandas DataFrame ---
import pandas as pd

pandas_df = pd.DataFrame({
    "product": ["Laptop", "Phone", "Tablet", "Monitor", "Keyboard"],
    "price": [999.99, 699.99, 449.99, 329.99, 79.99],
    "quantity": [50, 200, 150, 80, 500]
})

df_products = spark.createDataFrame(pandas_df)
df_products.show()
df_products.printSchema()

## 4. Basic DataFrame Operations: Select, Filter, and Sort

Core operations for exploring and transforming DataFrames.

In [None]:
# --- SELECT: Different ways to select columns ---
# Method 1: Column names as strings
df_employees.select("name", "salary").show(5)

# Method 2: Using col()
df_employees.select(F.col("name"), F.col("salary")).show(5)

# Method 3: Using DataFrame column reference
df_employees.select(df_employees.name, df_employees["salary"]).show(5)

# --- describe() for summary statistics ---
df_employees.describe().show()

In [None]:
# --- FILTER / WHERE: Filter rows based on conditions ---
# Employees with salary > 90000
df_employees.filter(F.col("salary") > 90000).show()

# Equivalent using where()
df_employees.where("salary > 90000").show()

# Multiple conditions with AND (&) and OR (|)
df_employees.filter(
    (F.col("department") == "Engineering") & (F.col("age") > 30)
).show()

# Using isin() for multiple values
df_employees.filter(F.col("department").isin("Engineering", "Sales")).show()

# Using like() for pattern matching
df_employees.filter(F.col("name").like("A%")).show()

# Using between()
df_employees.filter(F.col("salary").between(70000, 100000)).show()

In [None]:
# --- SORT / ORDER BY ---
# Sort by salary descending
df_employees.orderBy(F.col("salary").desc()).show()

# Sort by multiple columns
df_employees.orderBy("department", F.col("salary").desc()).show()

# --- DISTINCT and LIMIT ---
df_employees.select("department").distinct().show()
df_employees.limit(3).show()

# --- dtypes: Quick view of column types ---
print("Column types:", df_employees.dtypes)

## 5. Column Manipulations & Type Casting

Add, rename, transform, and cast columns. Includes common string and date functions.

In [None]:
# --- withColumn(): Add or transform columns ---
df_transformed = (
    df_employees
    # Add a new column: annual bonus (10% of salary)
    .withColumn("bonus", F.col("salary") * 0.10)
    # Add a constant column
    .withColumn("country", F.lit("USA"))
    # Cast hire_date string to DateType
    .withColumn("hire_date", F.to_date(F.col("hire_date"), "yyyy-MM-dd"))
    # Calculate years of experience from hire_date
    .withColumn("years_exp", F.round(F.datediff(F.current_date(), F.col("hire_date")) / 365, 1))
)

df_transformed.show()
df_transformed.printSchema()

In [None]:
# --- Rename and Drop columns ---
df_renamed = df_employees.withColumnRenamed("name", "employee_name")
df_renamed.show(3)

df_dropped = df_employees.drop("age", "hire_date")
df_dropped.show(3)

# --- String Functions ---
df_string = df_employees.select(
    F.col("name"),
    F.upper(F.col("name")).alias("name_upper"),
    F.lower(F.col("name")).alias("name_lower"),
    F.length(F.col("name")).alias("name_length"),
    F.substring(F.col("name"), 1, 3).alias("first_3_chars"),
    F.concat(F.col("name"), F.lit(" - "), F.col("department")).alias("name_dept"),
    F.regexp_replace(F.col("department"), "Engineering", "Eng").alias("dept_short"),
    F.trim(F.col("name")).alias("name_trimmed"),
)
df_string.show(truncate=False)

# --- Type Casting ---
df_cast = df_employees.select(
    F.col("salary").cast("double").alias("salary_double"),
    F.col("salary").cast(StringType()).alias("salary_string"),
    F.col("age").cast("string").alias("age_string"),
)
df_cast.show(5)
df_cast.printSchema()

## 6. Aggregations & GroupBy Operations

Perform grouped computations using `groupBy()` with aggregate functions.

In [None]:
# --- Basic GroupBy with single aggregate ---
df_employees.groupBy("department").count().show()

# --- GroupBy with multiple aggregations using agg() ---
df_agg = df_employees.groupBy("department").agg(
    F.count("*").alias("emp_count"),
    F.sum("salary").alias("total_salary"),
    F.avg("salary").alias("avg_salary"),
    F.min("salary").alias("min_salary"),
    F.max("salary").alias("max_salary"),
    F.round(F.stddev("salary"), 2).alias("stddev_salary"),
)
df_agg.orderBy("department").show()

# --- Pivot Table: Department salary by age group ---
df_pivot = (
    df_employees
    .withColumn("age_group",
        F.when(F.col("age") < 30, "Under 30")
         .when(F.col("age") < 40, "30-39")
         .otherwise("40+")
    )
    .groupBy("department")
    .pivot("age_group")
    .agg(F.round(F.avg("salary"), 0))
)
df_pivot.show()

# --- Global Aggregations (without groupBy) ---
df_employees.agg(
    F.sum("salary").alias("total_payroll"),
    F.avg("age").alias("avg_age"),
    F.countDistinct("department").alias("num_departments"),
).show()

## 7. Working with Null Values & Data Cleaning

Handle missing data, duplicates, and data quality issues.

In [None]:
# Create a DataFrame with null values for demonstration
data_nulls = [
    ("Alice", "Engineering", 95000, "alice@co.com"),
    ("Bob", None, 72000, "bob@co.com"),
    ("Charlie", "Engineering", None, None),
    ("Diana", "HR", 68000, "diana@co.com"),
    ("Eve", None, None, "eve@co.com"),
    ("Alice", "Engineering", 95000, "alice@co.com"),  # duplicate
]
df_dirty = spark.createDataFrame(data_nulls, ["name", "department", "salary", "email"])
print("=== Original (with nulls & duplicates) ===")
df_dirty.show()

# --- Check for nulls ---
df_dirty.select([F.sum(F.col(c).isNull().cast("int")).alias(c) for c in df_dirty.columns]).show()

# --- Drop rows with ANY null ---
print("=== Drop rows with any null ===")
df_dirty.na.drop("any").show()

# --- Drop rows where specific columns are null ---
print("=== Drop rows where department is null ===")
df_dirty.na.drop(subset=["department"]).show()

# --- Fill nulls with defaults ---
print("=== Fill nulls ===")
df_dirty.na.fill({"department": "Unknown", "salary": 0, "email": "N/A"}).show()

# --- Using coalesce() to pick first non-null value ---
df_dirty.select(
    "name",
    F.coalesce(F.col("department"), F.lit("Unassigned")).alias("department")
).show()

# --- Filter nulls ---
df_dirty.filter(F.col("department").isNull()).show()
df_dirty.filter(F.col("department").isNotNull()).show()

# --- Remove duplicates ---
print("=== After dropDuplicates ===")
df_dirty.dropDuplicates().show()
df_dirty.dropDuplicates(["name"]).show()  # based on specific columns

---
# PART 2 â€” INTERMEDIATE
---
## 8. Joins: Inner, Outer, Left, Right, and Cross

Combine DataFrames using various join types.

In [None]:
# --- Create two DataFrames for join examples ---
employees_data = [
    (1, "Alice", 101), (2, "Bob", 102), (3, "Charlie", 101),
    (4, "Diana", 103), (5, "Eve", None),
]
departments_data = [
    (101, "Engineering", "Building A"),
    (102, "Marketing", "Building B"),
    (103, "HR", "Building C"),
    (104, "Finance", "Building D"),
]

df_emp = spark.createDataFrame(employees_data, ["emp_id", "name", "dept_id"])
df_dept = spark.createDataFrame(departments_data, ["dept_id", "dept_name", "location"])

print("=== Employees ===")
df_emp.show()
print("=== Departments ===")
df_dept.show()

In [None]:
# --- INNER JOIN: Only matching rows ---
print("=== INNER JOIN ===")
df_emp.join(df_dept, df_emp.dept_id == df_dept.dept_id, "inner") \
    .drop(df_dept.dept_id).show()

# --- LEFT JOIN: All from left + matching from right ---
print("=== LEFT JOIN ===")
df_emp.join(df_dept, df_emp.dept_id == df_dept.dept_id, "left") \
    .drop(df_dept.dept_id).show()

# --- RIGHT JOIN: All from right + matching from left ---
print("=== RIGHT JOIN ===")
df_emp.join(df_dept, df_emp.dept_id == df_dept.dept_id, "right") \
    .drop(df_emp.dept_id).show()

# --- FULL OUTER JOIN: All rows from both ---
print("=== FULL OUTER JOIN ===")
df_emp.join(df_dept, df_emp.dept_id == df_dept.dept_id, "full") \
    .show()

# --- LEFT SEMI JOIN: Rows from left that have a match (like EXISTS) ---
print("=== LEFT SEMI JOIN ===")
df_emp.join(df_dept, df_emp.dept_id == df_dept.dept_id, "left_semi").show()

# --- LEFT ANTI JOIN: Rows from left that do NOT match (like NOT EXISTS) ---
print("=== LEFT ANTI JOIN ===")
df_emp.join(df_dept, df_emp.dept_id == df_dept.dept_id, "left_anti").show()

# --- CROSS JOIN: Cartesian product ---
print(f"=== CROSS JOIN ({df_emp.count()} x {df_dept.count()} = {df_emp.count() * df_dept.count()} rows) ===")
df_emp.crossJoin(df_dept).show(10)

## 9. Window Functions: Rank, Row Number, and Running Totals

Window functions perform calculations across a set of rows related to the current row â€” without collapsing them like `groupBy`.

In [None]:
# --- Window Specifications ---
window_dept = Window.partitionBy("department").orderBy(F.col("salary").desc())
window_dept_rows = Window.partitionBy("department").orderBy("salary") \
    .rowsBetween(Window.unboundedPreceding, Window.currentRow)

# --- Ranking Functions ---
df_window = df_employees.select(
    "name", "department", "salary",
    F.row_number().over(window_dept).alias("row_num"),
    F.rank().over(window_dept).alias("rank"),
    F.dense_rank().over(window_dept).alias("dense_rank"),
    F.ntile(2).over(window_dept).alias("ntile_2"),
)
df_window.show()

# --- Lag and Lead: Access previous/next row ---
window_by_salary = Window.partitionBy("department").orderBy("salary")
df_lag_lead = df_employees.select(
    "name", "department", "salary",
    F.lag("salary", 1).over(window_by_salary).alias("prev_salary"),
    F.lead("salary", 1).over(window_by_salary).alias("next_salary"),
    (F.col("salary") - F.lag("salary", 1).over(window_by_salary)).alias("salary_diff"),
)
df_lag_lead.show()

# --- Running Totals and Moving Averages ---
df_running = df_employees.select(
    "name", "department", "salary",
    F.sum("salary").over(window_dept_rows).alias("running_total"),
    F.avg("salary").over(window_dept_rows).alias("running_avg"),
    F.count("*").over(Window.partitionBy("department")).alias("dept_total_count"),
)
df_running.show()

## 10. User Defined Functions (UDFs) & Pandas UDFs

Create custom functions when built-in functions aren't enough. **Pandas UDFs** (vectorized) are significantly faster than regular UDFs.

In [None]:
from pyspark.sql.functions import udf, pandas_udf

# === Method 1: Regular UDF with decorator ===
@udf(returnType=StringType())
def salary_band(salary):
    if salary is None:
        return "Unknown"
    elif salary < 75000:
        return "Junior"
    elif salary < 100000:
        return "Mid"
    else:
        return "Senior"

df_employees.select("name", "salary", salary_band("salary").alias("band")).show()

# === Method 2: Register UDF for use in SQL ===
spark.udf.register("salary_band_sql", lambda s: "Junior" if s and s < 75000 else ("Mid" if s and s < 100000 else "Senior"), StringType())

# === Method 3: Pandas UDF (Vectorized â€” MUCH faster) ===
@pandas_udf(DoubleType())
def tax_amount(salary: pd.Series) -> pd.Series:
    """Calculate tax: 20% for salary > 90K, else 15%"""
    return salary.apply(lambda s: s * 0.20 if s > 90000 else s * 0.15)

df_employees.select(
    "name", "salary",
    tax_amount(F.col("salary")).alias("tax"),
    (F.col("salary") - tax_amount(F.col("salary"))).alias("net_salary"),
).show()

# === Performance note ===
# Regular UDFs: row-by-row processing (slow, serialization overhead)
# Pandas UDFs: batch processing with Arrow (fast, vectorized)

## 11. Working with Complex Data Types: Arrays, Maps, and Structs

PySpark supports nested and complex types for semi-structured data.

In [None]:
# --- Arrays ---
df_skills = spark.createDataFrame([
    ("Alice", ["Python", "Spark", "SQL"]),
    ("Bob", ["Java", "Scala"]),
    ("Charlie", ["Python", "R", "Spark", "ML"]),
], ["name", "skills"])

df_skills.show(truncate=False)
df_skills.printSchema()

# explode: one row per array element
df_skills.select("name", F.explode("skills").alias("skill")).show()

# posexplode: with position index
df_skills.select("name", F.posexplode("skills").alias("pos", "skill")).show()

# Array functions
df_skills.select(
    "name",
    F.size("skills").alias("num_skills"),
    F.array_contains("skills", "Python").alias("knows_python"),
    F.array_sort("skills").alias("skills_sorted"),
    F.array_distinct("skills").alias("skills_unique"),
).show(truncate=False)

# collect_list and collect_set (reverse of explode)
df_exploded = df_skills.select("name", F.explode("skills").alias("skill"))
df_exploded.groupBy("name").agg(
    F.collect_list("skill").alias("skills_list"),
    F.collect_set("skill").alias("skills_set"),
).show(truncate=False)

In [None]:
# --- Maps (Key-Value pairs) ---
df_maps = spark.createDataFrame([
    ("Alice", {"Python": 5, "SQL": 4, "Spark": 3}),
    ("Bob", {"Java": 4, "Scala": 3}),
], ["name", "skill_ratings"])

df_maps.show(truncate=False)

df_maps.select(
    "name",
    F.map_keys("skill_ratings").alias("skills"),
    F.map_values("skill_ratings").alias("ratings"),
    F.col("skill_ratings")["Python"].alias("python_rating"),
).show(truncate=False)

# Explode map into key-value rows
df_maps.select("name", F.explode("skill_ratings").alias("skill", "rating")).show()

# --- Structs (Nested objects) ---
df_struct = spark.createDataFrame([
    ("Alice", ("123 Main St", "NYC", "NY")),
    ("Bob", ("456 Oak Ave", "LA", "CA")),
], ["name", "address"])

# Access struct fields
df_struct.select("name", "address.*").show()
df_struct.select("name", F.col("address._1").alias("street")).show()

# Create struct from columns
df_employees.select(
    "name",
    F.struct("department", "salary").alias("job_info")
).show(truncate=False)

## 12. Spark SQL: Temporary Views and SQL Queries

Mix DataFrame API with SQL seamlessly. Register DataFrames as views and query with standard SQL.

In [None]:
# --- Register DataFrame as a temporary view ---
df_employees.createOrReplaceTempView("employees")

# --- Simple SQL queries ---
spark.sql("SELECT * FROM employees WHERE salary > 90000 ORDER BY salary DESC").show()

# --- Aggregations in SQL ---
spark.sql("""
    SELECT department,
           COUNT(*) as emp_count,
           ROUND(AVG(salary), 2) as avg_salary,
           MAX(salary) as max_salary
    FROM employees
    GROUP BY department
    ORDER BY avg_salary DESC
""").show()

# --- CTE (Common Table Expression) ---
spark.sql("""
    WITH ranked_employees AS (
        SELECT name, department, salary,
               ROW_NUMBER() OVER (PARTITION BY department ORDER BY salary DESC) as rn
        FROM employees
    )
    SELECT * FROM ranked_employees WHERE rn = 1
""").show()

# --- Subquery ---
spark.sql("""
    SELECT name, salary, department
    FROM employees
    WHERE salary > (SELECT AVG(salary) FROM employees)
    ORDER BY salary DESC
""").show()

# --- Mix SQL results with DataFrame API ---
top_earners = spark.sql("SELECT * FROM employees WHERE salary > 90000")
top_earners.groupBy("department").count().show()

# --- Using the UDF we registered earlier in SQL ---
spark.sql("""
    SELECT name, salary, salary_band_sql(salary) as band
    FROM employees
""").show()

---
# PART 3 â€” ADVANCED
---
## 13. Reading & Writing Data in Delta Lake Format

Delta Lake is the default storage format on Databricks. It provides ACID transactions, time travel, schema enforcement, and more.

In [None]:
# --- Write DataFrame as Delta table ---
delta_path = "/tmp/pyspark_tutorial/employees_delta"

# Write in Delta format (overwrite if exists)
df_employees.write.format("delta").mode("overwrite").save(delta_path)
print(f"âœ… Written to {delta_path}")

# --- Read Delta table ---
df_delta = spark.read.format("delta").load(delta_path)
df_delta.show()

# --- Write as a managed table ---
df_employees.write.format("delta").mode("overwrite").saveAsTable("tutorial_employees")
spark.sql("SELECT * FROM tutorial_employees LIMIT 5").show()

# --- Read/Write other formats ---
# CSV
csv_path = "/tmp/pyspark_tutorial/employees_csv"
df_employees.write.mode("overwrite").option("header", True).csv(csv_path)
df_csv = spark.read.option("header", True).option("inferSchema", True).csv(csv_path)
df_csv.show(3)

# JSON
json_path = "/tmp/pyspark_tutorial/employees_json"
df_employees.write.mode("overwrite").json(json_path)
df_json = spark.read.json(json_path)
df_json.show(3)

# Parquet
parquet_path = "/tmp/pyspark_tutorial/employees_parquet"
df_employees.write.mode("overwrite").parquet(parquet_path)
df_parquet = spark.read.parquet(parquet_path)
df_parquet.show(3)

In [None]:
# --- Delta Lake Advanced Features ---
from delta.tables import DeltaTable

# --- MERGE / UPSERT: Update existing + insert new rows ---
# New data with updates and new employees
new_data = [
    ("Alice", "Engineering", 100000, 31, "2020-01-15"),    # updated salary & age
    ("Liam", "Sales", 88000, 27, "2025-01-10"),            # new employee
    ("Mia", "Engineering", 97000, 29, "2024-06-15"),       # new employee
]
df_updates = spark.createDataFrame(new_data, columns)

# Perform MERGE
delta_table = DeltaTable.forPath(spark, delta_path)

delta_table.alias("target").merge(
    df_updates.alias("source"),
    "target.name = source.name"
).whenMatchedUpdateAll() \
 .whenNotMatchedInsertAll() \
 .execute()

print("=== After MERGE ===")
spark.read.format("delta").load(delta_path).show()

# --- Time Travel: Query previous versions ---
print("=== Version 0 (original data) ===")
spark.read.format("delta").option("versionAsOf", 0).load(delta_path).show()

# View Delta table history
delta_table.history().select("version", "timestamp", "operation", "operationMetrics").show(truncate=False)

# --- Schema Evolution: Add new columns automatically ---
df_with_bonus = df_employees.withColumn("bonus", F.lit(5000))
df_with_bonus.write.format("delta") \
    .mode("append") \
    .option("mergeSchema", "true") \
    .save(delta_path)

print("=== After Schema Evolution ===")
spark.read.format("delta").load(delta_path).printSchema()

## 14. DataFrame Caching, Persistence & Performance Tuning

Caching stores DataFrames in memory/disk to avoid recomputation. Use `explain()` to understand execution plans.

In [None]:
from pyspark import StorageLevel

# --- cache(): Store DataFrame in memory ---
df_cached = df_employees.cache()
df_cached.count()  # Trigger caching (lazy evaluation)
print(f"Is cached: {df_cached.is_cached}")

# --- persist(): Choose storage level ---
df_persisted = df_employees.persist(StorageLevel.MEMORY_AND_DISK)
df_persisted.count()

# --- unpersist(): Free the cached data ---
df_cached.unpersist()
df_persisted.unpersist()
print("Cache cleared")

# --- explain(): View execution plan ---
print("=== Simple Plan ===")
df_employees.filter(F.col("salary") > 90000).select("name", "salary").explain()

print("\n=== Extended Plan (Parsed â†’ Analyzed â†’ Optimized â†’ Physical) ===")
df_employees.filter(F.col("salary") > 90000) \
    .groupBy("department") \
    .agg(F.avg("salary").alias("avg_salary")) \
    .explain(mode="extended")

# --- Formatted plan (most readable in Databricks) ---
print("\n=== Formatted Plan ===")
df_employees.join(
    df_employees.groupBy("department").agg(F.avg("salary").alias("dept_avg")),
    "department"
).filter(F.col("salary") > F.col("dept_avg")).explain(mode="formatted")

## 15. Advanced Transformations with Higher-Order Functions

Higher-order functions apply transformations directly on array columns without needing `explode()`.

In [None]:
# Sample data with array columns
df_scores = spark.createDataFrame([
    ("Alice", [85, 92, 78, 95]),
    ("Bob", [70, 65, 80, 72]),
    ("Charlie", [90, 88, 95, 100]),
], ["name", "scores"])

# --- transform(): Apply function to each array element ---
df_scores.select(
    "name",
    "scores",
    F.transform("scores", lambda x: x + 5).alias("scores_curved"),  # Add 5 to each score
    F.transform("scores", lambda x: F.round(x / 100.0 * 4.0, 2)).alias("gpa_scale"),
).show(truncate=False)

# --- filter(): Keep only elements matching condition ---
df_scores.select(
    "name",
    "scores",
    F.filter("scores", lambda x: x >= 80).alias("passing_scores"),
).show(truncate=False)

# --- aggregate(): Reduce array to single value ---
df_scores.select(
    "name",
    "scores",
    F.aggregate("scores", F.lit(0), lambda acc, x: acc + x).alias("total_score"),
    F.aggregate(
        "scores", F.lit(0),
        lambda acc, x: acc + x,
        lambda acc: F.round(acc / F.lit(4), 2)  # finalize: compute average
    ).alias("avg_score"),
).show(truncate=False)

# --- exists(): Check if any element matches ---
df_scores.select(
    "name",
    F.exists("scores", lambda x: x == 100).alias("has_perfect"),
    F.exists("scores", lambda x: x < 70).alias("has_failing"),
).show()

# --- forall(): Check if ALL elements match ---
df_scores.select(
    "name",
    F.forall("scores", lambda x: x >= 70).alias("all_passing"),
).show()

# --- Chaining DataFrame transformations with .transform() ---
def add_salary_band(df):
    return df.withColumn("band",
        F.when(F.col("salary") < 75000, "Junior")
         .when(F.col("salary") < 100000, "Mid")
         .otherwise("Senior"))

def add_tax(df):
    return df.withColumn("tax", F.col("salary") * 0.2)

# Pipeline-style processing
df_employees.transform(add_salary_band).transform(add_tax).show()

## 16. Broadcast Variables and Accumulators

**Broadcast variables** efficiently share read-only data to all worker nodes. **Accumulators** are write-only variables for distributed counters.

In [None]:
# === Broadcast Variables ===
# Useful for small lookup tables that every node needs
dept_budgets = {"Engineering": 500000, "Marketing": 200000, "HR": 150000, "Sales": 300000}
broadcast_budgets = spark.sparkContext.broadcast(dept_budgets)

# Use broadcast variable in a UDF
@udf(IntegerType())
def get_budget(dept):
    return broadcast_budgets.value.get(dept, 0)

df_employees.select("name", "department", get_budget("department").alias("dept_budget")).show()

# === Broadcast Join (hint for small tables) ===
# Forces Spark to broadcast the smaller table to avoid shuffle
df_small_dept = spark.createDataFrame([
    ("Engineering", "Tech"), ("Marketing", "Business"),
    ("HR", "Support"), ("Sales", "Business")
], ["department", "category"])

df_broadcast_join = df_employees.join(
    F.broadcast(df_small_dept),  # Hint: broadcast this small table
    "department"
)
print("=== Broadcast Join ===")
df_broadcast_join.show()
# Verify broadcast join in the plan
df_broadcast_join.explain()

# === Accumulators ===
high_salary_count = spark.sparkContext.accumulator(0)
total_salary_acc = spark.sparkContext.accumulator(0)

def process_row(row):
    total_salary_acc.add(row.salary)
    if row.salary > 90000:
        high_salary_count.add(1)

df_employees.foreach(process_row)
print(f"High salary employees (>90K): {high_salary_count.value}")
print(f"Total salary sum: {total_salary_acc.value}")

## 17. Optimizing Spark Jobs: Partitioning, Bucketing & AQE

Control data distribution for better performance. **Adaptive Query Execution (AQE)** is enabled by default in Databricks.

In [None]:
# === Partitioning ===
print(f"Default partitions: {df_employees.rdd.getNumPartitions()}")

# repartition(): Increase partitions (full shuffle)
df_repart = df_employees.repartition(4)
print(f"After repartition(4): {df_repart.rdd.getNumPartitions()}")

# repartition by column (great for joins/groupBy on that column)
df_repart_col = df_employees.repartition("department")
print(f"After repartition('department'): {df_repart_col.rdd.getNumPartitions()}")

# coalesce(): Reduce partitions (no full shuffle â€” more efficient than repartition for reducing)
df_coalesced = df_repart.coalesce(2)
print(f"After coalesce(2): {df_coalesced.rdd.getNumPartitions()}")

# === Write with partitioning (partition pruning for queries) ===
partitioned_path = "/tmp/pyspark_tutorial/employees_partitioned"
df_employees.write.format("delta") \
    .mode("overwrite") \
    .partitionBy("department") \
    .save(partitioned_path)

# Reading with partition pruning â€” only reads relevant partitions
spark.read.format("delta").load(partitioned_path) \
    .filter(F.col("department") == "Engineering") \
    .explain()

# === Adaptive Query Execution (AQE) â€” enabled by default in Databricks ===
print("\n=== AQE Configuration ===")
print(f"AQE enabled: {spark.conf.get('spark.sql.adaptive.enabled', 'not set')}")

# Key AQE features (auto-managed in Databricks Serverless):
# 1. Dynamically coalesces shuffle partitions
# 2. Converts sort-merge joins to broadcast joins when data is small
# 3. Optimizes skewed joins automatically

# === Salting for skew joins (manual technique) ===
# When one join key has disproportionately many rows
import random

# Add salt to distribute skewed key across partitions
df_salted = df_employees.withColumn("salt", (F.rand() * 4).cast("int"))
df_salted.select("name", "department", "salt").show()

## 18. Structured Streaming Basics

Structured Streaming processes data incrementally. Uses the same DataFrame API with `readStream`/`writeStream`.

> **Note:** On Databricks free tier, streaming capabilities may be limited. The examples below use the built-in `rate` source for testing.

In [None]:
# === Rate Source: Generates incrementing data for testing ===
# Creates a stream with 'timestamp' and 'value' columns
df_stream = (
    spark.readStream
    .format("rate")               # Built-in test source
    .option("rowsPerSecond", 5)   # 5 rows per second
    .load()
)

# Apply transformations on the stream (same API as batch)
df_stream_transformed = (
    df_stream
    .withColumn("value_doubled", F.col("value") * 2)
    .withColumn("is_even", F.col("value") % 2 == 0)
    .withColumn("processed_at", F.current_timestamp())
)

# === Write stream to memory (for testing in notebooks) ===
query = (
    df_stream_transformed
    .writeStream
    .format("memory")             # Output to in-memory table
    .queryName("rate_stream")     # Table name to query
    .outputMode("append")         # append | complete | update
    .trigger(processingTime="5 seconds")  # Process every 5 seconds
    .start()
)

print(f"Stream is active: {query.isActive}")
print(f"Stream status: {query.status}")

In [None]:
# Wait a few seconds then query the in-memory stream table
import time
time.sleep(10)

# Query the streaming data like a regular table
spark.sql("SELECT * FROM rate_stream ORDER BY timestamp DESC LIMIT 10").show()

# === Aggregation on streaming data (requires 'complete' output mode) ===
query_agg = (
    df_stream_transformed
    .groupBy("is_even")
    .agg(
        F.count("*").alias("count"),
        F.avg("value").alias("avg_value"),
    )
    .writeStream
    .format("memory")
    .queryName("rate_stream_agg")
    .outputMode("complete")       # Required for aggregations
    .trigger(processingTime="5 seconds")
    .start()
)

time.sleep(10)
spark.sql("SELECT * FROM rate_stream_agg").show()

# === Monitor active streams ===
print(f"Active streams: {len(spark.streams.active)}")
for stream in spark.streams.active:
    print(f"  - {stream.name}: {stream.status}")

# === Stop all streams (cleanup) ===
for stream in spark.streams.active:
    stream.stop()
print("All streams stopped.")

## ðŸ§¹ Cleanup

Remove temporary data created during this tutorial.

In [None]:
# --- Cleanup temporary files and tables ---
dbutils.fs.rm("/tmp/pyspark_tutorial", recurse=True)
spark.sql("DROP TABLE IF EXISTS tutorial_employees")
print("âœ… Cleanup complete!")

---
## ðŸ“š Quick Reference Cheat Sheet

| **Category** | **Operation** | **Syntax** |
|---|---|---|
| **Create** | From list | `spark.createDataFrame(data, columns)` |
| **Create** | With schema | `spark.createDataFrame(data, schema)` |
| **Select** | Columns | `df.select("col1", "col2")` |
| **Filter** | Rows | `df.filter(F.col("x") > 5)` |
| **Sort** | Order | `df.orderBy(F.col("x").desc())` |
| **Aggregate** | GroupBy | `df.groupBy("col").agg(F.sum("x"))` |
| **Join** | Inner | `df1.join(df2, "key", "inner")` |
| **Window** | Rank | `F.row_number().over(window_spec)` |
| **Null** | Fill | `df.na.fill({"col": "default"})` |
| **Delta** | Write | `df.write.format("delta").save(path)` |
| **Delta** | Merge | `DeltaTable.forPath(spark, path).merge(...)` |
| **SQL** | Query | `spark.sql("SELECT * FROM view")` |
| **Cache** | Memory | `df.cache()` |
| **Stream** | Read | `spark.readStream.format("rate").load()` |
| **Debug** | Plan | `df.explain(mode="formatted")` |

---
**Happy Learning! ðŸŽ‰**  
*This notebook is compatible with Databricks Community Edition (Free) and Databricks Serverless Compute.*