# Table Partition Column Profiler 



This notebook helps identify the best partition column for ingesting large tables from SQL Server into Databricks using the Lakefed Ingest tool.

### Prerequisites

- `USE CONNECTION` privilege on the foreign catalog connection
- `SELECT` privilege on the source SQL Server tables
- Access to a compute cluster with DBR 17.3+
- Foreign catalog configured for SQL Server connection

#### **Requirements for `remote_query()`:**

- Unity Catalog enabled workspace

- Network connectivity from your Databricks Runtime cluster or SQL warehouse to the target database systems. [See Networking recommendations for Lakehouse Federation.](https://docs.databricks.com/aws/en/query-federation/networking)
- Databricks clusters must use Databricks Runtime 17.3 or above.

**Permissions required:**

- To use the remote_query function, you must have the `USE CONNECTION` privilege on the connection or the `SELECT` privilege on a view that wraps the function. 
- Single-user clusters also require the `MANAGE` permission on the connection.

> Note - `remote_query()` vs Federation - `remote_query()` pushes computation to SQL Server (preferred for large tables)
 > whereas, direct federation through foreign catalogs pulls data to Databricks.
> This notebook provides both options - use remote_query when available. Refer documentation [here](https://docs.databricks.com/aws/en/query-federation/remote-queries).

**Key considerations:**
- having a clustered index on the cluster/partition column on the sql server side is mandatory for ingestion performance
- Primary keys and unique indexes are poor partition/cluster columns (excluded automatically)
- Only integer columns are considered (timestamps/floats cause issues)
- Tables >10TB should use sampling mode for performance

In [0]:
# Widget parameters for easy configuration
dbutils.widgets.text("connection_name", "", "Connection Name (for remote_query - leave empty to use federation)")
dbutils.widgets.text("sql_database", "", "SQL Server Database (for remote_query)")
dbutils.widgets.text("src_catalog", "sqlserver_catalog", "UC Foreign Catalog")
dbutils.widgets.text("src_schema", "dbo", "Source Schema") 
dbutils.widgets.text("src_table", "store_sales_1tb", "Source Table (REQUIRED)")
dbutils.widgets.dropdown("use_sampling", "false", ["true", "false"], "Use Sampling (for 10TB+ tables)")
dbutils.widgets.text("table_size_gb", "", "Table Size in GB (REQUIRED)")

# Get values
connection_name = dbutils.widgets.get("connection_name")
sql_database = dbutils.widgets.get("sql_database")
src_catalog = dbutils.widgets.get("src_catalog")
src_schema = dbutils.widgets.get("src_schema")
src_table = dbutils.widgets.get("src_table")
use_sampling = dbutils.widgets.get("use_sampling") == "true"
table_size_gb = dbutils.widgets.get("table_size_gb")


# Validate required fields
if not src_table:
    raise ValueError("Source table is required!")
if not table_size_gb:
    raise ValueError("Table size is required!")

try:
    table_size_gb = float(table_size_gb)
except ValueError:
    raise ValueError("Table size must be a number!")

# Auto-enable sampling for very large tables
if table_size_gb > 10000 and not use_sampling:
    print(f"⚠️ WARNING: Table is {table_size_gb:.0f}GB - consider enabling sampling!")
# Determine query mode
USE_REMOTE_QUERY = bool(connection_name and sql_database)

print(f"Configuration:")
print(f"  Mode: {'remote_query' if USE_REMOTE_QUERY else 'Federation (fallback)'}")
print(f"  Catalog: {src_catalog}")
print(f"  Table: {src_schema}.{src_table}")
print(f"  Sampling: {'Enabled' if use_sampling else 'Disabled'}")

Configuration:
  Mode: remote_query
  Catalog: sqlserver_edwia_catalog
  Table: dbo.store_sales_1tb
  Sampling: Disabled


### Step 1: Identify Candidate Columns
First, we identify indexed columns that could be good partition candidates. We automatically exclude primary keys and unique indexes.

In [0]:
# Find Indexed Columns (Excluding Keys)

# Get primary key columns to exclude
pk_query = f"""
SELECT 
    kc.column_name,
    'PRIMARY KEY' as constraint_type
FROM {src_catalog}.information_schema.key_column_usage kc
JOIN {src_catalog}.information_schema.table_constraints tc
    ON kc.constraint_name = tc.constraint_name
    AND kc.table_name = tc.table_name
    AND kc.table_schema = tc.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
    AND kc.table_schema = '{src_schema}'
    AND kc.table_name = '{src_table}'
"""

print(f"Checking for primary keys on {src_schema}.{src_table}...")

pk_columns = [row['column_name'] for row in spark.sql(pk_query).collect()]

if pk_columns:
    print(f"⚠️ Excluding primary key columns: {', '.join(pk_columns)}")

# Get indexed integer columns (best partition candidates)
index_query = f"""
SELECT DISTINCT
    c.name as column_name,
    i.type_desc as index_type
FROM {src_catalog}.sys.indexes i
INNER JOIN {src_catalog}.sys.index_columns ic 
    ON i.object_id = ic.object_id AND i.index_id = ic.index_id
INNER JOIN {src_catalog}.sys.columns c 
    ON ic.object_id = c.object_id AND ic.column_id = c.column_id
INNER JOIN {src_catalog}.sys.tables t 
    ON i.object_id = t.object_id
INNER JOIN {src_catalog}.sys.schemas s 
    ON t.schema_id = s.schema_id
WHERE s.name = '{src_schema}' 
    AND t.name = '{src_table}'
    AND i.type > 0
    AND NOT ic.is_included_column
    AND c.system_type_id IN (48, 52, 56, 127)  -- integer types only
    AND NOT i.is_primary_key
    AND NOT i.is_unique
ORDER BY index_type DESC
"""

candidate_columns = [row['column_name'] for row in spark.sql(index_query).collect()]

print(f"\n✓ Found {len(candidate_columns)} candidate columns: {', '.join(candidate_columns[:10])}")

if not candidate_columns:
    print("\n⚠️ No indexed columns found. Checking all integer columns...")
    fallback_query = f"""
    SELECT column_name 
    FROM {src_catalog}.information_schema.columns
    WHERE table_schema = '{src_schema}'
        AND table_name = '{src_table}'
        AND data_type IN ('int', 'bigint', 'smallint', 'tinyint')
        AND column_name NOT IN ({','.join([f"'{c}'" for c in pk_columns]) if pk_columns else "''"})
    LIMIT 10
    """
    candidate_columns = [row['column_name'] for row in spark.sql(fallback_query).collect()]

Checking for primary keys on dbo.store_sales_1tb...

✓ Found 2 candidate columns: ss_item_sk, ss_sold_date_sk


### Step 2: Profile Candidate Columns
Now we analyze each candidate column for distribution characteristics.

In [0]:
import pandas as pd

# Helper function for SQL Server identifier escaping
def tsql_ident(name: str) -> str:
    """Escape SQL Server identifiers"""
    return '[' + name.replace(']', ']]') + ']'

# Query templates
QUERIES = {
    'basic_stats': """
        SELECT 
            COUNT_BIG(*) AS total_rows,
            COUNT_BIG(DISTINCT {column}) AS distinct_values,
            COUNT_BIG({column}) AS non_null_count,
            MIN({column}) AS min_value,
            MAX({column}) AS max_value
        FROM {schema}.{table}
    """,
    
    'distribution_stats': """
        SELECT 
            AVG(CAST(freq AS FLOAT)) AS avg_frequency,
            MAX(freq) AS max_frequency,
            MIN(freq) AS min_frequency,
            STDEV(freq) AS stddev_frequency
        FROM (
            SELECT COUNT_BIG(*) AS freq
            FROM {schema}.{table}
            WHERE {column} IS NOT NULL
            GROUP BY {column}
        ) AS frequency_table
    """,

    'distribution_stats_sampled': """
    SELECT 
        AVG(CAST(freq AS FLOAT)) AS avg_frequency,
        MAX(freq) AS max_frequency,
        MIN(freq) AS min_frequency,
        STDEV(freq) AS stddev_frequency
    FROM (
        SELECT {column}, COUNT(*) AS freq
        FROM (
            SELECT TOP 10000000 {column}
            FROM {schema}.{table} TABLESAMPLE (1 PERCENT)
            WHERE {column} IS NOT NULL
        ) AS sampled
        GROUP BY {column}
    ) AS frequency_table
    """
}

def profile_with_remote_query(column, connection, database):
    """Profile using remote_query (computation on SQL Server)"""
    safe_column = tsql_ident(column)
    safe_schema = tsql_ident(src_schema)
    safe_table = tsql_ident(src_table)
    
    # Use the template from QUERIES dict
    query = QUERIES['basic_stats'].format(
        column=safe_column,
        schema=safe_schema,
        table=safe_table
    )
    escaped_query = query.replace("'", "''")
    
    stats_sql = f"""
    SELECT * FROM remote_query(
        '{connection}',
        database => '{database}',
        query => '{escaped_query}',
        fetchsize => '10'
    )
    """
    stats = spark.sql(stats_sql).collect()[0]
    
    # Use appropriate distribution template
    dist_template = 'distribution_stats_sampled' if use_sampling else 'distribution_stats'
    dist_query = QUERIES[dist_template].format(
        column=safe_column,
        schema=safe_schema,
        table=safe_table
    )
    escaped_dist = dist_query.replace("'", "''")
    
    dist_sql = f"""
    SELECT * FROM remote_query(
        '{connection}',
        database => '{database}',
        query => '{escaped_dist}',
        fetchsize => '10'
    )
    """
    dist = spark.sql(dist_sql).collect()[0]
    
    return stats, dist

def profile_with_federation(column):
    """Profile using direct federation (fallback)"""
    # Basic stats
    stats = spark.sql(f"""
        SELECT 
            COUNT(*) as total_rows,
            COUNT(DISTINCT {column}) as distinct_values,
            COUNT({column}) as non_null_count,
            MIN({column}) as min_value,
            MAX({column}) as max_value
        FROM {src_catalog}.{src_schema}.{src_table}
    """).collect()[0]
    
    # Distribution stats (with sampling for large tables)
    if use_sampling:
        dist = spark.sql(f"""
            WITH sampled AS (
                SELECT {column}
                FROM {src_catalog}.{src_schema}.{src_table}
                WHERE {column} IS NOT NULL
                LIMIT 10000000
            )
            SELECT 
                AVG(cnt) as avg_frequency,
                MAX(cnt) as max_frequency
            FROM (
                SELECT COUNT(*) as cnt
                FROM sampled
                GROUP BY {column}
            ) freq
        """).collect()[0]
    else:
        dist = spark.sql(f"""
            SELECT 
                AVG(cnt) as avg_frequency,
                MAX(cnt) as max_frequency
            FROM (
                SELECT COUNT(*) as cnt
                FROM {src_catalog}.{src_schema}.{src_table}
                WHERE {column} IS NOT NULL
                GROUP BY {column}
            ) freq
        """).collect()[0]
    
    return stats, dist

# Profile each candidate
results = []

for i, column in enumerate(candidate_columns[:10], 1):  # Limit to top 10
    print(f"[{i}/{min(10, len(candidate_columns))}] Profiling {column}... ", end="")
    
    try:
        if USE_REMOTE_QUERY:
            stats, dist = profile_with_remote_query(column, connection_name, sql_database)
        else:
            stats, dist = profile_with_federation(column)
        
        total_rows = int(stats['total_rows'])
        distinct_values = int(stats['distinct_values'])
        null_count = total_rows - int(stats['non_null_count'])
        
        result = {
            'column_name': column,
            'total_rows': total_rows,
            'distinct_values': distinct_values,
            'null_percentage': round(null_count * 100.0 / total_rows, 2) if total_rows > 0 else 0,
            'min_value': stats['min_value'],
            'max_value': stats['max_value'],
            'avg_rows_per_value': int(dist['avg_frequency']) if dist['avg_frequency'] else 0,
            'max_rows_per_value': int(dist['max_frequency']) if dist['max_frequency'] else 0
        }
        
        # Calculate skew for display
        skew = result['max_rows_per_value'] / result['avg_rows_per_value'] if result['avg_rows_per_value'] > 0 else 999
        print(f"✓ Distinct: {distinct_values:,}, Skew: {skew:.1f}x")
        
    except Exception as e:
        print(f"✗ Failed: {str(e)[:50]}")
        result = {'column_name': column, 'error': str(e)[:50]}
    
    results.append(result)

# Create DataFrame
results_df = pd.DataFrame(results)
display(results_df)

[1/2] Profiling ss_item_sk... ✓ Distinct: 402,000, Skew: 2.1x
[2/2] Profiling ss_sold_date_sk... ✓ Distinct: 1,827, Skew: 21.6x


column_name,total_rows,distinct_values,null_percentage,min_value,max_value,avg_rows_per_value,max_rows_per_value
ss_item_sk,7031221350,402000,0.0,1,402000,17490,36780
ss_sold_date_sk,7031221350,1827,0.0,2450816,2452642,3848506,83202375


### Step 3: Score and Rank Columns
Score each column based on its suitability as a partition column.

In [0]:
def calculate_partition_score(row):
    """Score column suitability for partitioning (0-100)"""
    if 'error' in row or not row.get('distinct_values'):
        return -1
    
    total_rows = row['total_rows']
    distinct = row['distinct_values']
    null_pct = row['null_percentage']
    avg_rows = row['avg_rows_per_value']
    max_rows = row['max_rows_per_value']
    
    # HARD DISQUALIFIERS
    
    # Nulls > 1% - REJECT
    if null_pct > 1:
        return -4
    
    # Calculate skew ratio
    skew_ratio = max_rows / avg_rows if avg_rows > 0 else 999
    
    # Extreme skew (>10x) 
    if skew_ratio > 10:
        return -5
    
    # Too unique (>50% distinct) 
    cardinality_ratio = distinct / total_rows if total_rows > 0 else 1
    if cardinality_ratio > 0.5:
        return -1
    
    # Too few distinct values (<100) 
    if distinct < 100:
        return -2
    
    # SCORING FOR VALID CANDIDATES
    score = 100
    
    # Heavy penalty for skew (this is critical!)
    if skew_ratio > 5:
        score -= 60  # 5-10x skew
    elif skew_ratio > 3:
        score -= 40  # 3-5x skew
    elif skew_ratio > 2:
        score -= 20  # 2-3x skew
    
    # Penalty for any nulls
    if null_pct > 0.5:
        score -= 20
    elif null_pct > 0:
        score -= 10
    
    # Cardinality penalties
    if distinct > 100000:
        score -= 30  # Too many partitions
    elif distinct < 500:
        score -= 20  # Too few partitions
    
    return max(0, score)

def get_recommendation(row):
    score = row.get('partition_score', -1)
    skew = row.get('max_rows_per_value', 0) / row.get('avg_rows_per_value', 1) if row.get('avg_rows_per_value', 0) > 0 else 999
    
    if score == -5:
        return f' Extreme skew ({skew:.0f}x)'
    elif score == -4:
        return f' Too many NULLs ({row.get("null_percentage", 0):.1f}%)'
    elif score == -1:
        return ' Too unique'
    elif score == -2:
        return ' Too few distinct values'
    elif score >= 70:
        return ' EXCELLENT'
    elif score >= 40:
        return ' ACCEPTABLE'
    else:
        return ' POOR - High skew or other issues'

# Apply scoring
results_df['partition_score'] = results_df.apply(calculate_partition_score, axis=1)
results_df['recommendation'] = results_df.apply(get_recommendation, axis=1)

# Sort by score
results_df = results_df.sort_values('partition_score', ascending=False)

# Display with skew ratio visible
results_df['skew_ratio'] = results_df['max_rows_per_value'] / results_df['avg_rows_per_value']
display(results_df[['column_name', 'distinct_values', 'null_percentage', 
                    'avg_rows_per_value', 'max_rows_per_value', 'skew_ratio',
                    'partition_score', 'recommendation']])

print("\n" + "="*80)
print("⚠️ IMPORTANT: Please consult with your Databricks Solution Architect or")
print("   SME team to validate the partition column selection,") 
print("   especially for production workloads over 10TB.")
print("="*80)

column_name,distinct_values,null_percentage,avg_rows_per_value,max_rows_per_value,skew_ratio,partition_score,recommendation
ss_item_sk,402000,0.0,17490,36780,2.102915951972556,50,ACCEPTABLE
ss_sold_date_sk,1827,0.0,3848506,83202375,21.619395942217576,-5,Extreme skew (22x)



⚠️ IMPORTANT: Please consult with your Databricks Solution Architect or
   SME team to validate the partition column selection,
   especially for production workloads over 10TB.


### Step 4: Final Recommendations
Generate configuration recommendations for the best partition column.

In [0]:
valid_columns = results_df[results_df['partition_score'] > 0]

if not valid_columns.empty:
    best = valid_columns.iloc[0]
    
    print("=" * 60)
    print("RECOMMENDED CONFIGURATION")
    print("=" * 60)
    print(f"\nBest Partition Column: {best['column_name']}")
    print(f"Score: {best['partition_score']:.0f}/100")
    print(f"Distribution: {best['distinct_values']:,} distinct values")
    
    skew_ratio = best['max_rows_per_value'] / best['avg_rows_per_value'] if best['avg_rows_per_value'] > 0 else 0
    print(f"Skew: {skew_ratio:.1f}x (max vs avg)")
    print(f"Avg rows per value: {best['avg_rows_per_value']:,}")
    print(f"Max rows per value: {best['max_rows_per_value']:,}")
    
    # Use actual table size provided by user
    table_size_mb = table_size_gb * 1024
    total_rows = best['total_rows']
    distinct_values = best['distinct_values']

    # Target partition counts based on actual size
    if table_size_gb < 100:  # <100GB
        target_partitions = 100
    elif table_size_gb < 1000:  # 100GB-1TB
        target_partitions = 500
    elif table_size_gb < 10000:  # 1-10TB
        target_partitions = 2000
    else:  # >10TB
        target_partitions = 5000

    # Calculate ideal partition size
    ideal_partition_mb = int(table_size_mb / target_partitions)

    # Round to standard sizes
    if ideal_partition_mb < 512:
        partition_mb = 512
    elif ideal_partition_mb < 1024:
        partition_mb = 1024
    elif ideal_partition_mb < 2048:
        partition_mb = 2048
    elif ideal_partition_mb < 4096:
        partition_mb = 4096
    else:
        partition_mb = 8192

    print(f"\n📊 PARTITION SIZING ANALYSIS")
    print(f"Table size: {table_size_gb:.1f} GB")
    print(f"Target partitions: ~{target_partitions}")
    print(f"Recommended partition_size_mb: {partition_mb}")
    
    # Estimate number of partitions this will create
    estimated_partitions = (avg_mb_per_value * distinct_values) / partition_mb
    
    print(f"\n📊 PARTITION SIZING ANALYSIS")
    print(f"Estimated table size: {estimated_table_size_mb/1024:.1f} GB")
    print(f"Average MB per {best['column_name']}: {avg_mb_per_value:.1f} MB")
    print(f"Recommended partition_size_mb: {partition_mb}")
    print(f"Estimated total partitions: {estimated_partitions:.0f}")
    
    # Warnings
    if estimated_partitions > 10000:
        print(f"\n⚠️ WARNING: This will create {estimated_partitions:.0f} partitions!")
        print("Consider using a larger partition_size_mb or different column")
    
    if skew_ratio > 10:
        print(f"\n⚠️ WARNING: High skew detected ({skew_ratio:.1f}x)")
        print("Some partitions will be much larger than others")
        print(f"Largest partition may be ~{avg_mb_per_value * skew_ratio:.0f} MB")
    
    print(f"""
Control Table Configuration:
----------------------------
UPDATE {src_catalog}.control_table
SET 
    partition_col = '{best['column_name']}',
    partition_size_mb = {partition_mb},
    load_partitioned = true
WHERE src_table = '{src_table}';
""")
    
else:
    print("⚠️ No suitable partition columns found!")
    print("\nRecommendations:")
    print(" Use non-partitioned ingestion:")
    print(f"""
UPDATE {src_catalog}.control_table
SET 
    load_partitioned = false,
    partition_col = NULL,
    partition_size_mb = NULL
WHERE src_table = '{src_table}';
""")

RECOMMENDED CONFIGURATION

Best Partition Column: ss_item_sk
Score: 50/100
Distribution: 402,000 distinct values
Skew: 2.1x (max vs avg)
Avg rows per value: 17,490
Max rows per value: 36,780

📊 PARTITION SIZING ANALYSIS
Table size: 1134.0 GB
Target partitions: ~2000
Recommended partition_size_mb: 1024

📊 PARTITION SIZING ANALYSIS
Estimated table size: 3274.2 GB
Average MB per ss_item_sk: 8.3 MB
Recommended partition_size_mb: 1024
Estimated total partitions: 3274

Control Table Configuration:
----------------------------
UPDATE sqlserver_edwia_catalog.control_table
SET 
    partition_col = 'ss_item_sk',
    partition_size_mb = 1024,
    load_partitioned = true
WHERE src_table = 'store_sales_1tb';

