Skip to content
53 changes: 49 additions & 4 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import time
import base64
import threading
import logging

from google.protobuf.struct_pb2 import ListValue
from google.protobuf.struct_pb2 import Value
Expand All @@ -29,16 +30,27 @@
from google.api_core import datetime_helpers
from google.api_core.exceptions import Aborted
from google.cloud._helpers import _date_from_iso8601_date
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import JsonObject, Interval
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud.spanner_v1.types import ExecuteSqlRequest
from google.cloud.spanner_v1.types import TransactionOptions
from google.cloud.spanner_v1.data_types import JsonObject, Interval
from google.cloud.spanner_v1.request_id_header import with_request_id
from google.cloud.spanner_v1.types import TypeCode

from google.rpc.error_details_pb2 import RetryInfo

try:
from opentelemetry.propagate import inject
from opentelemetry.propagators.textmap import Setter
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.resourcedetector import gcp_resource_detector
from opentelemetry.resourcedetector.gcp_resource_detector import (
GoogleCloudResourceDetector,
)

# Overwrite the requests timeout for the detector.
# This is necessary as the client will wait the full timeout if the
# code is not run in a GCP environment, with the location endpoints available.
gcp_resource_detector._TIMEOUT_SEC = 0.2

HAS_OPENTELEMETRY_INSTALLED = True
except ImportError:
Expand All @@ -55,6 +67,12 @@
+ "numeric has a whole component with precision {}"
)

GOOGLE_CLOUD_REGION_GLOBAL = "global"

log = logging.getLogger(__name__)

_cloud_region: str = None


if HAS_OPENTELEMETRY_INSTALLED:

Expand All @@ -79,6 +97,33 @@ def set(self, carrier: List[Tuple[str, str]], key: str, value: str) -> None:
carrier.append((key, value))


def _get_cloud_region() -> str:
"""Get the location of the resource, caching the result.

Returns:
str: The location of the resource. If OpenTelemetry is not installed, returns a global region.
"""
global _cloud_region
if _cloud_region is not None:
return _cloud_region

try:
detector = GoogleCloudResourceDetector()
resources = detector.detect()
if ResourceAttributes.CLOUD_REGION in resources.attributes:
_cloud_region = resources.attributes[ResourceAttributes.CLOUD_REGION]
else:
_cloud_region = GOOGLE_CLOUD_REGION_GLOBAL
except Exception as e:
log.warning(
"Failed to detect GCP resource location for Spanner metrics, defaulting to 'global'. Error: %s",
e,
)
_cloud_region = GOOGLE_CLOUD_REGION_GLOBAL

return _cloud_region


def _try_to_coerce_bytes(bytestring):
"""Try to coerce a byte string into the right thing based on Python
version and whether or not it is base64 encoded.
Expand Down
9 changes: 9 additions & 0 deletions google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.cloud.spanner_v1 import SpannerClient
from google.cloud.spanner_v1 import gapic_version
from google.cloud.spanner_v1._helpers import (
_get_cloud_region,
_metadata_with_span_context,
)

Expand Down Expand Up @@ -75,6 +76,7 @@ def trace_call(
enable_end_to_end_tracing = False

db_name = ""
cloud_region = None
if session and getattr(session, "_database", None):
db_name = session._database.name

Expand All @@ -88,6 +90,7 @@ def trace_call(
)
db_name = observability_options.get("db_name", db_name)

cloud_region = _get_cloud_region()
tracer = get_tracer(tracer_provider)

# Set base attributes that we know for every trace created
Expand All @@ -97,6 +100,7 @@ def trace_call(
"db.instance": db_name,
"net.host.name": SpannerClient.DEFAULT_ENDPOINT,
OTEL_SCOPE_NAME: TRACER_NAME,
"cloud.region": cloud_region,
OTEL_SCOPE_VERSION: TRACER_VERSION,
# Standard GCP attributes for OTel, attributes are used for internal purpose and are subjected to change
"gcp.client.service": "spanner",
Expand All @@ -107,6 +111,11 @@ def trace_call(
if extra_attributes:
attributes.update(extra_attributes)

if "request_options" in attributes:
request_options = attributes.pop("request_options")
if request_options and request_options.request_tag:
attributes["request.tag"] = request_options.request_tag

if extended_tracing_globally_disabled:
enable_extended_tracing = False

Expand Down
6 changes: 6 additions & 0 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,8 +1025,14 @@ def run_in_transaction(self, func, *args, **kw):
reraises any non-ABORT exceptions raised by ``func``.
"""
observability_options = getattr(self, "observability_options", None)
transaction_tag = kw.get("transaction_tag")
extra_attributes = {}
if transaction_tag:
extra_attributes["transaction.tag"] = transaction_tag

with trace_call(
"CloudSpanner.Database.run_in_transaction",
extra_attributes=extra_attributes,
observability_options=observability_options,
), MetricsCapture():
# Sanity check: Is there a transaction already running?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,9 @@
from .metrics_tracer_factory import MetricsTracerFactory
import os
import logging
from .constants import (
SPANNER_SERVICE_NAME,
GOOGLE_CLOUD_REGION_KEY,
GOOGLE_CLOUD_REGION_GLOBAL,
)
from .constants import SPANNER_SERVICE_NAME

try:
from opentelemetry.resourcedetector import gcp_resource_detector

# Overwrite the requests timeout for the detector.
# This is necessary as the client will wait the full timeout if the
# code is not run in a GCP environment, with the location endpoints available.
gcp_resource_detector._TIMEOUT_SEC = 0.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check if we need this timeout in the new code after refactoring. Currently get region does not have any timeout

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes got missed while moving the logic to helper. Added back


import mmh3

logging.getLogger("opentelemetry.resourcedetector.gcp_resource_detector").setLevel(
Expand All @@ -44,6 +33,7 @@

from .metrics_tracer import MetricsTracer
from google.cloud.spanner_v1 import __version__
from google.cloud.spanner_v1._helpers import _get_cloud_region
from uuid import uuid4

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,7 +76,7 @@ def __new__(
cls._metrics_tracer_factory.set_client_hash(
cls._generate_client_hash(client_uid)
)
cls._metrics_tracer_factory.set_location(cls._get_location())
cls._metrics_tracer_factory.set_location(_get_cloud_region())
cls._metrics_tracer_factory.gfe_enabled = gfe_enabled

if cls._metrics_tracer_factory.enabled != enabled:
Expand Down Expand Up @@ -153,28 +143,3 @@ def _generate_client_hash(client_uid: str) -> str:

# Return as 6 digit zero padded hex string
return f"{sig_figs:06x}"

@staticmethod
def _get_location() -> str:
"""Get the location of the resource.

In case of any error during detection, this method will log a warning
and default to the "global" location.

Returns:
str: The location of the resource. If OpenTelemetry is not installed, returns a global region.
"""
if not HAS_OPENTELEMETRY_INSTALLED:
return GOOGLE_CLOUD_REGION_GLOBAL
try:
detector = gcp_resource_detector.GoogleCloudResourceDetector()
resources = detector.detect()

if GOOGLE_CLOUD_REGION_KEY in resources.attributes:
return resources.attributes[GOOGLE_CLOUD_REGION_KEY]
except Exception as e:
log.warning(
"Failed to detect GCP resource location for Spanner metrics, defaulting to 'global'. Error: %s",
e,
)
return GOOGLE_CLOUD_REGION_GLOBAL
5 changes: 5 additions & 0 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,14 @@ def run_in_transaction(self, func, *args, **kw):
database = self._database
log_commit_stats = database.log_commit_stats

extra_attributes = {}
if transaction_tag:
extra_attributes["transaction.tag"] = transaction_tag

with trace_call(
"CloudSpanner.Session.run_in_transaction",
self,
extra_attributes=extra_attributes,
observability_options=getattr(database, "observability_options", None),
) as span, MetricsCapture():
attempts: int = 0
Expand Down
8 changes: 6 additions & 2 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,11 @@ def read(
method=streaming_read_method,
request=read_request,
metadata=metadata,
trace_attributes={"table_id": table, "columns": columns},
trace_attributes={
"table_id": table,
"columns": columns,
"request_options": request_options,
},
column_info=column_info,
lazy_decode=lazy_decode,
)
Expand Down Expand Up @@ -601,7 +605,7 @@ def execute_sql(
method=execute_streaming_sql_method,
request=execute_sql_request,
metadata=metadata,
trace_attributes={"db.statement": sql},
trace_attributes={"db.statement": sql, "request_options": request_options},
column_info=column_info,
lazy_decode=lazy_decode,
)
Expand Down
8 changes: 6 additions & 2 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,10 @@ def execute_update(
request_options = RequestOptions(request_options)
request_options.transaction_tag = self.transaction_tag

trace_attributes = {"db.statement": dml}
trace_attributes = {
"db.statement": dml,
"request_options": request_options,
}

# If this request begins the transaction, we need to lock
# the transaction until the transaction ID is updated.
Expand Down Expand Up @@ -629,7 +632,8 @@ def batch_update(

trace_attributes = {
# Get just the queries from the DML statement batch
"db.statement": ";".join([statement.sql for statement in parsed])
"db.statement": ";".join([statement.sql for statement in parsed]),
"request_options": request_options,
}

# If this request begins the transaction, we need to lock
Expand Down
2 changes: 2 additions & 0 deletions tests/system/test_session_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from google.cloud.spanner_admin_database_v1 import DatabaseDialect
from google.cloud._helpers import UTC

from google.cloud.spanner_v1._helpers import _get_cloud_region
from google.cloud.spanner_v1._helpers import AtomicCounter
from google.cloud.spanner_v1.data_types import JsonObject
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
Expand Down Expand Up @@ -356,6 +357,7 @@ def _make_attributes(db_instance, **kwargs):
"db.url": "spanner.googleapis.com",
"net.host.name": "spanner.googleapis.com",
"db.instance": db_instance,
"cloud.region": _get_cloud_region(),
"gcp.client.service": "spanner",
"gcp.client.version": ot_helpers.LIB_VERSION,
"gcp.client.repo": "googleapis/python-spanner",
Expand Down
48 changes: 47 additions & 1 deletion tests/unit/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
import unittest
import mock

from google.cloud.spanner_v1 import TransactionOptions
from opentelemetry.sdk.resources import Resource
from opentelemetry.semconv.resource import ResourceAttributes


from google.cloud.spanner_v1 import TransactionOptions, _helpers


class Test_merge_query_options(unittest.TestCase):
Expand Down Expand Up @@ -89,6 +93,48 @@ def test_base_object_merge_dict(self):
self.assertEqual(result, expected)


class Test_get_cloud_region(unittest.TestCase):
def setUp(self):
_helpers._cloud_region = None

def _callFUT(self, *args, **kw):
from google.cloud.spanner_v1._helpers import _get_cloud_region

return _get_cloud_region(*args, **kw)

@mock.patch("google.cloud.spanner_v1._helpers.GoogleCloudResourceDetector.detect")
def test_get_location_with_region(self, mock_detect):
"""Test that _get_cloud_region returns the region when detected."""
mock_resource = Resource.create(
{ResourceAttributes.CLOUD_REGION: "us-central1"}
)
mock_detect.return_value = mock_resource

location = self._callFUT()
self.assertEqual(location, "us-central1")

@mock.patch("google.cloud.spanner_v1._helpers.GoogleCloudResourceDetector.detect")
def test_get_location_without_region(self, mock_detect):
"""Test that _get_cloud_region returns 'global' when no region is detected."""
mock_resource = Resource.create({}) # No region attribute
mock_detect.return_value = mock_resource

location = self._callFUT()
self.assertEqual(location, "global")

@mock.patch("google.cloud.spanner_v1._helpers.GoogleCloudResourceDetector.detect")
def test_get_location_with_exception(self, mock_detect):
"""Test that _get_cloud_region returns 'global' and logs a warning on exception."""
mock_detect.side_effect = Exception("detector failed")

with self.assertLogs(
"google.cloud.spanner_v1._helpers", level="WARNING"
) as log:
location = self._callFUT()
self.assertEqual(location, "global")
self.assertIn("Failed to detect GCP resource location", log.output[0])


class Test_make_value_pb(unittest.TestCase):
def _callFUT(self, *args, **kw):
from google.cloud.spanner_v1._helpers import _make_value_pb
Expand Down
Loading