From c9cbb83221f1496d56274bdc4c4e5c858d306006 Mon Sep 17 00:00:00 2001 From: Dave Connors Date: Wed, 8 Mar 2023 12:11:31 -0600 Subject: [PATCH] mae a working example --- data_diff/__main__.py | 13 ++++++- data_diff/dbt.py | 87 ++++++++++++++++++++++++------------------ data_diff/dbt_cloud.py | 13 +++++++ 3 files changed, 74 insertions(+), 39 deletions(-) create mode 100644 data_diff/dbt_cloud.py diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 6f2be411..5ab7d092 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -227,6 +227,13 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - metavar="PATH", help="Override the dbt project directory. Otherwise assumed to be the current directory.", ) +@click.option( + "--select", + default=None, + metavar="PATH", + help="select dbt resources to compare", +) + def main(conf, run, **kw): if kw["table2"] is None and kw["database2"]: # Use the "database table table" form @@ -263,8 +270,12 @@ def main(conf, run, **kw): profiles_dir_override=kw["dbt_profiles_dir"], project_dir_override=kw["dbt_project_dir"], is_cloud=kw["cloud"], + selection=kw["select"], ) - render_diff(diff, kw["limit"], kw["stats"], kw["json_output"]) + for d in diff: + # import pdb; pdb.set_trace() + rich.print(f"Diffing {'.'.join(d.info_tree.info.tables[0].table_path)} with {'.'.join(d.info_tree.info.tables[1].table_path)}") + render_diff(d, kw["limit"], kw["stats"], kw["json_output"]) else: return _data_diff(**kw) diff --git a/data_diff/dbt.py b/data_diff/dbt.py index 3ed9876b..64b946af 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -9,7 +9,7 @@ from typing import List, Optional, Dict import requests -from dbt_artifacts_parser.parser import parse_run_results, parse_manifest +# from dbt_artifacts_parser.parser import parse_run_results, parse_manifest from dbt.config.renderer import ProfileRenderer from .tracking import ( @@ -21,6 +21,8 @@ ) from .utils import run_as_daemon, truncate_error from . import connect_to_table, diff_tables, Algorithm +import subprocess +from .dbt_cloud import get_client, dynamic_request RUN_RESULTS_PATH = "/target/run_results.json" MANIFEST_PATH = "/target/manifest.json" @@ -28,7 +30,10 @@ PROFILES_FILE = "/profiles.yml" LOWER_DBT_V = "1.0.0" UPPER_DBT_V = "1.5.0" +DBT_CLOUD_API_KEY = os.getenv('DBT_CLOUD_API_KEY', None) +DBT_CLOUD_PROD_ENV_ID = os.getenv('DBT_CLOUD_PROD_ENV_ID', None) +dbtc_client = get_client(DBT_CLOUD_API_KEY) @dataclass class DiffVars: @@ -40,25 +45,27 @@ class DiffVars: def dbt_diff( - profiles_dir_override: Optional[str] = None, project_dir_override: Optional[str] = None, is_cloud: bool = False + profiles_dir_override: Optional[str] = None, project_dir_override: Optional[str] = None, is_cloud: bool = False, selection: str = None ) -> None: set_entrypoint_name("CLI-dbt") dbt_parser = DbtParser(profiles_dir_override, project_dir_override, is_cloud) - models = dbt_parser.get_models() + models = dbt_parser.get_models(selection) dbt_parser.set_project_dict() datadiff_variables = dbt_parser.get_datadiff_variables() - config_prod_database = datadiff_variables.get("prod_database") - config_prod_schema = datadiff_variables.get("prod_schema") + # config_prod_database = datadiff_variables.get("prod_database") + # config_prod_schema = datadiff_variables.get("prod_schema") datasource_id = datadiff_variables.get("datasource_id") if not is_cloud: dbt_parser.set_connection() - if config_prod_database is None or config_prod_schema is None: - raise ValueError("Expected a value for prod_database: or prod_schema: under \nvars:\n data_diff: ") + # if config_prod_database is None or config_prod_schema is None: + # raise ValueError("Expected a value for prod_database: or prod_schema: under \nvars:\n data_diff: ") + # import pdb; pdb.set_trace() + model_output = [] for model in models: - diff_vars = _get_diff_vars(dbt_parser, config_prod_database, config_prod_schema, model, datasource_id) + diff_vars = _get_diff_vars(dbt_parser, model, datasource_id) if is_cloud and len(diff_vars.primary_keys) > 0: _cloud_diff(diff_vars) @@ -73,7 +80,8 @@ def dbt_diff( ) if not is_cloud and len(diff_vars.primary_keys) == 1: - return _local_diff(diff_vars) + model_output.append(_local_diff(diff_vars)) + # print(result.diff) elif not is_cloud: rich.print( "[red]" @@ -83,31 +91,38 @@ def dbt_diff( + "[/] \n" + "Skipped due to missing primary-key tag or multi-column primary-key (unsupported for non --cloud diffs)\n" ) - - rich.print("Diffs Complete!") + return model_output def _get_diff_vars( dbt_parser: "DbtParser", - config_prod_database: Optional[str], - config_prod_schema: Optional[str], model, datasource_id: int, ) -> DiffVars: - dev_database = model.database - dev_schema = model.schema_ + dev_database = model.get("database") + dev_schema = model.get("schema") + unique_id = model.get("unique_id") primary_keys = dbt_parser.get_primary_keys(model) - prod_database = config_prod_database if config_prod_database else dev_database - prod_schema = config_prod_schema if config_prod_schema else dev_schema + prod_model_response = dynamic_request( + dbtc_client.metadata, + 'get_model_by_environment', + environment_id=DBT_CLOUD_PROD_ENV_ID, + unique_id=unique_id, + last_run_count=1 + ) + + prod_model_data = prod_model_response.get('data', {}).get("modelByEnvironment", [])[0] + prod_database = prod_model_data.get("database") + prod_schema = prod_model_data.get("schema") if dbt_parser.requires_upper: - dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.name]] - prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.name]] + dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.get("name")]] + prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.get("name")]] primary_keys = [x.upper() for x in primary_keys] else: - dev_qualified_list = [dev_database, dev_schema, model.name] - prod_qualified_list = [prod_database, prod_schema, model.name] + dev_qualified_list = [dev_database, dev_schema, model.get("name")] + prod_qualified_list = [prod_database, prod_schema, model.get("name")] return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection) @@ -243,32 +258,28 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo def get_datadiff_variables(self) -> dict: return self.project_dict.get("vars").get("data_diff") - def get_models(self): - with open(self.project_dir + RUN_RESULTS_PATH) as run_results: - run_results_dict = json.load(run_results) - run_results_obj = parse_run_results(run_results=run_results_dict) - - dbt_version = parse_version(run_results_obj.metadata.dbt_version) - - if dbt_version < parse_version(LOWER_DBT_V) or dbt_version >= parse_version(UPPER_DBT_V): - raise Exception( - f"Found dbt: v{dbt_version} Expected the dbt project's version to be >= {LOWER_DBT_V} and < {UPPER_DBT_V}" - ) + def get_models(self, selection): + if selection: + ls_cmd = ["dbt", "ls", "--resource-type", "model", "--select", selection] + else: + ls_cmd = ["dbt", "ls", "--resource-type", "model"] + result = subprocess.run(ls_cmd, capture_output=True, text=True) + model_list = ["model." + model for model in result.stdout.splitlines()] with open(self.project_dir + MANIFEST_PATH) as manifest: manifest_dict = json.load(manifest) - manifest_obj = parse_manifest(manifest=manifest_dict) + # for some reason the manifest parser appears to be choking on my seed + # manifest_obj = parse_manifest(manifest=manifest_dict) - success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"] - models = [manifest_obj.nodes.get(x) for x in success_models] + models = [manifest_dict.get("nodes").get(x) for x in model_list] if not models: - raise ValueError("Expected > 0 successful models runs from the last dbt command.") + raise ValueError("No models selected!") - rich.print(f"Found {str(len(models))} successful model runs from the last dbt command.") + rich.print(f"Found {str(len(models))} models to compare.") return models def get_primary_keys(self, model): - return list((x.name for x in model.columns.values() if "primary-key" in x.tags)) + return [x.get("name") for x in model.get("columns").values() if "primary-key" in x.get("tags")] def set_project_dict(self): with open(self.project_dir + PROJECT_FILE) as project: diff --git a/data_diff/dbt_cloud.py b/data_diff/dbt_cloud.py new file mode 100644 index 00000000..fffda3d9 --- /dev/null +++ b/data_diff/dbt_cloud.py @@ -0,0 +1,13 @@ +from dbtc import dbtCloudClient + +# first party + +def get_client(service_token): + client = dbtCloudClient(service_token=service_token) + return client + +def dynamic_request(_prop, method, *args, **kwargs): + return getattr(_prop, method)(*args, **kwargs) + +if __name__ == '__main__': + pass