# Data Vault 2.0 for Microsoft Fabric

This project provides an automated tool for building and managing Data Vault 2.0 architecture within Microsoft Fabric, abstracting a lot of work into writing SparkSQL queries into configurations (YAML or code), utilizing Materialized Lake Views via PySpark. 


## Core Architecture
1. **Staging Layer**: Raw Data ingestion with added metadata where Raw vault would directly pull data from (Load Date, Record Source, Pre-calculated Hash Keys with SHA1).
    - Standard Stage: Processes tables needed by Hubs and Satellites
    - Special Link Stage: A dedicated staging for link, pre-calculating hashes and joins, ensuring any future satellites connected to it won't need to do the same

2. **Raw Vault**: Contains the basic vault entities
    - Hubs: Unique business keys.
    - Links: Relationships between business keys.
    - Satellites: Descriptive state over time.

## Features
1. **SQL-Centric Automation**: We use PySpark‚Äôs ```sql()``` function as a powerful "Code Generator." By feeding metadata into SparkSQL, the engine automatically builds and manages your Data Vault entities and Materialized Lake Views (MLVs) without manual coding.

2. **Scalable Templating**: Scaling your architecture is as simple as updating a configuration. Whether you prefer using a YAML file or a Python dictionary, the "Template-First" approach ensures that adding 100 new Hubs or Links follows the exact same standardized logic.

3. **Simplicity & Visibility**: Using Materialized Lake Views provides a "setup-once" architecture that is both performant and transparent. Because MLVs are native to Fabric, they appear automatically in the Lineage View, allowing you to visually trace data flow from your Stage tables all the way to your final consumer views.

4. **Configuration-as-Code**: Your entire Data Vault design is portable. You can export your definitions into a YAML file to backup your architecture or import them into a new Fabric Workspace, making it easy to sync environments (Dev/Test/Prod) in seconds.

### Example Usage
```
dv = DataVaultManager("ingestion_time") # Get the default column for ingestion timestamp for downstream timestamping

dv.add_hub(Hub(
    name="hub_quotes",
    schema_name="silver",
    business_key_columns=["id"],
    source_table="cc_fabric_lakehouse_new.bronze.hubspot_quotes"
)).add_hub(Hub(
    name="hub_deals",
    schema_name="silver",
    business_key_columns=["deals"],
    load_datetime_column="load_ts"  # override the default timestamp column
    source_table="cc_fabric_lakehouse_new.bronze.deals"
)).add_link(Link(
    name="link_quote_deal",
    schema_name="silver",
    staging_schema="staging",
    anchor=LinkAnchor(
        table="cc_fabric_lakehouse_new.bronze.hubspot_quotes",  # Anchor table, the table after a FROM sql statement
        hub="hub_quotes",
        bk_columns=["id"]
    ),
    hub_mapping=[
        LinkHubJoin(
            hub="hub_deals",
            table="cc_fabric_lakehouse_new.bronze.hubspot_deals", # Join table, the table after a JOIN sql statement
            bk_columns=["id"],
            join_on={"value": "id"} # FK/PK of the Anchor table to FK/PK of the Joining table
        )
    ]
)).add_satellite(Satellite(
    name="sat_quotes",
    schema_name="silver",
    parent_hub_or_link="hub_quotes",
    descriptive_columns=['id', 'ingestion_time'],
    include_mode=False, # excludes the columns in the descriptive_columns parameter as the descriptive column of the satellite
    source_table="cc_fabric_lakehouse_new.bronze.hubspot_quotes_v1"
))

"""
execute: run the config or not
verbose: print the SQL strings generated
force: force execute even with warnings
"""

dv.construct_vault(execute=True, verbose=True, force=True)
dv.export_config('Files/vault_config.yaml') # export the Data Vault config to yaml

# import yaml config file and run
dv = DataVaultManager.from_config("Files/vault_config.yaml")
dv.construct_vault(execute=True, verbose=True, force=True)
```


### Internal Notebooks
- **dv_core**: Dataclasses and Enums
- **dv_ddl_generator**: DDL (SQL String) Generation for Stage and Vault Entities
- **dv_validator**: Config Parsing validations
- **dv_utils**: Import-Export util, etc...


In [1]:
%run dv_core

StatementMeta(, , -1, Finished, , Finished)

InvalidHttpRequestToLivy: [TooManyRequestsForCapacity] This spark job can't be run because you have hit a spark compute or API rate limit. To run this spark job, cancel an active Spark job through the Monitoring hub, choose a larger capacity SKU, or try again later. HTTP status code: 430 {Learn more} HTTP status code: 430.

In [2]:
%run dv_validator

StatementMeta(, 1fe321f3-eeee-4571-9a40-0a27ef35c019, 4, Finished, Available, Finished)

In [3]:
%run dv_ddl_generator

StatementMeta(, 1fe321f3-eeee-4571-9a40-0a27ef35c019, 5, Finished, Available, Finished)

In [4]:
%run dv_utils

StatementMeta(, 1fe321f3-eeee-4571-9a40-0a27ef35c019, 6, Finished, Available, Finished)

In [5]:
from tqdm.notebook import tqdm
from typing import Dict, List, Optional, Union


class DataVaultManager(DDLGenerator, DVUtils, Validator):
    def __init__(self, default_load_datetime_column: str):
        self._stage = []
        self._hubs = []
        self._links = []
        self._sats = []
        self._registered_hubs: Dict[str, RegisteredHub] = {}
        self._registered_links: Dict[str, RegisteredLink] = {}
        self._generated_queries: List[ParametrizedQuery] = []
        self._validation_issues: List[ValidationIssue] = []
        self._has_errors = False

        self._dedupe = True
        self._default_load_datetime_column = default_load_datetime_column
        self._source_load_datetime_map: Dict[str, str] = {}
        self._source_columns_cache: Dict[str, List[str]] = {}

        self.validator = SQLValidator()

    # PUBLIC API
    def add_hub(self, hub: Hub) -> "DataVaultManager":
        hub.source_table = self._normalize_source_table(hub.source_table)

        if self._check_vault_entity_duplicate(hub.name, self._hubs, "hub"):
            return self

        self._validate_identifier(hub.name, allow_dots=False)
        self._validate_identifier(hub.schema_name, allow_dots=False)
        self._validate_identifier(hub.source_table, allow_dots=True)  # Allow schema.table
        self._validate_batch(hub.business_key_columns)  # Validate all BK columns

        self._hubs.append(hub)
        return self

    def add_link(self, link: Link) -> "DataVaultManager":
        if self._check_vault_entity_duplicate(link.name, self._links, "link"):
            return self

        if link.name in self._registered_links:
            self._add_issue(
                WarningSeverity.ERROR,
                "data_vault_rule",
                link.name,
                f"Cannot modify existing link '{link.name}'. Links are immutable once created.",
                {
                    "rule": "Data Vault: Links cannot have hubs added after creation",
                    "suggestion": "Create a new link with a different name that includes all required hubs"
                }
            )
            return self

        self._validate_identifier(link.name, allow_dots=False)
        self._validate_identifier(link.schema_name, allow_dots=False)
        self._validate_identifier(link.anchor.table, allow_dots=True)

        all_hubs = self._get_all_hubs()
        referenced_hubs = []

        if link.source_columns:
            self._validate_batch(link.source_columns)

        if link.anchor.hub:
            self._validate_identifier(link.anchor.hub, allow_dots=False)
            referenced_hubs.append(link.anchor.hub)
            
            if link.anchor.bk_columns:
                self._validate_batch(link.anchor.bk_columns)
            
            if link.anchor.hub not in all_hubs:
                self._add_issue(
                    WarningSeverity.ERROR,
                    "link",
                    link.name,
                    f"Anchor references non-existent hub '{link.anchor.hub}'",
                    {"available_hubs": sorted(all_hubs.keys())}
                )

        for join in link.hub_mapping:
            self._validate_identifier(join.hub, allow_dots=False)
            self._validate_identifier(join.table, allow_dots=True)
            self._validate_batch(join.bk_columns)
            referenced_hubs.append(join.hub)

            if join.hub not in all_hubs:
                self._add_issue(
                    WarningSeverity.ERROR,
                    "link",
                    link.name,
                    f"hub_joins references non-existent hub '{join.hub}'",
                    {"available_hubs": sorted(all_hubs.keys())}
                )

        # Validate minimum hub count
        if len(referenced_hubs) < 2:
            self._add_issue(
                WarningSeverity.ERROR,
                "link",
                link.name,
                f"At least 2 hubs required, found {len(referenced_hubs)}: {referenced_hubs}"
            )

        self._links.append(link)
        return self

    def add_satellite(self, sat: Satellite) -> "DataVaultManager":
        sat.source_table = self._normalize_source_table(sat.source_table)

        if self._check_vault_entity_duplicate(sat.name, self._sats, "satellite"):
            return self

        self._validate_identifier(sat.name, allow_dots=False)
        self._validate_identifier(sat.schema_name, allow_dots=False)
        self._validate_identifier(sat.parent_hub_or_link, allow_dots=False)
        self._validate_identifier(sat.source_table, allow_dots=True)
        self._validate_batch(sat.descriptive_columns)

        self._validate_satellite_parent(sat)
        self._sats.append(sat)
        return self

    def register_hub(
        self, 
        name: str, 
        schema_name: str = "silver", 
        business_key_columns: List[str] = None
    ) -> "DataVaultManager":
        """Register an existing hub for reference by new satellites/links."""
        
        full_name = f"{schema_name}.{name}"
        
        if not self._check_table_existence(schema_name, name):
            self._add_issue(
                WarningSeverity.ERROR,
                "register",
                name,
                f"Hub '{full_name}' does not exist in catalog"
            )
            return self
        
        if business_key_columns is None:
            self._add_issue(
                WarningSeverity.ERROR,
                "register",
                name,
                f"Provide business key or primary keys explicitly for '{full_name}'. "
            )
            return self
        
        self._registered_hubs[name] = RegisteredHub(name, schema_name, business_key_columns)
        
        self._add_issue(
            WarningSeverity.INFO,
            "register",
            name,
            f"Registered existing hub '{full_name}' with keys: {business_key_columns}"
        )
        
        return self

    def register_link(self, name: str, schema_name: str = "silver") -> "DataVaultManager":
        """Register an existing link for reference by new satellites."""
        
        full_name = f"{schema_name}.{name}"
        
        if not self._check_table_existence(schema_name, name):
            self._add_issue(
                WarningSeverity.ERROR,
                "register",
                name,
                f"Link '{full_name}' does not exist in catalog"
            )
            return self
        
        self._registered_links[name] = RegisteredLink(name, schema_name)
        
        self._add_issue(
            WarningSeverity.INFO,
            "register",
            name,
            f"Registered existing link '{full_name}'"
        )
        
        return self

    def construct_vault(self, execute: bool=False, force: bool=False, verbose=False):
        
        staging_map = self._build_staging_map(self._get_unique_tables())

        self._process_entities(staging_map)
        
        self._validate_all_columns()
        self._detect_duplicate_satellites()

        all_ddls = self._generate_staging_DDL(staging_map) | self._generate_vault_entity_DDL()

        if verbose:
            for table, ddl in all_ddls.items():
                print(f"\n{'='*60}")
                print(f"-- {table}")
                print(f"{'='*60}")
                print(ddl)

        self._print_validation_summary()

        err_count = len([v for v in self._validation_issues if v.severity == WarningSeverity.ERROR])
        warn_count = len([v for v in self._validation_issues if v.severity == WarningSeverity.WARNING])

        if execute:
            if force:
                if err_count > 0:
                    raise UnresolvedVaultErrors(err_count)
                else:
                    if warn_count > 0:
                        print(f"Running with warnings ({warn_count} warnings). Proceeding with execution.")
                    self._execute_query(all_ddls)
            else:
                if err_count > 0 or warn_count > 0:
                    raise UnresolvedVaultErrors(err_count + warn_count)
                self._execute_query(all_ddls)
    
    # PIPELINE STAGES
    def _build_staging_map(self, source_set: set[str]) -> Dict[str, Stage]:
        staging_map: Dict[str, Stage] = {}

        # TODO: using dict data type dedupes for us nicely, maybe some way to warn that we deduped the specific columns
        for source in source_set:
            staging_map[source] = Stage(source_table=source)

        return staging_map

    def _process_entities(self, staging_map: Dict[str, Stage]) -> None:# Process in order
        hub_map = {hub.name: hub for hub in self._hubs}
        link_map = {link.name: link for link in self._links}
        
        self._process_hubs(staging_map)
        self._process_links(staging_map)
        self._process_satellites(staging_map, hub_map, link_map)

    def _process_hubs(self, staging_map: Dict[str, Stage]) -> None:
        for hub in self._hubs:
            stg = staging_map[hub.source_table]
            
            self._validate_identifier(hub.hash_key_name, allow_dots=False)
            stg.hash_keys[hub.hash_key_name] = hub.business_key_columns
            
            hub._set_stage(VaultEntityMetadata(
                schema=stg.schema_name,
                table=stg.name,
                columns=[hub.hash_key_name, *hub.business_key_columns]
            ))

    def _process_links(self, staging_map: Dict[str, Stage]) -> None:
        all_hubs = self._get_all_hubs()
        
        for link in self._links:
            link_stage_name = f"stg_{link.name}"
            link_hash_keys = []
            skip_link = False

            if link.anchor.hub:
                hub = all_hubs.get(link.anchor.hub)
                if hub:
                    self._validate_identifier(hub.hash_key_name, allow_dots=False)
                    link_hash_keys.append(hub.hash_key_name)
                else:
                    self._add_issue(
                        WarningSeverity.ERROR,
                        "link",
                        link.name,
                        f"Anchor references non-existent hub '{link.anchor.hub}'"
                    )
                    skip_link = True

            for mapping in link.hub_mapping:
                hub = all_hubs.get(mapping.hub)
            
                if hub is None:
                    skip_link = True
                    continue
                
                self._validate_identifier(hub.hash_key_name, allow_dots=False)
                link_hash_keys.append(hub.hash_key_name)

            if skip_link:
                continue

            # Derive staging schema from anchor table
            # anchor = link.anchor
            # anchor_schema = anchor.table.split(".")[-2] if "." in anchor.table else "bronze"
            # link_stage_name = f"stg_{link.name}"

            link._set_stage(VaultEntityMetadata(
                schema=link.staging_schema,
                table=f"stg_{link.name}",
                columns=link_hash_keys
            ))

    def _process_satellites(
        self,
        staging_map: Dict[str, Stage],
        hub_map: Dict[str, Hub],
        link_map: Dict[str, Link]
    ) -> None:
        all_hubs = {**hub_map, **self._registered_hubs}
        all_links = {**link_map, **self._registered_links}

        for sat in self._sats:
            stg = staging_map[sat.source_table]
            hash_diff_name = f"hd_{sat.name.replace('sat_', '')}"
            
            self._validate_identifier(hash_diff_name, allow_dots=False)

            parent = all_hubs.get(sat.parent_hub_or_link) or all_links.get(sat.parent_hub_or_link)

            if parent is None:
                self._add_issue(
                    WarningSeverity.ERROR,
                    "satellite",
                    sat.name,
                    f"Parent '{sat.parent_hub_or_link}' not found during processing"
                )
                continue

            if sat.hash_column:
                stg.hash_keys[parent.hash_key_name] = sat.hash_column
            
            sat._set_stage(VaultEntityMetadata(
                schema=stg.schema_name,
                table=stg.name,
                columns=[hash_diff_name, parent.hash_key_name]
            ))

            resolved = self._resolve_satellite_columns(sat)
            sat._set_resolved_columns(resolved)
    
    # LOOKUPS
    def _get_all_hubs(self) -> Dict[str, Union[Hub, RegisteredHub]]:
        all_hubs = {hub.name: hub for hub in self._hubs}
        all_hubs.update(self._registered_hubs)
        return all_hubs

    def _get_all_links(self) -> Dict[str, Union[Link, RegisteredLink]]:
        all_links = {link.name: link for link in self._links}
        all_links.update(self._registered_links)
        return all_links

    def _get_all_parents(self) -> Dict[str, Union[Hub, Link, RegisteredHub, RegisteredLink]]:
        return {**self._get_all_hubs(), **self._get_all_links()}

    def _get_unique_tables(self) -> set[str]:
        sources = set()
    
        for hub in self._hubs:
            sources.add(hub.source_table)
        
        for sat in self._sats:
            sources.add(sat.source_table)
        
        return sources
    
    # NORMALIZATION
    def _normalize_source_table(self, source_table: str, default_schema: str = "raw") -> str:
        if not source_table or not isinstance(source_table, str):
            raise ValueError(f"Source table must be a non-empty string, got: {source_table}")
        
        # Remove extra whitespace
        source_table = source_table.strip()
        
        # Split by dots
        parts = source_table.split(".")
        
        if len(parts) == 1:
            # Just table name - add default schema
            table = parts[0]
            normalized = f"{default_schema}.{table}"
            
            self._add_issue(
                WarningSeverity.INFO,
                "normalization",
                source_table,
                f"Source table normalized: '{source_table}' ‚Üí '{normalized}'",
                {"original": source_table, "normalized": normalized}
            )
            
            return normalized
            
        elif len(parts) == 2:
            # schema.table - already correct format
            schema, table = parts
            
            # Validate parts are not empty
            if not schema or not table:
                raise ValueError(f"Invalid source table format: '{source_table}'. Schema and table cannot be empty.")
            
            return source_table
            
        elif len(parts) == 3:
            # three-point lakehouse naming/cross-lakehouse
            lakehouse, schema, table = parts
            
            # Validate parts are not empty
            if not lakehouse or not schema or not table:
                raise ValueError(f"Invalid source table format: '{source_table}'. lakehouse, schema, and table cannot be empty.")

            return source_table
            
        else:
            raise ValueError(
                f"Invalid source table format: '{source_table}'. "
                f"Expected: 'table', 'schema.table', or 'catalog.schema.table'"
            )

    # COLUMN UTILS
    def _get_reserved_columns(self, sat: Optional[Satellite] = None, include_parent_hk: bool = False) -> set:
        
        reserved = {
            'load_datetime', 
            'record_source', 
            'load_date', 
            'load_ts',
            self._default_load_datetime_column
        }
        
        if sat:
            if sat.load_datetime_column:
                reserved.add(sat.load_datetime_column)
            
            if include_parent_hk:
                parent = self._get_all_parents().get(sat.parent_hub_or_link)
                if parent:
                    reserved.add(parent.hash_key_name)
        
        return reserved

    def _filter_columns(
        self,
        columns: List[str],
        reserved: set,
        stage: str,
        entity_name: str,
        reason: str
    ) -> List[str]:
        """Filter columns against reserved set with case-insensitive matching"""
        
        reserved_lower = {r.lower() for r in reserved}
        
        safe_columns = [
            col for col in columns
            if col.lower() not in reserved_lower
        ]
        
        excluded = set(columns) - set(safe_columns)
        if excluded:
            self._add_issue(
                WarningSeverity.WARNING,
                stage,
                entity_name,
                f"Excluded {len(excluded)} column(s): {sorted(excluded)}",
                {
                    "excluded_columns": sorted(excluded),
                    "reason": reason
                }
            )
        
        return safe_columns

    def _resolve_satellite_columns(self, sat: Satellite) -> List[str]:  
        reserved = self._get_reserved_columns(sat, include_parent_hk=True)
        
        if sat.include_mode:
            # Include mode: use specified columns
            return self._filter_columns(
                columns=sat.descriptive_columns,
                reserved=reserved,
                stage="satellite",
                entity_name=sat.name,
                reason="System columns excluded"
            )
        
        # Exclude mode: use ALL source columns except specified
        source_columns = self._get_source_columns(sat.source_table)
        
        if not source_columns:
            self._add_issue(
                WarningSeverity.ERROR,
                "satellite",
                sat.name,
                "Cannot use include_mode=False: failed to retrieve source columns"
            )
            return []
        
        all_excluded = reserved | set(sat.descriptive_columns)
        
        return self._filter_columns(
            columns=source_columns,
            reserved=all_excluded,
            stage="satellite",
            entity_name=sat.name,
            reason="User exclusion (include_mode=False) + system columns"
        )
    
    def _resolve_load_datetime_column(self, source_table: str, entity_override: Optional[str] = None) -> str:
        """Resolve load_datetime column: entity override > cached > default"""
        
        # Entity-level override takes priority
        if entity_override:
            self._source_load_datetime_map[source_table] = entity_override
            return entity_override
        
        # Check cached mapping
        if source_table in self._source_load_datetime_map:
            return self._source_load_datetime_map[source_table]
        
        # Validate default exists in source
        resolved = self._default_load_datetime_column
        source_columns = self._get_source_columns(source_table)
        
        if source_columns and resolved.lower() not in {c.lower() for c in source_columns}:
            self._add_issue(
                WarningSeverity.ERROR,
                "staging",
                source_table,
                f"load_datetime column '{resolved}' not found in source",
                {
                    "available_columns": source_columns,
                    "suggestion": "Specify load_datetime_column in entity definition"
                }
            )
        
        return resolved

    # SPARK INTERACTIONS
    def _check_table_existence(self, schema_name: str, table_name: str):
        try:
            full_name = f"{schema_name}.{table_name}"
            return spark.catalog.tableExists(full_name)
        except Exception as e:
            self._add_issue(
                WarningSeverity.WARNING,
                "lakehouse",
                table_name,
                f"Could not check table existence: {e}"
            )
            return False

    def _infer_hub_business_keys(self, schema_name: str, hub_name: str) -> List[str]:
        try:
            df = spark.table(f"{schema_name}.{hub_name}")
            # Exclude known system columns
            exclude = {'load_datetime', 'record_source', f"hk_{hub_name.replace('hub_', '')}"}
            return [c for c in df.columns if c.lower() not in {e.lower() for e in exclude}]
        except Exception:
            return []

    def _get_source_columns(self, source_table: str) -> List[str]:
        if source_table in self._source_columns_cache:
            return self._source_columns_cache[source_table]

        try:
            df = spark.table(source_table)
            self._source_columns_cache[source_table] = df.columns
            return df.columns
            
        except Exception as e:
            self._add_issue(
                WarningSeverity.ERROR,
                "staging",
                source_table,
                f"Could not retrieve source columns: {e}"
            )
            return []

    # EXECUTION AND OUTPUT
    def _print_validation_summary(self) -> None:
        
        if not self._validation_issues:
            print("‚úÖ No validation issues found\n")
            return
        
        severity_config = {
            WarningSeverity.ERROR: "‚ùå ERRORS",
            WarningSeverity.WARNING: "‚ö†Ô∏è  WARNINGS",
            WarningSeverity.INFO: "‚ÑπÔ∏è  INFO"
        }

        for severity, label in severity_config.items():
            issues = [i for i in self._validation_issues if i.severity == severity]
            
            if not issues:
                continue
                
            print(f"\n{label} ({len(issues)}):")
            for issue in issues:
                print(f"   [{issue.stage}] {issue.entity}: {issue.message}")
                if issue.details:
                    for key, value in issue.details.items():
                        print(f"      ‚îî‚îÄ {key}: {value}")

    def _execute_query(self, ddl_string_list: Dict[str, str]):
        print("EXECUTING!")
        
        created_count = 0
        skipped_count = 0

        for table, ddl in ddl_string_list.items():

            schema_name, table_name = table.split(".")
            
            # Check if table already exists
            if self._check_table_existence(schema_name, table_name):
                print(f"‚è≠Ô∏è  Skipping '{table}' ‚Äî already exists")
                skipped_count += 1
                continue
            
            # Table doesn't exist, create it
            print(f"üî® Creating '{table}'...")
            df = spark.sql(ddl)
            display(df)
            created_count += 1
        
        print(f"\n‚úÖ Execution complete: {created_count} created, {skipped_count} skipped")


StatementMeta(, 1fe321f3-eeee-4571-9a40-0a27ef35c019, 7, Finished, Available, Finished)