# Table Partition Column Profiler 
This notebook helps identify the best partition column for ingesting large tables from SQL Server into Databricks using the Lakefed Ingest tool.

### Prerequisites
- `USE CONNECTION` privilege on the foreign catalog connection
- `SELECT` privilege on the source SQL Server tables
- Access to a compute cluster with DBR 17.3+ to take advantate of `remote_query()`. Can fall back to Lakehouse Federation if not available.
- Foreign catalog configured for SQL Server connection

#### Requirements for `remote_query()`:
- Unity Catalog enabled workspace
- Network connectivity from your Databricks Runtime cluster or SQL warehouse to the target database systems. [See Networking recommendations for Lakehouse Federation.](https://docs.databricks.com/aws/en/query-federation/networking)
- Databricks clusters must use Databricks Runtime 17.3 or above.

**Permissions required:**
- To use the remote_query function, you must have the `USE CONNECTION` privilege on the connection or the `SELECT` privilege on a view that wraps the function. 
- Single-user clusters also require the `MANAGE` permission on the connection.

> Note - `remote_query()` vs Federation - `remote_query()` pushes down the entire query to SQL Server (preferred for large tables). Federated queries are not always pushed down and may be slower.  
> This notebook provides both options - use remote_query when available. Refer documentation [here](https://docs.databricks.com/aws/en/query-federation/remote-queries).

**Key considerations:**
- Selecting a partition column with an index on the SQL Server side is necessary in order to achieve reasonable performance.
- Primary Keys with clustered indexes are IDEAL partition columns (Great distribution / no skew). This profiler is most useful for tables without a clustered primary key.
- Only integer columns are considered (timestamps/floats cause issues).
- Tables >10TB should use sampling mode for performance.
- Default partition size of 2048MB is recommended for partitioned ingestion.

In [0]:
# Get values
connection_name = dbutils.widgets.get("connection_name")
sql_database = dbutils.widgets.get("sql_database")
src_catalog = dbutils.widgets.get("src_catalog")
src_schema = dbutils.widgets.get("src_schema")
src_table = dbutils.widgets.get("src_table")
use_sampling = True if dbutils.widgets.get("use_sampling") == "true" else False


# Validate required fields
if not src_table:
    raise ValueError("Source table is required!")

# Determine query mode
USE_REMOTE_QUERY = bool(connection_name and sql_database)

print(f"Mode: {'remote_query' if USE_REMOTE_QUERY else 'Federation'}")
print(f"Table: {src_catalog}.{src_schema}.{src_table}")
print(f"Sampling: {'Enabled (10TB+ mode)' if use_sampling else 'Disabled'}")

print(f"Configuration:")
print(f"  Mode: {'remote_query' if USE_REMOTE_QUERY else 'Federation (fallback)'}")
print(f"  Catalog: {src_catalog}")
print(f"  Table: {src_schema}.{src_table}")
print(f"  Sampling: {'Enabled' if use_sampling else 'Disabled'}")

### Step 1: Identify Candidate Columns
First, we identify indexed columns that could be good partition candidates.

In [0]:
FOUND_CLUSTERED_PK = None

# Check for clustered PK
pk_query = f"""
SELECT 
    c.name as column_name,
    i.type_desc as index_type,
    t.name as data_type
FROM {src_catalog}.sys.indexes i
INNER JOIN {src_catalog}.sys.index_columns ic 
    ON i.object_id = ic.object_id AND i.index_id = ic.index_id
INNER JOIN {src_catalog}.sys.columns c
    ON ic.object_id = c.object_id AND ic.column_id = c.column_id
INNER JOIN {src_catalog}.sys.types t
    ON c.system_type_id = t.system_type_id
INNER JOIN {src_catalog}.sys.tables tab
    ON i.object_id = tab.object_id
INNER JOIN {src_catalog}.sys.schemas s
    ON tab.schema_id = s.schema_id
WHERE i.is_primary_key
    AND i.type_desc = 'CLUSTERED'
    AND s.name = '{src_schema}'
    AND tab.name = '{src_table}'
    AND t.name IN ('int', 'bigint', 'smallint', 'tinyint')
"""

clustered_pk_df = spark.sql(pk_query)
display(clustered_pk_df)
clustered_pk = clustered_pk_df.collect()

if clustered_pk:
    if len(clustered_pk) == 1:
        pk_column = clustered_pk[0]['column_name']
        print("=" * 60)
        print("✅ FOUND CLUSTERED PRIMARY KEY - USE THIS!")
        print("=" * 60)
        print(f"\nColumn: {pk_column}")
        print("\nWhy this is ideal:")
        print("  • Great distribution / no skew")
        print("  • Index for fast MIN/MAX")
        print("\nRecommended Configuration:")
        print(f"""
        UPDATE control_table
        SET 
            partition_col = '{pk_column}',
            partition_size_mb = 2048,
            load_partitioned = true
        WHERE src_table = '{src_catalog}.{src_schema}.{src_table}';
        """)
    else:
        # Composite PK
        pk_columns = [row['column_name'] for row in clustered_pk]
        print("=" * 60)
        print(f"⚠️ COMPOSITE PRIMARY KEY ({len(pk_columns)} columns): {', '.join(pk_columns)}")
        print("Cannot use composite key directly for partitioning")
        print("\nRun the next sections - profile individual columns to find best option:")
        print("=" * 60)
        pk_column = None
    SKIP_PROFILING = False

else:
    print("No clustered primary key found.")
    print("Run the next section to analyze alternative columns...")
    SKIP_PROFILING = False

In [0]:
if not SKIP_PROFILING:
    # Find indexed integer columns
    index_query = f"""
    SELECT DISTINCT
        c.name as column_name,
        i.type_desc as index_type
    FROM {src_catalog}.sys.indexes i
    INNER JOIN {src_catalog}.sys.index_columns ic 
        ON i.object_id = ic.object_id AND i.index_id = ic.index_id
    INNER JOIN {src_catalog}.sys.columns c
        ON ic.object_id = c.object_id AND ic.column_id = c.column_id
    INNER JOIN {src_catalog}.sys.tables t
        ON i.object_id = t.object_id
    INNER JOIN {src_catalog}.sys.schemas s
        ON t.schema_id = s.schema_id
    WHERE s.name = '{src_schema}'
        AND t.name = '{src_table}'
        AND i.type > 0
        AND NOT ic.is_included_column
        AND c.system_type_id IN (48, 52, 56, 127)  -- Integer types
        AND NOT i.is_primary_key  -- Already checked PKs
    ORDER BY index_type DESC
    LIMIT 10
    """
    
    candidate_columns = [row['column_name'] for row in spark.sql(index_query).collect()]
    
    if not candidate_columns:
        print("No indexed columns found. Checking all integer columns...")
        fallback_query = f"""
        SELECT column_name 
        FROM {src_catalog}.information_schema.columns
        WHERE table_schema = '{src_schema}'
            AND table_name = '{src_table}'
            AND data_type IN ('int', 'bigint', 'smallint', 'tinyint')
        LIMIT 10
        """
        candidate_columns = [row['column_name'] for row in spark.sql(fallback_query).collect()]
    
    print(f"Found {len(candidate_columns)} candidates to analyze: {', '.join(candidate_columns)}")

### Step 2: Profile Candidate Columns
Now we analyze each candidate column for distribution characteristics.

In [0]:
import pandas as pd

if not SKIP_PROFILING and candidate_columns:
    
    # Helper function
    def tsql_ident(name: str) -> str:
        """Escape SQL Server identifiers"""
        return '[' + name.replace(']', ']]') + ']'
    
    # Query templates - INCLUDING sampling versions
    QUERIES = {
        'basic_stats': """
            SELECT 
                COUNT_BIG(*) AS total_rows,
                COUNT_BIG(DISTINCT {column}) AS distinct_values,
                COUNT_BIG({column}) AS non_null_count,
                MIN({column}) AS min_value,
                MAX({column}) AS max_value
            FROM {schema}.{table}
        """,
        
        'distribution_stats': """
            SELECT 
                AVG(CAST(freq AS FLOAT)) AS avg_frequency,
                MAX(freq) AS max_frequency
            FROM (
                SELECT COUNT_BIG(*) AS freq
                FROM {schema}.{table}
                WHERE {column} IS NOT NULL
                GROUP BY {column}
            ) AS freq_table
        """,
        
        'distribution_stats_sampled': """
            SELECT 
                AVG(CAST(freq AS FLOAT)) AS avg_frequency,
                MAX(freq) AS max_frequency
            FROM (
                SELECT {column}, COUNT(*) AS freq
                FROM (
                    SELECT TOP 10000000 {column}
                    FROM {schema}.{table} TABLESAMPLE (1 PERCENT)
                    WHERE {column} IS NOT NULL
                ) AS sampled
                GROUP BY {column}
            ) AS freq_table
        """
    }

    def profile_column(column):
        """Profile a single column - always returns consistent types"""
        safe_column = tsql_ident(column)
        safe_schema = tsql_ident(src_schema)
        safe_table = tsql_ident(src_table)
        
        if USE_REMOTE_QUERY:
            # Basic stats (always full scan)
            basic_query = QUERIES['basic_stats'].format(
                column=safe_column, schema=safe_schema, table=safe_table
            )
            escaped_query = basic_query.replace("'", "''")
            
            stats_sql = f"""
            SELECT * FROM remote_query(
                '{connection_name}',
                database => '{sql_database}',
                query => '{escaped_query}',
                fetchsize => '10'
            )
            """
            stats = spark.sql(stats_sql).collect()[0]
            
            # Distribution stats (with sampling option)
            dist_template = 'distribution_stats_sampled' if use_sampling else 'distribution_stats'
            dist_query = QUERIES[dist_template].format(
                column=safe_column, schema=safe_schema, table=safe_table
            )
            escaped_dist = dist_query.replace("'", "''")
            
            dist_sql = f"""
            SELECT * FROM remote_query(
                '{connection_name}',
                database => '{sql_database}',
                query => '{escaped_dist}',
                fetchsize => '10'
            )
            """
            dist = spark.sql(dist_sql).collect()[0]
            
        else:
            # Federation approach
            print(f"Using federation approach for {column}")
            stats = spark.sql(f"""
                SELECT 
                    COUNT(*) as total_rows,
                    COUNT(DISTINCT {column}) as distinct_values,
                    COUNT({column}) as non_null_count,
                    MIN({column}) as min_value,
                    MAX({column}) as max_value
                FROM {src_catalog}.{src_schema}.{src_table}
            """).collect()[0]
            
            if use_sampling:
                # Sample-based distribution
                dist = spark.sql(f"""
                    WITH sampled AS (
                        SELECT {column}
                        FROM {src_catalog}.{src_schema}.{src_table} TABLESAMPLE (1000000 ROWS)
                        WHERE {column} IS NOT NULL
                    )
                    SELECT 
                        AVG(CAST(cnt AS DOUBLE)) as avg_frequency,
                        MAX(cnt) as max_frequency
                    FROM (
                        SELECT COUNT(*) as cnt
                        FROM sampled
                        GROUP BY {column}
                    ) freq
                """).collect()[0]
            else:
                # Full distribution
                dist = spark.sql(f"""
                    SELECT 
                        AVG(CAST(cnt AS DOUBLE)) as avg_frequency,
                        MAX(cnt) as max_frequency
                    FROM (
                        SELECT COUNT(*) as cnt
                        FROM {src_catalog}.{src_schema}.{src_table}
                        WHERE {column} IS NOT NULL
                        GROUP BY {column}
                    ) freq
                """).collect()[0]
        
        # Always return Row objects
        return stats, dist

    # Profile each column - with fixed result processing
    results = []
    for i, column in enumerate(candidate_columns, 1):
        print(f"[{i}/{len(candidate_columns)}] Profiling {column}... ", end="")
        if use_sampling:
            print("(sampling) ", end="")
        
        try:
            stats, dist = profile_column(column)
            
            # Access Row fields directly - no .get() needed
            total_rows = int(stats['total_rows'])
            non_null_count = int(stats['non_null_count'])
            distinct_values = int(stats['distinct_values'])
            
            # Check if distribution stats are available
            avg_freq = dist['avg_frequency']
            max_freq = dist['max_frequency']
            
            result = {
                'column_name': column,
                'total_rows': total_rows,
                'distinct_values': distinct_values,
                'null_percentage': round((total_rows - non_null_count) * 100.0 / total_rows, 2) if total_rows > 0 else 0,
                'min_value': stats['min_value'],
                'max_value': stats['max_value'],
                'avg_rows_per_value': int(avg_freq) if avg_freq is not None else None,
                'max_rows_per_value': int(max_freq) if max_freq is not None else None
            }
            
            # Calculate and display skew if available
            if avg_freq and max_freq and avg_freq > 0:
                skew = max_freq / avg_freq
                print(f"✓ Skew: {skew:.1f}x")
            else:
                print("✓ Stats collected")
                
        except Exception as e:
            print(f"✗ Failed: {str(e)[:100]}")
            result = {'column_name': column, 'error': str(e)}
        
        results.append(result)
        results_df = pd.DataFrame(results)

display(results_df)

### Step 3: Score and Rank Columns
Score each column based on its suitability as a partition column.

### Important Note: 
The skew ratio tiers (2x, 5x, 10x) are designed to catch problematic distribution patterns that can actually impact job performance:
- <2x skew: Negligible impact - partition processing times stay balanced
- 2-5x skew: Noticeable but manageable - some partitions take longer but parallelism still effective
- 5-10x skew: Performance degradation - long-tail partitions start blocking executor slots
- 10x skew: Severe impact - one partition could take longer than processing 10 average ones combined

In [0]:
# Evaluate Columns Based on Skew and Nulls

if not SKIP_PROFILING and not results_df.empty:
    
    def score_column(row):
        """Simple scoring based on skew and nulls - handles None values"""
        if 'error' in row:
            return -1
        
        # Skip if no distribution stats available
        if row.get('avg_rows_per_value') is None or row.get('max_rows_per_value') is None:
            return -2  # Can't score without distribution data
        
        null_pct = row['null_percentage']
        avg_rows = row['avg_rows_per_value']
        max_rows = row['max_rows_per_value']
        
        # Disqualify if >1% nulls
        if null_pct > 1:
            return 0
        
        # Calculate skew
        if avg_rows > 0:
            skew_ratio = max_rows / avg_rows
            
            # Score based on skew
            if skew_ratio > 10:
                return 0  # Too skewed
            elif skew_ratio > 5:
                return 40  # High skew
            elif skew_ratio > 2:
                return 70  # Moderate skew
            else:
                return 100  # Low skew
        else:
            return 0

    # Apply scoring - handle None values properly
    results_df['skew_ratio'] = results_df.apply(
        lambda r: r['max_rows_per_value'] / r['avg_rows_per_value'] 
        if r.get('avg_rows_per_value') and r['avg_rows_per_value'] > 0 
        else None, axis=1
    )
    results_df['score'] = results_df.apply(score_column, axis=1)
    
    # Sort by score
    results_df = results_df.sort_values('score', ascending=False)
    
    # Display
    display_cols = ['column_name', 'distinct_values', 'null_percentage', 
                   'skew_ratio', 'score']
    display(results_df[[c for c in display_cols if c in results_df.columns]])

### Step 4: Final Recommendations
Generate configuration recommendations for the best partition column.

In [0]:
# Configuration Recommendations

print("=" * 60)
print("PARTITION COLUMN RECOMMENDATION")
print("=" * 60)

if SKIP_PROFILING:
    # Already showed PK recommendation
    pass
elif not results_df.empty:
    valid = results_df[results_df['score'] > 0]
    
    if not valid.empty:
    # Find ALL columns with the top score
        best_score = valid.iloc[0]['score']
        top_columns = valid[valid['score'] == best_score]
    
        if len(top_columns) > 1:
            print(f"\n MULTIPLE OPTIONS (Score: {best_score}/100)")
            print("=" * 60)
            
            for idx, row in top_columns.iterrows():
                print(f"\nOption {idx+1}: {row['column_name']}")
                print(f"  • Distinct values: {row['distinct_values']:,}")
                print(f"  • Null %: {row['null_percentage']:.2f}%")
                if row.get('skew_ratio'):
                    print(f"  • Skew ratio: {row['skew_ratio']:.1f}x")
                
                # Special indicators
                if pk_column and row['column_name'] == pk_column:
                    print("   CLUSTERED PRIMARY KEY - Recommended!")
                elif row['distinct_values'] == row['total_rows']:
                    print("   Unique column - consider if truly sequential")
            
            print("\nRecommendation: Choose based on:")
            print("  1. Clustered PK (if available)")
            print("  2. Higher distinct values (more flexibility in partitioning)")
            print("  3. Business logic (predictable growth)")
            
        else:
            # Single best column
            best = top_columns.iloc[0]
            print(f"\n✅ RECOMMENDED: {best['column_name']}")
    else:
        print("\n No suitable partition columns found!")
        print("\nRecommendations:")
        print("1. Add a clustered index on primary key (like, IDENTITY column) or if not PK, add a clustered index on a high cardinality column with no null values")
        print("2. Use non-partitioned ingestion if the dataset is less than 10GB")
        
    
    print("\nConfiguration:")
    if not valid.empty:
        print(f"""
UPDATE control_table
SET 
    partition_col = '{best['column_name']}',
    partition_size_mb = 2048,
    load_partitioned = true
WHERE src_table = '{src_catalog}.{src_schema}.{src_table}';
""")
    else:
        print(f"""
UPDATE control_table
SET 
    load_partitioned = false  -- No suitable partition column
WHERE src_table = '{src_catalog}.{src_schema}.{src_table}';
""")

print("\n" + "=" * 60)
print("NOTE: Clustered primary keys are ideal partition columns.")
print("=" * 60)