# 5.3 Minimizing Data Shuffling and Addressing Data Skew

This notebook explores strategies for minimizing expensive shuffle operations and handling data skew in PySpark applications while maintaining functional programming principles.

## Learning Objectives

By the end of this notebook, you will understand how to:
- Identify and understand shuffle operations in Spark
- Minimize shuffles through strategic partitioning and join strategies
- Detect and diagnose data skew in distributed processing
- Apply skew remediation techniques (filtering, AQE, salting)
- Use partitioning strategies (`repartition` vs `coalesce`) effectively
- Optimize functional pipelines for distributed performance
- Monitor and tune shuffle behavior with Spark UI

## Prerequisites

- Understanding of PySpark DataFrames and transformations
- Knowledge of functional programming concepts
- Familiarity with Spark's execution model
- Basic understanding of distributed computing

In [None]:
# Essential imports
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.types import *
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from typing import Dict, List, Tuple, Optional
import random
import time
from functools import reduce

# Initialize Spark session
try:
    spark
except NameError:
    spark = SparkSession.builder \
        .appName("ShuffleOptimization") \
        .config("spark.sql.adaptive.enabled", "true") \
        .config("spark.sql.adaptive.skewJoin.enabled", "true") \
        .getOrCreate()

print("✅ Setup complete - Ready for shuffle optimization!")
print(f"Spark Version: {spark.version}")
print(f"AQE Enabled: {spark.conf.get('spark.sql.adaptive.enabled')}")
print(f"Skew Join Optimization: {spark.conf.get('spark.sql.adaptive.skewJoin.enabled')}")

## 1. Understanding Shuffle Operations

**Shuffle** is the process of redistributing data across partitions, which may or may not involve moving data across executors. It's one of the most expensive operations in Spark.

### When Shuffles Occur

Shuffles happen during "wide" transformations:
- **Joins**: `join()`, `crossJoin()`
- **Aggregations**: `groupBy()`, `agg()`, `reduceByKey()`
- **Window Functions**: Operations with `partitionBy`
- **Repartitioning**: `repartition()`, `coalesce()` (sometimes)
- **Sorting**: `orderBy()`, `sortBy()`
- **Distinct**: `distinct()`, `dropDuplicates()`

### Why Shuffles Are Expensive

1. **Network I/O**: Data must be transferred between executors
2. **Disk I/O**: Intermediate data written to disk during shuffle
3. **Serialization/Deserialization**: Data must be serialized for network transfer
4. **Memory Pressure**: Can cause spills to disk if memory insufficient
5. **Synchronization**: All partitions must complete before next stage

In [None]:
# Create sample datasets to demonstrate shuffle operations

def create_orders_data(num_orders: int = 10000) -> DataFrame:
    """
    Pure function to create orders dataset.
    Includes some skewed data for demonstration.
    """
    
    # Create skewed distribution - 70% of orders from top 3 customers
    skewed_customer_ids = [1, 2, 3]  # Heavy customers
    normal_customer_ids = list(range(4, 101))  # Normal customers
    
    data = []
    for i in range(num_orders):
        # 70% of orders go to top 3 customers (skewed)
        if random.random() < 0.7:
            customer_id = random.choice(skewed_customer_ids)
        else:
            customer_id = random.choice(normal_customer_ids)
        
        data.append((
            i + 1,  # order_id
            customer_id,
            f"Product_{random.randint(1, 50)}",
            random.randint(1, 10),  # quantity
            round(random.uniform(10, 500), 2),  # price
            f"2024-{random.randint(1, 12):02d}-{random.randint(1, 28):02d}"
        ))
    
    schema = StructType([
        StructField("order_id", IntegerType(), False),
        StructField("customer_id", IntegerType(), False),
        StructField("product", StringType(), False),
        StructField("quantity", IntegerType(), False),
        StructField("price", DoubleType(), False),
        StructField("order_date", StringType(), False)
    ])
    
    return spark.createDataFrame(data, schema)

def create_customers_data(num_customers: int = 100) -> DataFrame:
    """
    Pure function to create customers dataset.
    """
    data = [
        (
            i,
            f"Customer_{i}",
            f"customer{i}@example.com",
            random.choice(["Bronze", "Silver", "Gold", "Platinum"]),
            random.choice(["US", "CA", "UK", "DE", "FR", "JP"])
        )
        for i in range(1, num_customers + 1)
    ]
    
    schema = StructType([
        StructField("customer_id", IntegerType(), False),
        StructField("name", StringType(), False),
        StructField("email", StringType(), False),
        StructField("tier", StringType(), False),
        StructField("country", StringType(), False)
    ])
    
    return spark.createDataFrame(data, schema)

# Create datasets
orders_df = create_orders_data(10000)
customers_df = create_customers_data(100)

print(f"Orders: {orders_df.count():,} records, {orders_df.rdd.getNumPartitions()} partitions")
print(f"Customers: {customers_df.count():,} records, {customers_df.rdd.getNumPartitions()} partitions")

print("\nSample orders:")
orders_df.show(5)

print("Sample customers:")
customers_df.show(5)

## 2. Visualizing Shuffle Behavior

Let's examine how different operations cause shuffles and their impact on performance.

In [None]:
print("="*80)
print("OPERATIONS THAT CAUSE SHUFFLES")
print("="*80)

# Operation 1: GroupBy Aggregation (Wide Transformation - Shuffles)
print("\n1. GroupBy Aggregation (SHUFFLES):")
start = time.time()
customer_totals = (
    orders_df
    .groupBy("customer_id")
    .agg(
        F.sum(F.col("quantity") * F.col("price")).alias("total_spent"),
        F.count("*").alias("order_count")
    )
)
result_count = customer_totals.count()
elapsed = time.time() - start

print(f"   Result: {result_count} customers")
print(f"   Time: {elapsed:.2f}s")
print(f"   Partitions: {customer_totals.rdd.getNumPartitions()}")
print(f"   ⚠️  Shuffle required: Data redistributed by customer_id")

# Operation 2: Join (Wide Transformation - Shuffles)
print("\n2. Join Operation (SHUFFLES):")
start = time.time()
enriched_orders = orders_df.join(customers_df, "customer_id")
result_count = enriched_orders.count()
elapsed = time.time() - start

print(f"   Result: {result_count:,} enriched orders")
print(f"   Time: {elapsed:.2f}s")
print(f"   Partitions: {enriched_orders.rdd.getNumPartitions()}")
print(f"   ⚠️  Shuffle required: Both sides may shuffle unless broadcasted")

# Operation 3: OrderBy (Wide Transformation - Shuffles)
print("\n3. OrderBy (SHUFFLES):")
start = time.time()
sorted_orders = orders_df.orderBy(F.desc("price"))
result_count = sorted_orders.count()
elapsed = time.time() - start

print(f"   Result: {result_count:,} sorted orders")
print(f"   Time: {elapsed:.2f}s")
print(f"   Partitions: {sorted_orders.rdd.getNumPartitions()}")
print(f"   ⚠️  Shuffle required: Data must be sorted globally")

# Operation 4: Filter (Narrow Transformation - No Shuffle)
print("\n4. Filter Operation (NO SHUFFLE):")
start = time.time()
high_value_orders = orders_df.filter(F.col("price") > 400)
result_count = high_value_orders.count()
elapsed = time.time() - start

print(f"   Result: {result_count:,} high-value orders")
print(f"   Time: {elapsed:.2f}s")
print(f"   Partitions: {high_value_orders.rdd.getNumPartitions()}")
print(f"   ✅ No shuffle: Narrow transformation, processed in-place")

# Operation 5: Select (Narrow Transformation - No Shuffle)
print("\n5. Select/WithColumn (NO SHUFFLE):")
start = time.time()
transformed = orders_df.select(
    "order_id",
    "customer_id",
    (F.col("quantity") * F.col("price")).alias("total")
)
result_count = transformed.count()
elapsed = time.time() - start

print(f"   Result: {result_count:,} transformed orders")
print(f"   Time: {elapsed:.2f}s")
print(f"   Partitions: {transformed.rdd.getNumPartitions()}")
print(f"   ✅ No shuffle: Narrow transformation, processed in-place")

print("\n" + "="*80)
print("💡 Key Insight: Wide transformations (groupBy, join, orderBy) cause shuffles")
print("   Narrow transformations (filter, select) process data in-place")
print("="*80)

## 3. Partitioning Strategies: `repartition` vs `coalesce`

Understanding when to use `repartition()` vs `coalesce()` is critical for performance.

In [None]:
print("="*80)
print("REPARTITION vs COALESCE")
print("="*80)

# Create a dataset with many partitions
many_partitions_df = orders_df.repartition(50)
print(f"\nStarting partitions: {many_partitions_df.rdd.getNumPartitions()}")

# Strategy 1: Coalesce (Efficient for reducing partitions)
print("\n1. Using coalesce() to reduce partitions:")
start = time.time()
coalesced_df = many_partitions_df.coalesce(5)
coalesced_df.write.mode("overwrite").format("noop").save()  # Trigger action
coalesce_time = time.time() - start

print(f"   Final partitions: {coalesced_df.rdd.getNumPartitions()}")
print(f"   Time: {coalesce_time:.2f}s")
print(f"   ✅ No full shuffle: Combines existing partitions efficiently")
print(f"   Use case: Reducing partitions before writing to disk")

# Strategy 2: Repartition (Full shuffle for even distribution)
print("\n2. Using repartition() to reduce partitions:")
start = time.time()
repartitioned_df = many_partitions_df.repartition(5)
repartitioned_df.write.mode("overwrite").format("noop").save()  # Trigger action
repartition_time = time.time() - start

print(f"   Final partitions: {repartitioned_df.rdd.getNumPartitions()}")
print(f"   Time: {repartition_time:.2f}s")
print(f"   ⚠️  Full shuffle: Evenly redistributes all data")
print(f"   Use case: Balancing skewed data or increasing parallelism")

# Performance comparison
print(f"\n📊 Performance Comparison:")
print(f"   coalesce() time: {coalesce_time:.2f}s")
print(f"   repartition() time: {repartition_time:.2f}s")
print(f"   Speedup: {repartition_time/coalesce_time:.1f}x faster with coalesce()")

# Partition-based repartitioning (for skew handling)
print("\n3. Repartitioning by column (for co-location):")
partitioned_by_customer = orders_df.repartition("customer_id")
print(f"   Partitions: {partitioned_by_customer.rdd.getNumPartitions()}")
print(f"   ✅ Co-locates data by customer_id for efficient joins")
print(f"   Use case: Pre-partitioning before joins or groupBy")

print("\n" + "="*80)
print("💡 Decision Guide:")
print("   • coalesce(): Reducing partitions, no data balancing needed")
print("   • repartition(n): Need even distribution or increasing partitions")
print("   • repartition(col): Co-locate data for joins/aggregations")
print("="*80)

## 4. Detecting Data Skew

Data skew occurs when some partitions have significantly more data than others, leading to uneven workload distribution and slow tasks.

In [None]:
print("="*80)
print("DATA SKEW DETECTION")
print("="*80)

# Analyze data distribution by customer
print("\n1. Analyzing order distribution by customer:")
customer_order_counts = (
    orders_df
    .groupBy("customer_id")
    .agg(F.count("*").alias("order_count"))
    .orderBy(F.desc("order_count"))
)

# Get statistics
stats = customer_order_counts.select(
    F.min("order_count").alias("min"),
    F.max("order_count").alias("max"),
    F.avg("order_count").alias("avg"),
    F.stddev("order_count").alias("stddev")
).collect()[0]

print(f"   Min orders per customer: {stats['min']}")
print(f"   Max orders per customer: {stats['max']}")
print(f"   Avg orders per customer: {stats['avg']:.1f}")
print(f"   Stddev: {stats['stddev']:.1f}")
print(f"   Skew ratio (max/avg): {stats['max']/stats['avg']:.1f}x")

if stats['max'] / stats['avg'] > 3:
    print(f"   ⚠️  SIGNIFICANT SKEW DETECTED!")
else:
    print(f"   ✅ Relatively balanced distribution")

# Show top skewed customers
print("\n   Top 10 customers by order count:")
customer_order_counts.show(10)

# Partition size analysis
print("\n2. Analyzing partition sizes after groupBy:")

def analyze_partition_sizes(df: DataFrame, operation_name: str) -> None:
    """
    Pure function to analyze partition size distribution.
    """
    # Get partition sizes
    partition_sizes = df.rdd.mapPartitions(
        lambda iterator: [sum(1 for _ in iterator)]
    ).collect()
    
    if partition_sizes:
        min_size = min(partition_sizes)
        max_size = max(partition_sizes)
        avg_size = sum(partition_sizes) / len(partition_sizes)
        
        print(f"\n   {operation_name}:")
        print(f"   Total partitions: {len(partition_sizes)}")
        print(f"   Min partition size: {min_size} records")
        print(f"   Max partition size: {max_size} records")
        print(f"   Avg partition size: {avg_size:.1f} records")
        print(f"   Skew ratio (max/avg): {max_size/avg_size if avg_size > 0 else 0:.1f}x")
        
        if max_size / avg_size > 2:
            print(f"   ⚠️  Partition skew detected!")
        else:
            print(f"   ✅ Balanced partitions")

# Analyze original data
analyze_partition_sizes(orders_df, "Original orders data")

# Analyze after groupBy
grouped_df = orders_df.groupBy("customer_id").agg(F.count("*").alias("count"))
analyze_partition_sizes(grouped_df, "After groupBy customer_id")

print("\n" + "="*80)
print("💡 Skew Detection Techniques:")
print("   1. Analyze key distribution with groupBy + count")
print("   2. Monitor Spark UI for long-running tasks")
print("   3. Check partition size distribution")
print("   4. Look for skew ratio > 2-3x average")
print("="*80)

## 5. Skew Remediation Strategies

Multiple strategies exist for handling data skew, each with different trade-offs.

In [None]:
print("="*80)
print("SKEW REMEDIATION STRATEGIES")
print("="*80)

# Strategy 1: Filter out skewed values
print("\n1. Strategy: Filter Out Skewed Values")
print("   Use case: Null or special values causing skew\n")

# Identify top skewed customers
top_customers = (
    orders_df
    .groupBy("customer_id")
    .agg(F.count("*").alias("order_count"))
    .filter(F.col("order_count") > 1000)  # Threshold for "heavy" customers
    .select("customer_id")
).collect()

skewed_customer_ids = [row.customer_id for row in top_customers]
print(f"   Identified {len(skewed_customer_ids)} skewed customers: {skewed_customer_ids}")

# Process normal and skewed separately
normal_orders = orders_df.filter(~F.col("customer_id").isin(skewed_customer_ids))
skewed_orders = orders_df.filter(F.col("customer_id").isin(skewed_customer_ids))

print(f"   Normal orders: {normal_orders.count():,}")
print(f"   Skewed orders: {skewed_orders.count():,}")
print(f"   ✅ Process separately with different strategies")

# Strategy 2: Salting (Advanced technique)
print("\n2. Strategy: Salting Keys")
print("   Use case: Cannot filter skewed values, need even distribution\n")

def salt_dataframe(df: DataFrame, key_col: str, salt_range: int = 10) -> DataFrame:
    """
    Pure function to add salt to skewed keys.
    Distributes skewed keys across multiple partitions.
    """
    return df.withColumn(
        "salted_key",
        F.concat(
            F.col(key_col).cast("string"),
            F.lit("_"),
            (F.rand() * salt_range).cast("int").cast("string")
        )
    )

# Apply salting to skewed data
salted_df = salt_dataframe(skewed_orders, "customer_id", salt_range=10)
print(f"   Original customer_id: 1")
print(f"   Salted keys: 1_0, 1_1, 1_2, ..., 1_9")
print(f"   ✅ Distributes one hot key across 10 partitions")

# For joins, need to explode the other side
def explode_for_salt_join(df: DataFrame, key_col: str, salt_range: int = 10) -> DataFrame:
    """
    Pure function to explode DataFrame for salted join.
    Creates copies with all possible salt values.
    """
    salt_values = list(range(salt_range))
    return df.withColumn(
        "salt",
        F.explode(F.array(*[F.lit(i) for i in salt_values]))
    ).withColumn(
        "salted_key",
        F.concat(
            F.col(key_col).cast("string"),
            F.lit("_"),
            F.col("salt").cast("string")
        )
    ).drop("salt")

print(f"\n   For joins with salted data:")
print(f"   • Salt the large skewed table")
print(f"   • Explode the smaller table with all salt values")
print(f"   • Join on salted_key")
print(f"   ⚠️  Trade-off: Increases data size but distributes load")

# Strategy 3: Adaptive Query Execution (AQE)
print("\n3. Strategy: Adaptive Query Execution (AQE)")
print("   Use case: Automatic skew handling (Spark 3.0+)\n")

print(f"   AQE Status: {spark.conf.get('spark.sql.adaptive.enabled')}")
print(f"   Skew Join: {spark.conf.get('spark.sql.adaptive.skewJoin.enabled')}")
print(f"   ✅ Automatically splits skewed partitions during joins")
print(f"   No code changes required - handled by Spark")

# Strategy 4: Increase Parallelism
print("\n4. Strategy: Increase Parallelism")
print("   Use case: Skew is moderate, just need more tasks\n")

current_partitions = int(spark.conf.get("spark.sql.shuffle.partitions"))
print(f"   Current shuffle partitions: {current_partitions}")
print(f"   Recommendation: Increase to 2x-4x for skewed workloads")
print(f"   Example: spark.conf.set('spark.sql.shuffle.partitions', '400')")
print(f"   ✅ More partitions = smaller chunks, less impact from skew")

print("\n" + "="*80)
print("💡 Strategy Selection Guide:")
print("   1. AQE (Default): Enable and let Spark handle it automatically")
print("   2. Filter: Remove skewed values if they can be processed separately")
print("   3. Increase Partitions: Simple fix for moderate skew")
print("   4. Salting: Last resort for severe skew that can't be filtered")
print("="*80)

## 6. Join Optimization Strategies

Different join strategies have very different shuffle characteristics.

In [None]:
print("="*80)
print("JOIN OPTIMIZATION STRATEGIES")
print("="*80)

# Strategy 1: Sort-Merge Join (Default for large tables)
print("\n1. Sort-Merge Join (Default):")
print("   Both sides shuffled and sorted by join key\n")

start = time.time()
sort_merge_join = orders_df.join(customers_df, "customer_id")
result_count = sort_merge_join.count()
sort_merge_time = time.time() - start

print(f"   Result: {result_count:,} records")
print(f"   Time: {sort_merge_time:.2f}s")
print(f"   Shuffle: Both sides")
print(f"   Use case: Both tables large, can't broadcast")

# Strategy 2: Broadcast Join (No shuffle for large table)
print("\n2. Broadcast Hash Join:")
print("   Small table broadcasted, large table not shuffled\n")

start = time.time()
broadcast_join = orders_df.join(F.broadcast(customers_df), "customer_id")
result_count = broadcast_join.count()
broadcast_time = time.time() - start

print(f"   Result: {result_count:,} records")
print(f"   Time: {broadcast_time:.2f}s")
print(f"   Speedup: {sort_merge_time/broadcast_time:.1f}x faster")
print(f"   Shuffle: Only small table (broadcast)")
print(f"   Use case: One table < 10MB (configurable)")

# Strategy 3: Bucketed Join (Pre-partitioned tables)
print("\n3. Bucketed Join (Delta Lake):")
print("   Pre-partitioned tables, no shuffle needed\n")

print("   Setup:")
print("   • Write orders bucketed by customer_id")
print("   • Write customers bucketed by customer_id")
print("   • Join on customer_id = no shuffle!")
print("   ")
print("   Example:")
print("   orders_df.write.format('delta')")
print("       .bucketBy(20, 'customer_id')")
print("       .save('/delta/orders')")
print("   ")
print("   ✅ Zero shuffle for joins on bucket column")
print("   Use case: Repeated joins on same key")

# Strategy 4: Partitioned Join (Co-located data)
print("\n4. Partitioned Join:")
print("   Repartition both sides by join key\n")

# Pre-partition both DataFrames
orders_partitioned = orders_df.repartition("customer_id")
customers_partitioned = customers_df.repartition("customer_id")

start = time.time()
partitioned_join = orders_partitioned.join(customers_partitioned, "customer_id")
result_count = partitioned_join.count()
partitioned_time = time.time() - start

print(f"   Result: {result_count:,} records")
print(f"   Time: {partitioned_time:.2f}s")
print(f"   Shuffle: Both sides (once, during repartition)")
print(f"   Use case: Multiple operations on same key")

print("\n" + "="*80)
print("💡 Join Strategy Decision Tree:")
print("   1. Small table (<10MB)? → Broadcast Join")
print("   2. Bucketed tables? → Bucketed Join")
print("   3. Multiple operations on same key? → Repartition first")
print("   4. Otherwise → Let AQE choose (usually sort-merge)")
print("="*80)

## 7. Functional Patterns for Shuffle Optimization

Let's create reusable functional utilities for shuffle optimization.

In [None]:
from dataclasses import dataclass
from typing import Callable

@dataclass
class ShuffleMetrics:
    """Immutable metrics for shuffle analysis"""
    operation: str
    input_partitions: int
    output_partitions: int
    execution_time: float
    record_count: int
    shuffle_occurred: bool
    
    def __str__(self) -> str:
        shuffle_status = "✅ No Shuffle" if not self.shuffle_occurred else "⚠️  Shuffle"
        return (
            f"{self.operation}:\n"
            f"  Records: {self.record_count:,}\n"
            f"  Partitions: {self.input_partitions} → {self.output_partitions}\n"
            f"  Time: {self.execution_time:.2f}s\n"
            f"  {shuffle_status}"
        )

def measure_operation(
    df: DataFrame,
    operation: Callable[[DataFrame], DataFrame],
    operation_name: str
) -> Tuple[DataFrame, ShuffleMetrics]:
    """
    Pure function to measure shuffle behavior of an operation.
    """
    input_partitions = df.rdd.getNumPartitions()
    
    start = time.time()
    result_df = operation(df)
    record_count = result_df.count()  # Trigger execution
    execution_time = time.time() - start
    
    output_partitions = result_df.rdd.getNumPartitions()
    
    # Heuristic: shuffle likely occurred if partition count changed
    shuffle_occurred = input_partitions != output_partitions
    
    metrics = ShuffleMetrics(
        operation=operation_name,
        input_partitions=input_partitions,
        output_partitions=output_partitions,
        execution_time=execution_time,
        record_count=record_count,
        shuffle_occurred=shuffle_occurred
    )
    
    return result_df, metrics

def optimize_partitioning(
    df: DataFrame,
    target_partition_size_mb: int = 128,
    avg_row_size_bytes: int = 500
) -> DataFrame:
    """
    Pure function to calculate optimal partition count.
    Returns optimally repartitioned DataFrame.
    """
    row_count = df.count()
    estimated_size_mb = (row_count * avg_row_size_bytes) / (1024 * 1024)
    
    target_partitions = max(1, int(estimated_size_mb / target_partition_size_mb))
    current_partitions = df.rdd.getNumPartitions()
    
    print(f"Partition Optimization Analysis:")
    print(f"  Estimated size: {estimated_size_mb:.1f} MB")
    print(f"  Current partitions: {current_partitions}")
    print(f"  Target partitions: {target_partitions}")
    print(f"  Target size per partition: {target_partition_size_mb} MB")
    
    if target_partitions < current_partitions:
        print(f"  → Using coalesce() to reduce partitions")
        return df.coalesce(target_partitions)
    elif target_partitions > current_partitions:
        print(f"  → Using repartition() to increase partitions")
        return df.repartition(target_partitions)
    else:
        print(f"  → Current partitioning is optimal")
        return df

# Test the utilities
print("="*80)
print("FUNCTIONAL SHUFFLE OPTIMIZATION UTILITIES")
print("="*80)

# Test 1: Measure groupBy operation
print("\n1. Measuring groupBy operation:")
result_df, metrics = measure_operation(
    orders_df,
    lambda df: df.groupBy("customer_id").agg(F.sum("quantity").alias("total_qty")),
    "GroupBy Aggregation"
)
print(metrics)

# Test 2: Optimize partitioning
print("\n2. Optimizing partition count:")
optimized_df = optimize_partitioning(orders_df, target_partition_size_mb=128)

# Test 3: Compare operations
print("\n3. Comparing narrow vs wide transformations:")

filter_df, filter_metrics = measure_operation(
    orders_df,
    lambda df: df.filter(F.col("price") > 100),
    "Filter (Narrow)"
)
print(filter_metrics)

print()

sort_df, sort_metrics = measure_operation(
    orders_df,
    lambda df: df.orderBy("price"),
    "OrderBy (Wide)"
)
print(sort_metrics)

print("\n" + "="*80)

## 8. Best Practices and Anti-Patterns

In [None]:
print("="*80)
print("BEST PRACTICES FOR SHUFFLE OPTIMIZATION")
print("="*80)

print("""
✅ BEST PRACTICES:

1. Minimize Wide Transformations
   • Filter data early to reduce shuffle volume
   • Use broadcast joins for small tables (<10MB)
   • Combine multiple aggregations in single groupBy

2. Optimize Partitioning
   • Target 128-200MB per partition
   • Use coalesce() when reducing partitions
   • Use repartition() for even distribution or increasing partitions
   • Repartition by join key before multiple operations

3. Handle Data Skew
   • Enable AQE (default in Spark 3.0+)
   • Monitor Spark UI for long-running tasks
   • Filter skewed values and process separately
   • Use salting as last resort for severe skew

4. Join Strategy Selection
   • Broadcast join: One table <10MB
   • Bucketed join: Repeated joins on same key
   • Sort-merge join: Both tables large
   • Let AQE optimize automatically when possible

5. Configuration Tuning
   • spark.sql.shuffle.partitions: Adjust based on data size
   • spark.sql.autoBroadcastJoinThreshold: Increase if safe
   • spark.sql.adaptive.enabled: Always enable (Spark 3.0+)
   • Let AQE auto-tune when possible

""")

print("="*80)
print("ANTI-PATTERNS TO AVOID")
print("="*80)

print("""
❌ ANTI-PATTERN 1: Unnecessary Shuffles

Bad:
df.repartition(100).filter(condition)  # Shuffle before filter

Good:
df.filter(condition).repartition(100)  # Filter first, reduce data

---

❌ ANTI-PATTERN 2: Using repartition() When coalesce() Suffices

Bad:
df.repartition(5)  # Full shuffle to reduce partitions

Good:
df.coalesce(5)  # Efficient partition combining

---

❌ ANTI-PATTERN 3: Multiple Shuffles on Same Key

Bad:
df.groupBy("key").agg(...)  # Shuffle 1
result.join(other_df, "key")  # Shuffle 2

Good:
df_partitioned = df.repartition("key")  # Shuffle once
agg_result = df_partitioned.groupBy("key").agg(...)  # No shuffle
joined = agg_result.join(other_df.repartition("key"), "key")  # No shuffle

---

❌ ANTI-PATTERN 4: Ignoring Data Skew

Bad:
# Let one task process 90% of data
skewed_df.groupBy("skewed_key").agg(...)

Good:
# Enable AQE or filter skewed values
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# or
normal_data = df.filter(~is_skewed)
skewed_data = df.filter(is_skewed)
# Process separately

---

❌ ANTI-PATTERN 5: collect() on Large Datasets

Bad:
data = df.collect()  # Shuffle all data to driver - OOM risk!
for row in data:
    process(row)

Good:
df.foreach(lambda row: process(row))  # Process on executors
# or
df.write.format(...).save(...)  # Write directly

---

❌ ANTI-PATTERN 6: Too Many Small Partitions

Bad:
df.repartition(10000)  # 10MB per partition - too many tasks!

Good:
# Target 128-200MB per partition
optimal_partitions = data_size_mb / 128
df.repartition(optimal_partitions)

""")

## Summary

In this notebook, we explored strategies for minimizing data shuffling and handling data skew in PySpark:

### Key Concepts Covered

1. **Understanding Shuffles**
   - Wide vs narrow transformations
   - Performance impact of network and disk I/O
   - Identifying shuffle operations in code

2. **Partitioning Strategies**
   - `repartition()` for even distribution and increasing partitions
   - `coalesce()` for efficient partition reduction
   - Repartitioning by column for co-location
   - Optimal partition sizing (128-200MB)

3. **Data Skew Detection**
   - Analyzing key distribution with groupBy
   - Monitoring Spark UI for long-running tasks
   - Partition size distribution analysis
   - Skew ratio calculation

4. **Skew Remediation**
   - Adaptive Query Execution (AQE) automatic handling
   - Filtering skewed values for separate processing
   - Increasing parallelism with more partitions
   - Salting technique for severe skew

5. **Join Optimization**
   - Broadcast joins for small tables
   - Bucketed joins for pre-partitioned data
   - Partitioned joins for multiple operations
   - Sort-merge joins with AQE optimization

### Functional Programming Integration

- Pure functions for metrics collection and analysis
- Immutable data structures for shuffle metrics
- Composable optimization utilities
- Declarative partition sizing functions

### Performance Principles

- **Minimize Shuffles**: Filter early, broadcast small tables, combine operations
- **Optimize Partitions**: Target 128-200MB per partition
- **Handle Skew**: Enable AQE, filter outliers, use salting as last resort
- **Monitor Performance**: Use Spark UI and custom metrics
- **Let AQE Help**: Leverage automatic optimizations in Spark 3.0+

### Next Steps

- Monitor your Spark applications in Spark UI
- Profile partition distributions in production workloads
- Experiment with different join strategies
- Enable and tune AQE for automatic optimization
- Build custom monitoring for skew detection

## Exercises

Practice implementing shuffle optimization techniques.

In [None]:
print("="*80)
print("EXERCISES: Practice Shuffle Optimization")
print("="*80)

print("""
Exercise 1: Identify Shuffle Operations
----------------------------------------
Analyze the following pipeline and identify all shuffle operations:

result = (df
    .filter(F.col("status") == "active")
    .select("user_id", "amount", "date")
    .groupBy("user_id")
    .agg(F.sum("amount").alias("total"))
    .join(users_df, "user_id")
    .orderBy(F.desc("total"))
)

List the shuffle operations and suggest optimizations.

---

Exercise 2: Optimize Partition Count
-------------------------------------
You have a 50GB dataset with 1000 partitions.

Questions:
1. What is the average partition size?
2. Is this optimal? (Target: 128-200MB per partition)
3. What operation would you use to optimize?
4. How many partitions should you target?

---

Exercise 3: Detect and Handle Skew
-----------------------------------
Create a function that:
1. Detects skew in a DataFrame by a given key
2. Returns separate DataFrames for skewed and normal data
3. Recommends a remediation strategy

def detect_and_split_skew(
    df: DataFrame,
    key_col: str,
    skew_threshold: float = 3.0
) -> Tuple[DataFrame, DataFrame, str]:
    # Your implementation
    pass

---

Exercise 4: Optimize Join Strategy
-----------------------------------
Given:
- orders_df: 100GB, 1M unique customer_ids
- customers_df: 5MB, 100K customers
- products_df: 50MB, 10K products

Task: Optimize this pipeline:

result = (orders_df
    .join(customers_df, "customer_id")
    .join(products_df, "product_id")
    .groupBy("customer_tier", "product_category")
    .agg(F.sum("amount"))
)

Questions:
1. Which joins should be broadcast?
2. Should you repartition? On which column?
3. Rewrite the optimized version

---

Exercise 5: Implement Salting
------------------------------
Implement a complete salted join for skewed data:

def salted_join(
    large_df: DataFrame,
    small_df: DataFrame,
    join_key: str,
    salt_range: int = 10
) -> DataFrame:
    """
    Perform a salted join to handle skew.
    
    Steps:
    1. Add salt to large_df's join key
    2. Explode small_df with all salt values
    3. Join on salted key
    4. Clean up salt columns
    """
    # Your implementation
    pass

---

Exercise 6: Performance Monitoring
-----------------------------------
Create a decorator that measures and reports shuffle metrics:

def monitor_shuffles(func):
    """
    Decorator to monitor shuffle behavior of DataFrame operations.
    Should report:
    - Execution time
    - Input/output partition counts
    - Estimated shuffle occurred
    """
    # Your implementation
    pass

@monitor_shuffles
def my_transformation(df: DataFrame) -> DataFrame:
    return df.groupBy("key").agg(F.sum("value"))

""")

print("\n📝 Complete these exercises to master shuffle optimization!")