In [None]:
from typing import Dict, List
from tqdm.notebook import tqdm

class DDLGenerator:
    def _generate_hash_sql(self, hash_config: dict) -> str:

        name = hash_config["name"]
        columns = hash_config["columns"]

        safe_columns = [
            f"COALESCE(NULLIF(UPPER(TRIM(CAST({col} AS STRING))), ''), '^^NULL^^')"  # coalesce to some kind of string that is unlikely to appear, formatting based on AutomateDV
            for col in columns
        ]

        cols_concat = ", '|', ".join(safe_columns)
    
        return f"SHA1(CONCAT({cols_concat})) AS {name}"

    def _generate_staging_DDL(self, staging_map: Dict[str, Stage]) -> List[str]:

        generated_ddl_str: Dict[str, str] = {}

        # Build override map from entities
        source_datetime_overrides = {}
        for hub in self._hubs:
            if hub.load_datetime_column:
                source_datetime_overrides[hub.source_table] = hub.load_datetime_column
        for sat in self._sats:
            if sat.load_datetime_column:
                source_datetime_overrides[sat.source_table] = sat.load_datetime_column

        for stage in tqdm(staging_map.values(), desc="Generating staging DDL string", unit="stage", total=len(staging_map)):
            hk_sql_string = [
                self._generate_hash_sql({"name": name, "columns": cols}) 
                for name, cols in stage.hash_keys.items()
            ]

            # safe_source_columns = self._get_staging_passthrough_columns(stage)

            # Resolve load_datetime column
            load_dt_col = self._resolve_load_datetime_column(
                stage.source_table,
                source_datetime_overrides.get(stage.source_table)
            )

            select_columns = (
                # safe_source_columns +
                ["*"] +
                hk_sql_string +
                [f"{load_dt_col} AS load_datetime"] +
                [f"'{stage.source_table}' AS record_source"]
            )

            columns_sql = ",\n\t\t\t\t".join(select_columns)

            ddl_string = f"""
            CREATE MATERIALIZED LAKE VIEW IF NOT EXISTS `{stage.schema_name}`.`{stage.name}`
            TBLPROPERTIES (delta.enableChangeDataFeed = true)
            AS
            SELECT 
                {columns_sql}
            FROM {stage.source_table}
            """.strip()

            full_name = f"{stage.schema_name}.{stage.name}"
            generated_ddl_str[full_name] = ddl_string

        generated_ddl_str.update(self._generate_link_staging_DDL())

        return generated_ddl_str

    def _generate_vault_entity_DDL(self) -> List[str]:
        
        vault_ddls: Dict[str, str] = {}

         # for stage in tqdm(staging_map.values(), desc="Generating staging DDL string", unit="stage", total=len(staging_map)):
        for hub in tqdm(self._hubs, desc="Generating hub DDL string", unit="tables", total=len(self._hubs)):
            ddl_string = f"""
            CREATE MATERIALIZED LAKE VIEW IF NOT EXISTS `{hub.schema_name}`.`{hub.name}`
            TBLPROPERTIES (delta.enableChangeDataFeed = true)
            AS
            SELECT
                {', '.join(hub._stage.columns)},
                MIN(NAMED_STRUCT('ldts', load_datetime, 'rs', '{hub._stage.schema}.{hub._stage.table}')).ldts AS load_datetime,
                MIN(NAMED_STRUCT('ldts', load_datetime, 'rs', '{hub._stage.schema}.{hub._stage.table}')).rs AS record_source
            FROM {hub._stage.schema}.{hub._stage.table}
            WHERE load_datetime IS NOT NULL
                AND record_source IS NOT NULL
            GROUP BY
                {', '.join(hub._stage.columns)}
            """.strip()

            # print(ddl_string+"\n")
            # vault_ddls.append(ddl_string)
            full_name = f"{hub.schema_name}.{hub.name}"
            vault_ddls[full_name] = ddl_string

        for link in tqdm(self._links, desc="Generating link DDL string", unit="tables", total=len(self._links)):
            if link._stage is None:
                continue

            # Collect hub names: anchor hub (if exists) + all join hubs
            hub_names = []
            if link.anchor.hub:
                hub_names.append(link.anchor.hub)
            hub_names.extend(join.hub for join in link.hub_mapping)
            
            virtual_record_source = ';'.join(hub_names)
            
            link_hk_name = f"hk_{link.name}"
            hash_sql_string = self._generate_hash_sql({"name": link_hk_name, "columns": link._stage.columns})

            ddl_string = f"""
            CREATE MATERIALIZED LAKE VIEW IF NOT EXISTS `{link.schema_name}`.`{link.name}`
            TBLPROPERTIES (delta.enableChangeDataFeed = true)
            AS
            SELECT
                {hash_sql_string},
                {', '.join(link._stage.columns)},
                MIN(NAMED_STRUCT('ldts', load_datetime, 'rs', '{link._stage.schema}.{link._stage.table}')).ldts AS load_datetime,
                MIN(NAMED_STRUCT('ldts', load_datetime, 'rs', '{link._stage.schema}.{link._stage.table}')).rs AS record_source,
                '{virtual_record_source}' AS virtual_record_source
            FROM {link._stage.schema}.{link._stage.table}
            WHERE load_datetime IS NOT NULL
                AND record_source IS NOT NULL
            GROUP BY
                {link_hk_name},
                {', '.join(link._stage.columns)}
            """.strip()

            # print(ddl_string+"\n")
            # vault_ddls.append(ddl_string)

            full_name = f"{link.schema_name}.{link.name}"
            vault_ddls[full_name] = ddl_string

        for sat in tqdm(self._sats, desc="Generating satellite DDL string", unit="tables", total=len(self._sats)):
            if sat._stage is None or sat._resolved_columns is None:
                self._add_issue(
                    WarningSeverity.ERROR,
                    "satellite",
                    sat.name,
                    "Satellite not properly processed. Missing stage or resolved columns."
                )
                continue
            
            hash_diff_name = f"hd_{sat.name.replace('sat_', '')}"
            parent = self._get_all_parents().get(sat.parent_hub_or_link)

            descriptive_cols = sat.resolved_columns

            hash_diff_sql = self._generate_hash_sql({
                "name": hash_diff_name,
                "columns": descriptive_cols
            })

            pk_name = f"pk_{sat.name.replace('sat_', '')}"
            
            pk_sql = self._generate_hash_sql({
                "name": pk_name,
                "columns": [parent.hash_key_name, "load_datetime"]
            })

            select_parts = [
                pk_sql,
                parent.hash_key_name,
                hash_diff_sql,
                *descriptive_cols
            ]

            '''ddl_string = f"""
            CREATE MATERIALIZED LAKE VIEW IF NOT EXISTS `{sat.schema_name}`.`{sat.name}`
            TBLPROPERTIES (delta.enableChangeDataFeed = true)
            AS
            SELECT

                {', '.join(select_parts)},
                load_datetime,
                '{sat._stage.schema}.{sat._stage.table}' AS record_source,
                '{sat.parent_hub_or_link}' AS virtual_record_source
            FROM {sat._stage.schema}.{sat._stage.table}
            """.strip()'''

            incremental_ver = True

            if incremental_ver:
                safe_columns = [
                    f"COALESCE(NULLIF(UPPER(TRIM(CAST({col} AS STRING))), ''), '^^NULL^^')"  # coalesce to some kind of string that is unlikely to appear, formatting based on AutomateDV
                    for col in [parent.hash_key_name, "load_datetime"]
                ]

                cols_concat = ", '|', ".join(safe_columns)
            
                pk = f"SHA1(CONCAT({cols_concat}))"

                payload_sql = ",\n                ".join([
                    f"MIN(NAMED_STRUCT('ldts', load_datetime, 'val', {col})).val AS {col}" 
                    for col in descriptive_cols
                ])

                ddl_string = f"""
                CREATE MATERIALIZED LAKE VIEW `{sat.schema_name}`.`{sat.name}`
                AS
                SELECT
                    MIN(NAMED_STRUCT('ldts', load_datetime, 'pk', {pk})).pk AS {pk_name},
                    {parent.hash_key_name},
                    {hash_diff_sql},
                    MIN(NAMED_STRUCT('ldts', load_datetime, 'rs', '{sat._stage.schema}.{sat._stage.table}')).ldts AS load_datetime,
                    MIN(NAMED_STRUCT('ldts', load_datetime, 'rs', '{sat._stage.schema}.{sat._stage.table}')).rs AS record_source,
                    {payload_sql}
                FROM {sat._stage.schema}.{sat._stage.table}
                GROUP BY 
                    {parent.hash_key_name}, 
                    {hash_diff_name}
                """.strip()
                        
            else:
                ddl_string = f"""
                CREATE MATERIALIZED LAKE VIEW IF NOT EXISTS `{sat.schema_name}`.`{sat.name}`
                TBLPROPERTIES (delta.enableChangeDataFeed = true)
                AS
                SELECT
                    {', '.join(select_parts)},
                    load_datetime,
                    '{sat._stage.schema}.{sat._stage.table}' AS record_source,
                    '{sat.parent_hub_or_link}' AS virtual_record_source
                
                FROM (
                    SELECT
                        *, 
                        LAG({hash_diff_name}) OVER (
                            PARTITION BY {parent.hash_key_name}
                            ORDER BY load_datetime ASC
                        ) AS prev_hash_diff,
                        ROW_NUMBER() OVER (
                            PARTITION BY {parent.hash_key_name}, load_datetime
                            ORDER BY record_source ASC
                        ) AS rn_same_timestamp
                    FROM {sat._stage.schema}.{sat._stage.table}
                    WHERE load_datetime IS NOT NULL
                    AND record_source IS NOT NULL
                )
                WHERE ({hash_diff_name} != prev_hash_diff OR prev_hash_diff IS NULL)
                AND rn_same_timestamp = 1
                """.strip() 

            # print(ddl_string+"\n")
            # vault_ddls.append(ddl_string)

            full_name = f"{sat.schema_name}.{sat.name}"
            vault_ddls[full_name] = ddl_string

        return vault_ddls

    def _generate_link_staging_DDL(self) -> List[str]:
        ddls: Dict[str, str] = {}
        all_hubs = self._get_all_hubs()
        
        for link in tqdm(self._links, desc="Generating link staging DDL string", unit="stage", total=len(self._links)):
            if link._stage is None:
                continue
            
            # Build aliases, hash columns, and joins
            hash_columns = []
            join_clauses = []
            
            anchor = link.anchor
            anchor_alias = "t0"

            load_dt_col = self._resolve_load_datetime_column(
                anchor.table,          # Still resolve against anchor table
                link.load_datetime_column  # âœ… Config lives on Link
            )

            # Handle anchor's hub contribution (if any)
            if anchor.hub:
                hub = all_hubs.get(anchor.hub)
                if hub:
                    hash_sql = self._generate_hash_sql({
                        "name": hub.hash_key_name,
                        "columns": [f"{anchor_alias}.{col}" for col in anchor.bk_columns]
                    })
                    hash_columns.append(hash_sql)
            
            for idx, ref in enumerate(link.hub_mapping):
                hub = all_hubs.get(ref.hub)
                if hub is None:
                    continue
                
                alias = f"t{idx + 1}"  # Start from t1 since t0 is anchor
                
                # Generate hash key SQL
                hash_sql = self._generate_hash_sql({
                    "name": hub.hash_key_name,
                    "columns": [f"{alias}.{col}" for col in ref.bk_columns]
                })
                hash_columns.append(hash_sql)
                
                # Build join clause
                join_conditions = " AND ".join([
                    f"{anchor_alias}.{anchor_col} = {alias}.{join_col}"
                    for anchor_col, join_col in ref.join_on.items()
                ])
                
                join_clauses.append(
                    f"LEFT JOIN {ref.table} {alias} ON {join_conditions}"
                )
            
            # Build FROM clause
            from_clause = f"{anchor.table} {anchor_alias}"
            
            # Build record source
            record_sources = [anchor.table.split(".")[-1]]
            record_sources.extend([j.table.split(".")[-1] for j in link.hub_mapping])
            record_source = ";".join(record_sources)
            
            # Assemble SELECT columns
            select_columns = (
                hash_columns +
                [f"{anchor_alias}.{load_dt_col} AS load_datetime"] +
                [f"'{record_source}' AS record_source"]
            )
            
            columns_sql = ",\n                ".join(select_columns)
            joins_sql = "\n            ".join(join_clauses)
            
            ddl_string = f"""
                CREATE MATERIALIZED LAKE VIEW IF NOT EXISTS `{link._stage.schema}`.`{link._stage.table}`
                TBLPROPERTIES (delta.enableChangeDataFeed = true)
                AS
                SELECT
                    {columns_sql}
                FROM {from_clause}
                {joins_sql}
            """.strip()
            
            full_name = f"{link._stage.schema}.{link._stage.table}"
            ddls[full_name] = ddl_string
        
        return ddls
    