# JSON Flattening Toolkit - Comprehensive Guide for Data Engineers & Scientists

> **A world-class exploration of JSON flattening techniques, patterns, and real-world applications**

This notebook is organized into **10 self-contained milestones**, each focusing on specific aspects of JSON flattening. You can work through them sequentially or jump to specific topics of interest.

## üìö Table of Contents

### Foundations
- **[Milestone 1: Foundations & Core Concepts](#milestone-1)** - Basic flattening, list policies, separators
- **[Milestone 2: Array Handling Strategies](#milestone-2)** - Index vs join, explosion, cartesian products

### Advanced Techniques  
- **[Milestone 3: Complex Structures](#milestone-3)** - Deep nesting, mixed types, null handling

### Real-World Use Cases
- **[Milestone 4: E-commerce Data](#milestone-4)** - Orders, products, customers, transactions
- **[Milestone 5: API & Event Data](#milestone-5)** - API responses, webhooks, event logs

### Data Pipelines
- **[Milestone 6: CSV Operations & Pipelines](#milestone-6)** - Read/write, transformations, batch processing

### Database Integration
- **[Milestone 7: MongoDB Integration](#milestone-7)** - Ingestion, querying, type inference
- **[Milestone 8: Snowflake Integration](#milestone-8)** - Schema generation, ingestion, queries

### Production Patterns
- **[Milestone 9: Advanced Patterns & Best Practices](#milestone-9)** - Performance, memory, error handling
- **[Milestone 10: End-to-End Workflows](#milestone-10)** - Complete pipelines, production examples

---

## üéØ Learning Objectives

By the end of this notebook, you will be able to:
- ‚úÖ Flatten complex nested JSON structures efficiently
- ‚úÖ Choose appropriate array handling strategies for your use case
- ‚úÖ Build data pipelines from JSON to CSV to databases
- ‚úÖ Handle edge cases (nulls, empty arrays, mixed types)
- ‚úÖ Integrate with MongoDB and Snowflake
- ‚úÖ Apply best practices for production systems

## üöÄ Quick Start

Let's set up our environment and import the necessary modules.

In [None]:
# ============================================================================
# IMPORTS - All imports at the top for clarity
# ============================================================================

import json
import sys
import os
import time
from pathlib import Path
from datetime import datetime
from typing import Any, Dict, List, Optional
from collections import Counter

# Ensure local package is importable when running in Docker/Repo root
PROJECT_ROOT = Path.cwd()
if (PROJECT_ROOT / "json_flatten").exists():
    sys.path.insert(0, str(PROJECT_ROOT))
elif (PROJECT_ROOT / "mongodb" / "json_flatten").exists():
    sys.path.insert(0, str(PROJECT_ROOT / "mongodb"))

# Core flattening functions
from json_flatten import flatten_json, flatten_records, write_csv, read_csv
from json_flatten.scenarios import get_scenarios

# Optional: MongoDB and Snowflake (may not be available)
try:
    from json_flatten.mongodb_io import ingest_csv_to_mongodb, query_mongodb, infer_type
    MONGO_AVAILABLE = True
except ImportError:
    MONGO_AVAILABLE = False
    print("‚ö† MongoDB integration not available (pymongo not installed)")

try:
    from json_flatten.snowflake_io import create_table_schema, ingest_csv_to_snowflake, query_snowflake
    SNOWFLAKE_AVAILABLE = True
except ImportError:
    SNOWFLAKE_AVAILABLE = False
    print("‚ö† Snowflake integration not available (snowflake-connector-python not installed)")

# PySpark imports (for large-scale processing)
try:
    import findspark
    findspark.init()
    from pyspark.sql import SparkSession
    from pyspark.sql.functions import col, explode, from_json, schema_of_json
    from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, BooleanType, ArrayType
    PYSPARK_AVAILABLE = True
except ImportError:
    PYSPARK_AVAILABLE = False
    print("‚ö† PySpark not available (pyspark not installed)")

# Setup output directory
OUTPUT_DIR = Path("notebook_output")
OUTPUT_DIR.mkdir(exist_ok=True)

# Helper functions for pretty printing and analysis
def print_section(title: str, char: str = "="):
    """Print a formatted section header."""
    print(f"\n{char * 60}")
    print(f"  {title}")
    print(f"{char * 60}\n")

def compare_before_after(before: Any, after: Dict[str, Any], title: str = "Transformation"):
    """Compare original and flattened data side by side."""
    print_section(title)
    print("BEFORE (Original JSON):")
    print(json.dumps(before, indent=2))
    print("\nAFTER (Flattened):")
    print(json.dumps(after, indent=2))
    print(f"\nüìä Flattened to {len(after)} fields")

def measure_time(func):
    """Decorator to measure execution time."""
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        elapsed = time.time() - start
        print(f"‚è±Ô∏è  Execution time: {elapsed:.4f} seconds")
        return result
    return wrapper

# Initialize PySpark if available
if PYSPARK_AVAILABLE:
    spark = SparkSession.builder \
        .appName("JSONFlattening") \
        .config("spark.sql.adaptive.enabled", "true") \
        .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
        .getOrCreate()
    spark.sparkContext.setLogLevel("WARN")  # Reduce verbosity
    print("‚úÖ PySpark session initialized")

print("‚úÖ Environment setup complete!")
print(f"üìÅ Output directory: {OUTPUT_DIR.absolute()}")
print(f"üîß MongoDB available: {MONGO_AVAILABLE}")
print(f"‚ùÑÔ∏è  Snowflake available: {SNOWFLAKE_AVAILABLE}")
print(f"‚ö° PySpark available: {PYSPARK_AVAILABLE}")

---

<a id="milestone-1"></a>

# Milestone 1: Foundations & Core Concepts

## Learning Objectives
- Understand the fundamental concept of JSON flattening
- Learn how nested structures are converted to flat dictionaries
- Explore different list handling policies
- Master custom separator usage

## Why Flatten JSON?

Data engineers and data scientists frequently encounter challenges:
- **Tabular formats** (CSV, databases) require flat structures
- **Analytics tools** work better with normalized data
- **Schema inference** is easier with flat structures
- **Database ingestion** requires consistent column structures

Let's start with the basics!

In [None]:
### 1.1 Understanding Nested Structures

**What is nesting?**  
Nesting occurs when JSON objects contain other objects or arrays inside them. Think of it like Russian dolls - objects within objects.

**Why is this a problem?**  
- Databases expect flat tables with columns
- CSV files are inherently flat (rows and columns)
- Analytics tools work better with normalized data
- Schema inference becomes complex with nested structures

**How does flattening work?**  
The `flatten_json()` function recursively traverses nested structures and creates dot-delimited keys. For example:
- `user.profile.name` represents the `name` field inside `profile` inside `user`
- The dot (`.`) is the default separator, but you can customize it

Let's see this in action:

In [None]:
# Example 1: Simple nested structure
data1 = {
    "user": {
        "id": 42,
        "profile": {
            "name": "Alice",
            "active": True
        }
    },
    "score": 9.5
}

flattened1 = flatten_json(data1)
compare_before_after(data1, flattened1, "Example 1: Simple Nested Structure")

### 1.2 Custom Separators

**Why use custom separators?**  
Sometimes the default dot (`.`) separator can conflict with your data:
- Field names might contain dots
- You might prefer underscores (`_`) or double underscores (`__`)
- Some systems have naming conventions

**Example use cases:**
- MongoDB uses dots for nested queries, so you might want `_` instead
- Some databases prefer `__` for clarity
- Your organization might have specific naming standards

Let's explore different separators:

## Array Handling

Arrays can be handled in two ways:
- **Index policy**: Creates indexed keys (e.g., `tags.0`, `tags.1`)
- **Join policy**: Joins primitive arrays with commas

In [None]:
# Example: Array explosion - creating multiple records
data5 = {
    "order_id": 1001,
    "customer": "Alice",
    "items": [
        {"sku": "A1", "qty": 2, "price": 10.50},
        {"sku": "B2", "qty": 1, "price": 5.25},
        {"sku": "C3", "qty": 3, "price": 8.00}
    ]
}

records = flatten_records(data5, explode_paths=["items"])
print(f"Created {len(records)} records from array explosion:")
for i, record in enumerate(records, 1):
    print(f"\nRecord {i}:")
    print(json.dumps(record, indent=2))

## CSV Operations

Converting flattened JSON to CSV format for database ingestion or analysis.

In [None]:
# Create output directory
output_dir = Path("notebook_output")
output_dir.mkdir(exist_ok=True)

# Flatten and write to CSV
sample_data = {
    "order_id": 1001,
    "customer": {"name": "Alice", "email": "alice@example.com"},
    "items": [
        {"sku": "A1", "qty": 2},
        {"sku": "B2", "qty": 1}
    ]
}

records = flatten_records(sample_data, explode_paths=["items"])
csv_path = output_dir / "orders.csv"
write_csv(records, csv_path)

print(f"‚úì Written {len(records)} records to {csv_path}")
print("\nCSV content:")
print(csv_path.read_text())

## Milestone 1 Summary

This toolkit provides comprehensive solutions for:
1. **Flattening complex JSON structures** with configurable policies
2. **Handling arrays** through indexing or explosion
3. **Creating cartesian products** from multiple array paths
4. **CSV conversion** for tabular data formats
5. **Database ingestion** into MongoDB and Snowflake

See README.md for complete documentation.

---

<a id="milestone-2"></a>

# Milestone 2: Array Handling Strategies

## Learning Objectives
- Compare index vs join list policies
- Understand array explosion into multiple records
- Create cartesian products across multiple array paths

Arrays are where flattening decisions have the biggest downstream impact. We'll compare policies and then explode arrays into multiple records.

In [None]:
print_section("Index vs Join list policies")

array_data = {
    "tags": ["alpha", "beta", "gamma"],
    "metrics": {"scores": [10, 20, None]},
    "meta": {"ids": [1, 2, 3]},
}

flatten_index = flatten_json(array_data, list_policy="index")
flatten_join = flatten_json(array_data, list_policy="join")

print("Index policy output:")
print(json.dumps(flatten_index, indent=2))
print("\nJoin policy output:")
print(json.dumps(flatten_join, indent=2))

In [None]:
def get_scenario_by_name(name: str):
    scenarios = {scenario.name: scenario for scenario in get_scenarios()}
    if name not in scenarios:
        raise KeyError(f"Scenario {name!r} not found")
    return scenarios[name]


def run_scenario(name: str, max_records: int = 3):
    scenario = get_scenario_by_name(name)
    print_section(f"Scenario: {scenario.name}")
    print(scenario.description)

    if scenario.mode == "records":
        records = flatten_records(
            scenario.data,
            explode_paths=scenario.explode_paths,
            list_policy=scenario.list_policy,
        )
        print(f"Records: {len(records)}")
        for record in records[:max_records]:
            print(json.dumps(record, indent=2))
        if len(records) > max_records:
            print(f"... {len(records) - max_records} more")
        return records

    flattened = flatten_json(scenario.data, list_policy=scenario.list_policy)
    print(json.dumps(flattened, indent=2))
    return flattened

In [None]:
print_section("Array explosion and cartesian products")

scenario = get_scenario_by_name("multi_path_explosion")
records = flatten_records(
    scenario.data,
    explode_paths=scenario.explode_paths,
    list_policy=scenario.list_policy,
)

print(f"Exploded to {len(records)} records")
for record in records:
    print(json.dumps(record, indent=2))

---

<a id="milestone-3"></a>

# Milestone 3: Complex Structures

## Learning Objectives
- Handle deep nesting and mixed types
- Process nulls, empty arrays, and optional fields
- Work with nested arrays inside arrays

These scenarios mirror real data engineering edge cases.

In [None]:
run_scenario("deep_nesting")
run_scenario("mixed_types")
run_scenario("empty_and_null_handling")
run_scenario("nested_arrays", max_records=2)

---

<a id="milestone-4"></a>

# Milestone 4: E-commerce Data

## Learning Objectives
- Flatten orders with line items
- Create cartesian combinations across items and discounts
- Preserve customer metadata

We'll use the built-in `data/orders.json` and enrich it with customer fields.

In [None]:
orders_path = Path("data/orders.json")
orders = json.loads(orders_path.read_text())

orders["customer"] = {
    "id": "cust_001",
    "name": "Ada Lovelace",
    "segment": "enterprise",
}

records = flatten_records(orders, explode_paths=["items", "discounts"])
print(f"Created {len(records)} order records")
for record in records:
    print(json.dumps(record, indent=2))

---

<a id="milestone-5"></a>

# Milestone 5: API & Event Data

## Learning Objectives
- Flatten nested API responses
- Handle event log arrays
- Normalize timestamps for analytics

In [None]:
api_response = {
    "request_id": "req_123",
    "status": "ok",
    "data": {
        "user": {"id": 7, "name": "Grace"},
        "roles": ["admin", "editor"],
        "metadata": {"source": "web", "region": "us-east-1"},
    },
}

flattened_api = flatten_json(api_response, list_policy="join")
compare_before_after(api_response, flattened_api, "API Response Flattening")

print_section("Event log normalization")
event_payload = {
    "service": "billing",
    "events": [
        {"type": "created", "timestamp": "2024-01-15T10:30:00Z", "amount": 45.5},
        {"type": "captured", "timestamp": "2024-01-15T10:31:05Z", "amount": 45.5},
    ],
}

records = flatten_records(event_payload, explode_paths=["events"])
print(f"Created {len(records)} event records")
for record in records:
    print(json.dumps(record, indent=2))

---

<a id="milestone-6"></a>

# Milestone 6: CSV Operations & Pipelines

## Learning Objectives
- Write flattened records to CSV
- Read CSV back into Python
- Build repeatable batch pipelines

In [None]:
scenario = get_scenario_by_name("list_of_objects_explode")
records = flatten_records(
    scenario.data,
    explode_paths=scenario.explode_paths,
    list_policy=scenario.list_policy,
)

csv_path = OUTPUT_DIR / "items_pipeline.csv"
write_csv(records, csv_path)

print(f"‚úì Wrote {len(records)} records to {csv_path}")
print("\nRound-trip read:")
round_trip = read_csv(csv_path)
for row in round_trip:
    print(row)

---

<a id="milestone-7"></a>

# Milestone 7: MongoDB Integration

## Learning Objectives
- Ingest flattened records into MongoDB
- Query collections for analytics
- Understand type inference behavior

In [None]:
if MONGO_AVAILABLE:
    mongo_uri = os.getenv("MONGO_URI", "mongodb://localhost:27017")
    database_name = os.getenv("MONGO_DB", "json_flatten_demo")
    collection_name = os.getenv("MONGO_COLLECTION", "orders")

    print_section("MongoDB ingestion")
    try:
        inserted = ingest_csv_to_mongodb(
            records,
            mongo_uri=mongo_uri,
            database_name=database_name,
            collection_name=collection_name,
        )
        print(f"Inserted {inserted} documents into {database_name}.{collection_name}")

        sample = query_mongodb(
            mongo_uri=mongo_uri,
            database_name=database_name,
            collection_name=collection_name,
            limit=3,
        )
        print("Sample documents:")
        for doc in sample:
            print(doc)
    except Exception as exc:
        print(f"MongoDB ingestion skipped: {exc}")
else:
    print("MongoDB integration not available in this environment.")

---

<a id="milestone-8"></a>

# Milestone 8: Snowflake Integration

## Learning Objectives
- Generate Snowflake table schemas
- Ingest CSV data into Snowflake
- Query flattened data with SQL

In [None]:
if SNOWFLAKE_AVAILABLE:
    print_section("Snowflake schema generation")
    try:
        schema_sql = create_table_schema(records, "orders_flat", "public")
        print(schema_sql)

        # Optional: ingest if credentials are set in environment variables
        required_env = [
            "SNOWFLAKE_ACCOUNT",
            "SNOWFLAKE_USER",
            "SNOWFLAKE_PASSWORD",
            "SNOWFLAKE_WAREHOUSE",
            "SNOWFLAKE_DATABASE",
            "SNOWFLAKE_SCHEMA",
        ]
        if all(os.getenv(key) for key in required_env):
            count = ingest_csv_to_snowflake(
                records,
                account=os.getenv("SNOWFLAKE_ACCOUNT"),
                user=os.getenv("SNOWFLAKE_USER"),
                password=os.getenv("SNOWFLAKE_PASSWORD"),
                warehouse=os.getenv("SNOWFLAKE_WAREHOUSE"),
                database=os.getenv("SNOWFLAKE_DATABASE"),
                schema=os.getenv("SNOWFLAKE_SCHEMA"),
                table_name="orders_flat",
            )
            print(f"Ingested {count} rows into Snowflake")
        else:
            print("Snowflake credentials not set; ingestion skipped.")
    except Exception as exc:
        print(f"Snowflake integration skipped: {exc}")
else:
    print("Snowflake integration not available in this environment.")

---

<a id="milestone-9"></a>

# Milestone 9: Advanced Patterns & Best Practices

## Learning Objectives
- Measure performance for large workloads
- Avoid unintended cartesian explosions
- Leverage PySpark for scale when available

In [None]:
@measure_time
def flatten_large_cartesian():
    scenario = get_scenario_by_name("large_cartesian_product")
    return flatten_records(
        scenario.data,
        explode_paths=scenario.explode_paths,
        list_policy=scenario.list_policy,
    )

large_records = flatten_large_cartesian()
print(f"Generated {len(large_records)} records from cartesian explosion")

print_section("PySpark (optional)")
if PYSPARK_AVAILABLE:
    scenario = get_scenario_by_name("multi_path_explosion")
    df = spark.createDataFrame([scenario.data])
    df.select("order_id", explode("items").alias("item")) \
        .select("order_id", col("item.sku").alias("item_sku")) \
        .show(truncate=False)
else:
    print("PySpark not available; skipping Spark example.")

---

<a id="milestone-10"></a>

# Milestone 10: End-to-End Workflows

## Learning Objectives
- Build a complete JSON ‚Üí CSV pipeline
- Validate round-trip data integrity
- Prepare data for database ingestion

In [None]:
sample_path = Path("data/sample.json")
sample = json.loads(sample_path.read_text())

flattened = flatten_json(sample, list_policy="index")
end_to_end_path = OUTPUT_DIR / "sample_flat.csv"
write_csv([flattened], end_to_end_path)

print(f"‚úì Wrote flattened sample to {end_to_end_path}")
print("Round-trip read:")
print(read_csv(end_to_end_path))

if MONGO_AVAILABLE:
    print("You can now ingest this CSV into MongoDB with ingest_csv_to_mongodb().")

---

## Troubleshooting & Tips

If a cell fails, try these first:
- **Imports fail**: run the first setup cell again and verify you started Jupyter from the repo root.
- **Module not found in Docker**: ensure the notebook was started with `make notebook-docker` and the repo is mounted.
- **MongoDB errors**: confirm a local MongoDB is running and `MONGO_URI` points to it.
- **Snowflake errors**: verify environment variables (`SNOWFLAKE_*`) and network access.
- **PySpark errors**: confirm Java is installed and Docker has enough memory (4GB+).

Tip: restart the kernel and re-run all cells if the environment feels inconsistent.

---

## Puzzle for Data Scientists

You receive 1,000 JSON records with **three array fields**: `items` (avg 4), `promos` (avg 2), and `regions` (avg 3). You flatten by exploding all three paths to create a cartesian product.

**Riddle:**
- How many records do you expect on average after explosion?
- If one region is missing (empty list) in 10% of records, how does that change the expected total?

Write your answer and then validate by simulating a small sample in code.

In [None]:
import random

random.seed(42)

num_records = 1000
avg_items = 4
avg_promos = 2
avg_regions = 3
missing_region_rate = 0.10

# Quick Monte Carlo simulation for expected exploded record count
simulated_total = 0
for _ in range(num_records):
    items = max(1, int(random.expovariate(1 / avg_items)))
    promos = max(1, int(random.expovariate(1 / avg_promos)))

    if random.random() < missing_region_rate:
        regions = 0
    else:
        regions = max(1, int(random.expovariate(1 / avg_regions)))

    simulated_total += items * promos * max(1, regions)

print(f"Simulated total records: {simulated_total}")
print(f"Average per input record: {simulated_total / num_records:.2f}")

expected_no_missing = avg_items * avg_promos * avg_regions
expected_with_missing = avg_items * avg_promos * ((1 - missing_region_rate) * avg_regions + missing_region_rate * 1)

print(f"Expected (no missing regions): {expected_no_missing:.2f}")
print(f"Expected (10% missing regions): {expected_with_missing:.2f}")

---

## Puzzle 2: The Duplicate-Key Trap (Intermediate)

You flatten 50,000 JSON docs. Each doc has an array `events` with a `type` field. You explode `events` and count `type` frequencies. Later you discover some events have duplicated keys (e.g., `"type"` appears twice in the raw JSON, last one wins during parsing).

**Questions:**
- How could this bias your frequency counts?
- What visual signal would you expect if you plot top-10 type frequencies before vs after a "last-key-wins" parser?

Validate by simulating a small dataset and plotting the before/after counts.

In [None]:
import matplotlib.pyplot as plt

random.seed(7)

# Simulate "true" types and an alternate value that overwrites it
true_types = ["click", "view", "purchase", "refund", "signup"]
shadow_types = ["click", "view", "purchase", "fraud", "bot"]

n_events = 2000
true_counts = Counter()
parsed_counts = Counter()

for _ in range(n_events):
    t = random.choices(true_types, weights=[50, 30, 10, 5, 5])[0]
    s = random.choices(shadow_types, weights=[40, 25, 15, 10, 10])[0]
    # "Raw" event has duplicated key; parser keeps last value (s)
    true_counts[t] += 1
    parsed_counts[s] += 1

labels = sorted(set(true_counts) | set(parsed_counts))
true_vals = [true_counts[l] for l in labels]
parsed_vals = [parsed_counts[l] for l in labels]

x = range(len(labels))
plt.figure(figsize=(8, 4))
plt.bar(x, true_vals, alpha=0.7, label="Before (true)")
plt.bar(x, parsed_vals, alpha=0.7, label="After (parsed)")
plt.xticks(x, labels, rotation=30)
plt.title("Type frequency shift from duplicate-key parsing")
plt.legend()
plt.tight_layout()
plt.show()

print("Top-5 before:", true_counts.most_common(5))
print("Top-5 after:", parsed_counts.most_common(5))

---

## Puzzle 3: Missingness Meets Explosion (Intermediate+)

You explode `items` (avg 5) and `coupons` (avg 2). But `coupons` is missing in 30% of records. You keep missing paths as `None` (not dropping records).

**Questions:**
- What is the expected multiplier on record count?
- How does the distribution look compared to the no-missing case?

Simulate and plot the distribution of expanded record counts per input record.

In [None]:
random.seed(21)

n = 5000
avg_items = 5
avg_coupons = 2
missing_rate = 0.30

counts_missing = []
counts_full = []

for _ in range(n):
    items = max(1, int(random.expovariate(1 / avg_items)))
    coupons = max(1, int(random.expovariate(1 / avg_coupons)))
    coupons_missing = 1 if random.random() < missing_rate else coupons

    counts_missing.append(items * coupons_missing)
    counts_full.append(items * coupons)

plt.figure(figsize=(8, 4))
plt.hist(counts_full, bins=30, alpha=0.6, label="No missing coupons")
plt.hist(counts_missing, bins=30, alpha=0.6, label="30% missing (kept as None)")
plt.title("Explosion size per record")
plt.xlabel("Expanded records per input")
plt.ylabel("Frequency")
plt.legend()
plt.tight_layout()
plt.show()

print("Expected no-missing multiplier:", avg_items * avg_coupons)
print("Expected with missing multiplier:", avg_items * ((1 - missing_rate) * avg_coupons + missing_rate * 1))

---

## Puzzle 4: Schema Drift Heatmap (Advanced)

You flatten daily event logs and track the **set of flattened keys** per day. Over a month, product teams add and remove fields.

**Questions:**
- How would you visualize drift so that it highlights new, removed, and rare fields?
- Which day shows the most schema churn?

Simulate 30 days of keys and plot a heatmap of key presence.

In [None]:
import numpy as np

random.seed(5)
np.random.seed(5)

base_keys = [f"k{i}" for i in range(1, 21)]
extra_keys = [f"x{i}" for i in range(1, 16)]

days = 30
all_keys = base_keys + extra_keys
presence = []

for day in range(days):
    day_keys = set(base_keys)
    # Gradually introduce extra keys
    for key in extra_keys:
        if random.random() < (day / days):
            day_keys.add(key)
    # Random removals
    for key in list(day_keys):
        if random.random() < 0.05:
            day_keys.remove(key)
    presence.append([1 if key in day_keys else 0 for key in all_keys])

presence = np.array(presence)

plt.figure(figsize=(10, 5))
plt.imshow(presence, aspect="auto", cmap="viridis")
plt.colorbar(label="Key present")
plt.yticks(range(0, days, 5), [f"Day {i+1}" for i in range(0, days, 5)])
plt.xticks(range(0, len(all_keys), 5), all_keys[::5], rotation=45)
plt.title("Schema drift heatmap")
plt.tight_layout()
plt.show()

churn = np.abs(np.diff(presence, axis=0)).sum(axis=1)
max_day = int(np.argmax(churn)) + 2
print(f"Highest churn on Day {max_day}")

---

## Puzzle 5: Simpson‚Äôs Paradox in Flattened Metrics (Advanced)

You flatten experiments by user and compute conversion rate by `device`. Overall, **mobile** has higher conversion. But after segmenting by `region`, desktop wins in every region.

**Questions:**
- How can this happen?
- What would the plot of per-region conversion rates look like vs the overall rate?

Simulate and plot the overall vs per-region conversion rates.

In [None]:
random.seed(11)

regions = ["NA", "EU", "APAC"]
# Desktop better within each region, but mobile more common in high-converting regions
region_sizes = {"NA": 4000, "EU": 3000, "APAC": 1000}
conversion = {
    "NA": {"desktop": 0.06, "mobile": 0.05},
    "EU": {"desktop": 0.08, "mobile": 0.07},
    "APAC": {"desktop": 0.04, "mobile": 0.03},
}
# Device mix skews mobile to NA/EU where rates are higher
device_mix = {
    "NA": {"desktop": 0.30, "mobile": 0.70},
    "EU": {"desktop": 0.25, "mobile": 0.75},
    "APAC": {"desktop": 0.70, "mobile": 0.30},
}

counts = {r: {"desktop": [0, 0], "mobile": [0, 0]} for r in regions}

for region in regions:
    for _ in range(region_sizes[region]):
        device = "mobile" if random.random() < device_mix[region]["mobile"] else "desktop"
        conv = 1 if random.random() < conversion[region][device] else 0
        counts[region][device][0] += conv
        counts[region][device][1] += 1

# Overall rates
overall = {"desktop": [0, 0], "mobile": [0, 0]}
for region in regions:
    for device in ["desktop", "mobile"]:
        overall[device][0] += counts[region][device][0]
        overall[device][1] += counts[region][device][1]

rate_overall = {
    d: overall[d][0] / overall[d][1] for d in overall
}

# Plot per-region vs overall
plt.figure(figsize=(8, 4))
for device in ["desktop", "mobile"]:
    regional_rates = [counts[r][device][0] / counts[r][device][1] for r in regions]
    plt.plot(regions, regional_rates, marker="o", label=f"{device} (by region)")

plt.hlines(rate_overall["desktop"], 0, len(regions) - 1, colors="C0", linestyles="--", label="desktop overall")
plt.hlines(rate_overall["mobile"], 0, len(regions) - 1, colors="C1", linestyles="--", label="mobile overall")
plt.title("Simpson's paradox: overall vs per-region")
plt.ylabel("Conversion rate")
plt.xticks(range(len(regions)), regions)
plt.legend()
plt.tight_layout()
plt.show()

print("Overall rates:", rate_overall)
for region in regions:
    print(region, {d: counts[region][d][0] / counts[region][d][1] for d in ["desktop", "mobile"]})