# ConversionCentral Managed Profiling
Run this notebook from a Databricks Repo so backend deployments control profiling logic.

In [None]:
# Collect parameters passed by the FastAPI backend
# Each widget is declared up front so Databricks jobs can safely supply overrides.
dbutils.widgets.text("table_group_id", "")
dbutils.widgets.text("profile_run_id", "")
dbutils.widgets.text("data_quality_schema", "")
dbutils.widgets.text("payload_path", "")
dbutils.widgets.text("callback_url", "")
dbutils.widgets.text("callback_token", "")
dbutils.widgets.text("catalog", "")
dbutils.widgets.text("schema_name", "")
dbutils.widgets.text("connection_id", "")
dbutils.widgets.text("connection_name", "")
dbutils.widgets.text("system_id", "")
dbutils.widgets.text("project_key", "")
dbutils.widgets.text("http_path", "")

from datetime import datetime
import json
import requests
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
table_group_id = dbutils.widgets.get("table_group_id")
profile_run_id = dbutils.widgets.get("profile_run_id")
dq_schema = (dbutils.widgets.get("data_quality_schema") or "").strip()
payload_path = (dbutils.widgets.get("payload_path") or "").strip() or None
callback_url = (dbutils.widgets.get("callback_url") or "").strip() or None
callback_token = (dbutils.widgets.get("callback_token") or "").strip() or None
connection_catalog = (dbutils.widgets.get("catalog") or "").strip()
connection_schema = (dbutils.widgets.get("schema_name") or "").strip()

if not table_group_id or not profile_run_id:
    raise ValueError("Required widgets missing: table_group_id/profile_run_id")
if not dq_schema:
    raise ValueError("Data quality schema widget is required for profiling runs.")

In [None]:
# Profile the tables registered for this table group and build the result payload.
import re
from contextlib import suppress
from typing import Iterable

from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, BinaryType, MapType, StructType
from pyspark.sql.utils import AnalysisException

MAX_COLUMNS_TO_PROFILE = 25
NULL_RATIO_ALERT_THRESHOLD = 0.5
HIGH_NULL_RATIO_THRESHOLD = 0.9

def _split_identifier(value: str | None) -> list[str]:
    cleaned = (value or "").replace("`", "").strip()
    if not cleaned:
        return []
    return [segment.strip() for segment in cleaned.split(".") if segment.strip()]

def _catalog_component(value: str | None) -> str | None:
    parts = _split_identifier(value)
    if len(parts) >= 2:
        return parts[0]
    return None

def _schema_component(value: str | None) -> str | None:
    parts = _split_identifier(value)
    if not parts:
        return None
    return parts[-1]

def _qualify(*parts: Iterable[str | None]) -> str:
    tokens: list[str] = []
    for part in parts:
        if isinstance(part, (list, tuple)):
            tokens.extend([token for token in part if token])
        elif part:
            tokens.append(part)
    if not tokens:
        raise ValueError("Cannot build a fully qualified identifier with no parts.")
    return ".".join(f"`{token}`" for token in tokens)

metadata_catalog = _catalog_component(dq_schema)
metadata_schema = _schema_component(dq_schema)
if metadata_schema is None:
    raise ValueError("Unable to resolve schema portion of the data quality schema setting.")
if metadata_catalog is None:
    fallback_catalog = _catalog_component(connection_catalog)
    if fallback_catalog:
        metadata_catalog = fallback_catalog
    else:
        with suppress(Exception):
            metadata_catalog = spark.catalog.currentCatalog()

connection_catalog_clean = _catalog_component(connection_catalog)
connection_schema_clean = _schema_component(connection_schema)

def _metadata_table(name: str) -> str:
    return _qualify(metadata_catalog, metadata_schema, name) if metadata_catalog else _qualify(metadata_schema, name)

def _compile_patterns(mask: str | None) -> list[re.Pattern[str]]:
    if not mask:
        return []
    tokens = [token.strip() for token in re.split(r"[\n,]+", mask) if token.strip()]
    compiled: list[re.Pattern[str]] = []
    for token in tokens:
        escaped = re.escape(token).replace("\\*", ".*").replace("\\%", ".*")
        compiled.append(re.compile(f"^{escaped}$", re.IGNORECASE))
    return compiled

def _matches_pattern(patterns: list[re.Pattern[str]], schema_name: str | None, table_name: str) -> bool:
    if not patterns:
        return False
    candidate_full = ".".join(filter(None, [(schema_name or "").lower(), table_name.lower()])).strip(".")
    short_name = table_name.lower()
    for pattern in patterns:
        if pattern.match(candidate_full) or pattern.match(short_name):
            return True
    return False

def _qualify_data_table(raw_schema: str | None, table_name: str) -> str:
    table_tokens = _split_identifier(table_name)
    if len(table_tokens) >= 2:
        return _qualify(table_tokens)
    schema_tokens = _split_identifier(raw_schema)
    if len(schema_tokens) >= 2:
        return _qualify(schema_tokens + table_tokens)
    catalog_part = schema_tokens[0] if schema_tokens else connection_catalog_clean
    schema_part = schema_tokens[-1] if schema_tokens else connection_schema_clean
    return _qualify(catalog_part, schema_part, table_tokens[0] if table_tokens else table_name)

def _select_profile_columns(df) -> list[str]:
    allowed: list[str] = []
    for field in df.schema.fields:
        if isinstance(field.dataType, (BinaryType, MapType, ArrayType, StructType)):
            continue
        allowed.append(field.name)
        if len(allowed) >= MAX_COLUMNS_TO_PROFILE:
            break
    return allowed

def _record_anomaly(buffer: list[dict[str, str]], table_name: str, column_name: str | None, anomaly_type: str, severity: str, description: str, detected_at: str) -> None:
    buffer.append(
        {
            "table_name": table_name,
            "column_name": column_name,
            "anomaly_type": anomaly_type,
            "severity": severity,
            "description": description,
            "detected_at": detected_at,
        }
    )

metadata_tables_name = _metadata_table("dq_tables")
group_table_name = _metadata_table("dq_table_groups")
group_rows = (
    spark.table(group_table_name)
    .where(F.col("table_group_id") == table_group_id)
    .select("name", "profiling_include_mask", "profiling_exclude_mask")
    .limit(1)
    .collect()
)
if not group_rows:
    raise ValueError(f"Table group '{table_group_id}' not found in schema '{dq_schema}'.")
group_details = group_rows[0].asDict()
include_patterns = _compile_patterns(group_details.get("profiling_include_mask"))
exclude_patterns = _compile_patterns(group_details.get("profiling_exclude_mask"))

table_rows = (
    spark.table(metadata_tables_name)
    .where(F.col("table_group_id") == table_group_id)
    .select("schema_name", "table_name")
    .collect()
)
if not table_rows:
    raise ValueError(f"No dq_tables rows registered for table_group_id '{table_group_id}'.")

table_candidates: list[dict[str, str]] = []
for row in table_rows:
    schema_value = (row["schema_name"] or connection_schema_clean or "").strip() or None
    table_value = (row["table_name"] or "").strip()
    if not table_value:
        continue
    if include_patterns and not _matches_pattern(include_patterns, schema_value, table_value):
        continue
    if exclude_patterns and _matches_pattern(exclude_patterns, schema_value, table_value):
        continue
    label = ".".join(filter(None, [schema_value, table_value])) or table_value
    table_candidates.append({"schema_name": schema_value, "table_name": table_value, "label": label})

if not table_candidates:
    raise ValueError("All candidate tables were filtered out by include/exclude masks.")

generated_at = datetime.utcnow().isoformat() + "Z"
anomalies: list[dict[str, str]] = []
table_profiles: list[dict[str, object]] = []
total_rows = 0
profiling_failures = 0
profiling_successes = 0

print(f"Profiling {len(table_candidates)} tables for group {table_group_id}.")
for candidate in table_candidates:
    schema_value = candidate["schema_name"]
    table_value = candidate["table_name"]
    label = candidate["label"]
    qualified_name = _qualify_data_table(schema_value, table_value)
    table_result: dict[str, object] = {
        "table_name": label,
        "qualified_name": qualified_name,
    }
    print(f"-> Scanning {qualified_name}")
    try:
        df = spark.read.table(qualified_name)
    except AnalysisException as exc:
        profiling_failures += 1
        table_result["error"] = str(exc)
        _record_anomaly(anomalies, label, None, "missing_table", "high", f"Spark could not read {qualified_name}: {exc}", generated_at)
        table_profiles.append(table_result)
        continue
    except Exception as exc:
        profiling_failures += 1
        table_result["error"] = str(exc)
        _record_anomaly(anomalies, label, None, "profiling_error", "high", f"Unexpected error while reading {qualified_name}: {exc}", generated_at)
        table_profiles.append(table_result)
        continue

    row_count = df.count()
    table_result["row_count"] = int(row_count)
    total_rows += row_count

    if row_count == 0:
        profiling_failures += 1
        _record_anomaly(anomalies, label, None, "empty_table", "high", "Table returned zero rows during profiling.", generated_at)
        table_profiles.append(table_result)
        continue

    profiling_successes += 1
    profile_columns = _select_profile_columns(df)
    table_result["profiled_columns"] = profile_columns
    full_column_count = len(df.columns)
    if full_column_count > len(profile_columns):
        table_result["profiled_columns_truncated"] = full_column_count - len(profile_columns)

    if profile_columns:
        agg_exprs = [F.sum(F.when(F.col(col_name).isNull(), 1).otherwise(0)).alias(col_name) for col_name in profile_columns]
        null_counts = df.agg(*agg_exprs).collect()[0].asDict()
        column_null_ratios: dict[str, float] = {}
        for column in profile_columns:
            null_ratio = float((null_counts.get(column, 0) or 0) / row_count)
            column_null_ratios[column] = null_ratio
            if null_ratio >= NULL_RATIO_ALERT_THRESHOLD:
                severity = "high" if null_ratio >= HIGH_NULL_RATIO_THRESHOLD else "medium"
                description = f"Null ratio {null_ratio:.2%} exceeds {NULL_RATIO_ALERT_THRESHOLD:.0%} threshold."
                _record_anomaly(anomalies, label, column, "null_ratio", severity, description, generated_at)
        table_result["column_null_ratios"] = column_null_ratios

    table_profiles.append(table_result)

status = "completed" if profiling_successes else "failed"
results = {
    "table_group_id": table_group_id,
    "profile_run_id": profile_run_id,
    "table_group_name": group_details.get("name"),
    "status": status,
    "row_count": int(total_rows),
    "anomaly_count": len(anomalies),
    "anomalies": anomalies,
    "generated_at": generated_at,
    "table_profiles": table_profiles,
    "diagnostics": {
        "tables_requested": len(table_rows),
        "tables_profiled": profiling_successes,
        "tables_failed": profiling_failures,
        "include_mask_applied": bool(include_patterns),
        "exclude_mask_applied": bool(exclude_patterns),
    },
}

print(
    f"Profiling complete: {profiling_successes} succeeded, {profiling_failures} failed, total rows={total_rows}."
)

In [None]:
# Persist payload and call back into the API

def _resolve_callback_target(base_url: str | None, run_id: str) -> str | None:
    if not base_url:
        return None
    normalized = base_url.strip()
    if not normalized:
        return None
    if "{profile_run_id}" in normalized:
        try:
            return normalized.format(profile_run_id=run_id)
        except (KeyError, ValueError):
            pass
    normalized = normalized.rstrip("/")
    if normalized.endswith("/complete"):
        return normalized
    return f"{normalized}/{run_id}/complete"

if payload_path:
    dbutils.fs.put(payload_path, json.dumps(results, indent=2), overwrite=True)
    print(f"Wrote profiling payload to {payload_path}")
else:
    print("No payload path supplied; skipping artifact export.")

callback_target = _resolve_callback_target(callback_url, profile_run_id)
if callback_target:
    headers = {"Content-Type": "application/json"}
    if callback_token:
        headers["Authorization"] = f"Bearer {callback_token}"
    callback_body = {
        "status": results["status"],
        "row_count": results["row_count"],
        "anomaly_count": results["anomaly_count"],
        "anomalies": results["anomalies"],
    }
    response = requests.post(callback_target, headers=headers, json=callback_body, timeout=30)
    response.raise_for_status()
    print(f"Callback succeeded: {callback_target} ({response.status_code})")
else:
    print("Callback URL not provided; skipping completion POST.")
results