# Semantic Model Generator

## Requirements

In [1]:
!pip install -q uv
!uv pip install -q mssql-python --system

In [2]:
import mssql_python
import pandas as pd
import struct
import os
import sys
import uuid
import hashlib
import base64
import json
import time
import requests
import contextlib
import io

from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any, Union

## Helper Functions

In [3]:
def connect_to_warehouse(sql_endpoint: str, database: str):
    """Connect to Fabric warehouse and return a cursor."""
    token_bytes = notebookutils.credentials.getToken("https://database.windows.net").encode("UTF-16-LE")

    conn_str = (
        f"Server={sql_endpoint},1433;"
        f"Database={database};"
        "Encrypt=yes;"
        "TrustServerCertificate=no;"
    )

    token_struct = struct.pack(
        f"<I{len(token_bytes)}s",
        len(token_bytes),
        token_bytes
    )

    conn = mssql_python.connect(conn_str, attrs_before={1256: token_struct})
    return conn.cursor()


In [4]:
def load_information_schema(cursor, schemas=("Dim", "Fact")) -> pd.DataFrame:
    """Load INFORMATION_SCHEMA.COLUMNS into a DataFrame."""
    if isinstance(schemas, str):
        schemas = [schemas]

    schema_list = ", ".join([f"'{s}'" for s in schemas])

    sql = f"""
    SELECT
        [TABLE_CATALOG],
        [TABLE_SCHEMA],
        [TABLE_NAME],
        [COLUMN_NAME],
        [DATA_TYPE],
        [IS_NULLABLE],
        [CHARACTER_MAXIMUM_LENGTH],
        [NUMERIC_PRECISION],
        [NUMERIC_SCALE],
        [ORDINAL_POSITION]
    FROM [INFORMATION_SCHEMA].[COLUMNS]
    WHERE [TABLE_SCHEMA] IN ({schema_list})
    ORDER BY
        [TABLE_CATALOG],
        [TABLE_SCHEMA],
        [TABLE_NAME],
        [ORDINAL_POSITION];
    """

    cursor.execute(sql)
    rows = cursor.fetchall()

    return pd.DataFrame.from_records(
        rows,
        columns=[c[0] for c in cursor.description]
    )


In [5]:
def deterministic_uuid(namespace: str, *values: str) -> str:
    """
    Generate a deterministic UUID based on namespace and values.
    
    Args:
        namespace: Namespace for the UUID (e.g., 'column', 'relationship', 'table')
        *values: Values to include in the hash (e.g., table name, column name)
    
    Returns:
        UUID string
        
    Examples:
        >>> uuid1 = deterministic_uuid('column', 'customers', 'id')
        >>> uuid2 = deterministic_uuid('column', 'customers', 'id')
        >>> uuid1 == uuid2
        True
        >>> uuid3 = deterministic_uuid('column', 'customers', 'name')
        >>> uuid1 != uuid3
        True
    """
    # Create a stable string from namespace and values
    content = f"{namespace}:{'|'.join(values)}"
    # Generate MD5 hash and convert to UUID format
    hash_bytes = hashlib.md5(content.encode('utf-8')).digest()
    return str(uuid.UUID(bytes=hash_bytes))

In [6]:
def sql_type_to_tmdl_type(sql_type: str, column_name: str) -> Tuple[str, str]:
    """
    Convert SQL data type to TMDL data type and format string.
    
    Args:
        sql_type: SQL data type (e.g., 'bigint', 'varchar', 'datetime')
        column_name: Column name (used for context)
    
    Returns:
        Tuple of (dataType, formatString)
        
    Examples:
        >>> sql_type_to_tmdl_type('bigint', 'count')
        ('int64', '0')
        >>> sql_type_to_tmdl_type('varchar', 'name')
        ('string', '')
        >>> sql_type_to_tmdl_type('datetime', 'created_at')
        ('dateTime', 'General Date')
        >>> sql_type_to_tmdl_type('decimal', 'amount')
        ('decimal', '0.00')
        >>> sql_type_to_tmdl_type('bit', 'is_active')
        ('boolean', '')
    """
    sql_type_lower = sql_type.lower()
    
    # Date/Time types
    if sql_type_lower in ('datetime', 'datetime2', 'date', 'smalldatetime', 'timestamp', 'timestamptz', 'timestamp_tz'):
        return 'dateTime', 'General Date'
    elif sql_type_lower == 'time':
        return 'dateTime', 'Long Time'
    
    # Numeric types
    elif sql_type_lower in ('int', 'bigint', 'smallint', 'tinyint'):
        return 'int64', '0'
    elif sql_type_lower in ('decimal', 'numeric', 'money', 'smallmoney'):
        return 'decimal', '0.00'
    elif sql_type_lower in ('float', 'real'):
        return 'double', '0.00'
    
    # Boolean
    elif sql_type_lower == 'bit':
        return 'boolean', ''
    
    # String (default)
    else:
        return 'string', ''

In [7]:
def determine_summarization(column_name: str, data_type: str) -> str:
    """
    Determine summarization strategy for a column.
    
    Args:
        column_name: Name of the column
        data_type: SQL data type
    
    Returns:
        Summarization strategy ('sum' or 'none')
        
    Examples:
        >>> determine_summarization('measure__amount', 'decimal')
        'sum'
        >>> determine_summarization('name', 'varchar')
        'none'
        >>> determine_summarization('quantity', 'int')
        'none'
        >>> determine_summarization('year', 'bigint')
        'none'
        >>> determine_summarization('id', 'int')
        'none'
    """
    # Only summarize columns explicitly marked as measures
    if column_name.startswith('measure__'):
        return 'sum'
    
    # Everything else: no summarization (Power BI will handle aggregation in DAX)
    return 'none'

In [8]:
def identify_relationships(
    tables_metadata: Dict[str, List[Dict[str, Any]]],
    key_prefixes: Union[str, List[str]] = None,
    exact_match_prefixes: Union[str, List[str]] = None
) -> Tuple[List[Dict[str, str]], List[str]]:
    """
    Identify relationships for a star schema.

    Automatically detects dimensions vs facts based on key column count:
    - Dimension: Tables with exactly 1 key column (their primary key only)
    - Fact: Tables with >1 key columns (their PK + foreign keys to dimensions)

    Relationships are 1:* from dimension to fact.

    Supports role-playing dimensions where the same dimension is referenced multiple times
    with different roles (e.g., _hk__customer__bill_to and _hk__customer__sell_to both
    connect to Dim Customer, or _wk__period__document and _wk__period__due both connect
    to Dim Period). The first relationship (by column order) is marked active,
    subsequent ones are inactive.

    Columns matching exact_match_prefixes use exact matching only (no role-playing pattern).

    Args:
        tables_metadata: Dictionary mapping table names to column metadata lists
        key_prefixes: Column prefix(es) to identify hash keys. Can be a string or list of strings (default: ['_hk__', '_wk__'])
        exact_match_prefixes: Column prefix(es) that should use exact matching only, not role-playing pattern. Can be a string or list of strings (default: ['_wk__ref__'])

    Returns:
        Tuple of (list of relationship dictionaries, list of tables with no relationships)

    Examples:
        >>> metadata = {
        ...     'dim__customers': [
        ...         {'COLUMN_NAME': '_hk__customer', 'DATA_TYPE': 'string'},
        ...         {'COLUMN_NAME': 'name', 'DATA_TYPE': 'string'}
        ...     ],
        ...     'fact__sales': [
        ...         {'COLUMN_NAME': '_hk__sale', 'DATA_TYPE': 'string'},
        ...         {'COLUMN_NAME': '_hk__customer', 'DATA_TYPE': 'string'},
        ...         {'COLUMN_NAME': 'amount', 'DATA_TYPE': 'decimal'}
        ...     ]
        ... }
        >>> rels, no_rels = identify_relationships(metadata)
        >>> len(rels)
        1
        >>> rels[0]['from_table']  # Fact is the 'from' (* side)
        'fact__sales'
        >>> rels[0]['to_table']  # Dimension is the 'to' (1 side)
        'dim__customers'
    """
    if key_prefixes is None:
        key_prefixes = ['_hk__', '_wk__']
    elif isinstance(key_prefixes, str):
        key_prefixes = [key_prefixes]

    if exact_match_prefixes is None:
        exact_match_prefixes = ['_wk__ref__']
    elif isinstance(exact_match_prefixes, str):
        exact_match_prefixes = [exact_match_prefixes]

    relationships: List[Dict[str, str]] = []
    tables_with_relationships = set()

    # Classify tables based on key column count
    dim_tables: Dict[str, str] = {}  # table_name -> primary key column
    fact_tables: Dict[str, List[Dict[str, Any]]] = {}  # table_name -> list of key column metadata (to preserve order)

    for table_name, columns in tables_metadata.items():
        hk_cols = [
            col
            for col in columns
            if any(col['COLUMN_NAME'].startswith(prefix) for prefix in key_prefixes)
        ]

        if len(hk_cols) == 1:
            # Dimension: exactly 1 key column (the PK)
            dim_tables[table_name] = hk_cols[0]['COLUMN_NAME']
        elif len(hk_cols) > 1:
            # Fact: multiple key columns (PK + FKs) - preserve column metadata with order
            fact_tables[table_name] = hk_cols

    # Helper function to extract the base key name from a column
    # e.g., _hk__customer__bill_to -> _hk__customer
    # e.g., _hk__customer -> _hk__customer
    # e.g., _wk__period__document -> _wk__period
    # Exception: columns matching exact_match_prefixes use exact matching only (no role-playing)
    def extract_base_key(column_name: str, prefixes: List[str]) -> Optional[str]:
        """Extract the base key name (prefix + first segment).

        Returns None for columns that should use exact matching only.
        """
        # Check if column matches any exact-match prefix
        for exact_prefix in exact_match_prefixes:
            if column_name.startswith(exact_prefix):
                return None

        for prefix in prefixes:
            if column_name.startswith(prefix):
                # Remove prefix
                without_prefix = column_name[len(prefix):]
                # Get first segment (before the first __ if any)
                if '__' in without_prefix:
                    first_segment = without_prefix.split('__')[0]
                else:
                    first_segment = without_prefix
                # Reconstruct base key
                return f"{prefix}{first_segment}"
        return None

    # Create relationships: fact (*) -> dimension (1)
    # In Power BI: fromColumn = many side, toColumn = one side
    for dim_name, dim_pk in dim_tables.items():
        # Track first relationship per fact-dimension pair (for setting isActive)
        dimension_base_key = extract_base_key(dim_pk, key_prefixes)

        for fact_name, fact_key_cols in fact_tables.items():
            # Find all matching columns in the fact table
            # Match by exact name OR by base key name (when both support it)
            matching_cols = []
            for col_metadata in fact_key_cols:
                fact_col_name = col_metadata['COLUMN_NAME']
                fact_base_key = extract_base_key(fact_col_name, key_prefixes)

                # Match if exact match OR if both have base keys and they match
                if fact_col_name == dim_pk or (fact_base_key is not None and dimension_base_key is not None and fact_base_key == dimension_base_key):
                    matching_cols.append((col_metadata, fact_col_name))

            # Create relationships for all matches
            # First one (by column order) is active, rest are inactive
            for idx, (col_metadata, fact_col_name) in enumerate(matching_cols):
                is_active = (idx == 0)

                # Use the fact column name in the relationship ID to make each one unique
                relationship_id = deterministic_uuid('relationship', fact_name, dim_name, fact_col_name, dim_pk)
                relationships.append({
                    'id': relationship_id,
                    'from_table': fact_name,     # Fact (* side)
                    'from_column': fact_col_name,
                    'to_table': dim_name,        # Dimension (1 side)
                    'to_column': dim_pk,
                    'is_active': is_active
                })
                tables_with_relationships.add(dim_name)
                tables_with_relationships.add(fact_name)

    # Identify tables with no relationships
    all_tables = set(tables_metadata.keys())
    tables_without_relationships = sorted(all_tables - tables_with_relationships)

    return relationships, tables_without_relationships

In [9]:
def generate_column_tmdl(col: Dict[str, Any], table_name: str, schema: str) -> str:
    """
    Generate TMDL content for a single column.

    Args:
        col: Column metadata dictionary
        table_name: Name of the table
        schema: Schema name

    Returns:
        TMDL string for the column

    Examples:
        >>> col = {'COLUMN_NAME': 'id', 'DATA_TYPE': 'bigint'}
        >>> tmdl = generate_column_tmdl(col, 'customers', 'star_schema')
        >>> 'column id' in tmdl
        True
        >>> 'dataType: int64' in tmdl
        True
    """
    col_name = col['COLUMN_NAME']
    sql_type = col['DATA_TYPE']

    # Always quote column names for safety and consistency
    column_header = f"\tcolumn '{col_name}'"
    lines = [column_header]
    
    # Determine data type and format
    data_type, format_string = sql_type_to_tmdl_type(sql_type, col_name)
    lines.append(f"\t\tdataType: {data_type}")
    
    if format_string:
        lines.append(f"\t\tformatString: {format_string}")
    
    # Lineage tags
    col_lineage_tag = deterministic_uuid('column', table_name, col_name)
    lines.append(f"\t\tlineageTag: {col_lineage_tag}")
    lines.append(f"\t\tsourceLineageTag: {col_name}")
    
    # Summarization
    summarize_by = determine_summarization(col_name, sql_type)
    lines.append(f"\t\tsummarizeBy: {summarize_by}")
    
    # Source column
    lines.append(f"\t\tsourceColumn: {col_name}")
    lines.append("")
    
    # Annotation
    lines.append("\t\tannotation SummarizationSetBy = Automatic")
    lines.append("")
    
    return "\n".join(lines)

In [10]:
def generate_measure_tmdl(col: Dict[str, Any], table_name: str) -> str:
    """
    Generate TMDL content for a measure based on a measure__ column.

    Args:
        col: Column metadata dictionary (must have COLUMN_NAME starting with 'measure__')
        table_name: Name of the table

    Returns:
        TMDL string for the measure

    Examples:
        >>> col = {'COLUMN_NAME': 'measure__total_amount', 'DATA_TYPE': 'decimal'}
        >>> tmdl = generate_measure_tmdl(col, 'sales')
        >>> 'measure total_amount' in tmdl
        True
        >>> "SUM('sales'[measure__total_amount])" in tmdl
        True
    """
    col_name = col['COLUMN_NAME']
    measure_name = col_name.replace('measure__', '', 1)

    # Always quote measure names for safety and consistency
    lines = [f"\tmeasure '{measure_name}'"]

    # DAX expression: SUM of the column
    # In DAX, column references are always in brackets [column]
    lines.append(f"\t\texpression: SUM('{table_name}'[{col_name}])")

    # Format string based on data type
    sql_type = col['DATA_TYPE']
    _, format_string = sql_type_to_tmdl_type(sql_type, col_name)
    if format_string:
        lines.append(f"\t\tformatString: {format_string}")

    # Lineage tag
    measure_lineage_tag = deterministic_uuid('measure', table_name, col_name)
    lines.append(f"\t\tlineageTag: {measure_lineage_tag}")
    lines.append("")

    return "\n".join(lines)

In [11]:
def generate_table_tmdl(
    table_name: str,
    columns: List[Dict[str, Any]],
    schema: str,
    catalog: str,
    entity_name: Optional[str] = None
) -> str:
    """
    Generate TMDL content for a single table.

    Args:
        table_name: Logical name of the table in the model
        columns: List of column metadata dictionaries
        schema: Schema name
        catalog: Catalog name
        entity_name: Source table name in the lakehouse (defaults to table_name)

    Returns:
        Complete TMDL string for the table

    Examples:
        >>> cols = [
        ...     {'COLUMN_NAME': 'id', 'DATA_TYPE': 'bigint'},
        ...     {'COLUMN_NAME': 'measure__amount', 'DATA_TYPE': 'decimal'}
        ... ]
        >>> tmdl = generate_table_tmdl('sales', cols, 'star_schema', 'gold')
        >>> 'table sales' in tmdl
        True
        >>> 'measure amount' in tmdl
        True
        >>> 'partition sales = entity' in tmdl
        True
    """
    if entity_name is None:
        entity_name = table_name

    # Always quote table names for safety and consistency
    table_header = f"table '{table_name}'"
    lines = [
        "/// Generated by generate_semantic_model.py - Do not edit manually",
        table_header
    ]

    # Generate unique lineage tags
    table_lineage_tag = deterministic_uuid('table', table_name)
    lines.append(f"	lineageTag: {table_lineage_tag}")
    lines.append(f"	sourceLineageTag: [{schema}].[{entity_name}]")
    lines.append("")

    # Generate columns
    for col in columns:
        lines.append(generate_column_tmdl(col, table_name, schema))

    # Generate measures for columns prefixed with measure__
    measure_columns = [col for col in columns if col['COLUMN_NAME'].startswith('measure__')]
    for col in measure_columns:
        lines.append(generate_measure_tmdl(col, table_name))

    # Partition (for Direct Lake mode)
    # Always quote partition names for safety and consistency
    partition_header = f"	partition '{table_name}' = entity"
    lines.append(partition_header)
    lines.append("		mode: directLake")
    lines.append("		source")
    lines.append(f"			entityName: {entity_name}")
    lines.append(f"			schemaName: {schema}")
    lines.append(f"			expressionSource: 'DirectLake - {catalog}'")

    # Annotations
    lines.append("")
    lines.append("	annotation PBI_ResultMetrics = []")

    return "\n".join(lines) + "\n\n"


In [12]:
def generate_relationships_tmdl(
    relationships: List[Dict[str, str]],
    manual_relationships: Optional[str] = None,
    assume_referential_integrity: bool = False
) -> str:
    """
    Generate TMDL content for relationships.

    Args:
        relationships: List of relationship dictionaries
        manual_relationships: Optional string of manually maintained relationships
        assume_referential_integrity: Whether to add relyOnReferentialIntegrity (default: False)

    Returns:
        TMDL string for all relationships

    Examples:
        >>> rels = [{
        ...     'id': '123e4567-e89b-12d3-a456-426614174000',
        ...     'from_table': '_events',
        ...     'from_column': '_uid__customers',
        ...     'to_table': 'customers',
        ...     'to_column': '_uid__customers'
        ... }]
        >>> tmdl = generate_relationships_tmdl(rels)
        >>> 'relationship 123e4567-e89b-12d3-a456-426614174000' in tmdl
        True
        >>> 'fromColumn: _events._uid__customers' in tmdl
        True
    """
    lines = []

    for rel in relationships:
        lines.append(f"relationship {rel['id']}")

        # Add isActive if specified and false (default is true, so only emit when false)
        if 'is_active' in rel and not rel['is_active']:
            lines.append("\tisActive: false")

        if assume_referential_integrity:
            lines.append("\trelyOnReferentialIntegrity")
        # Always quote table names, column names are not quoted after the dot
        # Format: 'Table'.column
        from_table_escaped = rel['from_table'].replace("'", "''")
        to_table_escaped = rel['to_table'].replace("'", "''")
        from_col = f"'{from_table_escaped}'.{rel['from_column']}"
        to_col = f"'{to_table_escaped}'.{rel['to_column']}"
        lines.append(f"\tfromColumn: {from_col}")
        lines.append(f"\ttoColumn: {to_col}")
        lines.append("")

    # Add any preserved manual relationships
    if manual_relationships:
        lines.append(manual_relationships.rstrip())

    # Ensure we don't have trailing empty lines, but do have one final blank line
    content = "\n".join(lines).rstrip()
    return content + "\n\n" if content else ""

In [13]:
def generate_model_tmdl(table_names: List[str], catalog: str, preserved_tables: Optional[List[str]] = None) -> str:
    """
    Generate model.tmdl content.
    
    Args:
        table_names: List of table names
        catalog: Catalog name
        preserved_tables: List of preserved table names to include
    
    Returns:
        TMDL string for the model
        
    Examples:
        >>> tmdl = generate_model_tmdl(['customers', 'orders'], 'gold')
        >>> 'model Model' in tmdl
        True
        >>> 'ref table customers' in tmdl
        True
        >>> 'ref table orders' in tmdl
        True
    """
    lines = [
        "model Model",
        "\tculture: en-US",
        "\tdefaultPowerBIDataSourceVersion: powerBI_V3",
        "\tdiscourageImplicitMeasures",
        "\tsourceQueryCulture: sv-SE",
        "\tdataAccessOptions",
        "\t\tlegacyRedirects",
        "\t\treturnErrorValuesAsNull",
        "",
        f'annotation PBI_QueryOrder = ["DirectLake - {catalog}"]',
        "",
        "annotation __PBI_TimeIntelligenceEnabled = 1",
        "",
        'annotation PBI_ProTooling = ["RemoteModeling","DirectLakeOnOneLakeCreatedInDesktop","DirectLakeOnOneLakeInWeb","WebModelingEdit","TMDLView_Desktop","CalcGroup"]',
        ""
    ]
    
    # Add generated tables (sorted) - these are all tables NOT in preserved_tables
    # Always quote table names for safety and consistency
    generated_tables = sorted([t for t in table_names if not preserved_tables or t not in preserved_tables])
    for table in generated_tables:
        lines.append(f"ref table '{table}'")

    # Add preserved/manual tables at the end (in their original order, not sorted)
    if preserved_tables:
        lines.append("")  # Blank line before manual section
        for table in preserved_tables:
            lines.append(f"ref table '{table}'")
    
    return "\n".join(lines) + "\n\n"

In [14]:
def generate_database_tmdl() -> str:
    """
    Generate database.tmdl content.
    
    Returns:
        TMDL string for the database
        
    Examples:
        >>> tmdl = generate_database_tmdl()
        >>> 'database' in tmdl
        True
        >>> 'compatibilityLevel: 1604' in tmdl
        True
    """
    return "database\n\tcompatibilityLevel: 1604\n\n"

In [15]:
def generate_expressions_tmdl(
    catalog: str,
    schema: str,
    direct_lake_url: str,
    preserved_annotations: Optional[str] = None
) -> str:
    """
    Generate expressions.tmdl content with DirectLake expression using OneLake.
    
    Args:
        catalog: Catalog name
        schema: Schema name
        direct_lake_url: Direct Lake URL (e.g., "https://onelake.dfs.fabric.microsoft.com/workspace-guid/lakehouse-guid")
        preserved_annotations: Optional additional annotations to preserve (e.g., PBI_RemovedChildren)
    
    Returns:
        TMDL string for expressions
        
    Examples:
        >>> tmdl = generate_expressions_tmdl('gold', 'star_schema', 'https://onelake.dfs.fabric.microsoft.com/workspace/lakehouse')
        >>> "expression 'DirectLake - gold'" in tmdl
        True
        >>> 'AzureStorage.DataLake' in tmdl
        True
    """
    lineage_tag = deterministic_uuid('expression', catalog)
    
    additional_annotations = f"\n\n{preserved_annotations}" if preserved_annotations else ""
    
    return f"""expression 'DirectLake - {catalog}' =
\t\tlet
\t\t    Källa = AzureStorage.DataLake("{direct_lake_url}", [HierarchicalNavigation=true])
\t\tin
\t\t    Källa
\tlineageTag: {lineage_tag}

\tannotation PBI_IncludeFutureArtifacts = False{additional_annotations}

"""

In [16]:
def generate_definition_pbism() -> str:
    """
    Generate definition.pbism content.
    
    Returns:
        JSON string for definition.pbism
        
    Examples:
        >>> pbism = generate_definition_pbism()
        >>> '"version": "4.2"' in pbism
        True
        >>> '"$schema"' in pbism
        True
    """
    return """{
  "$schema": "https://developer.microsoft.com/json-schemas/fabric/item/semanticModel/definitionProperties/1.0.0/schema.json",
  "version": "4.2",
  "settings": {}
}"""

In [17]:
def generate_platform_file(model_name: str) -> str:
    """
    Generate .platform file content.
    
    Args:
        model_name: Name of the semantic model
    
    Returns:
        JSON string for .platform file
        
    Examples:
        >>> platform = generate_platform_file('Test Model')
        >>> '"displayName": "Test Model"' in platform
        True
        >>> '"type": "SemanticModel"' in platform
        True
    """
    logical_id = deterministic_uuid('platform', model_name)
    return f"""{{
  "$schema": "https://developer.microsoft.com/json-schemas/fabric/gitIntegration/platformProperties/2.0.0/schema.json",
  "metadata": {{
    "type": "SemanticModel",
    "displayName": "{model_name}"
  }},
  "config": {{
    "version": "2.0",
    "logicalId": "{logical_id}"
  }}
}}"""

In [18]:
# Project root for resolving output paths
project_root = Path.cwd()

def get_output_directory(
    model_name: str,
    output_dir: Optional[str] = None,
    script_dir: Optional[Path] = None
) -> Tuple[Path, List[str], Optional[str], Optional[str]]:
    """
    Determine the output directory for the semantic model, cleaning generated files while preserving manual content.

    Uses watermarks to automatically detect which files are generated vs manually maintained.

    Args:
        model_name: Name of the semantic model
        output_dir: Optional output directory path (relative or absolute)
        script_dir: Optional script directory (defaults to current working directory)

    Returns:
        Tuple of (Path, preserved table names, manual relationships, preserved expr annotations)
    """
    if script_dir is None:
        script_dir = Path.cwd()

    if output_dir:
        if not Path(output_dir).is_absolute():
            result_path = script_dir / output_dir / f"{model_name}.SemanticModel"
        else:
            result_path = Path(output_dir) / f"{model_name}.SemanticModel"
    else:
        result_path = script_dir / "builtin" / f"{model_name}.SemanticModel"

    # If directory exists, clean generated files but preserve manual content
    preserved_table_names = []
    preserved_relationships = None
    preserved_expr_annotations = None
    watermark = "/// Generated by generate_semantic_model.py"

    if result_path.exists():
        tables_dir = result_path / "definition" / "tables"
        generated_table_names = set()

        if tables_dir.exists():
            # Auto-detect and preserve non-watermarked tables
            for table_file in tables_dir.iterdir():
                if table_file.suffix == ".tmdl":
                    content = table_file.read_text(encoding="utf-8")
                    if content.startswith(watermark):
                        # Generated file - track name and delete it
                        generated_table_names.add(table_file.stem)
                        table_file.unlink()
                    else:
                        # Manual file - preserve it
                        preserved_table_names.append(table_file.stem)

            import re

            # Extract manual relationships by parsing existing file
            rel_path = result_path / "definition" / "relationships.tmdl"
            if rel_path.exists():
                existing_content = rel_path.read_text(encoding="utf-8")

                # Parse all relationships from the file
                relationship_pattern = r'relationship\s+([\w\-]+)\s+.*?fromColumn:\s+(\w+)\..*?toColumn:\s+(\w+)\.'
                matches = re.findall(relationship_pattern, existing_content, re.DOTALL)

                manual_rels = []
                for match in matches:
                    rel_id, from_table, to_table = match
                    # If either table is NOT in generated set, this is a manual relationship
                    if from_table not in generated_table_names or to_table not in generated_table_names:
                        # Extract this relationship block
                        rel_block_pattern = rf'relationship\s+{re.escape(rel_id)}.*?(?=relationship\s+[\w\-]+|///|$)'
                        rel_match = re.search(rel_block_pattern, existing_content, re.DOTALL)
                        if rel_match:
                            manual_rels.append(rel_match.group(0).strip())

                if manual_rels:
                    preserved_relationships = "\n\n".join(manual_rels)

            # Extract additional annotations from expressions.tmdl (e.g., PBI_RemovedChildren)
            expr_path = result_path / "definition" / "expressions.tmdl"
            if expr_path.exists():
                expr_content = expr_path.read_text(encoding="utf-8")
                # Find all annotation lines after PBI_IncludeFutureArtifacts
                include_future_match = re.search(r'annotation\s+PBI_IncludeFutureArtifacts\s*=\s*\w+', expr_content)
                if include_future_match:
                    after_include_future = expr_content[include_future_match.end():]
                    # Extract remaining annotations (anything starting with whitespace+annotation)
                    additional_annotations = []
                    # Match lines that start with tabs/spaces followed by 'annotation'
                    annotation_pattern = r'([ \t]+annotation\s+\S+\s*=\s*[^\n]+)'
                    for match in re.finditer(annotation_pattern, after_include_future):
                        annotation_text = match.group(1)
                        if annotation_text.strip():
                            additional_annotations.append(annotation_text)
                    if additional_annotations:
                        preserved_expr_annotations = "\n\n".join(additional_annotations)
                        print(f"Preserved {len(additional_annotations)} expression annotation(s)")

            if preserved_table_names:
                print(f"Preserved {len(preserved_table_names)} manually maintained table(s): {', '.join(preserved_table_names)}")

    # Ensure directory structure exists
    result_path.mkdir(parents=True, exist_ok=True)
    (result_path / "definition" / "tables").mkdir(parents=True, exist_ok=True)

    # Return preserved table names, relationships, and expression annotations
    return result_path, preserved_table_names, preserved_relationships, preserved_expr_annotations


In [19]:
def save_semantic_model_files(
    output_dir: Path,
    tables_metadata: Dict[str, List[Dict[str, Any]]],
    relationships: List[Dict[str, str]],
    catalog: str,
    model_name: str,
    direct_lake_url: str,
    table_schemas: Dict[str, str],
    table_entities: Dict[str, str],
    preserved_tables: Optional[List[str]] = None,
    preserved_relationships: Optional[str] = None,
    preserved_expr_annotations: Optional[str] = None,
    assume_referential_integrity: bool = False
) -> None:
    """
    Save all generated TMDL files to disk.

    Args:
        output_dir: Output directory path
        tables_metadata: Dictionary mapping table names to column metadata
        relationships: List of relationship dictionaries
        catalog: Catalog name
        model_name: Model name
        direct_lake_url: Direct Lake URL
        table_schemas: Mapping of table_name -> schema
        table_entities: Mapping of table_name -> source entity/table name
        preserved_expr_annotations: Optional preserved expression annotations (e.g., PBI_RemovedChildren)
        assume_referential_integrity: Whether to add relyOnReferentialIntegrity to relationships (default: False)
    """
    def _display_path(path: Path) -> str:
        try:
            return path.resolve().relative_to(project_root.resolve()).as_posix()
        except Exception:
            return str(path.resolve())

    # Resolve and normalize the path to remove '..' segments
    print(f"\nGenerating semantic model files in {_display_path(output_dir)}...")

    # Create directory structure
    definition_dir = output_dir / "definition"
    tables_dir = definition_dir / "tables"
    tables_dir.mkdir(parents=True, exist_ok=True)

    # Generate and save definition.pbism
    pbism_path = output_dir / "definition.pbism"
    pbism_path.write_text(generate_definition_pbism(), encoding="utf-8")
    print(f"  Created {_display_path(pbism_path)}")

    # Generate and save .platform
    platform_path = output_dir / ".platform"
    platform_path.write_text(generate_platform_file(model_name), encoding="utf-8")
    print(f"  Created {_display_path(platform_path)}")

    # Generate and save database.tmdl
    db_path = definition_dir / "database.tmdl"
    db_path.write_text(generate_database_tmdl(), encoding="utf-8")
    print(f"  Created {_display_path(db_path)}")

    # Generate and save model.tmdl
    model_path = definition_dir / "model.tmdl"
    table_names = list(tables_metadata.keys())
    model_path.write_text(generate_model_tmdl(table_names, catalog, preserved_tables), encoding="utf-8")
    print(f"  Created {_display_path(model_path)}")

    # Generate and save expressions.tmdl
    expr_path = definition_dir / "expressions.tmdl"
    expr_path.write_text(generate_expressions_tmdl(catalog, None, direct_lake_url, preserved_expr_annotations), encoding="utf-8")
    print(f"  Created {_display_path(expr_path)}")

    # Generate and save relationships.tmdl (with preserved manual relationships)
    rel_path = definition_dir / "relationships.tmdl"
    if preserved_relationships:
        print(f"  Preserving {len(preserved_relationships.splitlines())} line(s) of manual relationships")

    rel_path.write_text(generate_relationships_tmdl(relationships, preserved_relationships, assume_referential_integrity), encoding="utf-8")
    print(f"  Created {_display_path(rel_path)}")

    # Generate and save table TMDL files
    total_measures = 0
    for table_name, columns in tables_metadata.items():
        table_path = tables_dir / f"{table_name}.tmdl"
        schema = table_schemas.get(table_name)
        entity_name = table_entities.get(table_name, table_name)
        tmdl_content = generate_table_tmdl(table_name, columns, schema, catalog, entity_name=entity_name)
        table_path.write_text(tmdl_content, encoding="utf-8")
        # Count measures in this table
        measure_count = sum(1 for col in columns if col['COLUMN_NAME'].startswith('measure__'))
        total_measures += measure_count

    print("\n✓ Semantic model generated successfully!")
    print(f"  Total tables: {len(tables_metadata)}")
    print(f"  Total relationships: {len(relationships)}")
    print(f"  Total measures: {total_measures}")


In [20]:
def build_tables_metadata_from_df(
    df,
    catalog: str,
    schemas
) -> Tuple[Dict[str, List[Dict[str, Any]]], Dict[str, str], Dict[str, str]]:
    """
    Build table metadata from INFORMATION_SCHEMA.COLUMNS DataFrame.

    Returns:
        tables_metadata: table_id -> list of column dicts
        table_schemas: table_id -> schema name
        table_entities: table_id -> source table name
    """
    if isinstance(schemas, str):
        schemas = [schemas]

    schemas_lower = [s.lower() for s in schemas]

    # Filter to target catalog/schemas
    df_filtered = df[
        (df["TABLE_CATALOG"].str.lower() == catalog.lower()) &
        (df["TABLE_SCHEMA"].str.lower().isin(schemas_lower))
    ].copy()

    # Determine if table name collisions exist across schemas
    name_counts = df_filtered.groupby("TABLE_NAME")["TABLE_SCHEMA"].nunique()
    has_collisions = any(name_counts > 1)
    multi_schema = len(schemas) > 1

    # Sort by schema/table + ordinal position to preserve column order
    df_filtered.sort_values(
        ["TABLE_SCHEMA", "TABLE_NAME", "ORDINAL_POSITION"],
        inplace=True
    )

    tables_metadata: Dict[str, List[Dict[str, Any]]] = {}
    table_schemas: Dict[str, str] = {}
    table_entities: Dict[str, str] = {}

    for _, row in df_filtered.iterrows():
        schema = row["TABLE_SCHEMA"]
        table_name = row["TABLE_NAME"]
        if multi_schema or has_collisions:
            table_id = f"{schema}.{table_name}"
        else:
            table_id = table_name

        tables_metadata.setdefault(table_id, []).append({
            "COLUMN_NAME": row["COLUMN_NAME"],
            "DATA_TYPE": row["DATA_TYPE"],
            "IS_NULLABLE": row.get("IS_NULLABLE", "YES"),
            "CHARACTER_MAXIMUM_LENGTH": row.get("CHARACTER_MAXIMUM_LENGTH"),
            "NUMERIC_PRECISION": row.get("NUMERIC_PRECISION"),
            "NUMERIC_SCALE": row.get("NUMERIC_SCALE"),
            "ORDINAL_POSITION": row.get("ORDINAL_POSITION"),
        })
        table_schemas[table_id] = schema
        table_entities[table_id] = table_name

    return tables_metadata, table_schemas, table_entities


In [21]:
def classify_tables_by_keys(tables_metadata: Dict[str, List[Dict[str, Any]]], key_prefixes) -> Tuple[List[str], List[str], List[str]]:
    """Classify tables into dims, facts, and others based on key column count."""
    if key_prefixes is None:
        key_prefixes = ['_hk__', '_wk__']
    elif isinstance(key_prefixes, str):
        key_prefixes = [key_prefixes]

    dims = []
    facts = []
    others = []

    for table_name, columns in tables_metadata.items():
        key_cols = [
            col for col in columns
            if any(col['COLUMN_NAME'].startswith(prefix) for prefix in key_prefixes)
        ]

        if len(key_cols) == 1:
            dims.append(table_name)
        elif len(key_cols) > 1:
            facts.append(table_name)
        else:
            others.append(table_name)

    return sorted(dims), sorted(facts), sorted(others)


def generate_diagram_layout_stub(
    tables_metadata: Dict[str, List[Dict[str, Any]]],
    key_prefixes,
    table_width: int = 220,
    table_height: int = 140,
    x_gap: int = 40,
    y_gap: int = 40
) -> Dict[str, Any]:
    """Generate a simple diagram layout JSON with dims horizontal and facts vertical."""
    dims, facts, others = classify_tables_by_keys(tables_metadata, key_prefixes)

    layout = {
        "version": 1,
        "tables": []
    }

    # Place dimensions horizontally (row)
    x = 0
    y = 0
    for name in dims:
        layout["tables"].append({
            "name": name,
            "x": x,
            "y": y,
            "width": table_width,
            "height": table_height
        })
        x += table_width + x_gap

    # Place facts vertically (column)
    x = max(x, table_width + x_gap) + (table_width + x_gap)
    y = 0
    for name in facts:
        layout["tables"].append({
            "name": name,
            "x": x,
            "y": y,
            "width": table_width,
            "height": table_height
        })
        y += table_height + y_gap

    # Place others in a second row under dims
    x = 0
    y = max(y, table_height + y_gap) + (table_height + y_gap)
    for name in others:
        layout["tables"].append({
            "name": name,
            "x": x,
            "y": y,
            "width": table_width,
            "height": table_height
        })
        x += table_width + x_gap

    return layout


def write_diagram_layout_json(
    output_dir: Path,
    tables_metadata: Dict[str, List[Dict[str, Any]]],
    key_prefixes
) -> Path:
    """Write a best-effort diagramLayout.json next to the semantic model definition."""
    layout = generate_diagram_layout_stub(tables_metadata, key_prefixes)
    path = output_dir / "diagramLayout.json"
    path.write_text(json.dumps(layout, indent=2), encoding="utf-8")
    return path


def build_definition_payload(model_dir: Path) -> Dict[str, Any]:
    """Build Fabric REST API definition payload from a SemanticModel folder."""
    parts = []
    for path in sorted(model_dir.rglob('*')):
        if not path.is_file():
            continue
        if path.name == "definition_payload.json":
            continue
        rel_path = path.relative_to(model_dir).as_posix()
        payload = base64.b64encode(path.read_bytes()).decode('utf-8')
        parts.append({
            "path": rel_path,
            "payload": payload,
            "payloadType": "InlineBase64"
        })

    return {
        "parts": parts
    }


In [22]:
def create_semantic_model_in_fabric(
    workspace_id: str,
    display_name: str,
    definition: Dict[str, Any],
    description: Optional[str] = None,
    folder_id: Optional[str] = None,
    poll_interval_seconds: int = 5
) -> Dict[str, Any]:
    """Create a semantic model via Fabric REST API and return the result."""
    token = notebookutils.credentials.getToken("https://api.fabric.microsoft.com")
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json"
    }

    payload = {
        "displayName": display_name,
        "definition": definition
    }
    if description:
        payload["description"] = description
    if folder_id:
        payload["folderId"] = folder_id

    url = f"https://api.fabric.microsoft.com/v1/workspaces/{workspace_id}/semanticModels"
    response = requests.post(url, headers=headers, json=payload)

    if response.status_code == 201:
        return response.json()

    if response.status_code == 202:
        location = response.headers.get("Location")
        if not location:
            raise RuntimeError("202 Accepted without Location header")

        # Poll LRO
        while True:
            time.sleep(poll_interval_seconds)
            lro = requests.get(location, headers=headers)
            if lro.status_code >= 400:
                raise RuntimeError(f"LRO failed: {lro.status_code} {lro.text}")
            data = lro.json() if lro.content else {}
            status = data.get("status") or data.get("provisioningState")
            if status in ("Succeeded", "Failed", "Canceled", "Cancelled"):
                return data

    # Other errors
    raise RuntimeError(f"Create failed: {response.status_code} {response.text}")


In [23]:
def generate_semantic_model(
    metadata_df,
    catalog: str,
    schemas,
    model_name: str,
    direct_lake_url: str,
    output_dir: Optional[str] = None,
    key_prefixes: Union[str, List[str]] = None,
    exact_match_prefixes: Union[str, List[str]] = None,
    assume_referential_integrity: bool = False
) -> None:
    """
    Main workflow function to generate a complete semantic model.

    Automatically preserves manually maintained content (non-watermarked files).

    Relationship detection:
    - Dimension: Tables with exactly 1 key column (their primary key)
    - Fact: Tables with >1 key columns (their PK + foreign keys to dimensions)
    - Creates 1:* relationships from dimension to fact

    Args:
        metadata_df: DataFrame holding INFORMATION_SCHEMA.COLUMNS
        catalog: Database catalog name (e.g., "gold")
        schemas: Schema name or list of schemas (e.g., "Dim" or ["Dim", "Fact"])
        model_name: Name for the semantic model
        direct_lake_url: Direct Lake URL for OneLake connection
        output_dir: Optional output directory path
        key_prefixes: Column prefix(es) to identify relationship keys. Can be a string or list of strings (default: ['_hk__', '_wk__'])
        exact_match_prefixes: Column prefix(es) that should use exact matching only, not role-playing pattern (default: ['_wk__ref__'])
        assume_referential_integrity: Whether to add relyOnReferentialIntegrity to relationships (default: False)
    """
    if key_prefixes is None:
        key_prefixes = ['_hk__', '_wk__']
    elif isinstance(key_prefixes, str):
        key_prefixes = [key_prefixes]

    if exact_match_prefixes is None:
        exact_match_prefixes = ['_wk__ref__']
    elif isinstance(exact_match_prefixes, str):
        exact_match_prefixes = [exact_match_prefixes]

    # Get output directory (auto-preserves non-watermarked content)
    script_dir = Path.cwd()
    output_path, preserved_table_names, preserved_relationships, preserved_expr_annotations = get_output_directory(model_name, output_dir, script_dir)

    # Build column metadata from information schema dataframe
    print(f"\nBuilding tables from {catalog} information schema...")
    tables_metadata, table_schemas, table_entities = build_tables_metadata_from_df(metadata_df, catalog, schemas)
    tables = list(tables_metadata.keys())
    print(f"Found {len(tables)} tables from information schema")

    # Identify relationships (1 key col = dim, >1 = fact)
    prefixes_str = ', '.join(f"'{p}'" for p in key_prefixes)
    print(f"\nIdentifying relationships (key prefixes: {prefixes_str})...")
    relationships, tables_without_relationships = identify_relationships(
        tables_metadata,
        key_prefixes=key_prefixes,
        exact_match_prefixes=exact_match_prefixes
    )
    print(f"Found {len(relationships)} relationships")

    # Print tables with no relationships
    if tables_without_relationships:
        print(f"\nWarning: {len(tables_without_relationships)} table(s) have no relationships detected:")
        for table in tables_without_relationships:
            print(f"  - {table}")
    else:
        print("\nAll tables have relationships.")

    # Save all files
    # Write diagram layout (best-effort)
    write_diagram_layout_json(output_path, tables_metadata, key_prefixes)
    save_semantic_model_files(
        output_path,
        tables_metadata,
        relationships,
        catalog,
        model_name,
        direct_lake_url,
        table_schemas,
        table_entities,
        preserved_table_names,
        preserved_relationships,
        preserved_expr_annotations,
        assume_referential_integrity
    )

    return output_path


## Execute

In [24]:
import sempy.fabric as sf

workspace_id = "e0573fbd-c1f4-4993-afa3-320620c17110"
warehouse_id = "9a3d14f5-8ce7-49e4-b79d-3ab8b834e16e"
model_name = "One Model To Rule Them All (Chris)"
direct_lake_url = f"https://onelake.dfs.fabric.microsoft.com/{workspace_id}/{warehouse_id}"

In [25]:
# Connect to warehouse
sql_endpoint = "5dscrplirguurh56sz3vy633zu-xu7vpyhuygjutl5dgidcbqlrca.datawarehouse.fabric.microsoft.com"
database = "WH_Gold"
cursor = connect_to_warehouse(sql_endpoint, database)


In [26]:
# Load information schema
metadata = load_information_schema(cursor, schemas=["Dim", "Fact"])


In [28]:
# Run generator
model_dir = generate_semantic_model(
    metadata,
    catalog=database,
    schemas=["Dim", "Fact"],
    model_name=model_name,
    key_prefixes="ID_",
    direct_lake_url=direct_lake_url,
    output_dir="./builtin"
)



Building tables from WH_Gold information schema...
Found 45 tables from information schema

Identifying relationships (key prefixes: 'ID_')...
Found 72 relationships

  - Dim.AttributeTable
  - Dim.BaseCurrency
  - Dim.BusinessUnit
  - Dim.CurrencyExRateLatest
  - Dim.Date
  - Dim.Fardigvara
  - Dim.FinishedGoods
  - Dim.PL_Link
  - Dim.ProdCodeForInvPart
  - Dim.TechnicalSpecificationGroup
  - Dim.YearMonth
  - Dim.vBaseCurrency
  - Dim.vProdCodeForInvPart
  - Fact.ModuleStatusInterval
  - Fact.StockBalance

Generating semantic model files in /synfs/resource/nb_resource/builtin/One Model To Rule Them All (Chris).SemanticModel...
  Created /synfs/resource/nb_resource/builtin/One Model To Rule Them All (Chris).SemanticModel/definition.pbism
  Created /synfs/resource/nb_resource/builtin/One Model To Rule Them All (Chris).SemanticModel/.platform
  Created /synfs/resource/nb_resource/builtin/One Model To Rule Them All (Chris).SemanticModel/definition/database.tmdl
  Created /synfs/resourc

In [29]:
# Build REST API definition payload
definition_payload = build_definition_payload(model_dir)
payload_path = model_dir / "definition_payload.json"
payload_path.write_text(json.dumps(definition_payload, indent=2), encoding="utf-8")
print(f"Wrote {payload_path}")


Wrote /home/trusted-service-user/work/builtin/One Model To Rule Them All (Chris).SemanticModel/definition_payload.json


In [30]:
# Create semantic model in Fabric
description = "Generated from INFORMATION_SCHEMA"

result = create_semantic_model_in_fabric(
    workspace_id=workspace_id,
    display_name=model_name,
    description=description,
    definition=definition_payload
)
print(result)


{'status': 'Succeeded', 'createdTimeUtc': '2026-02-06T11:05:56.1997074', 'lastUpdatedTimeUtc': '2026-02-06T11:06:09.5313677', 'percentComplete': 100, 'error': None}
