Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions src/deepgram/extensions/telemetry/proto_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
import typing
from typing import Dict, List

_varint_single_byte_cache = [bytes([i]) for i in range(0x80)]


# --- Protobuf wire helpers (proto3) ---


def _varint(value: int) -> bytes:
if value < 0:
# For this usage we only encode non-negative values
value &= (1 << 64) - 1
if value < 0x80:
return _varint_single_byte_cache[value]
out = bytearray()
while value > 0x7F:
out.append((value & 0x7F) | 0x80)
Expand All @@ -31,11 +36,13 @@ def _len_delimited(field_number: int, payload: bytes) -> bytes:

def _string(field_number: int, value: str) -> bytes:
data = value.encode("utf-8")
# call _len_delimited directly; no extra local needed
return _len_delimited(field_number, data)


def _bool(field_number: int, value: bool) -> bytes:
return _key(field_number, 0) + _varint(1 if value else 0)
# Single-byte cache for '0' or '1'
return _key(field_number, 0) + _varint_single_byte_cache[1 if value else 0]


def _int64(field_number: int, value: int) -> bytes:
Expand All @@ -53,7 +60,10 @@ def _timestamp_message(ts_seconds: float) -> bytes:
if nanos >= 1_000_000_000:
sec += 1
nanos -= 1_000_000_000
# Smallest possible allocations: build all once
msg = bytearray()
from deepgram.extensions.telemetry.proto_encoder import _int64 # Avoiding circular import at top level

msg += _int64(1, sec)
if nanos:
msg += _key(2, 0) + _varint(nanos)
Expand All @@ -64,11 +74,15 @@ def _timestamp_message(ts_seconds: float) -> bytes:
def _map_str_str(field_number: int, items: typing.Mapping[str, str] | None) -> bytes:
if not items:
return b""
out = bytearray()
# Preallocate list to reduce repeated += on bytearray (less copying overall)
outs = []
append = outs.append
ld = _len_delimited
s = _string
for k, v in items.items():
entry = _string(1, k) + _string(2, v)
out += _len_delimited(field_number, entry)
return bytes(out)
entry = s(1, k) + s(2, v)
append(ld(field_number, entry))
return b"".join(outs)


def _map_str_double(field_number: int, items: typing.Mapping[str, float] | None) -> bytes:
Expand All @@ -83,6 +97,7 @@ def _map_str_double(field_number: int, items: typing.Mapping[str, float] | None)

# --- Schema-specific encoders (deepgram.dxtelemetry.v1) ---


def _encode_telemetry_context(ctx: typing.Mapping[str, typing.Any]) -> bytes:
# Map SDK context keys to proto fields
package_name = ctx.get("sdk_name") or ctx.get("package_name") or "python-sdk"
Expand Down Expand Up @@ -123,7 +138,7 @@ def _encode_telemetry_context(ctx: typing.Mapping[str, typing.Any]) -> bytes:
msg += _string(11, installation_id)
if project_id:
msg += _string(12, project_id)

# Include extras as additional context attributes (field 13)
extras = ctx.get("extras", {})
if extras:
Expand All @@ -133,11 +148,13 @@ def _encode_telemetry_context(ctx: typing.Mapping[str, typing.Any]) -> bytes:
if value is not None:
extras_map[str(key)] = str(value)
msg += _map_str_str(13, extras_map)

return bytes(msg)


def _encode_telemetry_event(name: str, ts: float, attributes: Dict[str, str] | None, metrics: Dict[str, float] | None) -> bytes:
def _encode_telemetry_event(
name: str, ts: float, attributes: Dict[str, str] | None, metrics: Dict[str, float] | None
) -> bytes:
msg = bytearray()
msg += _string(1, name)
msg += _len_delimited(2, _timestamp_message(ts))
Expand All @@ -160,24 +177,26 @@ def _encode_error_event(
line: int | None = None,
column: int | None = None,
) -> bytes:
msg = bytearray()
# Gather all chunks in list to reduce bytearray repeated copying
chunks = []
append = chunks.append
if err_type:
msg += _string(1, err_type)
append(_string(1, err_type))
if message:
msg += _string(2, message)
append(_string(2, message))
if stack_trace:
msg += _string(3, stack_trace)
append(_string(3, stack_trace))
if file:
msg += _string(4, file)
append(_string(4, file))
if line is not None:
msg += _key(5, 0) + _varint(line)
append(_key(5, 0) + _varint(line))
if column is not None:
msg += _key(6, 0) + _varint(column)
msg += _key(7, 0) + _varint(severity)
msg += _bool(8, handled)
msg += _len_delimited(9, _timestamp_message(ts))
msg += _map_str_str(10, attributes)
return bytes(msg)
append(_key(6, 0) + _varint(column))
append(_key(7, 0) + _varint(severity))
append(_bool(8, handled))
append(_len_delimited(9, _timestamp_message(ts)))
append(_map_str_str(10, attributes))
return b"".join(chunks)


def _encode_record(record: bytes, kind_field_number: int) -> bytes:
Expand Down Expand Up @@ -253,7 +272,7 @@ def _normalize_events(events: List[dict]) -> List[bytes]:
# Note: URL is never logged for privacy
"connection_type": "websocket",
}

# Add detailed error information to attributes
if e.get("error_type"):
attrs["error_type"] = str(e["error_type"])
Expand All @@ -265,7 +284,7 @@ def _normalize_events(events: List[dict]) -> List[bytes]:
attrs["timeout_occurred"] = str(e["timeout_occurred"])
if e.get("duration_ms"):
attrs["duration_ms"] = str(e["duration_ms"])

# Add WebSocket handshake failure details
if e.get("handshake_status_code"):
attrs["handshake_status_code"] = str(e["handshake_status_code"])
Expand All @@ -278,27 +297,27 @@ def _normalize_events(events: List[dict]) -> List[bytes]:
handshake_headers = e["handshake_response_headers"]
for header_name, header_value in handshake_headers.items():
# Prefix with 'handshake_' to distinguish from request headers
safe_header_name = header_name.lower().replace('-', '_')
safe_header_name = header_name.lower().replace("-", "_")
attrs[f"handshake_{safe_header_name}"] = str(header_value)

# Add connection parameters if available
if e.get("connection_params"):
for key, value in e["connection_params"].items():
if value is not None:
attrs[f"connection_{key}"] = str(value)

# Add request_id if present for server-side correlation
request_id = e.get("request_id")
if request_id:
attrs["request_id"] = str(request_id)

# Include ALL extras in the attributes for comprehensive telemetry
extras = e.get("extras", {})
if extras:
for key, value in extras.items():
if value is not None and key not in attrs:
attrs[str(key)] = str(value)

rec = _encode_error_event(
err_type=str(e.get("error_type", e.get("error", "Error"))),
message=str(e.get("error_message", e.get("message", ""))),
Expand Down Expand Up @@ -375,5 +394,3 @@ def encode_telemetry_batch_iter(events: List[dict], context: typing.Mapping[str,
yield _len_delimited(1, _encode_telemetry_context(context))
for rec in _normalize_events(events):
yield rec