# Multi-Table Synthetic Data Example

Generate synthetic data for related tables while maintaining referential integrity.

In [None]:
import numpy as np
import pandas as pd

from genesis.multitable import MultiTableGenerator, RelationalSchema

## Create Sample Database Tables

In [None]:
np.random.seed(42)

# Customers table
n_customers = 100
customers = pd.DataFrame({
    'customer_id': range(1, n_customers + 1),
    'name': [f'Customer_{i}' for i in range(1, n_customers + 1)],
    'age': np.random.randint(18, 70, n_customers),
    'city': np.random.choice(['NYC', 'LA', 'Chicago', 'Houston', 'Phoenix'], n_customers),
    'account_type': np.random.choice(['Basic', 'Premium', 'Enterprise'], n_customers, p=[0.6, 0.3, 0.1])
})

# Orders table (each customer has 0-10 orders)
orders_list = []
order_id = 1
for cust_id in customers['customer_id']:
    n_orders = np.random.poisson(3)  # Average 3 orders per customer
    for _ in range(n_orders):
        orders_list.append({
            'order_id': order_id,
            'customer_id': cust_id,
            'amount': np.random.exponential(100) + 10,
            'status': np.random.choice(['Completed', 'Pending', 'Cancelled'], p=[0.8, 0.15, 0.05])
        })
        order_id += 1

orders = pd.DataFrame(orders_list)

# Products table
n_products = 50
products = pd.DataFrame({
    'product_id': range(1, n_products + 1),
    'name': [f'Product_{i}' for i in range(1, n_products + 1)],
    'category': np.random.choice(['Electronics', 'Clothing', 'Home', 'Food'], n_products),
    'price': np.random.uniform(5, 500, n_products).round(2)
})

# Order items (N:M relationship through junction table)
order_items_list = []
item_id = 1
for oid in orders['order_id']:
    n_items = np.random.randint(1, 5)
    for _ in range(n_items):
        order_items_list.append({
            'item_id': item_id,
            'order_id': oid,
            'product_id': np.random.randint(1, n_products + 1),
            'quantity': np.random.randint(1, 5)
        })
        item_id += 1

order_items = pd.DataFrame(order_items_list)

tables = {
    'customers': customers,
    'orders': orders,
    'products': products,
    'order_items': order_items
}

print("Table sizes:")
for name, df in tables.items():
    print(f"  {name}: {len(df)} rows")

## Define Schema and Relationships

In [None]:
# Define foreign key relationships
foreign_keys = [
    {
        'child_table': 'orders',
        'child_column': 'customer_id',
        'parent_table': 'customers',
        'parent_column': 'customer_id'
    },
    {
        'child_table': 'order_items',
        'child_column': 'order_id',
        'parent_table': 'orders',
        'parent_column': 'order_id'
    },
    {
        'child_table': 'order_items',
        'child_column': 'product_id',
        'parent_table': 'products',
        'parent_column': 'product_id'
    }
]

# Create relational schema
schema = RelationalSchema.from_dataframes(
    tables,
    foreign_keys=foreign_keys,
    primary_keys={
        'customers': 'customer_id',
        'orders': 'order_id',
        'products': 'product_id',
        'order_items': 'item_id'
    }
)

print("Schema graph:")
print(schema)

## Generate Synthetic Multi-Table Data

In [None]:
generator = MultiTableGenerator(
    method='gaussian_copula',
    verbose=True
)

# Fit on all tables
generator.fit_tables(tables, schema)

# Generate synthetic tables (maintaining ratios)
synthetic_tables = generator.generate_tables(
    n_samples={
        'customers': 50,  # 50 customers
        'products': 30,   # 30 products
        # orders and order_items generated based on cardinality
    }
)

print("\nSynthetic table sizes:")
for name, df in synthetic_tables.items():
    print(f"  {name}: {len(df)} rows")

## Verify Referential Integrity

In [None]:
syn_customers = synthetic_tables['customers']
syn_orders = synthetic_tables['orders']
syn_products = synthetic_tables['products']
syn_order_items = synthetic_tables['order_items']

# Check FK: orders.customer_id -> customers.customer_id
valid_customer_ids = set(syn_customers['customer_id'])
order_customer_ids = set(syn_orders['customer_id'])
fk1_valid = order_customer_ids.issubset(valid_customer_ids)
print(f"orders.customer_id → customers.customer_id: {'✓' if fk1_valid else '✗'}")

# Check FK: order_items.order_id -> orders.order_id
valid_order_ids = set(syn_orders['order_id'])
item_order_ids = set(syn_order_items['order_id'])
fk2_valid = item_order_ids.issubset(valid_order_ids)
print(f"order_items.order_id → orders.order_id: {'✓' if fk2_valid else '✗'}")

# Check FK: order_items.product_id -> products.product_id
valid_product_ids = set(syn_products['product_id'])
item_product_ids = set(syn_order_items['product_id'])
fk3_valid = item_product_ids.issubset(valid_product_ids)
print(f"order_items.product_id → products.product_id: {'✓' if fk3_valid else '✗'}")

## Compare Cardinality Distributions

In [None]:
import matplotlib.pyplot as plt

# Orders per customer
real_orders_per_cust = orders.groupby('customer_id').size()
syn_orders_per_cust = syn_orders.groupby('customer_id').size()

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(real_orders_per_cust, bins=20, alpha=0.7, label='Real')
axes[0].hist(syn_orders_per_cust, bins=20, alpha=0.7, label='Synthetic')
axes[0].set_title('Orders per Customer')
axes[0].legend()

# Items per order
real_items_per_order = order_items.groupby('order_id').size()
syn_items_per_order = syn_order_items.groupby('order_id').size()

axes[1].hist(real_items_per_order, bins=10, alpha=0.7, label='Real')
axes[1].hist(syn_items_per_order, bins=10, alpha=0.7, label='Synthetic')
axes[1].set_title('Items per Order')
axes[1].legend()

plt.tight_layout()
plt.show()

## Evaluate Per-Table Quality

In [None]:
from genesis.evaluation.evaluator import QualityEvaluator

for name in ['customers', 'orders', 'products']:
    evaluator = QualityEvaluator(tables[name], synthetic_tables[name])
    report = evaluator.evaluate()
    print(f"{name}: Fidelity={report.fidelity_score*100:.1f}%, Utility={report.utility_score*100:.1f}%")

## Test Join Operations

In [None]:
# Join synthetic tables
syn_joined = (
    syn_orders
    .merge(syn_customers, on='customer_id')
    .merge(syn_order_items, on='order_id')
    .merge(syn_products, on='product_id')
)

print(f"Joined synthetic data: {len(syn_joined)} rows")
syn_joined.head()

## Export

In [None]:
for name, df in synthetic_tables.items():
    df.to_csv(f'synthetic_{name}.csv', index=False)

print("All tables exported!")