Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse query comment and use as bigquery job labels. #3145

Merged
merged 10 commits into from
Mar 22, 2021
5 changes: 3 additions & 2 deletions core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dbt.logger import GLOBAL_LOGGER as logger
from typing_extensions import Protocol
from dbt.dataclass_schema import (
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin,
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin, HyphenatedDbtClassMixin,
ValidatedStringMixin, register_pattern
)
from dbt.contracts.util import Replaceable
Expand Down Expand Up @@ -212,9 +212,10 @@ def to_target_dict(self):


@dataclass
class QueryComment(dbtClassMixin):
class QueryComment(HyphenatedDbtClassMixin):
comment: str = DEFAULT_QUERY_COMMENT
append: bool = False
job_label: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config is query-comment, rather than query_comment (docs). Not exactly sure why we did this, but I guess kebab casing is common in dbt_project.yml / project-level configs. I think this should be job-label (instead of job_label) for consistency

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I switched the base class to HyphenatedDbtClassMixin.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, I may not understand what this base class does, since that change broke tests that I thought were unrelated. Reverted for now, but let me know if you have thoughts about the right way to do this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took another look at this, and I think I found a subtle bug in the casing logic. Should be fixed now.



class AdapterRequiredConfig(HasCredentials, Protocol):
Expand Down
4 changes: 3 additions & 1 deletion core/dbt/dataclass_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def __post_serialize__(self, dct):
# performing the conversion to a dict
@classmethod
def __pre_deserialize__(cls, data):
if cls._hyphenated:
# `data` might not be a dict, e.g. for `query_comment`, which accepts
# a dict or a string; only snake-case for dict values.
if cls._hyphenated and isinstance(data, dict):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this make sense @gshank?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at all the existing calls to this class method, it seems to be true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I'm not sure exactly how this code gets reached, but I dropped into a debugger at this point on the unit test suite and found non-dictionary values--specifically, there's a test that sets query-comment to "". You should be able to reproduce the issue by dropping the isinstance check that I added and running something like pytest -s test/unit/test_query_headers.py -k test_disable_query_comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also run that test with the following added to pre_deserialize:

        if cls._hyphenated and not isinstance(data, dict):
            import ipdb; ipdb.set_trace()

Looking at the stack, I'm inside a function called __unpack_union_Project_query_comment__9f5a8d0e89384cc286d2b99acef5623d that I'm guessing is dynamically generated and eval-ed by mashumaro 😱.

new_dict = {}
for key in data:
if '-' in key:
Expand Down
35 changes: 31 additions & 4 deletions plugins/bigquery/dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import re
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache
Expand Down Expand Up @@ -305,12 +307,16 @@ def raw_execute(self, sql, fetch=False, *, use_legacy_sql=False):

logger.debug('On {}: {}', conn.name, sql)

job_params = {'use_legacy_sql': use_legacy_sql}
if self.profile.query_comment.job_label:
query_comment = self.query_header.comment.query_comment
labels = self._labels_from_query_comment(query_comment)
else:
labels = {}

if active_user:
job_params['labels'] = {
'dbt_invocation_id': active_user.invocation_id
}
labels['dbt_invocation_id'] = active_user.invocation_id

job_params = {'use_legacy_sql': use_legacy_sql, 'labels': labels}

priority = conn.credentials.priority
if priority == Priority.Batch:
Expand Down Expand Up @@ -544,6 +550,16 @@ def _retry_generator(self):
initial=self.DEFAULT_INITIAL_DELAY,
maximum=self.DEFAULT_MAXIMUM_DELAY)

def _labels_from_query_comment(self, comment: str) -> Dict:
try:
comment_labels = json.loads(comment)
except (TypeError, ValueError):
return {'query_comment': _sanitize_label(comment)}
return {
_sanitize_label(key): _sanitize_label(str(value))
for key, value in comment_labels.items()
}


class _ErrorCounter(object):
"""Counts errors seen up to a threshold then raises the next error."""
Expand Down Expand Up @@ -573,3 +589,14 @@ def _is_retryable(error):
e['reason'] == 'rateLimitExceeded' for e in error.errors):
return True
return False


_SANITIZE_LABEL_PATTERN = re.compile(r"[^a-z0-9_-]")


def _sanitize_label(value: str, max_length: int = 63) -> str:
"""Return a legal value for a BigQuery label."""
value = value.lower()
value = _SANITIZE_LABEL_PATTERN.sub("_", value)
value = value[: max_length - 1]
return value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anecdotally, this has the effect of appending an extra underscore at the end of my non-dict comment, e.g. comment: whatever in dbt_project.yml becomes query_comment: whatever_ in BigQuery, and comment: something else becomes query_comment: something_else_. Is that intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this is happening--there may be something upstream of this code adding trailing whitespace, because I don't get that behavior testing the function directly. I tried adding a strip(). Does that help?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the strip() fixed it!

Screen Shot 2021-03-09 at 10 39 33 AM

(The numerous _ in the top example are all special characters.)

12 changes: 10 additions & 2 deletions test/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import agate
import decimal
import json
import re
import unittest
from contextlib import contextmanager
Expand Down Expand Up @@ -588,7 +589,6 @@ def test_query_and_results(self, mock_bq):
self.mock_client.query.assert_called_once_with(
'sql', job_config=mock_bq.QueryJobConfig())


def test_copy_bq_table_appends(self):
self._copy_table(
write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND)
Expand All @@ -615,12 +615,20 @@ def test_copy_bq_table_truncates(self):
kwargs['job_config'].write_disposition,
dbt.adapters.bigquery.impl.WRITE_TRUNCATE)

def test_job_labels_valid_json(self):
expected = {"key": "value"}
labels = self.connections._labels_from_query_comment(json.dumps(expected))
self.assertEqual(labels, expected)

def test_job_labels_invalid_json(self):
labels = self.connections._labels_from_query_comment("not json")
self.assertEqual(labels, {"query_comment": "not_json"})

def _table_ref(self, proj, ds, table, conn):
return google.cloud.bigquery.table.TableReference.from_string(
'{}.{}.{}'.format(proj, ds, table))

def _copy_table(self, write_disposition):

self.connections.table_ref = self._table_ref
source = BigQueryRelation.create(
database='project', schema='dataset', identifier='table1')
Expand Down