# Module 2: ETL Development with PySpark in Databricks
## Laboratory Exercises

Welcome to the hands-on laboratory exercises for Module 2! Today we'll build real ETL pipelines for GlobalMart's data platform.

### Prerequisites
- Completed Module 1
- Running Databricks cluster
- Access to sample datasets
- Approximately 4 hours for completion

### Lab Structure
1. **Lab 1**: DataFrame Fundamentals and Basic Transformations (45 minutes)
2. **Lab 2**: File Operations and Data Management (30 minutes)
3. **Lab 3**: Advanced Transformations with Spark SQL (45 minutes)
4. **Lab 4**: Delta Tables for Reliable ETL (45 minutes)
5. **Lab 5**: Creating and Using UDFs (30 minutes)
6. **Lab 6**: Building a Complete ETL Pipeline (60 minutes)

### GlobalMart Context
GlobalMart operates:
- 500+ physical stores across 30 countries
- E-commerce platform with 10M+ customers
- 50,000+ products across multiple categories
- Processing 1M+ transactions daily

Your mission: Build the ETL pipelines that transform raw data into analytics-ready datasets!

## Initial Setup

Run this cell first to set up your environment and create sample data:

In [0]:
# Import required libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
from delta.tables import DeltaTable
import random
import datetime
import json

# Verify Spark session
print(f"Spark version: {spark.version}")
print(f"Python version: {sc.pythonVer}")

# Create working directory
working_dir = "/tmp/module2_etl"
dbutils.fs.mkdirs(working_dir)

# Set up database
spark.sql("CREATE DATABASE IF NOT EXISTS globalmart")
spark.sql("USE globalmart")
print(f"\nCurrent database: {spark.catalog.currentDatabase()}")
print("\nSetup complete! Ready for ETL development.")

## Data Generation

Let's create realistic sample data for GlobalMart:

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import round, rand
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType, DoubleType
import datetime
import random
# Generate sample customer data with realistic quality issues
def generate_customers(num_customers=10000):
    """Generate customer data with intentional quality issues"""
    countries = ['USA', 'UK', 'Canada', 'Germany', 'France', 'Japan', 'Australia']
    
    customer_schema = StructType([
        StructField("customer_id", StringType(), True),
        StructField("first_name", StringType(), True),
        StructField("last_name", StringType(), True),
        StructField("email", StringType(), True),
        StructField("phone", StringType(), True),
        StructField("country", StringType(), True),
        StructField("registration_date", TimestampType(), True),
        StructField("lifetime_value", DoubleType(), True), # Changed to DoubleType
    ])

    customers = []
    for i in range(num_customers):
        # Introduce data quality issues
        email = f"customer{i}@email.com" if random.random() > 0.05 else None  # 5% null emails
        phone = f"+1-555-{random.randint(1000, 9999)}-{random.randint(1000, 9999)}"
        if random.random() < 0.1:  # 10% have formatting issues
            phone = phone.replace("-", "") if random.random() > 0.5 else phone.replace("+1", "")
        
        customers.append({
            "customer_id": f"C{str(i).zfill(6)}",
            "first_name": f"First{i}" if random.random() > 0.02 else None,  # 2% null names
            "last_name": f"Last{i}",
            "email": email.upper() if email and random.random() > 0.7 else email,  # Mixed case
            "phone": phone,
            "country": random.choice(countries),
            "registration_date": (datetime.datetime.now() - datetime.timedelta(days=random.randint(0, 1000))),
            "lifetime_value": random.uniform(10, 10000) # Use random.uniform for Python list
        })
    
    # Fill nulls in StringType columns with '' and DoubleType with 0.0
    return spark.createDataFrame(customers, schema=customer_schema) \
                .fillna('', subset=[col.name for col in customer_schema if col.dataType == StringType()]) \
                .fillna(0.0, subset=[col.name for col in customer_schema if col.dataType == DoubleType()])


# Generate sample product data
def generate_products(num_products=1000):
    """Generate product catalog data"""
    categories = ['Electronics', 'Clothing', 'Home', 'Sports', 'Books', 'Toys']
    
    product_schema = StructType([
        StructField("product_id", StringType(), True),
        StructField("product_name", StringType(), True),
        StructField("category", StringType(), True),
        StructField("price", DoubleType(), True),
        StructField("cost", DoubleType(), True),
        StructField("supplier_id", StringType(), True),
        StructField("weight_kg", DoubleType(), True)
    ])

    products = []
    for i in range(num_products):
        products.append({
            "product_id": f"P{str(i).zfill(5)}",
            "product_name": f"Product {i}",
            "category": random.choice(categories),
            "price": random.uniform(10, 500),
            "cost": random.uniform(5, 250),
            "supplier_id": f"S{str(random.randint(1, 50)).zfill(3)}",
            "weight_kg": random.uniform(0.1, 20) if random.random() > 0.2 else None
        })
    
    return spark.createDataFrame(products, schema=product_schema).fillna('')

# Generate sample transactions
def generate_transactions(num_transactions=50000):
    """Generate sales transaction data"""
    channels = ['Online', 'Store', 'Mobile']
    
    transaction_schema = StructType([
        StructField("transaction_id", StringType(), True),
        StructField("customer_id", StringType(), True),
        StructField("product_id", StringType(), True),
        StructField("quantity", IntegerType(), True),
        StructField("channel", StringType(), True),
        StructField("transaction_date", StringType(), True), # Keep as StringType for raw data
        StructField("transaction_timestamp", StringType(), True), # Keep as StringType for raw data
        StructField("store_id", StringType(), True)
    ])

    transactions = []
    for i in range(num_transactions):
        trans_date = datetime.datetime.now() - datetime.timedelta(days=random.randint(0, 90))
        
        transactions.append({
            "transaction_id": f"T{str(i).zfill(8)}",
            "customer_id": f"C{str(random.randint(0, 9999)).zfill(6)}",
            "product_id": f"P{str(random.randint(0, 999)).zfill(5)}",
            "quantity": random.randint(1, 10),
            "channel": random.choice(channels),
            "transaction_date": trans_date.strftime('%Y-%m-%d'),
            "transaction_timestamp": trans_date.strftime('%Y-%m-%d %H:%M:%S'),
            "store_id": f"ST{str(random.randint(1, 500)).zfill(3)}" if random.random() > 0.5 else None # Randomly assign store_id
        })
    
    return spark.createDataFrame(transactions, schema=transaction_schema).fillna('')

# Generate all datasets
print("Generating sample data...")
customers_df = generate_customers()
products_df = generate_products()
transactions_df = generate_transactions()

# Save as temporary views
customers_df.createOrReplaceTempView("raw_customers")
products_df.createOrReplaceTempView("raw_products")
transactions_df.createOrReplaceTempView("raw_transactions")

print(f"✓ Generated {customers_df.count():,} customers")
print(f"✓ Generated {products_df.count():,} products")
print(f"✓ Generated {transactions_df.count():,} transactions")
print("\nSample data ready for processing!")

---
## Lab 1: DataFrame Fundamentals and Basic Transformations (45 minutes)

### Objectives
- Load and explore data using DataFrames
- Apply basic transformations
- Handle data quality issues
- Create derived columns

### Exercise 1.1: Data Exploration

In [0]:
# Load data into DataFrames
customers = spark.table("raw_customers")
products = spark.table("raw_products")
transactions = spark.table("raw_transactions")

# Basic exploration
print("=== CUSTOMER DATA ===")
print(f"Schema:")
customers.printSchema()
print(f"\nSample records:")
display(customers.limit(5))

# Check for data quality issues
print("\n=== DATA QUALITY CHECK ===")
null_counts = customers.select([
    count(when(col(c).isNull(), c)).alias(c) 
    for c in customers.columns
])
print("Null counts per column:")
display(null_counts)

### Exercise 1.2: Basic Transformations

In [0]:
# Clean customer data
cleaned_customers = customers \
    .withColumn("email", lower(col("email"))) \
    .withColumn("full_name", concat_ws(" ", col("first_name"), col("last_name"))) \
    .withColumn("has_email", col("email").isNotNull()) \
    .withColumn("registration_year", year(col("registration_date"))) \
    .withColumn("customer_segment", 
        when(col("lifetime_value") > 5000, "High Value")
        .when(col("lifetime_value") > 1000, "Medium Value")
        .otherwise("Low Value")
    )

print("Cleaned customer data:")
display(cleaned_customers.select(
    "customer_id", "full_name", "email", "customer_segment", "registration_year"
).limit(10))

# Analyze customer segments
print("\nCustomer Segmentation:")
segment_analysis = cleaned_customers \
    .groupBy("customer_segment") \
    .agg(
        count("*").alias("customer_count"),
        avg("lifetime_value").alias("avg_lifetime_value"),
        min("registration_date").alias("earliest_registration")
    ) \
    .orderBy("avg_lifetime_value", ascending=False)

display(segment_analysis)

### Exercise 1.3: Working with Multiple DataFrames

In [0]:
# Join transactions with product information
enriched_transactions = transactions \
    .join(products, "product_id", "left") \
    .withColumn("revenue", col("quantity") * col("price")) \
    .withColumn("profit", col("revenue") - (col("quantity") * col("cost")))

# Daily sales summary
daily_sales = enriched_transactions \
    .groupBy("transaction_date", "channel") \
    .agg(
        count("transaction_id").alias("num_transactions"),
        sum("quantity").alias("units_sold"),
        sum("revenue").alias("total_revenue"),
        sum("profit").alias("total_profit"),
        avg("revenue").alias("avg_transaction_value")
    ) \
    .orderBy("transaction_date", "channel")

print("Daily Sales Summary (last 7 days):")
display(daily_sales.filter(
    col("transaction_date") >= date_sub(current_date(), 7)
))

### Exercise 1.4: Data Quality Validation

In [0]:
# Create data quality validation function
def validate_dataframe(df, validations):
    """
    Validate DataFrame based on rules
    validations: dict of column_name -> validation_function
    """
    results = {}
    total_rows = df.count()
    
    for column, validation_expr in validations.items():
        valid_count = df.filter(validation_expr).count()
        invalid_count = total_rows - valid_count
        
        results[column] = {
            "total": total_rows,
            "valid": valid_count,
            "invalid": invalid_count,
            "validity_rate": valid_count / total_rows * 100
        }
    
    return results

# Define validation rules
customer_validations = {
    "email": col("email").isNotNull() & col("email").contains("@"),
    "phone": col("phone").isNotNull() & (length(col("phone")) >= 10),
    "lifetime_value": col("lifetime_value").isNotNull() & (col("lifetime_value") >= 0),
    "registration_date": col("registration_date").isNotNull()
}

# Run validation
validation_results = validate_dataframe(customers, customer_validations)

print("Customer Data Quality Report:")
print("=" * 50)
for column, stats in validation_results.items():
    print(f"{column}:")
    print(f"  - Valid: {stats['valid']:,} ({stats['validity_rate']}%)")
    print(f"  - Invalid: {stats['invalid']:,}")
    print()

### 💡 Lab 1 Key Takeaways
- DataFrames are the foundation of ETL in Spark
- Transformations are lazy - nothing executes until an action
- Always validate data quality early in your pipeline
- Use built-in functions for better performance

---
## Lab 2: File Operations and Data Management (30 minutes)

### Objectives
- Master dbutils.fs operations
- Implement file-based processing patterns
- Handle multiple file formats
- Archive processed files

### Exercise 2.1: File System Operations

In [0]:
# Create directory structure
base_path = "/tmp/module2_etl"
paths = {
    "landing": f"{base_path}/landing",
    "processing": f"{base_path}/processing",
    "processed": f"{base_path}/processed",
    "error": f"{base_path}/error",
    "archive": f"{base_path}/archive"
}

# Create all directories
for name, path in paths.items():
    dbutils.fs.mkdirs(path)
    print(f"Created: {path}")

# List directory contents
print("\nDirectory structure:")
display(dbutils.fs.ls(base_path))

### Exercise 2.2: Writing Data in Multiple Formats

In [0]:
# Prepare sample data for export
export_data = cleaned_customers.limit(1000)

# Write in different formats
formats = {
    "csv": {"path": f"{paths['landing']}/customers.csv", "options": {"header": "true"}},
    "json": {"path": f"{paths['landing']}/customers.json", "options": {}},
    "parquet": {"path": f"{paths['landing']}/customers.parquet", "options": {}}
}

for format_name, config in formats.items():
    export_data.coalesce(1).write \
        .mode("overwrite") \
        .options(**config["options"]) \
        .format(format_name) \
        .save(config["path"])
    print(f"✓ Wrote {format_name} to {config['path']}")

# Check file sizes
print("\nFile sizes comparison:")
for format_name, config in formats.items():
    files = dbutils.fs.ls(config["path"])
    data_files = [f for f in files if not f.name.startswith("_")]
    if data_files:
        size_mb = data_files[0].size / 1024 / 1024
        print(f"{format_name}: {size_mb:.2f} MB")

### Exercise 2.3: File Processing Pipeline

In [0]:
# Simulate incoming files
for i in range(3):
    file_data = transactions.limit(100).withColumn("batch_id", lit(i))
    file_path = f"{paths['landing']}/transactions_batch_{i}.csv"
    
    file_data.coalesce(1).write \
        .mode("overwrite") \
        .option("header", "true") \
        .csv(file_path)
    
    print(f"Created: {file_path}")

# Process files
def process_file(file_path, archive_path):
    """
    Process a single file and archive it
    """
    try:
        # Read file
        df = spark.read.option("header", "true").csv(file_path)
        record_count = df.count()
        
        # Process (simple transformation)
        processed_df = df.withColumn("processed_timestamp", current_timestamp())
        
        # Move to archive
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        archive_file = f"{archive_path}/{timestamp}_{file_path.split('/')[-1]}"
        dbutils.fs.mv(file_path, archive_file,recurse=True)
        
        return {
            "status": "success",
            "records": record_count,
            "archived_to": archive_file
        }
    except Exception as e:
        return {
            "status": "error",
            "error": str(e)
        }

# Process all files in landing
landing_files = dbutils.fs.ls(paths['landing'])
csv_files = [f.path for f in landing_files if ('csv' in f.path) and ('transactions' in f.path)]

print("Processing files:")
for file_path in csv_files:
    result = process_file(file_path, paths['archive'])
    print(f"\nFile: {file_path}")
    print(f"Result: {result}")

### Exercise 2.4: File Monitoring Pattern

In [0]:
# Create file monitoring function
def monitor_directory(path, pattern=None):
    """
    Monitor directory for files matching pattern
    """
    files = dbutils.fs.ls(path)
    
    # Filter by pattern if provided
    if pattern:
        files = [f for f in files if pattern in f.name]
    
    # Get file details
    file_info = []
    for file in files:
        if not file.name.startswith("_"):  # Skip metadata files
            file_info.append({
                "name": file.name,
                "path": file.path,
                "size": file.size,
                "modified_time": datetime.datetime.fromtimestamp(file.modificationTime / 1000)
            })
    
    return sorted(file_info, key=lambda x: x['modified_time'], reverse=True)

# Monitor archive directory
print("Archive Directory Contents:")
archived_files = monitor_directory(paths['archive'])

for file in archived_files:
    print(f"\nFile: {file['name']}")
    print(f"  Size: {file['size']} B")
    print(f"  Modified: {file['modified_time']}")

# Create processing log
processing_log = spark.createDataFrame(archived_files)
processing_log.write.mode("overwrite").json(f"{paths['processed']}/processing_log")
print("\n✓ Processing log saved")

### 💡 Lab 2 Key Takeaways
- dbutils.fs provides powerful file operations
- Always implement proper file archiving
- Monitor directories for new files to process
- Different formats have different size/performance characteristics

---
## Lab 3: Advanced Transformations with Spark SQL (45 minutes)

### Objectives
- Use Spark SQL for complex transformations
- Combine SQL and DataFrame operations
- Implement window functions
- Optimize query performance

### Exercise 3.1: Creating and Using Temp Views

In [0]:
# Register DataFrames as temporary views
cleaned_customers.createOrReplaceTempView("customers")
products.createOrReplaceTempView("products")
enriched_transactions.createOrReplaceTempView("transactions")

# Show available tables
print("Available tables:")
display(spark.sql("SHOW TABLES"))

# Basic SQL query
top_customers_sql = spark.sql("""
    SELECT 
        customer_id,
        full_name,
        customer_segment,
        lifetime_value,
        registration_date
    FROM customers
    WHERE lifetime_value > 5000
    ORDER BY lifetime_value DESC
    LIMIT 10
""")

print("\nTop 10 High-Value Customers:")
display(top_customers_sql)

### Exercise 3.2: Complex Analytical Queries

In [0]:
# Monthly sales analysis with multiple metrics
monthly_analysis = spark.sql("""
    WITH monthly_sales AS (
        SELECT 
            DATE_TRUNC('month', transaction_date) as month,
            channel,
            category,
            COUNT(DISTINCT t.customer_id) as unique_customers,
            COUNT(*) as transaction_count,
            SUM(quantity) as units_sold,
            SUM(revenue) as total_revenue,
            SUM(profit) as total_profit,
            AVG(revenue) as avg_transaction_value
        FROM transactions t
        GROUP BY DATE_TRUNC('month', transaction_date), channel, category
    )
    SELECT 
        month,
        channel,
        category,
        unique_customers,
        transaction_count,
        units_sold,
        ROUND(total_revenue, 2) as total_revenue,
        ROUND(total_profit, 2) as total_profit,
        ROUND(avg_transaction_value, 2) as avg_transaction_value,
        ROUND(total_profit / total_revenue * 100, 2) as profit_margin_pct
    FROM monthly_sales
    WHERE month >= DATE_SUB(CURRENT_DATE(), 90)
    ORDER BY month DESC, total_revenue DESC
""")

print("Monthly Sales Analysis:")
display(monthly_analysis)

### Exercise 3.3: Window Functions

In [0]:
# Customer purchase patterns with window functions
customer_patterns = spark.sql("""
    WITH customer_transactions AS (
        SELECT 
            t.customer_id,
            c.full_name,
            c.customer_segment,
            t.transaction_date,
            t.revenue,
            -- Running total per customer
            SUM(t.revenue) OVER (
                PARTITION BY t.customer_id 
                ORDER BY t.transaction_date 
                ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
            ) as running_total,
            -- Days since last purchase
            DATEDIFF(
                t.transaction_date,
                LAG(t.transaction_date, 1) OVER (
                    PARTITION BY t.customer_id 
                    ORDER BY t.transaction_date
                )
            ) as days_since_last_purchase,
            -- Customer's transaction rank
            ROW_NUMBER() OVER (
                PARTITION BY t.customer_id 
                ORDER BY t.transaction_date
            ) as transaction_number,
            -- Percentile rank by revenue
            PERCENT_RANK() OVER (
                PARTITION BY t.customer_id 
                ORDER BY t.revenue
            ) as revenue_percentile
        FROM transactions t
        JOIN customers c ON t.customer_id = c.customer_id
    )
    SELECT 
        customer_id,
        full_name,
        customer_segment,
        transaction_date,
        revenue,
        running_total,
        days_since_last_purchase,
        transaction_number,
        ROUND(revenue_percentile * 100, 2) as revenue_percentile_pct
    FROM customer_transactions
    WHERE transaction_number <= 5  -- First 5 transactions per customer
    ORDER BY customer_id, transaction_date
    LIMIT 20
""")

print("Customer Purchase Patterns (Window Functions):")
display(customer_patterns)

### Exercise 3.4: Combining SQL and DataFrame Operations

In [0]:
# Start with SQL for complex logic
cohort_analysis_sql = spark.sql("""
    SELECT 
        DATE_TRUNC('month', c.registration_date) as cohort_month,
        c.customer_segment,
        COUNT(DISTINCT c.customer_id) as cohort_size,
        COUNT(DISTINCT t.customer_id) as active_customers,
        SUM(t.revenue) as cohort_revenue
    FROM customers c
    LEFT JOIN transactions t ON c.customer_id = t.customer_id
    WHERE c.registration_date >= DATE_SUB(CURRENT_DATE(), 180)
    GROUP BY DATE_TRUNC('month', c.registration_date), c.customer_segment
""")

# Continue with DataFrame API for additional processing
cohort_analysis_df = cohort_analysis_sql \
    .withColumn("activation_rate", 
        col("active_customers") / col("cohort_size") * 100) \
    .withColumn("avg_revenue_per_active", 
        when(col("active_customers") > 0, 
             col("cohort_revenue") / col("active_customers"))
        .otherwise(0)) \
    .orderBy("cohort_month", "customer_segment")

print("Cohort Analysis (SQL + DataFrame):")
display(cohort_analysis_df)

# Create a pivot table for better visualization
cohort_pivot = cohort_analysis_df \
    .groupBy("cohort_month") \
    .pivot("customer_segment") \
    .agg(first("activation_rate")) \
    .orderBy("cohort_month")

print("\nActivation Rate by Cohort and Segment:")
display(cohort_pivot)

### 💡 Lab 3 Key Takeaways
- Spark SQL provides familiar syntax for complex transformations
- Window functions enable sophisticated analytics
- Combine SQL and DataFrame API for maximum flexibility
- CTEs (WITH clauses) improve query readability

---
## Lab 4: Delta Tables for Reliable ETL (45 minutes)

### Objectives
- Create and manage Delta tables
- Implement merge operations
- Use time travel for data recovery
- Optimize Delta table performance

### Exercise 4.1: Creating Delta Tables

In [0]:
# Create managed Delta tables
# Customer dimension table
cleaned_customers.write \
    .format("delta") \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .saveAsTable("globalmart.dim_customers")

print("✓ Created dim_customers table")

# Product dimension table
products.write \
    .format("delta") \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .saveAsTable("globalmart.dim_products")

print("✓ Created dim_products table")

# Transaction fact table - partitioned by date
enriched_transactions.write \
    .format("delta") \
    .mode("overwrite") \
    .partitionBy("transaction_date") \
    .option("overwriteSchema", "true") \
    .saveAsTable("globalmart.fact_transactions")

print("✓ Created fact_transactions table")

# Verify tables
print("\nDelta Tables:")
display(spark.sql("SHOW TABLES IN globalmart"))

### Exercise 4.2: Delta Merge Operations

In [0]:
# Simulate customer updates
customer_updates = spark.sql("""
    SELECT 
        customer_id,
        first_name,
        last_name,
        LOWER(email) as email,
        phone,
        country,
        registration_date,
        CASE 
            WHEN customer_id LIKE '%00' THEN lifetime_value * 1.1  -- 10% increase
            ELSE lifetime_value 
        END as lifetime_value,
        full_name,
        has_email,
        registration_year,
        CASE 
            WHEN lifetime_value * 1.1 > 5000 THEN 'High Value'
            WHEN lifetime_value * 1.1 > 1000 THEN 'Medium Value'
            ELSE 'Low Value'
        END as customer_segment
    FROM globalmart.dim_customers
    WHERE customer_id LIKE '%0'  -- Update 10% of customers
""")

# New customers to insert
new_customers = generate_customers(100) \
    .withColumn("customer_id", concat(lit("N"), col("customer_id"))) \
    .withColumn("email", lower(col("email"))) \
    .withColumn("full_name", concat_ws(" ", col("first_name"), col("last_name"))) \
    .withColumn("has_email", col("email").isNotNull()) \
    .withColumn("registration_year", year(col("registration_date"))) \
    .withColumn("customer_segment", 
        when(col("lifetime_value") > 5000, "High Value")
        .when(col("lifetime_value") > 1000, "Medium Value")
        .otherwise("Low Value")
    )

# Combine updates and inserts
merge_data = customer_updates.union(new_customers)

print(f"Updates: {customer_updates.count()}")
print(f"New: {new_customers.count()}")
print(f"Total merge records: {merge_data.count()}")

In [0]:
# Perform merge operation
from delta.tables import DeltaTable

# Get Delta table reference
customer_delta = DeltaTable.forName(spark, "globalmart.dim_customers")

# Record counts before merge
before_count = spark.table("globalmart.dim_customers").count()

# Perform merge
merge_result = customer_delta.alias("target").merge(
    merge_data.alias("source"),
    "target.customer_id = source.customer_id"
).whenMatchedUpdate(
    set={
        "lifetime_value": "source.lifetime_value",
        "customer_segment": "source.customer_segment",
        "email": "source.email"  # Update email in case it changed
    }
).whenNotMatchedInsert(
    values={
        "customer_id": "source.customer_id",
        "first_name": "source.first_name",
        "last_name": "source.last_name",
        "email": "source.email",
        "phone": "source.phone",
        "country": "source.country",
        "registration_date": "source.registration_date",
        "lifetime_value": "source.lifetime_value",
        "full_name": "source.full_name",
        "has_email": "source.has_email",
        "registration_year": "source.registration_year",
        "customer_segment": "source.customer_segment"
    }
).execute()

# Record counts after merge
after_count = spark.table("globalmart.dim_customers").count()

print(f"\nMerge Results:")
print(f"Records before: {before_count:,}")
print(f"Records after: {after_count:,}")
print(f"Net new records: {after_count - before_count:,}")

### Exercise 4.3: Time Travel and History

In [0]:
# View table history
history_df = spark.sql("DESCRIBE HISTORY globalmart.dim_customers")
print("Table History:")
display(history_df.select("version", "timestamp", "operation", "operationMetrics"))

# Get version numbers
versions = history_df.select("version").collect()
latest_version = versions[0][0]
previous_version = versions[1][0] if len(versions) > 1 else 0

print(f"\nLatest version: {latest_version}")
print(f"Previous version: {previous_version}")

In [0]:
# Compare versions using time travel
if previous_version != latest_version:
    # Read previous version
    customers_previous = spark.read \
        .format("delta") \
        .option("versionAsOf", previous_version) \
        .table("globalmart.dim_customers")
    
    # Read current version
    customers_current = spark.table("globalmart.dim_customers")
    
    # Find changes
    # New customers
    new_customers = customers_current.join(
        customers_previous,
        customers_current.customer_id == customers_previous.customer_id,
        "left_anti"
    )
    
    # Updated customers
    updated_customers = customers_current.alias("curr").join(
        customers_previous.alias("prev"),
        col("curr.customer_id") == col("prev.customer_id"),
        "inner"
    ).filter(
        col("curr.lifetime_value") != col("prev.lifetime_value")
    ).select(
        col("curr.customer_id"),
        col("curr.full_name"),
        col("prev.lifetime_value").alias("old_lifetime_value"),
        col("curr.lifetime_value").alias("new_lifetime_value"),
        (col("curr.lifetime_value") - col("prev.lifetime_value")).alias("change")
    )
    
    print(f"New customers: {new_customers.count()}")
    print(f"Updated customers: {updated_customers.count()}")
    
    print("\nSample of updated customers:")
    display(updated_customers.limit(10))

### Exercise 4.4: Delta Table Optimization

In [0]:
# Check table details before optimization
detail_before = spark.sql("DESCRIBE DETAIL globalmart.fact_transactions")
print("Table Details Before Optimization:")
display(detail_before.select("numFiles", "sizeInBytes", "properties"))

# Optimize table
print("\nRunning OPTIMIZE...")
optimize_result = spark.sql("""
    OPTIMIZE globalmart.fact_transactions
    ZORDER BY (customer_id, product_id)
""")

display(optimize_result)

# Check table details after optimization
detail_after = spark.sql("DESCRIBE DETAIL globalmart.fact_transactions")
print("\nTable Details After Optimization:")
display(detail_after.select("numFiles", "sizeInBytes", "properties"))

### 💡 Lab 4 Key Takeaways
- Delta tables provide ACID transactions on data lakes
- Merge operations enable efficient upserts
- Time travel allows data recovery and auditing
- Optimization improves query performance

---
## Lab 5: Creating and Using UDFs (30 minutes)

### Objectives
- Create custom UDFs for business logic
- Handle edge cases in UDFs
- Compare UDF performance
- Implement Pandas UDFs

### Exercise 5.1: Creating Basic UDFs

In [0]:
from pyspark.sql.functions import udf, col, least, lit
from pyspark.sql.types import StringType, FloatType, BooleanType
import re

# Email validation UDF
def validate_email(email):
    """Validate email format"""
    if email is None:
        return False
    
    # Simple email regex
    pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
    return bool(re.match(pattern, email.lower()))

# Register UDF
validate_email_udf = udf(validate_email, BooleanType())

# Phone formatting UDF
def format_phone(phone):
    """Format phone numbers consistently"""
    if phone is None:
        return None
    
    # Remove all non-digits
    digits = re.sub(r'\D', '', phone)
    
    # Format as (XXX) XXX-XXXX for 10 digits
    if len(digits) == 10:
        return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
    elif len(digits) == 11 and digits[0] == '1':
        return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
    else:
        return phone  # Return original if can't format

format_phone_udf = udf(format_phone, StringType())

# Customer score UDF (complex business logic)
def calculate_customer_score(lifetime_value, registration_date, transaction_count):
    """Calculate customer score based on multiple factors"""
    if lifetime_value is None or registration_date is None:
        return 0.0
    
    # Days since registration
    days_active = (datetime.datetime.now().date() - registration_date.date()).days
    
    # Score components
    value_score = lifetime_value / 1000 if lifetime_value / 1000 < 10 else 10
    tenure_score = days_active / 365 if days_active / 365 < 5 else 5
    frequency_score = 0 if transaction_count is None else (transaction_count / 10 if transaction_count / 10 < 5 else 5)
    
    total_score = value_score + tenure_score + frequency_score
    return float(total_score)

calculate_score_udf = udf(calculate_customer_score, FloatType())

print("✓ UDFs created successfully")

### Exercise 5.2: Applying UDFs

In [0]:
# Load customer data and add transaction counts
customer_metrics = spark.sql(
    """
    SELECT 
        c.*,
        COALESCE(t.transaction_count, 0) as transaction_count
    FROM globalmart.dim_customers c
    LEFT JOIN (
        SELECT 
            customer_id,
            COUNT(*) as transaction_count
        FROM globalmart.fact_transactions
        GROUP BY customer_id
    ) t ON c.customer_id = t.customer_id
"""
)

# Apply UDFs
customers_enhanced = (
    customer_metrics.withColumn("email_valid", validate_email_udf(col("email")))
    .withColumn("phone_formatted", format_phone_udf(col("phone")))
    .withColumn(
        "customer_score",
        calculate_score_udf(
            col("lifetime_value"), col("registration_date"), col("transaction_count")
        ),
    )
)

# Show results
print("Enhanced Customer Data:")
display(
    customers_enhanced.select(
        "customer_id",
        "email",
        "email_valid",
        "phone",
        "phone_formatted",
        "lifetime_value",
        "transaction_count",
        "customer_score",
    ).limit(20)
)

In [0]:
customer_metrics.printSchema()

In [0]:
customers_enhanced = customer_metrics \
    .withColumn("email_valid", validate_email_udf(col("email"))) \
    .withColumn("phone_formatted", format_phone_udf(col("phone"))) \
    .withColumn("customer_score", 
        calculate_score_udf(
            col("lifetime_value"),
            col("registration_date"),
            col("transaction_count")
        )
    )

customers_enhanced.show(10)

### Exercise 5.3: Pandas UDFs for Better Performance

In [0]:
from pyspark.sql.functions import pandas_udf
import pandas as pd

# Pandas UDF for customer scoring (vectorized)
@pandas_udf(returnType=FloatType())
def calculate_customer_score_pandas(lifetime_value: pd.Series, 
                                   registration_date: pd.Series,
                                   transaction_count: pd.Series) -> pd.Series:
    """Vectorized customer score calculation"""
    # Calculate days active
    today = pd.Timestamp.now().date()
    days_active = (today - registration_date.dt.date).dt.days
    
    # Score components
    value_score = (lifetime_value / 1000).clip(upper=10)
    tenure_score = (days_active / 365).clip(upper=5)
    frequency_score = (transaction_count / 10).clip(upper=5)
    
    # Total score
    total_score = value_score + tenure_score + frequency_score
    
    # Handle nulls
    total_score = total_score.fillna(0.0).round(2)
    
    return total_score

# Performance comparison
import time

# Test with regular UDF
start_time = time.time()
regular_udf_result = customer_metrics \
    .withColumn("score_regular", 
        calculate_score_udf(
            col("lifetime_value"),
            col("registration_date"),
            col("transaction_count")
        )
    ) \
    .select(avg("score_regular")).collect()[0][0]
regular_time = time.time() - start_time

# Test with Pandas UDF
start_time = time.time()
pandas_udf_result = customer_metrics \
    .withColumn("score_pandas",
        calculate_customer_score_pandas(
            col("lifetime_value"),
            col("registration_date"),
            col("transaction_count")
        )
    ) \
    .select(avg("score_pandas")).collect()[0][0]
pandas_time = time.time() - start_time

print("Performance Comparison:")
print(f"Regular UDF: {regular_time:.2f} seconds (avg score: {regular_udf_result:.2f})")
print(f"Pandas UDF: {pandas_time:.2f} seconds (avg score: {pandas_udf_result:.2f})")
print(f"Speedup: {regular_time/pandas_time:.2f}x")

### Exercise 5.4: SQL UDFs

In [0]:
# Register Python UDF for SQL use
spark.udf.register("validate_email_sql", validate_email, BooleanType())
spark.udf.register("format_phone_sql", format_phone, StringType())

# Use UDF in SQL
sql_udf_result = spark.sql("""
    SELECT 
        customer_id,
        email,
        validate_email_sql(email) as is_valid_email,
        phone,
        format_phone_sql(phone) as formatted_phone,
        CASE 
            WHEN validate_email_sql(email) = true THEN 'Valid'
            ELSE 'Invalid'
        END as email_status
    FROM globalmart.dim_customers
    WHERE customer_id LIKE 'C00000%'
    LIMIT 10
""")

print("SQL UDF Results:")
display(sql_udf_result)

### 💡 Lab 5 Key Takeaways
- UDFs enable custom business logic in Spark
- Always handle null values in UDFs
- Pandas UDFs offer better performance for complex operations
- Register UDFs for SQL use when needed
- Use built-in functions when possible for best performance

---
## Lab 6: Building a Complete ETL Pipeline (60 minutes)

### Objectives
- Integrate all components into a production-ready pipeline
- Implement comprehensive error handling
- Add monitoring and logging
- Optimize end-to-end performance

### Exercise 6.1: Pipeline Framework

In [0]:
# Create pipeline framework
import logging
from datetime import datetime
import time

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("GlobalMartETL")

class ETLPipeline:
    """
    Complete ETL Pipeline for GlobalMart
    """
    
    def __init__(self, spark, config):
        self.spark = spark
        self.config = config
        self.metrics = {}
        self.start_time = None
        
    def log_metric(self, metric_name, value):
        """Log pipeline metrics"""
        self.metrics[metric_name] = value
        logger.info(f"Metric - {metric_name}: {value}")
    
    def validate_data(self, df, stage_name, required_columns):
        """Validate data at each stage"""
        logger.info(f"Validating {stage_name}...")
        
        # Check record count
        count = df.count()
        self.log_metric(f"{stage_name}_count", count)
        
        if count == 0:
            raise ValueError(f"No records found in {stage_name}")
        
        # Check required columns
        missing_cols = set(required_columns) - set(df.columns)
        if missing_cols:
            raise ValueError(f"Missing columns in {stage_name}: {missing_cols}")
        
        # Check for nulls in key columns
        null_counts = df.select([
            sum(col(c).isNull().cast("int")).alias(c) 
            for c in required_columns
        ]).collect()[0]
        
        for col_name, null_count in null_counts.asDict().items():
            if null_count > count * 0.1:  # More than 10% nulls
                logger.warning(f"High null rate in {stage_name}.{col_name}: {null_count/count:.2%}")
        
        return True
    
    def extract_data(self):
        """Extract data from sources"""
        logger.info("Starting data extraction...")
        
        try:
            # In real scenario, this would read from external sources
            customers = self.spark.table("raw_customers")
            products = self.spark.table("raw_products")
            transactions = self.spark.table("raw_transactions")
            
            # Validate extracted data
            self.validate_data(customers, "customers", ["customer_id", "registration_date"])
            self.validate_data(products, "products", ["product_id", "price"])
            self.validate_data(transactions, "transactions", ["transaction_id", "customer_id", "product_id"])
            
            return {
                "customers": customers,
                "products": products,
                "transactions": transactions
            }
            
        except Exception as e:
            logger.error(f"Extraction failed: {str(e)}")
            raise
    
    def transform_data(self, raw_data):
        """Transform data with business logic"""
        logger.info("Starting data transformation...")
        
        try:
            # Customer transformations
            customers_clean = raw_data["customers"] \
                .withColumn("email", lower(trim(col("email")))) \
                .withColumn("email_valid", validate_email_udf(col("email"))) \
                .withColumn("phone_formatted", format_phone_udf(col("phone"))) \
                .withColumn("full_name", concat_ws(" ", col("first_name"), col("last_name"))) \
                .withColumn("customer_segment",
                    when(col("lifetime_value") > 5000, "High Value")
                    .when(col("lifetime_value") > 1000, "Medium Value")
                    .otherwise("Low Value")
                ) \
                .filter(col("customer_id").isNotNull())
            
            # Product transformations
            products_clean = raw_data["products"] \
                .withColumn("profit_margin", 
                    round((col("price") - col("cost")) / col("price") * 100, 2)) \
                .withColumn("price_category",
                    when(col("price") > 100, "Premium")
                    .when(col("price") > 50, "Standard")
                    .otherwise("Budget")
                ) \
                .filter(col("product_id").isNotNull())
            
            # Transaction enrichment
            transactions_enriched = raw_data["transactions"] \
                .join(products_clean, "product_id", "left") \
                .join(customers_clean.select("customer_id", "customer_segment", "country"), 
                      "customer_id", "left") \
                .withColumn("revenue", col("quantity") * col("price")) \
                .withColumn("profit", col("quantity") * (col("price") - col("cost"))) \
                .withColumn("transaction_hour", hour(col("transaction_timestamp"))) \
                .withColumn("transaction_day_of_week", dayofweek(col("transaction_date"))) \
                .filter(col("transaction_id").isNotNull())
            
            # Validate transformations
            self.validate_data(customers_clean, "customers_transformed", 
                             ["customer_id", "customer_segment"])
            self.validate_data(transactions_enriched, "transactions_transformed", 
                             ["transaction_id", "revenue", "profit"])
            
            return {
                "dim_customers": customers_clean,
                "dim_products": products_clean,
                "fact_transactions": transactions_enriched
            }
            
        except Exception as e:
            logger.error(f"Transformation failed: {str(e)}")
            raise
    
    def load_data(self, transformed_data):
        """Load data to Delta tables"""
        logger.info("Starting data load...")
        
        try:
            # Load dimensions with SCD Type 1 (overwrite)
            for table_name in ["dim_customers", "dim_products"]:
                logger.info(f"Loading {table_name}...")
                
                transformed_data[table_name].write \
                    .format("delta") \
                    .mode("overwrite") \
                    .option("overwriteSchema", "true") \
                    .saveAsTable(f"globalmart.{table_name}_pipeline")
                
                count = self.spark.table(f"globalmart.{table_name}_pipeline").count()
                self.log_metric(f"{table_name}_loaded", count)
            
            # Load fact table with append
            logger.info("Loading fact_transactions...")
            
            transformed_data["fact_transactions"] \
                .withColumn("load_timestamp", current_timestamp()) \
                .write \
                .format("delta") \
                .mode("append") \
                .partitionBy("transaction_date") \
                .saveAsTable("globalmart.fact_transactions_pipeline")
            
            count = transformed_data["fact_transactions"].count()
            self.log_metric("fact_transactions_loaded", count)
            
            return True
            
        except Exception as e:
            logger.error(f"Load failed: {str(e)}")
            raise
    
    def run(self):
        """Execute complete pipeline"""
        self.start_time = time.time()
        logger.info("="*50)
        logger.info("Starting GlobalMart ETL Pipeline")
        logger.info(f"Run timestamp: {datetime.now()}")
        
        try:
            # Extract
            raw_data = self.extract_data()
            
            # Transform
            transformed_data = self.transform_data(raw_data)
            
            # Load
            self.load_data(transformed_data)
            
            # Calculate total runtime
            runtime = time.time() - self.start_time
            self.log_metric("total_runtime_seconds", runtime)
            
            # Save metrics
            self.save_metrics()
            
            logger.info("Pipeline completed successfully!")
            logger.info("="*50)
            
            return True
            
        except Exception as e:
            runtime = time.time() - self.start_time
            self.log_metric("total_runtime_seconds", runtime)
            self.log_metric("status", "FAILED")
            self.log_metric("error", str(e))
            
            logger.error(f"Pipeline failed after {runtime:.2f} seconds")
            logger.error(f"Error: {str(e)}")
            logger.info("="*50)
            
            # Save metrics even on failure
            self.save_metrics()
            
            raise
    
    def save_metrics(self):
        """Save pipeline metrics for monitoring"""
        metrics_data = [{
            "run_timestamp": datetime.now(),
            "metric_name": k,
            "metric_value": str(v)
        } for k, v in self.metrics.items()]
        
        metrics_df = self.spark.createDataFrame(metrics_data)
        
        metrics_df.write \
            .format("delta") \
            .mode("append") \
            .saveAsTable("globalmart.pipeline_metrics")

# Create pipeline configuration
pipeline_config = {
    "max_nulls_percentage": 0.1,
    "batch_size": 10000,
    "error_tolerance": 0.05
}

print("✓ Pipeline framework created")

### Exercise 6.2: Run the Pipeline

In [0]:
# Initialize and run pipeline
pipeline = ETLPipeline(spark, pipeline_config)

# Run the pipeline
pipeline.run()
print("\nPipeline Metrics:")
for metric, value in pipeline.metrics.items():
    print(f"  {metric}: {value}")
        


### Exercise 6.3: Monitor Pipeline Results

In [0]:
# Check pipeline output
print("Pipeline Output Validation:")
print("=" * 50)

# Check loaded tables
tables = [
    "globalmart.dim_customers_pipeline",
    "globalmart.dim_products_pipeline", 
    "globalmart.fact_transactions_pipeline"
]

for table in tables:
    count = spark.table(table).count()
    print(f"{table}: {count:,} records")

# View metrics history
print("\nPipeline Metrics History:")
metrics_history = spark.sql("""
    SELECT 
        run_timestamp,
        metric_name,
        metric_value
    FROM globalmart.pipeline_metrics
    ORDER BY run_timestamp DESC, metric_name
    LIMIT 20
""")

display(metrics_history)

### Exercise 6.4: Create Monitoring Dashboard

In [0]:
# Create monitoring views
# Data quality dashboard
data_quality_dashboard = spark.sql("""
    WITH quality_metrics AS (
        SELECT 
            'Customers' as dataset,
            COUNT(*) as total_records,
            SUM(CASE WHEN email_valid = true THEN 1 ELSE 0 END) as valid_emails,
            COUNT(DISTINCT customer_segment) as segments,
            MIN(registration_date) as earliest_date,
            MAX(registration_date) as latest_date
        FROM globalmart.dim_customers_pipeline
        
        UNION ALL
        
        SELECT 
            'Products' as dataset,
            COUNT(*) as total_records,
            COUNT(DISTINCT category) as valid_emails,  -- Using as category count
            COUNT(DISTINCT price_category) as segments,
            NULL as earliest_date,
            NULL as latest_date
        FROM globalmart.dim_products_pipeline
    )
    SELECT * FROM quality_metrics
""")

print("Data Quality Dashboard:")
display(data_quality_dashboard)

# Performance trends
performance_trends = spark.sql("""
    SELECT 
        DATE(run_timestamp) as run_date,
        MAX(CASE WHEN metric_name = 'total_runtime_seconds' 
            THEN CAST(metric_value AS FLOAT) END) as runtime_seconds,
        MAX(CASE WHEN metric_name = 'customers_count' 
            THEN CAST(metric_value AS INT) END) as customers_processed,
        MAX(CASE WHEN metric_name = 'fact_transactions_loaded' 
            THEN CAST(metric_value AS INT) END) as transactions_loaded
    FROM globalmart.pipeline_metrics
    GROUP BY DATE(run_timestamp)
    ORDER BY run_date DESC
""")

print("\nPerformance Trends:")
display(performance_trends)

### 💡 Lab 6 Key Takeaways
- Production pipelines need comprehensive error handling
- Validate data at each stage of processing
- Log metrics for monitoring and optimization
- Use Delta tables for reliable data storage
- Build reusable frameworks for consistency

---
## 🎉 Module 2 Complete!

### Your Achievements
You've successfully completed Module 2 and built production-grade ETL pipelines! Here's what you've mastered:

1. **DataFrame Operations**: Complex transformations, joins, and aggregations
2. **File Management**: Working with DBFS and multiple file formats
3. **Spark SQL**: Advanced analytics with window functions
4. **Delta Tables**: ACID transactions, merges, and time travel
5. **UDFs**: Custom business logic with performance optimization
6. **Complete Pipeline**: End-to-end ETL with monitoring and error handling

### Final Summary

In [0]:
# Generate learning summary
print("=" * 60)
print("MODULE 2 LEARNING SUMMARY")
print("=" * 60)

# Count what we've created
table_count = spark.sql("SHOW TABLES IN globalmart").count()
print(f"\n📊 Tables created: {table_count}")

# Total records processed
total_records = 0
for table in spark.sql("SHOW TABLES IN globalmart").collect():
    if table.isTemporary == False:
        count = spark.table(f"globalmart.{table.tableName}").count()
        total_records += count

print(f"📈 Total records processed: {total_records:,}")

# Skills checklist
skills = [
    "PySpark DataFrame transformations",
    "File system operations with dbutils",
    "Spark SQL and window functions",
    "Delta Lake operations",
    "UDF creation and optimization",
    "Production pipeline development",
    "Error handling and monitoring",
    "Performance optimization"
]

print("\n✅ Skills Mastered:")
for skill in skills:
    print(f"   - {skill}")

print("\n🚀 You're now ready for Module 3: Incremental Processing with Delta Lake!")
print("=" * 60)

### Before You Go

**Remember to:**
1. Save this notebook for future reference
2. Review any sections that were challenging
3. Practice the patterns with your own data
4. Prepare questions for the Wednesday session

**Optional Challenges:**
- Optimize the pipeline to run 50% faster
- Add data quality rules using Delta Live Tables expectations
- Implement streaming ingestion for real-time data
- Create a data lineage visualization

Great work today! See you in Module 3! 🎓