Skip to content

Commit

Permalink
[ENH]: add grpc client interceptor (#1818)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Add grpc client interceptor

## Test plan
*How are these changes tested?*

- Tested locally with distributed tracing
  • Loading branch information
nicolasgere committed Mar 6, 2024
1 parent 401a7f0 commit b7e8b62
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 5 deletions.
7 changes: 4 additions & 3 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(self, system: System):
def heartbeat(self) -> int:
return int(time.time_ns())

@trace_method("SegmentAPI.create_database", OpenTelemetryGranularity.OPERATION)
@override
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
if len(name) < 3:
Expand All @@ -122,11 +123,11 @@ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
name=name,
tenant=tenant,
)

@trace_method("SegmentAPI.get_database", OpenTelemetryGranularity.OPERATION)
@override
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database:
return self._sysdb.get_database(name=name, tenant=tenant)

@trace_method("SegmentAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
def create_tenant(self, name: str) -> None:
if len(name) < 3:
Expand All @@ -135,7 +136,7 @@ def create_tenant(self, name: str) -> None:
self._sysdb.create_tenant(
name=name,
)

@trace_method("SegmentAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
@override
def get_tenant(self, name: str) -> t.Tenant:
return self._sysdb.get_tenant(name=name)
Expand Down
4 changes: 4 additions & 0 deletions chromadb/db/impl/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
UpdateSegmentRequest,
)
from chromadb.proto.coordinator_pb2_grpc import SysDBStub
from chromadb.telemetry.opentelemetry import OpenTelemetryClient
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from chromadb.types import (
Collection,
Database,
Expand Down Expand Up @@ -64,6 +66,8 @@ def start(self) -> None:
self._channel = grpc.insecure_channel(
f"{self._coordinator_url}:{self._coordinator_port}"
)
interceptors = [OtelInterceptor()]
self._channel = grpc.intercept_channel(self._channel, *interceptors)
self._sys_db_stub = SysDBStub(self._channel) # type: ignore
return super().start()

Expand Down
3 changes: 3 additions & 0 deletions chromadb/logservice/logservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from chromadb.proto.convert import to_proto_submit
from chromadb.proto.logservice_pb2 import PushLogsRequest, PullLogsRequest, RecordLog
from chromadb.proto.logservice_pb2_grpc import LogServiceStub
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from chromadb.types import (
SubmitEmbeddingRecord,
SeqId,
Expand Down Expand Up @@ -50,6 +51,8 @@ def start(self) -> None:
self._channel = grpc.insecure_channel(
f"{self._log_service_url}:{self._log_service_port}"
)
interceptors = [OtelInterceptor()]
self._channel = grpc.intercept_channel(self._channel, *interceptors)
self._log_service_stub = LogServiceStub(self._channel) # type: ignore
super().start()

Expand Down
2 changes: 2 additions & 0 deletions chromadb/telemetry/opentelemetry/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class OtelInterceptor(
grpc.StreamStreamClientInterceptor
):
def _intercept_call(self, continuation, client_call_details, request_or_iterator):
if tracer is None:
return continuation(client_call_details, request_or_iterator)
with tracer.start_as_current_span(f"RPC {client_call_details.method}", kind=SpanKind.CLIENT) as span:
# Prepare metadata for propagation
metadata = client_call_details.metadata[:] if client_call_details.metadata else []
Expand Down
6 changes: 4 additions & 2 deletions go/shared/otel/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ func decodeSpanID(encodedSpanID string) (s trace.SpanID, err error) {

// ServerGrpcInterceptor is a gRPC server interceptor for tracing and optional metadata-based context enhancement.
func ServerGrpcInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
tracer := otel.GetTracerProvider().Tracer("")

// Init with a default tracer if not already set. (Unit test)
if tracer == nil {
tracer = otel.GetTracerProvider().Tracer("LOCAL")
}
// Attempt to retrieve metadata, but proceed normally if not present.
md, _ := metadata.FromIncomingContext(ctx)

Expand Down

0 comments on commit b7e8b62

Please sign in to comment.