Skip to content

Commit

Permalink
Merge pull request #60 from nicholasyager/fix/inject_groups_into_mode…
Browse files Browse the repository at this point in the history
…lnodes

Feature: Support cross-project Group evaluation
  • Loading branch information
nicholasyager committed Jun 1, 2024
2 parents b5144d8 + 083dc5a commit 5f6681d
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 23 deletions.
65 changes: 57 additions & 8 deletions dbt_loom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import os
import re
from pathlib import Path
from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Set

import yaml
from dbt.contracts.graph.node_args import ModelNodeArgs

from dbt.contracts.graph.nodes import ModelNode

from dbt.plugins.manager import dbt_hook, dbtPlugin
from dbt.plugins.manifest import PluginNodes
Expand All @@ -30,16 +30,18 @@ class LoomModelNodeArgs(ModelNodeArgs):
"""A dbt-loom extension of ModelNodeArgs to preserve resource types across lineages."""

resource_type: NodeType = NodeType.Model
group: Optional[str] = None

def __init__(self, **kwargs):
super().__init__(
**{
key: value
for key, value in kwargs.items()
if key not in ("resource_type")
if key not in ("resource_type", "group")
}
)
self.resource_type = kwargs["resource_type"]
self.resource_type = kwargs.get("resource_type", NodeType.Model)
self.group = kwargs.get("group")

@property
def unique_id(self) -> str:
Expand Down Expand Up @@ -90,7 +92,7 @@ def convert_model_nodes_to_model_node_args(
identifier=node.identifier,
**(
# Small bit of logic to support both pydantic 2 and pydantic 1
node.model_dump(exclude={"schema_name", "depends_on", "node_config"})
node.model_dump(exclude={"schema_name", "depends_on", "node_config"}) # type: ignore
if hasattr(node, "model_dump")
else node.dict(exclude={"schema_name", "depends_on", "node_config"})
),
Expand All @@ -100,10 +102,11 @@ def convert_model_nodes_to_model_node_args(
}


@dataclass
class LoomRunnableConfig:
"""A shim class to allow is_invalid_*_ref functions to correctly handle access for loom-injected models."""

restrict_access: bool = True
restrict_access: bool = False
vars: VarProvider = VarProvider(vars={})


Expand All @@ -124,6 +127,7 @@ def __init__(self, project_name: str):
)

self._manifest_loader = ManifestLoader()
self.manifests: Dict[str, Dict] = {}

self.config: Optional[dbtLoomConfig] = self.read_config(configuration_path)
self.models: Dict[str, LoomModelNodeArgs] = {}
Expand All @@ -144,18 +148,61 @@ def __init__(self, project_name: str):
)
)

dbt.parser.manifest.ManifestLoader.check_valid_group_config_node = ( # type: ignore
self.group_validation_wrapper(
dbt.parser.manifest.ManifestLoader.check_valid_group_config_node # type: ignore
)
)

dbt.contracts.graph.nodes.ModelNode.from_args = ( # type: ignore
self.model_node_wrapper(dbt.contracts.graph.nodes.ModelNode.from_args) # type: ignore
)

super().__init__(project_name)

def model_node_wrapper(self, function) -> Callable:
"""Wrap the ModelNode.from_args function and inject extra properties from the LoomModelNodeArgs."""

def outer_function(args: LoomModelNodeArgs) -> ModelNode:
model = function(args)
model.group = args.group
return model

return outer_function

def group_validation_wrapper(self, function) -> Callable:
"""Wrap the check_valid_group_config_node function to inject upstream group names."""

def outer_function(
inner_self, groupable_node, valid_group_names: Set[str]
) -> bool:
new_groups: Set[str] = {
model.group for model in self.models.values() if model.group is not None
}

return function(
inner_self, groupable_node, valid_group_names.union(new_groups)
)

return outer_function

def dependency_wrapper(self, function) -> Callable:
def outer_function(inner_self, node, target_model, dependencies) -> bool:
if self.config is not None:
for manifest in self.config.manifests:
dependencies[manifest.name] = LoomRunnableConfig()
for manifest_name in self.manifests.keys():
dependencies[manifest_name] = LoomRunnableConfig()

return function(inner_self, node, target_model, dependencies)

return outer_function

def get_groups(self) -> Set[str]:
"""Get all groups defined in injected models."""

return {
model.group for model in self.models.values() if model.group is not None
}

def read_config(self, path: Path) -> Optional[dbtLoomConfig]:
"""Read the dbt-loom configuration file."""
if not path.exists():
Expand Down Expand Up @@ -196,6 +243,8 @@ def initialize(self) -> None:
if manifest is None:
continue

self.manifests[manifest_reference.name] = manifest

selected_nodes = identify_node_subgraph(manifest)
self.models.update(convert_model_nodes_to_model_node_args(selected_nodes))

Expand Down
1 change: 1 addition & 0 deletions dbt_loom/manifests.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ManifestNode(BaseModel):
latest_version: Optional[str] = None
deprecation_date: Optional[datetime.datetime] = None
access: Optional[str] = "protected"
group: Optional[str] = None
generated_at: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
depends_on: Optional[DependsOn] = None
depends_on_nodes: List[str] = Field(default_factory=list)
Expand Down
2 changes: 1 addition & 1 deletion test_projects/revenue/dbt_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ vars:
# Configuring models
# Full documentation: https://docs.getdbt.com/docs/configuring-models

restrict-access: true
restrict-access: false

# In this example config, we tell dbt to build all models in the example/ directory
# as tables. These settings can be overridden in the individual model files
Expand Down
4 changes: 4 additions & 0 deletions test_projects/revenue/models/groups.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
groups:
- name: sales
owner:
email: sales@example.com
24 changes: 21 additions & 3 deletions test_projects/revenue/models/marts/__models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ version: 2

models:
- name: orders
description: Order overview data mart, offering key details for each order inlcluding if it's a customer's first order and a food vs. drink item breakdown. One row per order.
description: >
Order overview data mart, offering key details for each order including if it's a customer's
first order and a food vs. drink item breakdown. One row per order.
access: public
tests:
- dbt_utils.expression_is_true:
expression: "count_food_items + count_drink_items = count_items"
- dbt_utils.expression_is_true:
expression: "subtotal_food_items + subtotal_drink_items = subtotal"

columns:
- name: order_id
description: The unique key of the orders mart.
Expand Down Expand Up @@ -38,9 +41,11 @@ models:
- name: order_cost
description: The sum of supply expenses to fulfill the order.
- name: location_name
description: The full location name of where this order was placed. Denormalized from `stg_locations`.
description: >
The full location name of where this order was placed. Denormalized from `stg_locations`.
- name: is_first_order
description: A boolean indicating if this order is from a new customer placing their first order.
description: >
A boolean indicating if this order is from a new customer placing their first order.
- name: is_food_order
description: A boolean indicating if this order included any food items.
- name: is_drink_order
Expand All @@ -55,3 +60,16 @@ models:
columns:
- include: all
exclude: [location_id]

- name: accounts
description: >
All accounts with whom we have done business. This is a very sensitive asset.
access: private
group: sales

columns:
- name: name
description: Name of the account.
tests:
- not_null
- unique
8 changes: 8 additions & 0 deletions test_projects/revenue/models/marts/accounts.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
with

final as (
select name from {{ ref('stg_accounts') }}
)


select * from final
2 changes: 1 addition & 1 deletion test_projects/revenue/models/staging/stg_accounts.sql
Original file line number Diff line number Diff line change
@@ -1 +1 @@
select * from {{ ref('accounts') }}
select * from {{ ref('seed_accounts') }}
2 changes: 1 addition & 1 deletion test_projects/revenue/seeds/__seeds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ seeds:
# cannot have both in latest version of 1.7.x.
access: public

- name: accounts
- name: seed_accounts
config:
access: private
File renamed without changes.
66 changes: 57 additions & 9 deletions tests/test_dbt_core_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

import dbt
from dbt.cli.main import dbtRunner, dbtRunnerResult
from dbt.contracts.results import RunExecutionResult, NodeResult

from dbt.contracts.graph.nodes import ModelNode


import dbt.exceptions
import pytest

starting_path = os.getcwd()


def test_dbt_core_runs_loom_plugin():
Expand All @@ -17,20 +19,21 @@ def test_dbt_core_runs_loom_plugin():
runner = dbtRunner()

# Compile the revenue project
starting_path = os.getcwd()

os.chdir(f"{starting_path}/test_projects/revenue")
runner.invoke(["clean"])
runner.invoke(["deps"])
runner.invoke(["compile"])

# Run `build` in the customer_success project
os.chdir(f"{starting_path}/test_projects/customer_success")
runner.invoke(["clean"])
runner.invoke(["deps"])
output: dbtRunnerResult = runner.invoke(["build"])

# Make sure nothing failed
assert output.exception is None

runner.invoke(["deps"])
output: dbtRunnerResult = runner.invoke(["ls"])

# Make sure nothing failed
Expand All @@ -45,21 +48,32 @@ def test_dbt_core_runs_loom_plugin():
"revenue.orders.v2",
}

os.chdir(starting_path)

assert set(output.result).issuperset(
subset
), "The child project is missing expected nodes. Check that injection still works."


@pytest.mark.skip(
reason="This only applies when a project has restrict-access: true, which bugs all dbt tests "
"on private nodes. We can bring this back when that is not the case."
)
def test_dbt_loom_injects_dependencies():
"""Verify that dbt-core runs the dbt-loom plugin and that it flags access violations."""

starting_path = os.getcwd()
runner = dbtRunner()

# Compile the revenue project
os.chdir(f"{starting_path}/test_projects/revenue")
runner.invoke(["clean"])
runner.invoke(["deps"])
output = runner.invoke(["compile"])

assert output.exception is None, output.exception.get_message() # type: ignore

path = Path(
f"{starting_path}/test_projects/customer_success/models/staging/stg_orders_enhanced.sql"
)
print(path)

with open(path, "w") as file:
file.write(
"""
Expand All @@ -72,20 +86,54 @@ def test_dbt_loom_injects_dependencies():
"""
)

# Run `ls`` in the customer_success project
os.chdir(f"{starting_path}/test_projects/customer_success")
runner.invoke(["clean"])
runner.invoke(["deps"])
output: dbtRunnerResult = runner.invoke(["build"])

path.unlink()

# Make sure nothing failed
assert isinstance(output.exception, dbt.exceptions.DbtReferenceError)


def test_dbt_loom_injects_groups():
"""Verify that dbt-core runs the dbt-loom plugin and that it flags group violations."""

runner = dbtRunner()

# Compile the revenue project
os.chdir(f"{starting_path}/test_projects/revenue")
runner.invoke(["clean"])
runner.invoke(["deps"])
runner.invoke(["compile"])
output = runner.invoke(["compile"])

assert output.exception is None

path = Path(
f"{starting_path}/test_projects/customer_success/models/marts/marketing_lists.sql"
)

with open(path, "w") as file:
file.write(
"""
with
upstream as (
select * from {{ ref('accounts') }}
)
select * from upstream
"""
)

# Run `ls`` in the customer_success project
os.chdir(f"{starting_path}/test_projects/customer_success")
runner.invoke(["clean"])
runner.invoke(["deps"])
output: dbtRunnerResult = runner.invoke(["build"])

path.unlink()
os.chdir(starting_path)

# Make sure nothing failed
assert isinstance(output.exception, dbt.exceptions.DbtReferenceError)

0 comments on commit 5f6681d

Please sign in to comment.