In [0]:
%pip install pyyaml

In [None]:
from databricks.sdk import WorkspaceClient
import os, re, yaml
from math import ceil
from collections import defaultdict, deque

def generate_pipeline_gateway_ymls(
    metadata_df,
    output_name: str = "generated",
    destination_catalog_default: str = "poc2_test_catalog",
    destination_schema_default: str = "final_saas_db",
    source_type_default: str = "SQLSERVER",
    pipeline_table_cap: int = 150,
    gateway_table_cap: int = 1000,
    large_table_threshold: int = 50_000_000,
    cdc_applier_timeout_seconds: str = "600",
    cluster_node_type_id: str = "m5d.large",
    cluster_driver_node_type_id: str = "c5a.8xlarge",
    cluster_num_workers: int = 1,
    output_base_dir: str = "resources",
    use_dabs_references: bool = True,
    work_client: WorkspaceClient | None = None,
    debug: bool = False,
):
    """
    Generate Databricks Asset Bundle YMLs (gateways & pipelines) from a DataFrame.
    - The DataFrame must contain: server_name, connection_name, database_name, schema_name, table_name, row_count, priority_flag
    - Creates two files under:
        resources/gateways/gateway_<output_name>.yml
        resources/pipelines/pipeline_<output_name>.yml

    Args:
        metadata_df: Spark DataFrame with required columns
        output_name: Base name for generated YAML files
        destination_catalog_default: Default destination catalog name (or DABS resource key if use_dabs_references=True)
        destination_schema_default: Default destination schema name (or DABS resource key if use_dabs_references=True)
        source_type_default: Default source type (e.g., SQLSERVER)
        pipeline_table_cap: Max tables per pipeline
        gateway_table_cap: Max tables per gateway
        large_table_threshold: Row count threshold for large tables
        cdc_applier_timeout_seconds: CDC applier timeout
        cluster_node_type_id: Worker node type
        cluster_driver_node_type_id: Driver node type
        cluster_num_workers: Number of workers
        output_base_dir: Base directory for output files
        use_dabs_references: If True, generate DABS variable references for schemas/catalogs
        work_client: Optional WorkspaceClient instance
        debug: Enable debug output
    """

    # --- Helpers ---
    def slug(s: str) -> str:
        return re.sub(r"_+", "_", re.sub(r"[^a-z0-9_]+", "_", (s or "").lower())).strip("_")

    def normalize_server_name(s: str) -> str:
        s = re.sub(r"\s+", " ", (s or "")).strip()
        return s.lower()

    def write_yml(obj, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w", encoding="utf-8") as f:
            yaml.safe_dump(obj, f, sort_keys=False)
    
    def get_catalog_reference(catalog_name: str) -> str:
        """Generate catalog reference - either DABS variable or direct name"""
        if use_dabs_references:
            # Assume schemas.yml defines: resources.schemas.<schema_name>.catalog_name
            return f"${{resources.schemas.{destination_schema_default}.catalog_name}}"
        return catalog_name
    
    def get_schema_reference(schema_name: str) -> str:
        """Generate schema reference - either DABS variable or direct name"""
        if use_dabs_references:
            return f"${{resources.schemas.{schema_name}.name}}"
        return schema_name

    # --- Derive output file names ---
    gateways_filename  = f"gateway_{output_name}.yml"
    pipelines_filename = f"pipeline_{output_name}.yml"

    gateways_path  = os.path.join(output_base_dir, "gateways",  gateways_filename)
    pipelines_path = os.path.join(output_base_dir, "pipelines", pipelines_filename)

    # Generate references for catalog and schema
    catalog_ref = get_catalog_reference(destination_catalog_default)
    schema_ref = get_schema_reference(destination_schema_default)

    # --- Load metadata from DataFrame ---
    df = metadata_df.select(
        "server_name", "connection_name", "database_name",
        "schema_name", "table_name", "row_count", "priority_flag"
    )

    rows = []
    for r in df.collect():
        server_key  = normalize_server_name(r["server_name"])
        server_slug = slug(server_key)
        rows.append({
            "server_key": server_key,
            "server_slug": server_slug,
            "connection_name": r["connection_name"],
            "source_catalog":  r["database_name"],
            "source_schema":   r["schema_name"],
            "source_table":    r["table_name"],
            "row_count":       int(r["row_count"] or 0),
            "priority_flag":   int(r["priority_flag"] or 0),
            "destination_catalog": catalog_ref,
            "destination_schema":  schema_ref,
            "source_type":         source_type_default,
        })

    if debug:
        print(f"üì¶ Loaded {len(rows)} rows from DataFrame")
        if use_dabs_references:
            print(f"üîó Using DABS references: catalog={catalog_ref}, schema={schema_ref}")

    # --- Initialize Databricks client ---
    w = work_client or WorkspaceClient()
    connection_cache = {c.name: c.connection_id for c in w.connections.list()}

    def resolve_connection_id(name: str) -> str:
        cid = connection_cache.get(name)
        if not cid:
            print(f"‚ö†Ô∏è Warning: Connection '{name}' not found in workspace connections.")
            return "00000000-0000-0000-0000-000000000000"
        return cid

    # --- Group rows per server ---
    by_server = defaultdict(list)
    for r in rows:
        by_server[r["server_key"]].append(r)

    server_gateways = {}
    server_pipelines = {}

    # --- Allocation logic ---
    for server_key, items in by_server.items():
        server_slug = items[0]["server_slug"]

        priority_items = [t for t in items if t["priority_flag"] == 1]
        normal_items   = [t for t in items if t["priority_flag"] != 1]

        large_items = [t for t in normal_items if t["row_count"] >= large_table_threshold]
        small_items = [t for t in normal_items if t["row_count"] <  large_table_threshold]

        pipelines = []

        # 1Ô∏è‚É£ Dedicated pipelines for priority tables
        for t in priority_items:
            pname = f"final_{server_slug}_prio_{slug(t['source_table'])}"
            pipelines.append({"name": pname[:100], "tables": [t]})

        # 2Ô∏è‚É£ Normal pipelines: distribute large tables evenly, fill small
        normal_count   = len(normal_items)
        base_pipelines = max(ceil(max(normal_count, 1) / pipeline_table_cap), len(large_items))
        bins = [{"name": f"final_{server_slug}_ingestion_{i+1}", "tables": []} for i in range(base_pipelines or 1)]

        for idx, t in enumerate(large_items):
            bins[idx % len(bins)]["tables"].append(t)

        def large_count(bin_): 
            return sum(1 for tt in bin_["tables"] if tt["row_count"] >= large_table_threshold)

        small_q = deque(small_items)
        while small_q:
            bins.sort(key=lambda b: (len(b["tables"]), large_count(b)))
            b = bins[0]
            if len(b["tables"]) >= pipeline_table_cap:
                bname = f"final_{server_slug}_ingestion_{len(bins)+1}"
                bins.append({"name": bname, "tables": []})
                continue
            b["tables"].append(small_q.popleft())

        pipelines.extend([b for b in bins if b["tables"]])

        # 3Ô∏è‚É£ Gateways per server (‚â§ gateway_table_cap tables/gateway)
        total_tables = sum(len(p["tables"]) for p in pipelines)
        num_gateways = max(1, ceil(total_tables / gateway_table_cap))

        base_conn_name   = items[0]["connection_name"]
        resolved_conn_id = resolve_connection_id(base_conn_name)

        gateways = []
        for gidx in range(num_gateways):
            gname = f"final_{server_slug}_gateway_{gidx+1}"
            gateways.append({
                "resource_key": f"pipeline_{gname}",
                "name": gname,
                "connection_name": base_conn_name,
                "connection_id":   resolved_conn_id,
                "storage_catalog": catalog_ref,
                "storage_schema":  schema_ref,
                "gateway_storage_name": gname,
                "source_type": source_type_default,
                "assigned_table_count": 0,
            })

        # Assign pipelines evenly across gateways
        pipelines.sort(key=lambda p: len(p["tables"]), reverse=True)
        for p in pipelines:
            gateways.sort(key=lambda g: g["assigned_table_count"])
            g = gateways[0]
            g["assigned_table_count"] += len(p["tables"])
            p["gateway_ref"] = g["resource_key"]

        server_gateways[server_key]  = gateways
        server_pipelines[server_key] = pipelines

        if debug:
            print(f"üñ•Ô∏è {server_key}: {len(pipelines)} pipelines, {len(gateways)} gateways")

    # --- Build YMLs ---
    gateways_yml = {"resources": {"pipelines": {}}}
    for _, gateways in server_gateways.items():
        for g in gateways:
            gateways_yml["resources"]["pipelines"][g["resource_key"]] = {
                "name": g["name"],
                "clusters": [{
                    "node_type_id": cluster_node_type_id,
                    "driver_node_type_id": cluster_driver_node_type_id,
                    "num_workers": cluster_num_workers,
                }],
                "gateway_definition": {
                    "connection_name": g["connection_name"],
                    "connection_id":   g["connection_id"],
                    "gateway_storage_catalog": g["storage_catalog"],
                    "gateway_storage_schema":  g["storage_schema"],
                    "gateway_storage_name":    g["gateway_storage_name"],
                    "source_type": g["source_type"],
                },
                "target": g["storage_schema"],
                "continuous": True,
                "catalog": g["storage_catalog"],
            }

    pipelines_yml = {"resources": {"pipelines": {}}}
    for _, pipes in server_pipelines.items():
        for p in pipes:
            key = f"pipeline_{slug(p['name'])}"
            objects = [{
                "table": {
                    "source_catalog": t["source_catalog"],
                    "source_schema":  t["source_schema"],
                    "source_table":   t["source_table"],
                    "destination_catalog": t["destination_catalog"],
                    "destination_schema":  t["destination_schema"],
                }
            } for t in p["tables"]]
            first = p["tables"][0]
            pipelines_yml["resources"]["pipelines"][key] = {
                "name": p["name"],
                "configuration": {
                    "pipelines.cdcApplierFetchMetadataTimeoutSeconds": cdc_applier_timeout_seconds
                },
                "ingestion_definition": {
                    "ingestion_gateway_id": f"${{resources.pipelines.{p['gateway_ref']}.id}}",
                    "objects": objects,
                    "source_type": first["source_type"],
                },
                "target": first["destination_schema"],
                "catalog": first["destination_catalog"],
            }

    # --- Write .yml files ---
    write_yml(gateways_yml,  gateways_path)
    write_yml(pipelines_yml, pipelines_path)

    # --- Print summary ---
    summary = {
        "servers": len(by_server),
        "gateways": sum(len(g) for g in server_gateways.values()),
        "pipelines": sum(len(p) for p in server_pipelines.values()),
        "tables": sum(len(p2["tables"]) for ps in server_pipelines.values() for p2 in ps),
        "paths": {"gateways": gateways_path, "pipelines": pipelines_path},
        "use_dabs_references": use_dabs_references,
    }

    print("\n‚úÖ YAML generation complete:")
    print(f"  üìÅ Gateway file created:  {gateways_path}")
    print(f"  üìÅ Pipeline file created: {pipelines_path}")
    print(f"  üñ•Ô∏è Servers processed:     {summary['servers']}")
    print(f"  üß© Gateways created:      {summary['gateways']}")
    print(f"  üîÑ Pipelines created:     {summary['pipelines']}")
    print(f"  üìä Tables assigned:       {summary['tables']}")
    if use_dabs_references:
        print(f"  üîó Using DABS schema references")
    print()

    return summary

In [None]:
# Example 1: Load from Unity Catalog table with DABS schema references
metadata_df = spark.table("jack_demos.pipeline_split.synthetic_table_inventory_refactor")

result = generate_pipeline_gateway_ymls(
    metadata_df=metadata_df,
    output_name="synthetic_table_inventory_refactor",
    destination_catalog_default="poc2_test_catalog",
    destination_schema_default="final_saas_db",
    source_type_default="SQLSERVER",
    pipeline_table_cap=150,
    gateway_table_cap=1000,
    large_table_threshold=50_000_000,
    output_base_dir="resources",
    use_dabs_references=True,  # Use DABS schema references (default)
    debug=True,
)

display(result)

# Example 2: Create DataFrame from other sources
# You can now create the DataFrame from CSV, API, or any other source
# df = spark.read.csv("path/to/metadata.csv", header=True)
# result = generate_pipeline_gateway_ymls(
#     metadata_df=df, 
#     output_name="my_ingestion",
#     use_dabs_references=False,  # Use hardcoded catalog/schema names
#     ...
# )