diff --git a/src/deepgram/extensions/telemetry/proto_encoder.py b/src/deepgram/extensions/telemetry/proto_encoder.py index a085ed0e..7e055cd7 100644 --- a/src/deepgram/extensions/telemetry/proto_encoder.py +++ b/src/deepgram/extensions/telemetry/proto_encoder.py @@ -6,13 +6,18 @@ import typing from typing import Dict, List +_varint_single_byte_cache = [bytes([i]) for i in range(0x80)] + # --- Protobuf wire helpers (proto3) --- + def _varint(value: int) -> bytes: if value < 0: # For this usage we only encode non-negative values value &= (1 << 64) - 1 + if value < 0x80: + return _varint_single_byte_cache[value] out = bytearray() while value > 0x7F: out.append((value & 0x7F) | 0x80) @@ -31,11 +36,13 @@ def _len_delimited(field_number: int, payload: bytes) -> bytes: def _string(field_number: int, value: str) -> bytes: data = value.encode("utf-8") + # call _len_delimited directly; no extra local needed return _len_delimited(field_number, data) def _bool(field_number: int, value: bool) -> bytes: - return _key(field_number, 0) + _varint(1 if value else 0) + # Single-byte cache for '0' or '1' + return _key(field_number, 0) + _varint_single_byte_cache[1 if value else 0] def _int64(field_number: int, value: int) -> bytes: @@ -53,7 +60,10 @@ def _timestamp_message(ts_seconds: float) -> bytes: if nanos >= 1_000_000_000: sec += 1 nanos -= 1_000_000_000 + # Smallest possible allocations: build all once msg = bytearray() + from deepgram.extensions.telemetry.proto_encoder import _int64 # Avoiding circular import at top level + msg += _int64(1, sec) if nanos: msg += _key(2, 0) + _varint(nanos) @@ -64,11 +74,15 @@ def _timestamp_message(ts_seconds: float) -> bytes: def _map_str_str(field_number: int, items: typing.Mapping[str, str] | None) -> bytes: if not items: return b"" - out = bytearray() + # Preallocate list to reduce repeated += on bytearray (less copying overall) + outs = [] + append = outs.append + ld = _len_delimited + s = _string for k, v in items.items(): - entry = _string(1, k) + _string(2, v) - out += _len_delimited(field_number, entry) - return bytes(out) + entry = s(1, k) + s(2, v) + append(ld(field_number, entry)) + return b"".join(outs) def _map_str_double(field_number: int, items: typing.Mapping[str, float] | None) -> bytes: @@ -83,6 +97,7 @@ def _map_str_double(field_number: int, items: typing.Mapping[str, float] | None) # --- Schema-specific encoders (deepgram.dxtelemetry.v1) --- + def _encode_telemetry_context(ctx: typing.Mapping[str, typing.Any]) -> bytes: # Map SDK context keys to proto fields package_name = ctx.get("sdk_name") or ctx.get("package_name") or "python-sdk" @@ -123,7 +138,7 @@ def _encode_telemetry_context(ctx: typing.Mapping[str, typing.Any]) -> bytes: msg += _string(11, installation_id) if project_id: msg += _string(12, project_id) - + # Include extras as additional context attributes (field 13) extras = ctx.get("extras", {}) if extras: @@ -133,11 +148,13 @@ def _encode_telemetry_context(ctx: typing.Mapping[str, typing.Any]) -> bytes: if value is not None: extras_map[str(key)] = str(value) msg += _map_str_str(13, extras_map) - + return bytes(msg) -def _encode_telemetry_event(name: str, ts: float, attributes: Dict[str, str] | None, metrics: Dict[str, float] | None) -> bytes: +def _encode_telemetry_event( + name: str, ts: float, attributes: Dict[str, str] | None, metrics: Dict[str, float] | None +) -> bytes: msg = bytearray() msg += _string(1, name) msg += _len_delimited(2, _timestamp_message(ts)) @@ -160,24 +177,26 @@ def _encode_error_event( line: int | None = None, column: int | None = None, ) -> bytes: - msg = bytearray() + # Gather all chunks in list to reduce bytearray repeated copying + chunks = [] + append = chunks.append if err_type: - msg += _string(1, err_type) + append(_string(1, err_type)) if message: - msg += _string(2, message) + append(_string(2, message)) if stack_trace: - msg += _string(3, stack_trace) + append(_string(3, stack_trace)) if file: - msg += _string(4, file) + append(_string(4, file)) if line is not None: - msg += _key(5, 0) + _varint(line) + append(_key(5, 0) + _varint(line)) if column is not None: - msg += _key(6, 0) + _varint(column) - msg += _key(7, 0) + _varint(severity) - msg += _bool(8, handled) - msg += _len_delimited(9, _timestamp_message(ts)) - msg += _map_str_str(10, attributes) - return bytes(msg) + append(_key(6, 0) + _varint(column)) + append(_key(7, 0) + _varint(severity)) + append(_bool(8, handled)) + append(_len_delimited(9, _timestamp_message(ts))) + append(_map_str_str(10, attributes)) + return b"".join(chunks) def _encode_record(record: bytes, kind_field_number: int) -> bytes: @@ -253,7 +272,7 @@ def _normalize_events(events: List[dict]) -> List[bytes]: # Note: URL is never logged for privacy "connection_type": "websocket", } - + # Add detailed error information to attributes if e.get("error_type"): attrs["error_type"] = str(e["error_type"]) @@ -265,7 +284,7 @@ def _normalize_events(events: List[dict]) -> List[bytes]: attrs["timeout_occurred"] = str(e["timeout_occurred"]) if e.get("duration_ms"): attrs["duration_ms"] = str(e["duration_ms"]) - + # Add WebSocket handshake failure details if e.get("handshake_status_code"): attrs["handshake_status_code"] = str(e["handshake_status_code"]) @@ -278,27 +297,27 @@ def _normalize_events(events: List[dict]) -> List[bytes]: handshake_headers = e["handshake_response_headers"] for header_name, header_value in handshake_headers.items(): # Prefix with 'handshake_' to distinguish from request headers - safe_header_name = header_name.lower().replace('-', '_') + safe_header_name = header_name.lower().replace("-", "_") attrs[f"handshake_{safe_header_name}"] = str(header_value) - + # Add connection parameters if available if e.get("connection_params"): for key, value in e["connection_params"].items(): if value is not None: attrs[f"connection_{key}"] = str(value) - + # Add request_id if present for server-side correlation request_id = e.get("request_id") if request_id: attrs["request_id"] = str(request_id) - + # Include ALL extras in the attributes for comprehensive telemetry extras = e.get("extras", {}) if extras: for key, value in extras.items(): if value is not None and key not in attrs: attrs[str(key)] = str(value) - + rec = _encode_error_event( err_type=str(e.get("error_type", e.get("error", "Error"))), message=str(e.get("error_message", e.get("message", ""))), @@ -375,5 +394,3 @@ def encode_telemetry_batch_iter(events: List[dict], context: typing.Mapping[str, yield _len_delimited(1, _encode_telemetry_context(context)) for rec in _normalize_events(events): yield rec - -