# 5.1 Strategic Data Handling: Caching, Broadcast Joins, and Efficient Formats

This notebook demonstrates strategic data handling techniques for optimizing PySpark performance while maintaining functional programming principles.

## Learning Objectives
- Understand when and how to use caching effectively
- Master broadcast joins for performance optimization
- Choose optimal file formats for different use cases
- Apply strategic caching patterns in functional pipelines
- Balance performance optimization with functional purity

## Understanding Spark Caching Strategies

Caching is a controlled departure from pure statelessness that can dramatically improve performance for iterative algorithms and data reuse patterns.

In [1]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import *
import time
import random

# Create larger dataset for meaningful caching demonstrations
def generate_large_sales_data(num_records=100000):
    """Generate large sales dataset for caching demonstrations"""
    
    categories = ['Electronics', 'Clothing', 'Books', 'Home', 'Sports', 'Beauty', 'Automotive']
    regions = ['North', 'South', 'East', 'West', 'Central']
    
    data = []
    for i in range(num_records):
        data.append((
            i + 1,  # transaction_id
            f"Customer_{random.randint(1, num_records // 10)}",  # customer_id
            f"Product_{random.randint(1, 1000)}",  # product_name
            random.choice(categories),  # category
            random.choice(regions),  # region
            round(random.uniform(10, 2000), 2),  # amount
            random.randint(1, 5),  # quantity
            f"2023-{random.randint(1, 12):02d}-{random.randint(1, 28):02d}"  # sale_date
        ))
    
    schema = StructType([
        StructField("transaction_id", IntegerType(), False),
        StructField("customer_id", StringType(), False),
        StructField("product_name", StringType(), False),
        StructField("category", StringType(), False),
        StructField("region", StringType(), False),
        StructField("amount", DoubleType(), False),
        StructField("quantity", IntegerType(), False),
        StructField("sale_date", StringType(), False)
    ])
    
    return spark.createDataFrame(data, schema)

print("Generating large sales dataset...")
sales_df = generate_large_sales_data(50000)  # 50K records for demo
sales_df = sales_df.withColumn("sale_date", F.to_date("sale_date"))

print(f"Generated {sales_df.count():,} sales records")
print("\nSample data:")
sales_df.show(5)
sales_df.printSchema()

Generating large sales dataset...


NameError: name 'spark' is not defined

## Caching Performance Comparison

Let's demonstrate the performance impact of caching on iterative operations:

In [None]:
print("=== Caching Performance Comparison ===")

def complex_transformation(df):
    """Complex transformation that we'll reuse multiple times"""
    return (df
            .withColumn("total_value", F.col("amount") * F.col("quantity"))
            .withColumn("year", F.year("sale_date"))
            .withColumn("month", F.month("sale_date"))
            .withColumn("is_high_value", F.when(F.col("total_value") > 1000, True).otherwise(False))
            .filter(F.col("total_value") > 50)  # Filter out very small transactions
           )

def multiple_analyses_without_caching(df):
    """Perform multiple analyses without caching"""
    transformed_df = complex_transformation(df)
    
    start_time = time.time()
    
    # Analysis 1: Total revenue by category
    category_revenue = (transformed_df
                       .groupBy("category")
                       .agg(F.sum("total_value").alias("total_revenue"))
                       .collect())
    
    # Analysis 2: High-value transaction count by region
    high_value_by_region = (transformed_df
                           .filter(F.col("is_high_value"))
                           .groupBy("region")
                           .count()
                           .collect())
    
    # Analysis 3: Monthly trends
    monthly_trends = (transformed_df
                     .groupBy("year", "month")
                     .agg(F.avg("total_value").alias("avg_value"),
                          F.count("*").alias("transaction_count"))
                     .collect())
    
    end_time = time.time()
    
    return end_time - start_time, len(category_revenue), len(high_value_by_region), len(monthly_trends)

def multiple_analyses_with_caching(df):
    """Perform multiple analyses with caching"""
    transformed_df = complex_transformation(df)
    
    # Cache the transformed DataFrame
    transformed_df.cache()
    
    start_time = time.time()
    
    # Analysis 1: Total revenue by category
    category_revenue = (transformed_df
                       .groupBy("category")
                       .agg(F.sum("total_value").alias("total_revenue"))
                       .collect())
    
    # Analysis 2: High-value transaction count by region
    high_value_by_region = (transformed_df
                           .filter(F.col("is_high_value"))
                           .groupBy("region")
                           .count()
                           .collect())
    
    # Analysis 3: Monthly trends
    monthly_trends = (transformed_df
                     .groupBy("year", "month")
                     .agg(F.avg("total_value").alias("avg_value"),
                          F.count("*").alias("transaction_count"))
                     .collect())
    
    end_time = time.time()
    
    # Unpersist to free memory
    transformed_df.unpersist()
    
    return end_time - start_time, len(category_revenue), len(high_value_by_region), len(monthly_trends)

print("\nRunning analysis without caching...")
time_without_cache, cat_count, region_count, month_count = multiple_analyses_without_caching(sales_df)

print("\nRunning analysis with caching...")
time_with_cache, cat_count_cached, region_count_cached, month_count_cached = multiple_analyses_with_caching(sales_df)

print(f"\n=== Performance Results ===")
print(f"Without caching: {time_without_cache:.2f} seconds")
print(f"With caching:    {time_with_cache:.2f} seconds")
print(f"Performance improvement: {time_without_cache/time_with_cache:.1f}x faster with caching")
print(f"Time saved: {time_without_cache - time_with_cache:.2f} seconds")

# Verify results are identical
assert cat_count == cat_count_cached, "Category counts should match"
assert region_count == region_count_cached, "Region counts should match"
assert month_count == month_count_cached, "Monthly counts should match"
print("\n✅ Results verified: Identical outputs with and without caching")

## Strategic Caching Patterns

Different caching strategies for different use cases:

In [None]:
print("=== Strategic Caching Patterns ===")

from pyspark import StorageLevel

class CachingStrategies:
    """Functional caching strategies for different scenarios"""
    
    @staticmethod
    def cache_for_iterative_ml(df, storage_level=StorageLevel.MEMORY_AND_DISK):
        """
        Caching strategy for iterative ML algorithms
        Uses MEMORY_AND_DISK for fault tolerance
        """
        return df.persist(storage_level)
    
    @staticmethod
    def cache_for_interactive_analysis(df):
        """
        Caching strategy for interactive data exploration
        Prioritizes memory for fast access
        """
        return df.cache()  # Equivalent to MEMORY_ONLY
    
    @staticmethod
    def cache_for_batch_processing(df):
        """
        Caching strategy for batch processing pipelines
        Uses disk storage for large datasets
        """
        return df.persist(StorageLevel.MEMORY_AND_DISK_SER)
    
    @staticmethod
    def conditional_caching(df, cache_condition_func, threshold=1000000):
        """
        Conditionally cache based on dataset characteristics
        """
        row_count = df.count()
        
        if cache_condition_func(row_count, threshold):
            print(f"Dataset size ({row_count:,} rows) meets caching criteria")
            return df.cache()
        else:
            print(f"Dataset size ({row_count:,} rows) too small for caching benefits")
            return df

# Demonstrate different caching strategies
print("\n1. Memory-only caching for interactive analysis:")
interactive_df = CachingStrategies.cache_for_interactive_analysis(sales_df)
print(f"Storage level: {interactive_df.storageLevel}")

print("\n2. Memory and disk caching for ML:")
ml_df = CachingStrategies.cache_for_iterative_ml(sales_df)
print(f"Storage level: {ml_df.storageLevel}")

print("\n3. Serialized caching for batch processing:")
batch_df = CachingStrategies.cache_for_batch_processing(sales_df)
print(f"Storage level: {batch_df.storageLevel}")

print("\n4. Conditional caching:")
# Cache if dataset is large enough
cache_condition = lambda count, threshold: count > threshold
conditional_df = CachingStrategies.conditional_caching(sales_df, cache_condition, threshold=10000)

# Clean up cached DataFrames
interactive_df.unpersist()
ml_df.unpersist()
batch_df.unpersist()
conditional_df.unpersist()

print("\n✅ Cache cleanup completed")

## Delta Cache vs Spark Cache

In Databricks, Delta Cache provides additional performance benefits over standard Spark cache:

In [None]:
print("=== Delta Cache vs Spark Cache ===")

# Note: Delta Cache is automatic in Databricks for Delta tables
# This section demonstrates concepts and best practices

class DeltaCachingPatterns:
    """Best practices for Delta Cache utilization"""
    
    @staticmethod
    def write_for_delta_cache_optimization(df, path, partition_columns=None):
        """
        Write DataFrame to Delta format optimized for Delta Cache
        """
        writer = df.write.format("delta").mode("overwrite")
        
        if partition_columns:
            writer = writer.partitionBy(*partition_columns)
        
        # Optimize file size for Delta Cache (128MB - 1GB per file)
        writer = writer.option("delta.targetFileSize", "268435456")  # 256MB
        
        writer.save(path)
        return path
    
    @staticmethod
    def read_with_delta_cache_awareness(path, select_columns=None, filter_condition=None):
        """
        Read Delta table with patterns that maximize Delta Cache benefits
        """
        df = spark.read.format("delta").load(path)
        
        # Column pruning helps Delta Cache efficiency
        if select_columns:
            df = df.select(*select_columns)
        
        # Predicate pushdown helps with cache locality
        if filter_condition:
            df = df.filter(filter_condition)
        
        return df

# Create optimized Delta table for caching
delta_path = "/tmp/delta_sales_optimized"

try:
    dbutils.fs.rm(delta_path, True)
except:
    pass

print("\nCreating optimized Delta table...")
DeltaCachingPatterns.write_for_delta_cache_optimization(
    sales_df, 
    delta_path, 
    partition_columns=["category"]  # Partition by category for better cache locality
)

print("✅ Delta table created with cache optimization")

# Demonstrate cache-aware reading patterns
print("\nTesting cache-aware reading patterns:")

# Pattern 1: Column pruning
pruned_df = DeltaCachingPatterns.read_with_delta_cache_awareness(
    delta_path,
    select_columns=["category", "region", "amount", "quantity"],
    filter_condition=F.col("amount") > 100
)

print(f"Pruned DataFrame columns: {pruned_df.columns}")
print(f"Filtered records: {pruned_df.count():,}")

# Pattern 2: Repeated access (benefits from Delta Cache)
start_time = time.time()
for i in range(3):
    result = pruned_df.agg(F.avg("amount")).collect()[0][0]
    print(f"Iteration {i+1}: Average amount = ${result:.2f}")

total_time = time.time() - start_time
print(f"Total time for 3 iterations: {total_time:.2f} seconds")
print("(Subsequent iterations should be faster due to Delta Cache)")

## Broadcast Joins for Performance

Broadcast joins can dramatically improve performance when joining large DataFrames with smaller lookup tables:

In [None]:
print("=== Broadcast Joins Demonstration ===")

# Create lookup tables (small DataFrames perfect for broadcasting)
category_metadata = [
    ("Electronics", "Technology", 0.15, "High-tech consumer goods"),
    ("Clothing", "Fashion", 0.25, "Apparel and accessories"),
    ("Books", "Education", 0.08, "Educational and recreational reading"),
    ("Home", "Lifestyle", 0.12, "Home improvement and furniture"),
    ("Sports", "Recreation", 0.18, "Sports equipment and gear"),
    ("Beauty", "Personal Care", 0.22, "Cosmetics and personal care items"),
    ("Automotive", "Transportation", 0.10, "Vehicle parts and accessories")
]

category_schema = StructType([
    StructField("category", StringType(), False),
    StructField("category_group", StringType(), False),
    StructField("commission_rate", DoubleType(), False),
    StructField("description", StringType(), True)
])

category_lookup_df = spark.createDataFrame(category_metadata, category_schema)

region_metadata = [
    ("North", "Northern Region", "Chicago", 1.05),
    ("South", "Southern Region", "Atlanta", 0.98),
    ("East", "Eastern Region", "New York", 1.12),
    ("West", "Western Region", "Los Angeles", 1.08),
    ("Central", "Central Region", "Dallas", 1.02)
]

region_schema = StructType([
    StructField("region", StringType(), False),
    StructField("region_name", StringType(), False),
    StructField("headquarters", StringType(), False),
    StructField("cost_multiplier", DoubleType(), False)
])

region_lookup_df = spark.createDataFrame(region_metadata, region_schema)

print("Lookup tables created:")
print("\nCategory lookup:")
category_lookup_df.show()
print("\nRegion lookup:")
region_lookup_df.show()

print(f"\nDataset sizes:")
print(f"Sales data: {sales_df.count():,} rows")
print(f"Category lookup: {category_lookup_df.count()} rows")
print(f"Region lookup: {region_lookup_df.count()} rows")

In [None]:
# Performance comparison: Regular join vs Broadcast join
print("\n=== Join Performance Comparison ===")

def regular_joins(sales_df, category_df, region_df):
    """Perform joins without broadcast hint"""
    start_time = time.time()
    
    result = (sales_df
             .join(category_df, "category")
             .join(region_df, "region")
             .withColumn("total_value", F.col("amount") * F.col("quantity"))
             .withColumn("commission", F.col("total_value") * F.col("commission_rate"))
             .withColumn("adjusted_value", F.col("total_value") * F.col("cost_multiplier"))
            )
    
    # Force execution
    count = result.count()
    
    end_time = time.time()
    return end_time - start_time, count

def broadcast_joins(sales_df, category_df, region_df):
    """Perform joins with broadcast hints"""
    start_time = time.time()
    
    result = (sales_df
             .join(F.broadcast(category_df), "category")  # Broadcast hint
             .join(F.broadcast(region_df), "region")      # Broadcast hint
             .withColumn("total_value", F.col("amount") * F.col("quantity"))
             .withColumn("commission", F.col("total_value") * F.col("commission_rate"))
             .withColumn("adjusted_value", F.col("total_value") * F.col("cost_multiplier"))
            )
    
    # Force execution
    count = result.count()
    
    end_time = time.time()
    return end_time - start_time, count

print("Running regular joins...")
regular_time, regular_count = regular_joins(sales_df, category_lookup_df, region_lookup_df)

print("Running broadcast joins...")
broadcast_time, broadcast_count = broadcast_joins(sales_df, category_lookup_df, region_lookup_df)

print(f"\n=== Join Performance Results ===")
print(f"Regular joins:    {regular_time:.2f} seconds ({regular_count:,} rows)")
print(f"Broadcast joins:  {broadcast_time:.2f} seconds ({broadcast_count:,} rows)")
print(f"Performance improvement: {regular_time/broadcast_time:.1f}x faster with broadcast")
print(f"Time saved: {regular_time - broadcast_time:.2f} seconds")

# Verify results are identical
assert regular_count == broadcast_count, "Row counts should match"
print("\n✅ Results verified: Identical outputs with both join strategies")

## Intelligent Broadcast Join Strategy

Let's create a functional approach to automatically determine when to use broadcast joins:

In [None]:
print("=== Intelligent Broadcast Join Strategy ===")

class BroadcastJoinOptimizer:
    """Functional utilities for optimizing join strategies"""
    
    @staticmethod
    def estimate_dataframe_size(df, sample_fraction=0.01):
        """
        Estimate DataFrame size for broadcast decision
        Pure function that doesn't modify the DataFrame
        """
        # Sample the DataFrame to estimate row size
        sample_df = df.sample(sample_fraction)
        sample_count = sample_df.count()
        
        if sample_count == 0:
            return 0
        
        # Estimate bytes per row (rough calculation)
        # This is a simplified estimation - in practice, you might use more sophisticated methods
        columns = len(df.columns)
        estimated_bytes_per_row = columns * 50  # Rough estimate
        
        total_rows = df.count()
        estimated_size_mb = (total_rows * estimated_bytes_per_row) / (1024 * 1024)
        
        return estimated_size_mb
    
    @staticmethod
    def should_broadcast(df, broadcast_threshold_mb=200):
        """
        Determine if a DataFrame should be broadcasted
        Pure function for broadcast decision logic
        """
        estimated_size = BroadcastJoinOptimizer.estimate_dataframe_size(df)
        return estimated_size < broadcast_threshold_mb, estimated_size
    
    @staticmethod
    def smart_join(left_df, right_df, join_keys, join_type="inner", 
                   auto_broadcast=True, broadcast_threshold_mb=200):
        """
        Intelligent join function that automatically decides on broadcast strategy
        """
        if not auto_broadcast:
            return left_df.join(right_df, join_keys, join_type)
        
        # Check if either DataFrame should be broadcasted
        left_should_broadcast, left_size = BroadcastJoinOptimizer.should_broadcast(
            left_df, broadcast_threshold_mb)
        right_should_broadcast, right_size = BroadcastJoinOptimizer.should_broadcast(
            right_df, broadcast_threshold_mb)
        
        print(f"Left DataFrame size estimate: {left_size:.1f} MB")
        print(f"Right DataFrame size estimate: {right_size:.1f} MB")
        
        if right_should_broadcast and not left_should_broadcast:
            print("🚀 Broadcasting right DataFrame")
            return left_df.join(F.broadcast(right_df), join_keys, join_type)
        elif left_should_broadcast and not right_should_broadcast:
            print("🚀 Broadcasting left DataFrame")
            return F.broadcast(left_df).join(right_df, join_keys, join_type)
        elif left_should_broadcast and right_should_broadcast:
            # Both are small, broadcast the smaller one
            if left_size <= right_size:
                print("🚀 Broadcasting left DataFrame (smaller of two small DataFrames)")
                return F.broadcast(left_df).join(right_df, join_keys, join_type)
            else:
                print("🚀 Broadcasting right DataFrame (smaller of two small DataFrames)")
                return left_df.join(F.broadcast(right_df), join_keys, join_type)
        else:
            print("📊 Using regular join (both DataFrames too large for broadcast)")
            return left_df.join(right_df, join_keys, join_type)

# Test the intelligent join optimizer
print("\nTesting intelligent join strategy:")

# Join with category lookup (should broadcast)
print("\n1. Joining with category lookup:")
result1 = BroadcastJoinOptimizer.smart_join(
    sales_df, 
    category_lookup_df, 
    "category"
)

# Chain with region lookup (should also broadcast)
print("\n2. Chaining with region lookup:")
final_result = BroadcastJoinOptimizer.smart_join(
    result1,
    region_lookup_df,
    "region"
)

print(f"\nFinal result count: {final_result.count():,} rows")
print("\nSample of enriched data:")
(final_result
 .select("transaction_id", "category", "category_group", "region", "region_name", 
         "amount", "commission_rate")
 .show(5))

## File Format Optimization

Choosing the right file format is crucial for performance. Let's compare different formats:

In [None]:
print("=== File Format Performance Comparison ===")

import os

class FileFormatOptimizer:
    """Utilities for file format optimization"""
    
    @staticmethod
    def write_and_measure(df, path, format_type, **options):
        """
        Write DataFrame in specified format and measure performance
        """
        try:
            dbutils.fs.rm(path, True)
        except:
            pass
        
        start_time = time.time()
        
        writer = df.write.format(format_type).mode("overwrite")
        
        # Apply any format-specific options
        for key, value in options.items():
            writer = writer.option(key, value)
        
        writer.save(path)
        
        write_time = time.time() - start_time
        
        # Measure file size
        try:
            file_info = dbutils.fs.ls(path)
            total_size = sum([f.size for f in file_info if f.name.endswith('.parquet') 
                             or f.name.endswith('.json') or f.name.endswith('.csv')]) / (1024*1024)  # MB
        except:
            total_size = 0
        
        return write_time, total_size
    
    @staticmethod
    def read_and_measure(path, format_type, select_columns=None, filter_condition=None):
        """
        Read DataFrame and measure performance
        """
        start_time = time.time()
        
        df = spark.read.format(format_type).load(path)
        
        if select_columns:
            df = df.select(*select_columns)
        
        if filter_condition:
            df = df.filter(filter_condition)
        
        # Force execution
        count = df.count()
        
        read_time = time.time() - start_time
        
        return read_time, count

# Test different file formats
formats_to_test = [
    ("parquet", "/tmp/sales_parquet", {}),
    ("delta", "/tmp/sales_delta", {}),
    ("json", "/tmp/sales_json", {}),
    ("csv", "/tmp/sales_csv", {"header": "true"})
]

results = []

print("\nTesting file format performance...")
for format_name, path, options in formats_to_test:
    print(f"\nTesting {format_name.upper()} format...")
    
    # Write performance
    write_time, file_size = FileFormatOptimizer.write_and_measure(
        sales_df, path, format_name, **options)
    
    # Read performance (full read)
    read_time_full, record_count = FileFormatOptimizer.read_and_measure(
        path, format_name)
    
    # Read performance (with column pruning)
    read_time_pruned, pruned_count = FileFormatOptimizer.read_and_measure(
        path, format_name, 
        select_columns=["category", "region", "amount"],
        filter_condition=F.col("amount") > 500
    )
    
    results.append({
        'format': format_name.upper(),
        'write_time': write_time,
        'file_size_mb': file_size,
        'read_time_full': read_time_full,
        'read_time_pruned': read_time_pruned,
        'record_count': record_count,
        'pruned_count': pruned_count
    })
    
    print(f"  Write time: {write_time:.2f}s")
    print(f"  File size: {file_size:.1f} MB")
    print(f"  Full read time: {read_time_full:.2f}s")
    print(f"  Pruned read time: {read_time_pruned:.2f}s")

# Display comparison table
print("\n" + "="*80)
print("FILE FORMAT PERFORMANCE COMPARISON")
print("="*80)
print(f"{'Format':<8} {'Write(s)':<10} {'Size(MB)':<10} {'Read Full(s)':<12} {'Read Pruned(s)':<15}")
print("-"*80)

for result in results:
    print(f"{result['format']:<8} {result['write_time']:<10.2f} {result['file_size_mb']:<10.1f} "
          f"{result['read_time_full']:<12.2f} {result['read_time_pruned']:<15.2f}")

print("\n📊 Performance Analysis:")
parquet_result = next(r for r in results if r['format'] == 'PARQUET')
delta_result = next(r for r in results if r['format'] == 'DELTA')
json_result = next(r for r in results if r['format'] == 'JSON')

print(f"• Parquet vs JSON compression: {json_result['file_size_mb']/parquet_result['file_size_mb']:.1f}x smaller")
print(f"• Parquet vs JSON read speed: {json_result['read_time_full']/parquet_result['read_time_full']:.1f}x faster")
print(f"• Column pruning benefit (Parquet): {parquet_result['read_time_full']/parquet_result['read_time_pruned']:.1f}x faster")

## Strategic Data Handling Best Practices

Let's create a comprehensive framework for strategic data handling decisions:

In [None]:
print("=== Strategic Data Handling Framework ===")

class DataHandlingStrategy:
    """Comprehensive framework for data handling decisions"""
    
    @staticmethod
    def analyze_dataset_characteristics(df):
        """
        Analyze DataFrame characteristics to inform handling strategy
        """
        characteristics = {}
        
        # Basic metrics
        characteristics['row_count'] = df.count()
        characteristics['column_count'] = len(df.columns)
        
        # Estimate size
        characteristics['estimated_size_mb'] = BroadcastJoinOptimizer.estimate_dataframe_size(df)
        
        # Check for skew (simplified)
        if characteristics['row_count'] > 1000:  # Only for reasonably sized datasets
            try:
                # Check partition count distribution
                partition_count = df.rdd.getNumPartitions()
                characteristics['partition_count'] = partition_count
                
                # Estimate skew by checking partition size variation
                partition_sizes = df.rdd.mapPartitions(lambda iterator: [sum(1 for _ in iterator)]).collect()
                if partition_sizes:
                    avg_partition_size = sum(partition_sizes) / len(partition_sizes)
                    max_partition_size = max(partition_sizes)
                    characteristics['skew_ratio'] = max_partition_size / avg_partition_size if avg_partition_size > 0 else 1
                else:
                    characteristics['skew_ratio'] = 1
            except:
                characteristics['partition_count'] = 1
                characteristics['skew_ratio'] = 1
        else:
            characteristics['partition_count'] = 1
            characteristics['skew_ratio'] = 1
        
        return characteristics
    
    @staticmethod
    def recommend_caching_strategy(characteristics, access_pattern="multiple"):
        """
        Recommend caching strategy based on dataset characteristics
        """
        recommendations = []
        
        row_count = characteristics['row_count']
        estimated_size = characteristics['estimated_size_mb']
        
        # Size-based recommendations
        if estimated_size < 100:  # Small datasets
            if access_pattern == "multiple":
                recommendations.append("✅ CACHE (MEMORY_ONLY) - Small dataset, multiple access")
            else:
                recommendations.append("❌ No caching needed - Single access of small dataset")
        
        elif estimated_size < 1000:  # Medium datasets
            if access_pattern == "multiple":
                recommendations.append("✅ CACHE (MEMORY_AND_DISK) - Medium dataset, fault tolerance")
            elif access_pattern == "iterative":
                recommendations.append("✅ PERSIST (MEMORY_AND_DISK_SER) - Iterative access pattern")
            else:
                recommendations.append("⚠️  Consider caching - Medium dataset, depends on access pattern")
        
        else:  # Large datasets
            if access_pattern == "iterative":
                recommendations.append("✅ PERSIST (MEMORY_AND_DISK_SER) - Large dataset, iterative ML")
            else:
                recommendations.append("⚠️  Selective caching - Cache only frequently accessed subsets")
        
        return recommendations
    
    @staticmethod
    def recommend_join_strategy(left_chars, right_chars):
        """
        Recommend join strategy based on DataFrame characteristics
        """
        recommendations = []
        
        left_size = left_chars['estimated_size_mb']
        right_size = right_chars['estimated_size_mb']
        
        broadcast_threshold = 200  # MB
        
        if right_size < broadcast_threshold:
            recommendations.append(f"✅ BROADCAST right table ({right_size:.1f} MB < {broadcast_threshold} MB)")
        elif left_size < broadcast_threshold:
            recommendations.append(f"✅ BROADCAST left table ({left_size:.1f} MB < {broadcast_threshold} MB)")
        else:
            # Check for skew
            if left_chars['skew_ratio'] > 3 or right_chars['skew_ratio'] > 3:
                recommendations.append("⚠️  SORT-MERGE JOIN with skew handling (data skew detected)")
            else:
                recommendations.append("📊 SORT-MERGE JOIN (both tables too large for broadcast)")
        
        return recommendations
    
    @staticmethod
    def recommend_file_format(use_case, data_characteristics):
        """
        Recommend file format based on use case and data characteristics
        """
        recommendations = []
        
        if use_case == "analytics":
            recommendations.append("✅ DELTA LAKE - Best for analytics with ACID transactions")
            recommendations.append("✅ PARQUET - Good alternative for read-heavy analytics")
        
        elif use_case == "ml_training":
            recommendations.append("✅ PARQUET - Optimal for ML feature stores")
            recommendations.append("✅ DELTA LAKE - Good for versioned ML datasets")
        
        elif use_case == "streaming":
            recommendations.append("✅ DELTA LAKE - Required for streaming analytics")
        
        elif use_case == "data_exchange":
            recommendations.append("✅ PARQUET - Standard for data interchange")
            recommendations.append("⚠️  JSON - For schema flexibility (with compression)")
        
        return recommendations

# Demonstrate the comprehensive strategy framework
print("\nAnalyzing sales dataset characteristics...")
sales_characteristics = DataHandlingStrategy.analyze_dataset_characteristics(sales_df)
category_characteristics = DataHandlingStrategy.analyze_dataset_characteristics(category_lookup_df)

print("\n📊 Dataset Analysis Results:")
print(f"Sales Dataset:")
print(f"  - Rows: {sales_characteristics['row_count']:,}")
print(f"  - Columns: {sales_characteristics['column_count']}")
print(f"  - Estimated size: {sales_characteristics['estimated_size_mb']:.1f} MB")
print(f"  - Partitions: {sales_characteristics['partition_count']}")
print(f"  - Skew ratio: {sales_characteristics['skew_ratio']:.2f}")

print(f"\nCategory Lookup:")
print(f"  - Rows: {category_characteristics['row_count']}")
print(f"  - Estimated size: {category_characteristics['estimated_size_mb']:.1f} MB")

# Get recommendations
print("\n🎯 Strategic Recommendations:")

print("\nCaching Strategy:")
caching_recs = DataHandlingStrategy.recommend_caching_strategy(sales_characteristics, "multiple")
for rec in caching_recs:
    print(f"  {rec}")

print("\nJoin Strategy:")
join_recs = DataHandlingStrategy.recommend_join_strategy(sales_characteristics, category_characteristics)
for rec in join_recs:
    print(f"  {rec}")

print("\nFile Format Recommendations:")
format_recs = DataHandlingStrategy.recommend_file_format("analytics", sales_characteristics)
for rec in format_recs:
    print(f"  {rec}")

## Summary

**Key Takeaways:**

1. **Strategic Caching**:
   - Cache DataFrames that are accessed multiple times
   - Choose appropriate storage levels based on use case
   - Use conditional caching based on dataset characteristics
   - Always unpersist when done to free resources

2. **Broadcast Join Benefits**:
   - Dramatic performance improvement for small lookup tables
   - Automatic optimization with AQE in Spark 3.0+
   - Manual broadcast hints for explicit control
   - Intelligent decision making based on dataset size

3. **File Format Optimization**:
   - Parquet/Delta for analytics workloads
   - Columnar formats provide better compression and read performance
   - Column pruning benefits with columnar formats
   - Consider use case when choosing formats

4. **Functional Programming Alignment**:
   - Caching as managed "state" within the framework
   - Pure functions for optimization decision logic
   - Immutable DataFrames with strategic persistence
   - Composable optimization strategies

5. **Performance Principles**:
   - Measure before optimizing
   - Understand your data characteristics
   - Use framework-provided optimizations (AQE, Delta Cache)
   - Balance optimization with code maintainability

**Best Practices for Strategic Data Handling**:
- Analyze dataset characteristics before choosing strategies
- Use intelligent decision frameworks rather than hard-coded rules
- Leverage platform optimizations (Photon, AQE, Delta Cache)
- Monitor and measure the impact of optimizations
- Keep optimization logic separate and testable

**Next Steps**: In the next notebook, we'll explore techniques for minimizing data shuffling and handling data skew, which are critical for large-scale PySpark performance.

## Exercise

Apply strategic data handling to your own use case:

1. Create a dataset with lookup tables for your domain
2. Implement caching performance comparison
3. Test broadcast join optimizations
4. Compare file format performance for your data
5. Build a decision framework for your specific use case
6. Create intelligent optimization functions

In [None]:
# Your exercise code here

def create_your_dataset_with_lookups():
    """
    Create a dataset relevant to your domain with appropriate lookup tables
    """
    # Your data generation logic
    pass

def test_caching_performance(df):
    """
    Test caching performance for your specific access patterns
    """
    # Your caching performance test
    pass

def optimize_joins_for_your_data(main_df, lookup_df):
    """
    Apply intelligent join optimization for your data
    """
    # Your join optimization logic
    pass

def build_your_optimization_framework():
    """
    Build optimization decision framework for your domain
    """
    # Your optimization framework
    pass

# Test your implementations
# your_data = create_your_dataset_with_lookups()
# test_caching_performance(your_data)
# build_your_optimization_framework()