diff --git a/data_diff/dbt.py b/data_diff/dbt.py index 67912de0..c824e4bf 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -46,6 +46,7 @@ class DiffVars: primary_keys: List[str] datasource_id: str connection: Dict[str, str] + threads: Optional[int] def dbt_diff( @@ -110,7 +111,7 @@ def _get_diff_vars( dev_qualified_list = [dev_database, dev_schema, model.alias] prod_qualified_list = [prod_database, prod_schema, model.alias] - return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection) + return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads) def _local_diff(diff_vars: DiffVars) -> None: @@ -118,8 +119,8 @@ def _local_diff(diff_vars: DiffVars) -> None: dev_qualified_string = ".".join(diff_vars.dev_path) prod_qualified_string = ".".join(diff_vars.prod_path) - table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys)) - table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys)) + table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads) + table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads) table1_columns = list(table1.get_schema()) try: @@ -260,6 +261,7 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo self.connection = None self.project_dict = None self.requires_upper = False + self.threads = None self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt() @@ -345,6 +347,7 @@ def set_connection(self): "role": rendered_credentials.get("role"), "schema": rendered_credentials.get("schema"), } + self.threads = rendered_credentials.get("threads") self.requires_upper = True elif conn_type == "bigquery": method = rendered_credentials.get("method") @@ -357,6 +360,7 @@ def set_connection(self): "project": rendered_credentials.get("project"), "dataset": rendered_credentials.get("dataset"), } + self.threads = rendered_credentials.get("threads") elif conn_type == "duckdb": conn_info = { "driver": conn_type, @@ -373,6 +377,7 @@ def set_connection(self): "port": rendered_credentials.get("port"), "dbname": rendered_credentials.get("dbname"), } + self.threads = rendered_credentials.get("threads") elif conn_type == "databricks": conn_info = { "driver": conn_type, @@ -382,6 +387,7 @@ def set_connection(self): "schema": rendered_credentials.get("schema"), "access_token": rendered_credentials.get("token"), } + self.threads = rendered_credentials.get("threads") elif conn_type == "postgres": conn_info = { "driver": "postgresql", @@ -391,6 +397,7 @@ def set_connection(self): "port": rendered_credentials.get("port"), "dbname": rendered_credentials.get("dbname") or rendered_credentials.get("database"), } + self.threads = rendered_credentials.get("threads") else: raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs") diff --git a/tests/test_dbt.py b/tests/test_dbt.py index b86b46d6..5b6ab25d 100644 --- a/tests/test_dbt.py +++ b/tests/test_dbt.py @@ -359,7 +359,7 @@ def test_local_diff(self, mock_diff_tables): dev_qualified_list = ["dev_db", "dev_schema", "dev_table"] prod_qualified_list = ["prod_db", "prod_schema", "prod_table"] expected_keys = ["key"] - diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection) + diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection, None) with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect: _local_diff(diff_vars) @@ -368,8 +368,8 @@ def test_local_diff(self, mock_diff_tables): ) self.assertEqual(len(mock_diff_tables.call_args[1]['extra_columns']), 2) self.assertEqual(mock_connect.call_count, 2) - mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys)) - mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys)) + mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None) + mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None) mock_diff.get_stats_string.assert_called_once() @patch("data_diff.dbt.diff_tables") @@ -386,7 +386,7 @@ def test_local_diff_no_diffs(self, mock_diff_tables): dev_qualified_list = ["dev_db", "dev_schema", "dev_table"] prod_qualified_list = ["prod_db", "prod_schema", "prod_table"] expected_keys = ["primary_key_column"] - diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection) + diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection, None) with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect: _local_diff(diff_vars) @@ -395,8 +395,8 @@ def test_local_diff_no_diffs(self, mock_diff_tables): ) self.assertEqual(len(mock_diff_tables.call_args[1]['extra_columns']), 2) self.assertEqual(mock_connect.call_count, 2) - mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys)) - mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys)) + mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None) + mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None) mock_diff.get_stats_string.assert_not_called() @patch("data_diff.dbt.rich.print") @@ -413,7 +413,7 @@ def test_cloud_diff(self, mock_request, mock_os_environ, mock_print): expected_datasource_id = 1 expected_primary_keys = ["primary_key_column"] diff_vars = DiffVars( - dev_qualified_list, prod_qualified_list, expected_primary_keys, expected_datasource_id, None + dev_qualified_list, prod_qualified_list, expected_primary_keys, expected_datasource_id, None, None ) _cloud_diff(diff_vars) @@ -443,7 +443,7 @@ def test_cloud_diff_ds_id_none(self, mock_request, mock_os_environ, mock_print): prod_qualified_list = ["prod_db", "prod_schema", "prod_table"] expected_datasource_id = None primary_keys = ["primary_key_column"] - diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None) + diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None, None) with self.assertRaises(ValueError): _cloud_diff(diff_vars) @@ -463,7 +463,7 @@ def test_cloud_diff_api_key_none(self, mock_request, mock_os_environ, mock_print prod_qualified_list = ["prod_db", "prod_schema", "prod_table"] expected_datasource_id = 1 primary_keys = ["primary_key_column"] - diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None) + diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None, None) with self.assertRaises(ValueError): _cloud_diff(diff_vars) @@ -487,7 +487,7 @@ def test_diff_is_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_ mock_dbt_parser.return_value = mock_dbt_parser_inst mock_dbt_parser_inst.get_models.return_value = [mock_model] mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict - expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None) + expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None) mock_get_diff_vars.return_value = expected_diff_vars dbt_diff(is_cloud=True) mock_dbt_parser_inst.get_models.assert_called_once() @@ -514,7 +514,7 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m } mock_dbt_parser_inst.get_models.return_value = [mock_model] mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict - expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None) + expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None) mock_get_diff_vars.return_value = expected_diff_vars dbt_diff(is_cloud=False) @@ -542,7 +542,7 @@ def test_diff_no_prod_configs( mock_dbt_parser_inst.get_models.return_value = [mock_model] mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict - expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None) + expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None) mock_get_diff_vars.return_value = expected_diff_vars with self.assertRaises(ValueError): dbt_diff(is_cloud=False) @@ -570,7 +570,7 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m } mock_dbt_parser_inst.get_models.return_value = [mock_model] mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict - expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None) + expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None) mock_get_diff_vars.return_value = expected_diff_vars dbt_diff(is_cloud=False) @@ -599,7 +599,7 @@ def test_diff_only_prod_schema( mock_dbt_parser_inst.get_models.return_value = [mock_model] mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict - expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None) + expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None) mock_get_diff_vars.return_value = expected_diff_vars with self.assertRaises(ValueError): dbt_diff(is_cloud=False) @@ -631,7 +631,7 @@ def test_diff_is_cloud_no_pks( mock_dbt_parser_inst.get_models.return_value = [mock_model] mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict - expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None) + expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None, None) mock_get_diff_vars.return_value = expected_diff_vars dbt_diff(is_cloud=True) @@ -662,7 +662,7 @@ def test_diff_not_is_cloud_no_pks( mock_dbt_parser_inst.get_models.return_value = [mock_model] mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict - expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None) + expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None, None) mock_get_diff_vars.return_value = expected_diff_vars dbt_diff(is_cloud=False) mock_dbt_parser_inst.get_models.assert_called_once()