# Transaction DP System - Production Test Notebook

This notebook provides comprehensive testing of the differential privacy pipeline for transaction data.

**Test Coverage:**
1. ‚úÖ Data loading and validation
2. ‚úÖ Privacy configuration and budget allocation
3. ‚úÖ User-level DP parameters (D_max, K, sensitivities)
4. ‚úÖ Pipeline execution with top-down algorithm
5. ‚úÖ Privacy guarantee verification
6. ‚úÖ Utility evaluation metrics

**Key Privacy Concepts:**
- **zCDP (œÅ-zCDP)**: Privacy budget measured in rho, converts to (Œµ,Œ¥)-DP
- **User-level DP**: Protects entire card's transaction history (not just single transactions)
- **Global Sensitivity**: sqrt(M √ó D_max) √ó K where M=max cells per card, D_max=max distinct days
- **Sequential Composition**: Budget accumulates across days within a month

**Input Data Schema:**
Your data should have these columns:
- `pspiin`: PSP identifier (optional)
- `acceptorid`: Acceptor/merchant identifier
- `card_number`: Card identifier
- `transaction_date`: Date of transaction
- `transaction_amount`: Transaction amount
- `city`: City of the acceptor
- `mcc`: Merchant Category Code

**Output (with DP):**
Aggregated at `(transaction_date, city, mcc)` level with:
- `transaction_count`: Count of transactions
- `unique_cards`: Count of distinct cards
- `transaction_amount_sum`: Sum of transaction amounts


---
## 1. Setup & Environment Configuration

Configure logging, imports, and verify environment.


In [None]:
import sys
import os
import logging
import math
from datetime import datetime
from fractions import Fraction

# Add project root to Python path (required for imports to work in Jupyter)
PROJECT_ROOT = os.getcwd()
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
    print(f"Added to sys.path: {PROJECT_ROOT}")

# Configure logging to print to stdout (Jupyter/terminal)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(name)s - %(message)s',
    datefmt='%H:%M:%S',
    handlers=[logging.StreamHandler(sys.stdout)],
    force=True  # Override any existing config
)

# Set log level for all transaction_dp loggers
logging.getLogger('transaction_dp').setLevel(logging.INFO)
logging.getLogger('py4j').setLevel(logging.WARNING)  # Reduce Spark noise

logger = logging.getLogger('demo_notebook')

# Print environment info
print("="*70)
print("ENVIRONMENT INFORMATION")
print("="*70)
print(f"Python Version: {sys.version}")
print(f"Working Directory: {os.getcwd()}")
print(f"Timestamp: {datetime.now().isoformat()}")

# Check required files exist
required_files = [
    'data/city_province.csv',
    'core/config.py',
    'core/pipeline.py',
    'core/sensitivity.py',
    'engine/topdown.py'
]
print(f"\nRequired Files Check:")
for f in required_files:
    exists = os.path.exists(f)
    status = "‚úÖ" if exists else "‚ùå"
    print(f"  {status} {f}")
    if not exists:
        raise FileNotFoundError(f"Required file missing: {f}")

print("\n‚úÖ Environment setup complete!")


---
## 2. Spark Configuration & Initialization


In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

# Spark configuration - optimized for 38 cores and 220GB RAM
SPARK_MASTER = "local[38]"  # Use all 38 available cores
SPARK_APP_NAME = "TransactionDP-Test"
# Memory allocation: 220GB total, leaving ~10GB for OS
# - Executor: 170GB (heap memory for processing)
# - Driver: 10GB (coordination, not data processing)
# - Overhead: 30GB (off-heap, network buffers, etc.)
# Total: 170 + 10 + 30 = 210GB, leaving 10GB for OS
SPARK_EXECUTOR_MEMORY = "170g"
SPARK_DRIVER_MEMORY = "10g"

print("="*70)
print("SPARK CONFIGURATION")
print("="*70)
print(f"  Master: {SPARK_MASTER}")
print(f"  App Name: {SPARK_APP_NAME}")
print(f"  Executor Memory: {SPARK_EXECUTOR_MEMORY}")
print(f"  Driver Memory: {SPARK_DRIVER_MEMORY}")

# Stop any existing Spark session
existing_session = SparkSession.getActiveSession()
if existing_session:
    print("\nStopping existing Spark session...")
    existing_session.stop()
    import time
    time.sleep(0.5)

# Create Spark session with optimizations to reduce RowBasedKeyValueBatch spill warnings
# These settings improve memory management during aggregations and joins
spark = SparkSession.builder \
    .appName(SPARK_APP_NAME) \
    .master(SPARK_MASTER) \
    .config("spark.executor.memory", SPARK_EXECUTOR_MEMORY) \
    .config("spark.driver.memory", SPARK_DRIVER_MEMORY) \
    .config("spark.sql.shuffle.partitions", "228") \
    .config("spark.default.parallelism", "228") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .config("spark.sql.adaptive.skewJoin.enabled", "true") \
    .config("spark.sql.adaptive.localShuffleReader.enabled", "true") \
    .config("spark.sql.adaptive.advisoryPartitionSizeInBytes", "256MB") \
    .config("spark.sql.adaptive.maxNumPostShufflePartitions", "500") \
    .config("spark.memory.fraction", "0.75") \
    .config("spark.memory.storageFraction", "0.4") \
    .config("spark.executor.memoryOverhead", "30g") \
    .config("spark.driver.maxResultSize", "4g") \
    .config("spark.shuffle.spill.compress", "true") \
    .config("spark.shuffle.compress", "true") \
    .config("spark.shuffle.spill.numElementsForceSpillThreshold", "1000000") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.network.timeout", "600s") \
    .config("spark.sql.broadcastTimeout", "600s") \
    .config("spark.sql.autoBroadcastJoinThreshold", "200MB") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "false") \
    .getOrCreate()

# Verify Spark session
actual_master = spark.sparkContext.master
actual_parallelism = spark.sparkContext.defaultParallelism

print(f"\n‚úÖ Spark session initialized!")
print(f"  Actual Master: {actual_master}")
print(f"  Default Parallelism: {actual_parallelism}")
print(f"  Spark Version: {spark.version}")

# Note about RowBasedKeyValueBatch warnings
print(f"\nüìù Note: If you see 'RowBasedKeyValueBatch: calling spill()' warnings,")
print(f"   this is a known Spark behavior during large aggregations.")
print(f"   The optimizations above help reduce memory pressure, but the warning")
print(f"   itself is harmless and doesn't affect correctness.")
print(f"\nüíª Hardware Configuration:")
print(f"   - CPU Cores: 38 (fully utilized)")
print(f"   - Total RAM: 220GB")
print(f"   - Executor Memory: 170GB (heap)")
print(f"   - Driver Memory: 10GB")
print(f"   - Memory Overhead: 30GB (off-heap)")
print(f"   - Reserved for OS: ~10GB")
print(f"   - Shuffle Partitions: 228 (6x cores for optimal parallelism)")
print(f"\n‚öôÔ∏è Memory Management:")
print(f"   - Memory Fraction: 75% (balanced heap usage)")
print(f"   - Storage Fraction: 40% (cache vs execution)")
print(f"   - Max Result Size: 4GB (prevents driver OOM)")

# Helper functions
def show_df(df, n=20, truncate=True):
    """Display Spark DataFrame in notebook."""
    df.show(n=n, truncate=truncate)
    
def to_pandas_safe(df, max_rows=100000):
    """Convert Spark DataFrame to Pandas, but only if small enough."""
    count = df.count()
    if count > max_rows:
        raise ValueError(f"DataFrame too large ({count:,} rows). Use Spark operations.")
    return df.toPandas()

print("\nüìù Helper functions available: show_df(), to_pandas_safe()")


---
## 3. Configure Data Paths

**Update the `DATA_INPUT_PATH` below to point to your data file.**

Expected columns:
- `pspiin`: PSP identifier (optional, not used in DP)
- `acceptorid`: Acceptor/merchant identifier  
- `card_number`: Card identifier
- `transaction_date`: Date of transaction (format: YYYY-MM-DD)
- `transaction_amount`: Transaction amount
- `city`: City of the acceptor
- `mcc`: Merchant Category Code


In [None]:
# ============================================
# UPDATE THIS PATH TO YOUR DATA FILE
# ============================================
DATA_INPUT_PATH = 'data/your_transactions.parquet'  # <-- CHANGE THIS

# Other paths
CITY_PROVINCE_PATH = 'data/city_province.csv'
OUTPUT_PATH = 'data/dp_results'

print("="*70)
print("DATA CONFIGURATION")
print("="*70)
print(f"  Input Path: {DATA_INPUT_PATH}")
print(f"  City-Province Mapping: {CITY_PROVINCE_PATH}")
print(f"  Output Path: {OUTPUT_PATH}")

# Check if input file exists
if not os.path.exists(DATA_INPUT_PATH):
    print(f"\n‚ùå ERROR: Input file not found: {DATA_INPUT_PATH}")
    print("   Please update DATA_INPUT_PATH to point to your data file.")
    raise FileNotFoundError(f"Input file not found: {DATA_INPUT_PATH}")

print(f"\n‚úÖ Input file found!")


---
## 4. Load and Analyze Raw Data

Understand data characteristics for privacy parameter tuning.


In [None]:
# Load data
print("Loading data...")
df_spark = spark.read.parquet(DATA_INPUT_PATH)

# Basic statistics
total_count = df_spark.count()

print("="*70)
print("RAW DATA ANALYSIS")
print("="*70)
print(f"\nTotal records: {total_count:,}")
print(f"\nSchema:")
df_spark.printSchema()

# Verify required columns exist
required_cols = ['card_number', 'transaction_date', 'transaction_amount', 'city', 'mcc']
missing_cols = [col for col in required_cols if col not in df_spark.columns]
if missing_cols:
    print(f"\n‚ùå ERROR: Missing required columns: {missing_cols}")
    raise ValueError(f"Missing required columns: {missing_cols}")

# Unique counts
unique_cards = df_spark.select('card_number').distinct().count()
unique_cities = df_spark.select('city').distinct().count()
unique_mccs = df_spark.select('mcc').distinct().count()

print(f"\nüìä Unique Counts:")
print(f"  Cards: {unique_cards:,}")
print(f"  Cities: {unique_cities:,}")
print(f"  MCCs: {unique_mccs:,}")

# Date and amount ranges
stats = df_spark.agg(
    F.min('transaction_date').alias('min_date'),
    F.max('transaction_date').alias('max_date'),
    F.min('transaction_amount').alias('min_amount'),
    F.max('transaction_amount').alias('max_amount'),
    F.avg('transaction_amount').alias('avg_amount'),
    F.stddev('transaction_amount').alias('std_amount'),
    F.percentile_approx('transaction_amount', 0.99).alias('p99_amount')
).collect()[0]

print(f"\nüìÖ Date Range: {stats['min_date']} to {stats['max_date']}")
print(f"\nüí∞ Amount Statistics:")
print(f"  Min: {stats['min_amount']:,.0f}")
print(f"  Max: {stats['max_amount']:,.0f}")
print(f"  Mean: {stats['avg_amount']:,.0f}")
print(f"  Std Dev: {stats['std_amount']:,.0f}")
print(f"  99th Percentile: {stats['p99_amount']:,.0f}")

# Sample data
print(f"\nüìù Sample rows:")
show_df(df_spark, n=5)


### 4.1 User-Level DP Parameters Analysis

Compute critical parameters for user-level differential privacy:
- **M**: Max cells (city√óMCC√óday combinations) a single card appears in
- **D_max**: Max distinct days a single card makes transactions
- **K**: Per-cell contribution bound


In [None]:
print("="*70)
print("USER-LEVEL DP PARAMETER ANALYSIS")
print("="*70)

# Compute M: Max cells per card
# A cell is (city, mcc, day) combination
cells_per_card = df_spark.groupBy('card_number', 'city', 'mcc', 'transaction_date') \
    .count() \
    .groupBy('card_number') \
    .agg(F.count('*').alias('num_cells'))

M_stats = cells_per_card.agg(
    F.max('num_cells').alias('max_M'),
    F.avg('num_cells').alias('avg_M'),
    F.percentile_approx('num_cells', 0.99).alias('p99_M'),
    F.percentile_approx('num_cells', 0.95).alias('p95_M')
).collect()[0]

print(f"\nüìä M (Max Cells per Card):")
print(f"  Max: {M_stats['max_M']}")
print(f"  99th Percentile: {M_stats['p99_M']}")
print(f"  95th Percentile: {M_stats['p95_M']}")
print(f"  Mean: {M_stats['avg_M']:.2f}")

# Compute D_max: Max distinct days per card
days_per_card = df_spark.groupBy('card_number') \
    .agg(F.countDistinct('transaction_date').alias('num_days'))

D_stats = days_per_card.agg(
    F.max('num_days').alias('max_D'),
    F.avg('num_days').alias('avg_D'),
    F.percentile_approx('num_days', 0.99).alias('p99_D')
).collect()[0]

print(f"\nüìÖ D_max (Max Distinct Days per Card):")
print(f"  Max: {D_stats['max_D']}")
print(f"  99th Percentile: {D_stats['p99_D']}")
print(f"  Mean: {D_stats['avg_D']:.2f}")

# Compute K: Transactions per cell
txns_per_cell = df_spark.groupBy('card_number', 'city', 'mcc', 'transaction_date') \
    .agg(F.count('*').alias('txns_in_cell'))

K_stats = txns_per_cell.agg(
    F.max('txns_in_cell').alias('max_K'),
    F.avg('txns_in_cell').alias('avg_K'),
    F.percentile_approx('txns_in_cell', 0.99).alias('p99_K'),
    F.percentile_approx('txns_in_cell', 0.75).alias('p75_K')
).collect()[0]

print(f"\nüî¢ K (Transactions per Card per Cell):")
print(f"  Max: {K_stats['max_K']}")
print(f"  99th Percentile: {K_stats['p99_K']}")
print(f"  75th Percentile: {K_stats['p75_K']}")
print(f"  Mean: {K_stats['avg_K']:.2f}")

# Store computed values for later use
COMPUTED_M = int(M_stats['max_M'])
COMPUTED_D_MAX = int(D_stats['max_D'])
COMPUTED_K = int(K_stats['p99_K'])  # Use 99th percentile for bounded contribution

# Compute number of days in data
min_date = stats['min_date']
max_date = stats['max_date']
if isinstance(min_date, str):
    min_date = datetime.strptime(min_date, '%Y-%m-%d').date()
if isinstance(max_date, str):
    max_date = datetime.strptime(max_date, '%Y-%m-%d').date()
NUM_DAYS = (max_date - min_date).days + 1

print(f"\n" + "="*70)
print(f"COMPUTED PARAMETERS FOR DP:")
print(f"  M (max cells per card): {COMPUTED_M}")
print(f"  D_max (max days per card): {COMPUTED_D_MAX}")
print(f"  K (contribution bound): {COMPUTED_K}")
print(f"  NUM_DAYS (total days in data): {NUM_DAYS}")
print(f"  sqrt(M √ó D_max) √ó K = {math.sqrt(COMPUTED_M * COMPUTED_D_MAX) * COMPUTED_K:.2f}")
print("="*70)


---
## 5. Configure DP Pipeline

Set up differential privacy configuration with all parameters.


In [None]:
from core.config import Config

# Create configuration
config = Config()

# === DATA SETTINGS ===
config.data.input_path = DATA_INPUT_PATH
config.data.output_path = OUTPUT_PATH
config.data.city_province_path = CITY_PROVINCE_PATH
config.data.input_format = 'parquet'
config.data.num_days = NUM_DAYS
config.data.winsorize_percentile = 99.0  # Cap amounts at 99th percentile

# === COLUMN MAPPINGS ===
# Map your column names to internal names used by the pipeline
config.columns = {
    'amount': 'transaction_amount',        # Your amount column
    'transaction_date': 'transaction_date', # Your date column
    'card_number': 'card_number',          # Your card identifier column
    'acceptor_id': 'acceptorid',           # Your acceptor/merchant column
    'acceptor_city': 'city',               # Your city column
    'mcc': 'mcc'                           # Your MCC column
}

# === PRIVACY SETTINGS ===
# Total privacy budget (rho for zCDP)
# Rule of thumb: rho=1 gives strong utility, rho=0.25 gives strong privacy
config.privacy.total_rho = Fraction(1, 2)  # rho = 0.5
config.privacy.delta = 1e-10

# Geographic budget split (Province vs City level)
config.privacy.geographic_split = {
    'province': 0.2,  # 20% for province-level aggregates
    'city': 0.8       # 80% for city-level aggregates
}

# Query budget split - 3 queries now
config.privacy.query_split = {
    'transaction_count': 0.34,
    'unique_cards': 0.33,
    'total_amount': 0.33
}

# Bounded contribution settings
config.privacy.contribution_bound_method = 'percentile'
config.privacy.contribution_bound_percentile = 99.0

# Suppression settings
config.privacy.suppression_threshold = 5

# Sensitivity method
config.privacy.sensitivity_method = 'global'

# MCC grouping for stratified sensitivity
config.privacy.mcc_grouping_enabled = True
config.privacy.mcc_num_groups = 5

# Confidence intervals
config.privacy.confidence_levels = [0.90, 0.95]

# === SPARK SETTINGS ===
config.spark.app_name = SPARK_APP_NAME
config.spark.master = SPARK_MASTER
config.spark.executor_memory = SPARK_EXECUTOR_MEMORY
config.spark.driver_memory = SPARK_DRIVER_MEMORY

# Validate configuration
config.validate()

print("="*70)
print("DP CONFIGURATION SUMMARY")
print("="*70)
print(f"\nüìä Privacy Budget:")
print(f"  Total œÅ (rho): {config.privacy.total_rho} = {float(config.privacy.total_rho):.4f}")
print(f"  Œ¥ (delta): {config.privacy.delta}")

# Convert zCDP to (Œµ,Œ¥)-DP for reference
rho = float(config.privacy.total_rho)
delta = config.privacy.delta
epsilon = rho + 2 * math.sqrt(rho * math.log(1/delta))
print(f"  Equivalent (Œµ,Œ¥)-DP: Œµ ‚âà {epsilon:.2f}, Œ¥ = {delta}")

print(f"\nüó∫Ô∏è Geographic Budget Split:")
for level, weight in config.privacy.geographic_split.items():
    level_rho = rho * weight
    print(f"  {level.capitalize()}: {weight:.0%} ‚Üí œÅ = {level_rho:.4f}")

print(f"\nüìã Query Budget Split:")
for query, weight in config.privacy.query_split.items():
    query_rho = rho * weight
    print(f"  {query}: {weight:.0%} ‚Üí œÅ = {query_rho:.4f}")

print(f"\nüîß Other Settings:")
print(f"  Contribution Bound Method: {config.privacy.contribution_bound_method}")
print(f"  Suppression Threshold: {config.privacy.suppression_threshold}")
print(f"  Sensitivity Method: {config.privacy.sensitivity_method}")
print(f"  MCC Grouping: {'Enabled' if config.privacy.mcc_grouping_enabled else 'Disabled'}")

print(f"\n‚úÖ Configuration validated!")


---
## 6. Run DP Pipeline

Execute the differential privacy pipeline with Top-Down Algorithm.


In [None]:
from core.pipeline import DPPipeline

print("="*70)
print("EXECUTING DP PIPELINE")
print("="*70)

start_time = datetime.now()

# Create and run pipeline
pipeline = DPPipeline(config)
result = pipeline.run()

end_time = datetime.now()
duration = (end_time - start_time).total_seconds()

print("\n" + "="*70)
print("PIPELINE RESULTS")
print("="*70)

if result['success']:
    print(f"\n‚úÖ SUCCESS!")
else:
    print(f"\n‚ùå FAILED!")

print(f"\nüìä Execution Summary:")
print(f"  Records Processed: {result.get('total_records', 'N/A'):,}")
print(f"  Privacy Budget Used: œÅ = {result.get('budget_used', 'N/A')}")
print(f"  Duration: {duration:.2f} seconds")
print(f"  Output Path: {result.get('output_path', 'N/A')}")

if result.get('errors'):
    print(f"\n‚ö†Ô∏è Errors:")
    for error in result['errors']:
        print(f"    - {error}")


---
## 7. Privacy Verification

Verify that privacy guarantees are correctly implemented.


In [None]:
print("="*70)
print("PRIVACY GUARANTEE VERIFICATION")
print("="*70)

if not result['success']:
    print("‚ö†Ô∏è Pipeline failed - skipping privacy verification")
else:
    import json
    
    # Load metadata
    output_path = config.data.output_path
    metadata_path = os.path.join(output_path, 'metadata.json')
    
    if os.path.exists(metadata_path):
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        print(f"\nüìã Pipeline Metadata:")
        print(json.dumps(metadata, indent=2))
    
    # Verify budget composition
    print(f"\nüîê Budget Composition Verification:")
    total_rho = float(config.privacy.total_rho)
    print(f"  Total Budget: œÅ = {total_rho}")
    
    # Geographic composition
    geo_rho_sum = sum(total_rho * w for w in config.privacy.geographic_split.values())
    print(f"  Geographic Split Sum: {geo_rho_sum:.4f} (should = {total_rho})")
    geo_check = abs(geo_rho_sum - total_rho) < 1e-6
    print(f"  Geographic Composition: {'‚úÖ VALID' if geo_check else '‚ùå INVALID'}")
    
    # Query composition
    query_sum = sum(config.privacy.query_split.values())
    print(f"  Query Split Sum: {query_sum:.4f} (should = 1.0)")
    query_check = abs(query_sum - 1.0) < 1e-6
    print(f"  Query Composition: {'‚úÖ VALID' if query_check else '‚ùå INVALID'}")
    
    # Sensitivity verification
    print(f"\nüéØ Sensitivity Verification:")
    # Get computed values, with fallback if not defined (from cell 10)
    try:
        computed_m = COMPUTED_M
        computed_d_max = COMPUTED_D_MAX
        computed_k = COMPUTED_K
    except NameError:
        print("  ‚ö†Ô∏è Warning: COMPUTED_M, COMPUTED_D_MAX, or COMPUTED_K not found.")
        print("  Using values from config or defaults.")
        computed_m = getattr(config.privacy, 'computed_m', None)
        computed_d_max = config.privacy.computed_d_max
        computed_k = config.privacy.computed_contribution_bound
        if computed_m is None or computed_d_max is None or computed_k is None:
            print("  ‚ö†Ô∏è Cannot compute sensitivities - missing required parameters.")
            computed_m = 1  # Default fallback
            computed_d_max = 1
            computed_k = 1
    
    d_max = config.privacy.computed_d_max or computed_d_max
    k_bound = config.privacy.computed_contribution_bound or computed_k
    
    print(f"  D_max (max days): {d_max}")
    print(f"  K (contribution bound): {k_bound}")
    print(f"  M (max cells): {computed_m}")
    
    sqrt_md = math.sqrt(computed_m * d_max)
    sens_count = sqrt_md * k_bound
    sens_unique = sqrt_md * 1
    
    print(f"\n  Expected Sensitivities (L2):")
    print(f"    transaction_count: ‚àö(M√óD_max)√óK = {sens_count:.2f}")
    print(f"    unique_cards: ‚àö(M√óD_max)√ó1 = {sens_unique:.2f}")
    
    # Privacy guarantee summary
    # Recompute epsilon and delta if not already defined (from cell 12)
    try:
        _ = epsilon
        _ = delta
    except NameError:
        delta = config.privacy.delta
        rho = float(config.privacy.total_rho)
        epsilon = rho + 2 * math.sqrt(rho * math.log(1/delta))
    
    print(f"\nüìú PRIVACY GUARANTEE SUMMARY:")
    print(f"  Mechanism: Discrete Gaussian (zCDP)")
    print(f"  Privacy Unit: (card_number, month)")
    print(f"  Composition: Sequential across days, Parallel across cells")
    print(f"  Total Budget: œÅ = {total_rho} zCDP")
    print(f"  Equivalent (Œµ,Œ¥)-DP: Œµ ‚âà {epsilon:.2f}, Œ¥ = {delta}")
    
    if geo_check and query_check:
        print(f"\n‚úÖ Privacy verification PASSED!")
    else:
        print(f"\n‚ùå Privacy verification FAILED!")


---
## 8. View Results

Load and examine the DP-protected output.


In [None]:
import json

print("="*70)
print("DP-PROTECTED OUTPUT")
print("="*70)

output_path = config.data.output_path

if os.path.exists(output_path):
    print(f"\nüìÅ Output directory: {output_path}")
    print(f"\nContents:")
    for item in os.listdir(output_path):
        item_path = os.path.join(output_path, item)
        if os.path.isfile(item_path):
            size = os.path.getsize(item_path)
            print(f"  - {item} ({size:,} bytes)")
        else:
            print(f"  - {item}/")
    
    # Load metadata
    metadata_path = os.path.join(output_path, 'metadata.json')
    if os.path.exists(metadata_path):
        print("\nüìã Metadata:")
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        print(json.dumps(metadata, indent=2))
    
    # Load protected data
    protected_data_path = os.path.join(output_path, "protected_data")
    if os.path.exists(protected_data_path):
        print(f"\nüìä Loading protected data...")
        dp_df = spark.read.parquet(protected_data_path)
        dp_count = dp_df.count()
        print(f"  Protected cells: {dp_count:,}")
        print(f"\n  Schema:")
        dp_df.printSchema()
        print(f"\n  Sample:")
        show_df(dp_df, n=10)
else:
    print(f"‚ùå Output directory not found: {output_path}")


---
## 9. Utility Evaluation

Compare original vs DP-protected data to measure utility loss.


In [None]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
from pyspark.sql.functions import col, count, countDistinct, sum as spark_sum

print("="*70)
print("UTILITY EVALUATION")
print("="*70)

if not result['success']:
    print("‚ö†Ô∏è Pipeline failed - skipping utility evaluation")
else:
    # Aggregate original data to same granularity
    print("\nüìä Aggregating original data...")
    original_agg = df_spark.groupBy('city', 'mcc', 'transaction_date').agg(
        count('*').alias('transaction_count'),
        countDistinct('card_number').alias('unique_cards'),
        spark_sum('transaction_amount').alias('transaction_amount_sum')
    )
    
    orig_count = original_agg.count()
    print(f"  Original cells: {orig_count:,}")
    
    # Load DP data
    protected_data_path = os.path.join(output_path, "protected_data")
    dp_agg = spark.read.parquet(protected_data_path)
    dp_count = dp_agg.count()
    print(f"  DP-protected cells: {dp_count:,}")
    
    # Compare totals using Spark
    print(f"\n" + "="*70)
    print("AGGREGATE LEVEL COMPARISON")
    print("="*70)
    
    NUMERIC_COLS = ['transaction_count', 'unique_cards', 'transaction_amount_sum']
    
    for col_name in NUMERIC_COLS:
        orig_total = original_agg.agg(spark_sum(col_name)).collect()[0][0] or 0
        dp_total = dp_agg.agg(spark_sum(col_name)).collect()[0][0] or 0
        
        if orig_total > 0:
            error_pct = abs(dp_total - orig_total) / orig_total * 100
            status = "‚úÖ" if error_pct < 5 else ("‚ö†Ô∏è" if error_pct < 15 else "‚ùå")
        else:
            error_pct = 0
            status = "‚ö†Ô∏è"
        
        print(f"\n{col_name}:")
        print(f"  Original Total: {orig_total:,.0f}")
        print(f"  DP Total: {dp_total:,.0f}")
        print(f"  Error: {error_pct:.2f}% {status}")


---
## 10. Production Readiness Checklist

Verify the system is ready for production deployment.


---
## 10.1 Research-Grade DP Validation (Census 2020 Methodology)

The following tests validate the DP implementation according to standards used in the US Census 2020 Disclosure Avoidance System:

**A. Statistical Accuracy Tests**
- Per-cell error distribution analysis
- Bias verification (should be ~0 for unbiased mechanisms)
- Variance verification against theoretical œÉ¬≤

**B. Privacy Guarantee Verification**
- Sensitivity computation validation
- Noise calibration verification
- Composition theorem verification

**C. Utility Metrics (Census 2020 Standard)**
- Root Mean Square Error (RMSE)
- Mean Absolute Error (MAE)
- Coefficient of Variation (CV)
- Coverage probability for confidence intervals


In [None]:
"""
RESEARCH-GRADE DP VALIDATION
Following US Census 2020 DAS methodology
"""
import numpy as np
from scipy import stats as scipy_stats

print("="*70)
print("RESEARCH-GRADE DP VALIDATION (Census 2020 Methodology)")
print("="*70)

if not result['success']:
    print("‚ö†Ô∏è Pipeline failed - skipping research validation")
else:
    # =========================================================================
    # A. PER-CELL ERROR ANALYSIS
    # =========================================================================
    print("\n" + "="*70)
    print("A. PER-CELL ERROR ANALYSIS")
    print("="*70)
    
    # Join original and DP data at cell level
    # First, prepare original aggregates with matching schema
    original_cells = df_spark.groupBy('city', 'mcc', 'transaction_date').agg(
        F.count('*').alias('orig_count'),
        F.countDistinct('card_number').alias('orig_unique'),
        F.sum('transaction_amount').alias('orig_amount')
    )
    
    # Load DP data
    dp_cells = spark.read.parquet(os.path.join(output_path, "protected_data"))
    
    # Rename DP columns for join
    dp_renamed = dp_cells.select(
        F.col('city').alias('dp_city'),
        F.col('mcc').alias('dp_mcc'),
        F.col('transaction_date').alias('dp_date'),
        F.col('transaction_count').alias('dp_count'),
        F.col('unique_cards').alias('dp_unique'),
        F.col('transaction_amount_sum').alias('dp_amount')
    )
    
    # Join on cell key
    joined = original_cells.join(
        dp_renamed,
        (original_cells.city == dp_renamed.dp_city) &
        (original_cells.mcc == dp_renamed.dp_mcc) &
        (original_cells.transaction_date == dp_renamed.dp_date),
        "outer"
    ).fillna(0)
    
    # Compute errors
    errors_df = joined.select(
        'city', 'mcc', 'transaction_date',
        'orig_count', 'dp_count',
        (F.col('dp_count') - F.col('orig_count')).alias('error_count'),
        'orig_unique', 'dp_unique',
        (F.col('dp_unique') - F.col('orig_unique')).alias('error_unique'),
        'orig_amount', 'dp_amount',
        (F.col('dp_amount') - F.col('orig_amount')).alias('error_amount')
    )
    
    # Compute error statistics
    error_stats = errors_df.agg(
        # Count errors
        F.count('*').alias('n_cells'),
        F.mean('error_count').alias('bias_count'),
        F.stddev('error_count').alias('std_count'),
        F.expr('percentile_approx(abs(error_count), 0.5)').alias('mae_count'),
        F.sqrt(F.mean(F.pow('error_count', 2))).alias('rmse_count'),
        # Unique card errors
        F.mean('error_unique').alias('bias_unique'),
        F.stddev('error_unique').alias('std_unique'),
        F.expr('percentile_approx(abs(error_unique), 0.5)').alias('mae_unique'),
        F.sqrt(F.mean(F.pow('error_unique', 2))).alias('rmse_unique'),
        # Amount errors
        F.mean('error_amount').alias('bias_amount'),
        F.stddev('error_amount').alias('std_amount'),
        F.sqrt(F.mean(F.pow('error_amount', 2))).alias('rmse_amount')
    ).collect()[0]
    
    print(f"\nüìä Error Statistics Across {error_stats['n_cells']:,} Cells:")
    print(f"\n  TRANSACTION COUNT:")
    print(f"    Bias (should be ‚âà0): {error_stats['bias_count']:.4f}")
    print(f"    Std Dev: {error_stats['std_count']:.2f}")
    print(f"    MAE: {error_stats['mae_count']:.2f}")
    print(f"    RMSE: {error_stats['rmse_count']:.2f}")
    
    print(f"\n  UNIQUE CARDS:")
    print(f"    Bias (should be ‚âà0): {error_stats['bias_unique']:.4f}")
    print(f"    Std Dev: {error_stats['std_unique']:.2f}")
    print(f"    MAE: {error_stats['mae_unique']:.2f}")
    print(f"    RMSE: {error_stats['rmse_unique']:.2f}")
    
    print(f"\n  TRANSACTION AMOUNT:")
    print(f"    Bias (should be ‚âà0): {error_stats['bias_amount']:.2f}")
    print(f"    Std Dev: {error_stats['std_amount']:.2f}")
    print(f"    RMSE: {error_stats['rmse_amount']:.2f}")
    
    # Bias test (should not reject H0: bias=0)
    n_cells = error_stats['n_cells']
    bias_count = error_stats['bias_count']
    std_count = error_stats['std_count']
    
    if std_count > 0 and n_cells > 30:
        t_stat = bias_count / (std_count / np.sqrt(n_cells))
        p_value = 2 * (1 - scipy_stats.t.cdf(abs(t_stat), n_cells - 1))
        bias_test_pass = p_value > 0.05
        print(f"\n  üìà Bias Test (H0: bias=0):")
        print(f"    t-statistic: {t_stat:.4f}")
        print(f"    p-value: {p_value:.4f}")
        print(f"    Result: {'‚úÖ PASS (unbiased)' if bias_test_pass else '‚ùå FAIL (biased)'}")


---
## 12. Interactive 3D Visualization of DP Noise

Scientific-level visualization of differential privacy noise effects using interactive 3D surface plots.

**Features:**
- **Dual Surface Plots**: Side-by-side comparison of original vs DP-protected data
- **Dynamic Axes**: User-configurable X/Y axes from (City, MCC, Day)
- **Province/Month Filtering**: Select specific province and month for analysis
- **Metric Selection**: Visualize any of the three queries
- **Statistical Metrics**: RMSE, MAE, and maximum error displayed
- **Publication Quality**: Suitable for research papers and presentations


In [None]:
"""
INTERACTIVE 3D VISUALIZATION OF DIFFERENTIAL PRIVACY NOISE
===========================================================
Scientific-level visualization using Plotly for publication-quality figures.
"""

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd
from ipywidgets import widgets, interactive
from IPython.display import display, HTML

print("="*70)
print("3D VISUALIZATION: ORIGINAL vs DP-PROTECTED DATA")
print("="*70)

if not result['success']:
    print("‚ö†Ô∏è Pipeline failed - cannot create visualization")
else:
    # =========================================================================
    # STEP 1: DATA PREPARATION
    # =========================================================================
    print("\nüìä Preparing data for visualization...")
    
    # Load original data with province info
    from schema.geography import Geography
    
    geo = Geography.from_csv(CITY_PROVINCE_PATH)
    
    # Create broadcast mapping for province lookup
    city_to_province_map = geo.city_to_province_broadcast()
    
    # Aggregate original data to cell level
    original_cells = df_spark.groupBy('city', 'mcc', 'transaction_date').agg(
        F.count('*').alias('orig_transaction_count'),
        F.countDistinct('card_number').alias('orig_unique_cards'),
        F.sum('transaction_amount').alias('orig_total_amount')
    )
    
    # Add province information to original cells
    @F.udf('int')
    def get_province_code(city):
        if city in city_to_province_map:
            return city_to_province_map[city][0]  # province_code
        return geo.UNKNOWN_PROVINCE_CODE  # Unknown
    
    @F.udf('int')
    def get_city_code(city):
        if city in city_to_province_map:
            return city_to_province_map[city][2]  # city_code
        return geo.UNKNOWN_CITY_CODE  # Unknown
    
    original_cells = original_cells.withColumn('province_code', get_province_code('city'))
    original_cells = original_cells.withColumn('city_code', get_city_code('city'))
    
    # Load DP-protected data
    dp_cells = spark.read.parquet(os.path.join(output_path, "protected_data"))
    
    # Convert dates to day indices for both
    # Extract day index from transaction_date (assuming it's 0-based day index or date string)
    original_cells = original_cells.withColumn(
        'day_idx',
        F.when(F.col('transaction_date').cast('int').isNotNull(), 
               F.col('transaction_date').cast('int'))
        .otherwise(F.datediff(F.col('transaction_date'), F.min('transaction_date').over(F.Window.orderBy())))
    )
    
    # Join original and DP data
    joined_data = original_cells.join(
        dp_cells,
        (original_cells.province_code == dp_cells.province_code) &
        (original_cells.city_code == dp_cells.city_code) &
        (original_cells.mcc == dp_cells.mcc) &
        (original_cells.day_idx == dp_cells.day_idx),
        "outer"
    ).fillna(0)
    
    # Convert to Pandas for visualization (should be manageable size after aggregation)
    viz_df = joined_data.select(
        F.coalesce(original_cells.province_code, dp_cells.province_code).alias('province_code'),
        F.coalesce(original_cells.city_code, dp_cells.city_code).alias('city_code'),
        F.coalesce(original_cells.mcc, dp_cells.mcc).alias('mcc'),
        F.coalesce(original_cells.day_idx, dp_cells.day_idx).alias('day_idx'),
        'orig_transaction_count',
        'orig_unique_cards',
        'orig_total_amount',
        'transaction_count',
        'unique_cards',
        'transaction_amount_sum'
    ).toPandas()
    
    # Rename DP columns for consistency
    viz_df.rename(columns={
        'transaction_count': 'dp_transaction_count',
        'unique_cards': 'dp_unique_cards',
        'transaction_amount_sum': 'dp_total_amount'
    }, inplace=True)
    
    print(f"‚úÖ Data prepared: {len(viz_df):,} cells loaded")
    
    # Get unique values for filters
    provinces = sorted(viz_df['province_code'].unique())
    cities = sorted(viz_df['city_code'].unique())
    mccs = sorted(viz_df['mcc'].unique())
    days = sorted(viz_df['day_idx'].unique())
    
    print(f"  Provinces: {len(provinces)}")
    print(f"  Cities: {len(cities)}")
    print(f"  MCCs: {len(mccs)}")
    print(f"  Days: {len(days)}")
    
    # =========================================================================
    # STEP 2: VISUALIZATION FUNCTION
    # =========================================================================
    
    def create_3d_surface_comparison(
        df, 
        x_axis='day_idx', 
        y_axis='mcc', 
        metric='transaction_count',
        province_filter=None,
        month_filter=None
    ):
        """
        Create side-by-side 3D surface plots comparing original and DP-protected data.
        
        Args:
            df: DataFrame with joined original and DP data
            x_axis: Column name for X-axis ('city_code', 'mcc', 'day_idx')
            y_axis: Column name for Y-axis ('city_code', 'mcc', 'day_idx')
            metric: Metric to visualize ('transaction_count', 'unique_cards', 'total_amount')
            province_filter: Province code to filter (None = all provinces)
            month_filter: Month to filter (None = all months)
        
        Returns:
            Plotly figure object
        """
        # Filter data
        filtered_df = df.copy()
        if province_filter is not None:
            filtered_df = filtered_df[filtered_df['province_code'] == province_filter]
        
        # Determine aggregation dimension (the one not used for axes)
        all_dims = {'city_code', 'mcc', 'day_idx'}
        used_dims = {x_axis, y_axis}
        agg_dim = list(all_dims - used_dims)[0]
        
        # Aggregate over the unused dimension
        orig_col = f'orig_{metric}'
        dp_col = f'dp_{metric}'
        
        grouped = filtered_df.groupby([x_axis, y_axis]).agg({
            orig_col: 'sum',
            dp_col: 'sum'
        }).reset_index()
        
        # Create pivot tables for surface plots
        pivot_orig = grouped.pivot(index=y_axis, columns=x_axis, values=orig_col)
        pivot_dp = grouped.pivot(index=y_axis, columns=x_axis, values=dp_col)
        
        # Fill missing values with NaN for proper visualization
        pivot_orig = pivot_orig.fillna(0)
        pivot_dp = pivot_dp.fillna(0)
        
        # Convert to numpy arrays
        z_orig = pivot_orig.values
        z_dp = pivot_dp.values
        x_vals = pivot_orig.columns.values
        y_vals = pivot_orig.index.values
        
        # Compute error metrics
        valid_mask = (z_orig > 0) | (z_dp > 0)
        if valid_mask.sum() > 0:
            errors = z_dp[valid_mask] - z_orig[valid_mask]
            rmse = np.sqrt(np.mean(errors**2))
            mae = np.mean(np.abs(errors))
            max_error = np.max(np.abs(errors))
            bias = np.mean(errors)
            
            # Relative error for non-zero cells
            nonzero_mask = z_orig[valid_mask] > 0
            if nonzero_mask.sum() > 0:
                rel_errors = np.abs(errors[nonzero_mask]) / z_orig[valid_mask][nonzero_mask] * 100
                mean_rel_error = np.mean(rel_errors)
            else:
                mean_rel_error = 0
        else:
            rmse = mae = max_error = bias = mean_rel_error = 0
        
        # Create subplot figure (1 row, 2 columns)
        fig = make_subplots(
            rows=1, cols=2,
            subplot_titles=('Original Data', 'DP-Protected Data'),
            specs=[[{'type': 'surface'}, {'type': 'surface'}]],
            horizontal_spacing=0.05
        )
        
        # Determine color scale range (use same for both plots)
        vmin = min(z_orig.min(), z_dp.min())
        vmax = max(z_orig.max(), z_dp.max())
        
        # Color scale: professional scientific palette
        colorscale = 'Viridis'  # or 'Plasma', 'Inferno', 'Turbo'
        
        # Original data surface
        fig.add_trace(
            go.Surface(
                z=z_orig,
                x=x_vals,
                y=y_vals,
                colorscale=colorscale,
                cmin=vmin,
                cmax=vmax,
                showscale=False,
                hovertemplate=(
                    f'{x_axis}: %{{x}}<br>'
                    f'{y_axis}: %{{y}}<br>'
                    'Value: %{z:,.0f}<br>'
                    '<extra></extra>'
                ),
                name='Original'
            ),
            row=1, col=1
        )
        
        # DP-protected data surface
        fig.add_trace(
            go.Surface(
                z=z_dp,
                x=x_vals,
                y=y_vals,
                colorscale=colorscale,
                cmin=vmin,
                cmax=vmax,
                showscale=True,
                colorbar=dict(
                    title=metric.replace('_', ' ').title(),
                    x=1.02
                ),
                hovertemplate=(
                    f'{x_axis}: %{{x}}<br>'
                    f'{y_axis}: %{{y}}<br>'
                    'Value: %{z:,.0f}<br>'
                    '<extra></extra>'
                ),
                name='DP-Protected'
            ),
            row=1, col=2
        )
        
        # Update layout
        axis_labels = {
            'city_code': 'City Code',
            'mcc': 'MCC',
            'day_idx': 'Day Index'
        }
        
        title_parts = [f'3D Surface: {metric.replace("_", " ").title()}']
        if province_filter is not None:
            title_parts.append(f'Province {province_filter}')
        title_parts.append(f'Aggregated over {agg_dim.replace("_", " ")}')
        title_parts.append(f'(œÅ={float(config.privacy.total_rho):.3f})')
        
        fig.update_layout(
            title=dict(
                text=' | '.join(title_parts),
                font=dict(size=14)
            ),
            scene=dict(
                xaxis_title=axis_labels.get(x_axis, x_axis),
                yaxis_title=axis_labels.get(y_axis, y_axis),
                zaxis_title='Value',
                camera=dict(eye=dict(x=1.5, y=1.5, z=1.3))
            ),
            scene2=dict(
                xaxis_title=axis_labels.get(x_axis, x_axis),
                yaxis_title=axis_labels.get(y_axis, y_axis),
                zaxis_title='Value',
                camera=dict(eye=dict(x=1.5, y=1.5, z=1.3))
            ),
            height=600,
            width=1400,
            showlegend=False
        )
        
        # Add statistical annotation
        annotation_text = (
            f'<b>Statistical Metrics:</b><br>'
            f'RMSE: {rmse:,.2f} | '
            f'MAE: {mae:,.2f} | '
            f'Max Error: {max_error:,.0f} | '
            f'Bias: {bias:,.2f} | '
            f'Mean Rel. Error: {mean_rel_error:.1f}%<br>'
            f'Cells: {valid_mask.sum():,} | '
            f'Province-Month Total: EXACT (public invariant)'
        )
        
        fig.add_annotation(
            text=annotation_text,
            xref="paper", yref="paper",
            x=0.5, y=-0.05,
            showarrow=False,
            font=dict(size=11),
            align='center',
            xanchor='center'
        )
        
        return fig
    
    # =========================================================================
    # STEP 3: INTERACTIVE CONTROLS
    # =========================================================================
    
    print("\nüéõÔ∏è Creating interactive controls...")
    
    # Metric mapping
    metric_options = {
        'Transaction Count': 'transaction_count',
        'Unique Cards': 'unique_cards',
        'Total Amount': 'total_amount'
    }
    
    # Axis options
    axis_options = {
        'Day Index': 'day_idx',
        'MCC (Merchant Category)': 'mcc',
        'City Code': 'city_code'
    }
    
    # Create widgets
    x_axis_widget = widgets.Dropdown(
        options=list(axis_options.keys()),
        value='Day Index',
        description='X-Axis:',
        style={'description_width': '120px'}
    )
    
    y_axis_widget = widgets.Dropdown(
        options=list(axis_options.keys()),
        value='MCC (Merchant Category)',
        description='Y-Axis:',
        style={'description_width': '120px'}
    )
    
    metric_widget = widgets.Dropdown(
        options=list(metric_options.keys()),
        value='Transaction Count',
        description='Metric:',
        style={'description_width': '120px'}
    )
    
    province_widget = widgets.Dropdown(
        options=[('All Provinces', None)] + [(f'Province {p}', p) for p in provinces],
        value=None,
        description='Province:',
        style={'description_width': '120px'}
    )
    
    # Update button
    update_button = widgets.Button(
        description='Update Visualization',
        button_style='primary',
        icon='refresh'
    )
    
    # Output widget for the plot
    output_widget = widgets.Output()
    
    # =========================================================================
    # STEP 4: UPDATE FUNCTION
    # =========================================================================
    
    def update_plot(b=None):
        """Update the 3D visualization based on widget selections."""
        with output_widget:
            output_widget.clear_output(wait=True)
            
            # Get selected values
            x_axis_name = x_axis_widget.value
            y_axis_name = y_axis_widget.value
            metric_name = metric_widget.value
            province_val = province_widget.value
            
            x_axis = axis_options[x_axis_name]
            y_axis = axis_options[y_axis_name]
            metric = metric_options[metric_name]
            
            # Validate axes are different
            if x_axis == y_axis:
                print("‚ö†Ô∏è X-axis and Y-axis must be different. Please select different axes.")
                return
            
            # Create and display figure
            try:
                fig = create_3d_surface_comparison(
                    viz_df,
                    x_axis=x_axis,
                    y_axis=y_axis,
                    metric=metric,
                    province_filter=province_val
                )
                fig.show()
            except Exception as e:
                print(f"‚ùå Error creating visualization: {e}")
                import traceback
                traceback.print_exc()
    
    update_button.on_click(update_plot)
    
    # =========================================================================
    # STEP 5: DISPLAY INTERFACE
    # =========================================================================
    
    print("‚úÖ Visualization ready!")
    print("\n" + "="*70)
    print("INTERACTIVE CONTROLS")
    print("="*70)
    print("Configure the visualization parameters below and click 'Update Visualization'")
    print("\nNote: The third dimension (not selected for X or Y) will be aggregated.")
    print("Province-month totals are EXACT (public data) - noise is at cell level.")
    print("="*70)
    
    # Display controls
    display(HTML("<h3>Visualization Configuration</h3>"))
    display(widgets.VBox([
        widgets.HBox([x_axis_widget, y_axis_widget]),
        widgets.HBox([metric_widget, province_widget]),
        update_button
    ]))
    
    # Display output area
    display(output_widget)
    
    # Create initial plot
    print("\nüìä Generating initial visualization...")
    update_plot()


In [None]:
    # =========================================================================
    # B. PRIVACY GUARANTEE VERIFICATION
    # =========================================================================
    print("\n" + "="*70)
    print("B. PRIVACY GUARANTEE VERIFICATION")
    print("="*70)
    
    # Get privacy parameters
    total_rho = float(config.privacy.total_rho)
    delta = config.privacy.delta
    
    # Compute theoretical noise parameters
    # For zCDP with œÅ, the Gaussian mechanism uses œÉ¬≤ = Œî¬≤/(2œÅ)
    
    # Get sensitivity values (from preprocessing or computed)
    try:
        d_max_val = config.privacy.computed_d_max or COMPUTED_D_MAX
        k_val = config.privacy.computed_contribution_bound or COMPUTED_K
        m_val = COMPUTED_M
    except NameError:
        d_max_val = config.privacy.computed_d_max or 1
        k_val = config.privacy.computed_contribution_bound or 1
        m_val = 1
    
    # L2 sensitivity for count query: sqrt(M * D_max) * K
    l2_sens_count = np.sqrt(m_val * d_max_val) * k_val
    l2_sens_unique = np.sqrt(m_val * d_max_val) * 1  # Each card contributes 1
    
    # Budget per query (assuming equal split for simplicity)
    rho_per_query = total_rho * config.privacy.query_split.get('transaction_count', 0.34)
    rho_per_query_city = rho_per_query * config.privacy.geographic_split.get('city', 0.8)
    
    # Theoretical variance: œÉ¬≤ = Œî¬≤/(2œÅ)
    theoretical_var_count = (l2_sens_count ** 2) / (2 * rho_per_query_city)
    theoretical_std_count = np.sqrt(theoretical_var_count)
    
    print(f"\nüîê Privacy Parameters:")
    print(f"  Total œÅ (zCDP): {total_rho}")
    print(f"  Œ¥: {delta}")
    print(f"  Œµ (converted): {total_rho + 2 * np.sqrt(total_rho * np.log(1/delta)):.2f}")
    
    print(f"\nüéØ Sensitivity Analysis:")
    print(f"  M (max cells per card): {m_val}")
    print(f"  D_max (max days per card): {d_max_val}")
    print(f"  K (contribution bound): {k_val}")
    print(f"  L2 Sensitivity (count): {l2_sens_count:.2f}")
    print(f"  L2 Sensitivity (unique): {l2_sens_unique:.2f}")
    
    print(f"\nüìä Noise Calibration Verification:")
    print(f"  œÅ per query (city level): {rho_per_query_city:.6f}")
    print(f"  Theoretical œÉ (count): {theoretical_std_count:.2f}")
    print(f"  Observed œÉ (count): {error_stats['std_count']:.2f}")
    
    # Check if observed variance is close to theoretical
    # Allow 50% tolerance due to post-processing (NNLS, rounding)
    var_ratio = error_stats['std_count'] / theoretical_std_count if theoretical_std_count > 0 else float('inf')
    var_check = 0.5 <= var_ratio <= 2.0
    
    print(f"  Ratio (observed/theoretical): {var_ratio:.2f}")
    print(f"  Variance Check: {'‚úÖ PASS' if var_check else '‚ö†Ô∏è WARNING (post-processing may affect variance)'}")
    
    # =========================================================================
    # C. COMPOSITION VERIFICATION
    # =========================================================================
    print("\n" + "="*70)
    print("C. COMPOSITION THEOREM VERIFICATION")
    print("="*70)
    
    # zCDP composition: œÅ_total = Œ£ œÅ_i (additive)
    geo_weights = config.privacy.geographic_split
    query_weights = config.privacy.query_split
    
    print(f"\nüìê Budget Composition:")
    print(f"  Geographic levels: {list(geo_weights.keys())}")
    print(f"  Queries: {list(query_weights.keys())}")
    
    # Verify weights sum to 1
    geo_sum = sum(geo_weights.values())
    query_sum = sum(query_weights.values())
    
    print(f"\n  Geographic weights sum: {geo_sum:.4f} (should = 1.0)")
    print(f"  Query weights sum: {query_sum:.4f} (should = 1.0)")
    
    # Total budget breakdown
    print(f"\n  Budget Allocation:")
    for geo_level, geo_w in geo_weights.items():
        for query, query_w in query_weights.items():
            allocated_rho = total_rho * geo_w * query_w
            print(f"    {geo_level}/{query}: œÅ = {allocated_rho:.6f}")
    
    # Verify total
    total_allocated = sum(
        total_rho * geo_w * query_w 
        for geo_w in geo_weights.values() 
        for query_w in query_weights.values()
    )
    composition_valid = abs(total_allocated - total_rho) < 1e-10
    
    print(f"\n  Total allocated: œÅ = {total_allocated:.6f}")
    print(f"  Original budget: œÅ = {total_rho:.6f}")
    print(f"  Composition: {'‚úÖ VALID' if composition_valid else '‚ùå INVALID'}")


In [None]:
    # =========================================================================
    # D. UTILITY BY COUNT SIZE (Census 2020 Style Analysis)
    # =========================================================================
    print("\n" + "="*70)
    print("D. UTILITY BY COUNT SIZE (Census 2020 Analysis)")
    print("="*70)
    
    # Stratify by original count size
    stratified = errors_df.withColumn(
        'count_bucket',
        F.when(F.col('orig_count') == 0, '0 (empty)')
        .when(F.col('orig_count') <= 5, '1-5 (small)')
        .when(F.col('orig_count') <= 20, '6-20 (medium)')
        .when(F.col('orig_count') <= 100, '21-100 (large)')
        .otherwise('>100 (very large)')
    )
    
    bucket_stats = stratified.groupBy('count_bucket').agg(
        F.count('*').alias('n_cells'),
        F.mean('error_count').alias('mean_error'),
        F.stddev('error_count').alias('std_error'),
        F.mean(F.abs('error_count')).alias('mae'),
        F.mean(
            F.when(F.col('orig_count') > 0, 
                   F.abs(F.col('error_count')) / F.col('orig_count') * 100)
            .otherwise(None)
        ).alias('mape')
    ).orderBy('count_bucket')
    
    print("\nüìä Error by Original Count Size:")
    print("-" * 80)
    print(f"{'Bucket':<20} {'N Cells':>10} {'Mean Err':>12} {'Std Err':>12} {'MAE':>10} {'MAPE %':>10}")
    print("-" * 80)
    
    for row in bucket_stats.collect():
        mape_str = f"{row['mape']:.1f}" if row['mape'] is not None else "N/A"
        print(f"{row['count_bucket']:<20} {row['n_cells']:>10,} {row['mean_error']:>12.2f} "
              f"{row['std_error']:>12.2f} {row['mae']:>10.2f} {mape_str:>10}")
    
    # =========================================================================
    # E. RESEARCH READINESS SUMMARY
    # =========================================================================
    print("\n" + "="*70)
    print("E. RESEARCH READINESS SUMMARY")
    print("="*70)
    
    research_checks = []
    
    # 1. Unbiasedness
    bias_ok = abs(error_stats['bias_count']) < 1.0  # Allow small bias
    research_checks.append(('Unbiased Mechanism', bias_ok))
    
    # 2. Variance calibration
    research_checks.append(('Variance Calibration', var_check))
    
    # 3. Composition validity
    research_checks.append(('Budget Composition', composition_valid))
    
    # 4. Reasonable utility (MAPE < 50% for medium+ cells)
    medium_plus = stratified.filter(F.col('orig_count') > 5)
    if medium_plus.count() > 0:
        avg_mape = medium_plus.filter(F.col('orig_count') > 0).agg(
            F.mean(F.abs(F.col('error_count')) / F.col('orig_count') * 100)
        ).collect()[0][0]
        utility_ok = avg_mape is not None and avg_mape < 50
        research_checks.append(('Reasonable Utility (MAPE<50%)', utility_ok))
    
    # 5. No systematic errors
    systematic_ok = abs(error_stats['bias_unique']) < 1.0
    research_checks.append(('No Systematic Errors', systematic_ok))
    
    print("\nüìã Research Validation Checklist:")
    for check_name, passed in research_checks:
        status = '‚úÖ' if passed else '‚ùå'
        print(f"  {status} {check_name}")
    
    all_research_passed = all(c[1] for c in research_checks)
    
    print("\n" + "="*70)
    if all_research_passed:
        print("üéì RESEARCH READY: This DP implementation passes Census 2020-style validation.")
        print("   The methodology is suitable for academic research and publication.")
    else:
        print("‚ö†Ô∏è NOT RESEARCH READY: Some validation checks failed.")
        print("   Review the failed checks before using for research.")
        failed = [c[0] for c in research_checks if not c[1]]
        print(f"   Failed: {', '.join(failed)}")
    print("="*70)


In [None]:
print("="*70)
print("PRODUCTION READINESS CHECKLIST")
print("="*70)

checks = []

# 1. Pipeline Success
check_1 = result['success']
checks.append(('Pipeline Execution', check_1))
print(f"\n{'‚úÖ' if check_1 else '‚ùå'} Pipeline Execution: {'PASSED' if check_1 else 'FAILED'}")

# 2. Output Files Exist
output_exists = os.path.exists(os.path.join(output_path, 'protected_data'))
checks.append(('Output Files', output_exists))
print(f"{'‚úÖ' if output_exists else '‚ùå'} Output Files: {'EXIST' if output_exists else 'MISSING'}")

# 3. Metadata Present
metadata_exists = os.path.exists(os.path.join(output_path, 'metadata.json'))
checks.append(('Metadata', metadata_exists))
print(f"{'‚úÖ' if metadata_exists else '‚ùå'} Metadata: {'PRESENT' if metadata_exists else 'MISSING'}")

# 4. Budget Composition Valid
budget_valid = abs(sum(config.privacy.geographic_split.values()) - 1.0) < 1e-6
budget_valid = budget_valid and abs(sum(config.privacy.query_split.values()) - 1.0) < 1e-6
checks.append(('Budget Composition', budget_valid))
print(f"{'‚úÖ' if budget_valid else '‚ùå'} Budget Composition: {'VALID' if budget_valid else 'INVALID'}")

# 5. No Negative Counts (sanity check)
if output_exists:
    dp_df = spark.read.parquet(os.path.join(output_path, 'protected_data'))
    neg_counts = dp_df.filter(F.col('transaction_count') < 0).count()
    no_negative = neg_counts == 0
    checks.append(('No Negative Counts', no_negative))
    print(f"{'‚úÖ' if no_negative else '‚ö†Ô∏è'} No Negative Counts: {'PASSED' if no_negative else f'{neg_counts} negative values'}")

# 6. Reasonable Processing Time
reasonable_time = duration < 300  # 5 minutes for test data
checks.append(('Processing Time', reasonable_time))
print(f"{'‚úÖ' if reasonable_time else '‚ö†Ô∏è'} Processing Time: {duration:.1f}s {'(OK)' if reasonable_time else '(SLOW)'}")

# Summary
all_passed = all(c[1] for c in checks)
passed_count = sum(1 for c in checks if c[1])

print(f"\n" + "="*70)
print(f"SUMMARY: {passed_count}/{len(checks)} checks passed")
print("="*70)

if all_passed:
    print(f"\nüéâ PRODUCTION READY!")
    print(f"   The DP system has passed all checks and is ready for deployment.")
else:
    print(f"\n‚ö†Ô∏è NOT READY FOR PRODUCTION")
    print(f"   Please address the failed checks before deployment.")
    failed = [c[0] for c in checks if not c[1]]
    print(f"   Failed: {', '.join(failed)}")


---
## 11. Cleanup & Summary


In [None]:
# Uncomment to clean up generated output files
# import shutil
# 
# if os.path.exists(config.data.output_path):
#     shutil.rmtree(config.data.output_path)
#     print(f"Removed: {config.data.output_path}")

print("="*70)
print("NOTEBOOK COMPLETE")
print("="*70)
print(f"\nTimestamp: {datetime.now().isoformat()}")
print(f"\nüìã Summary:")
print(f"  - Records processed: {result.get('total_records', 'N/A'):,}")
print(f"  - Privacy budget: œÅ = {config.privacy.total_rho}")
print(f"  - Pipeline status: {'‚úÖ SUCCESS' if result['success'] else '‚ùå FAILED'}")

# Check if all_passed is defined (from cell 22)
try:
    production_status = '‚úÖ YES' if all_passed else '‚ùå NO'
    print(f"  - Production ready: {production_status}")
except NameError:
    print(f"  - Production ready: ‚ö†Ô∏è Run cell 22 to check production readiness")

print(f"\nüìä Output Metrics (with DP):")
print(f"  - transaction_count: Count of transactions per (date, city, mcc)")
print(f"  - unique_cards: Count of distinct cards per (date, city, mcc)")
print(f"  - transaction_amount_sum: Sum of amounts per (date, city, mcc)")
