Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions eval_protocol/adapters/fireworks_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
from .base import BaseAdapter
from .lp_deserializer import decompress_and_parse_lp
from .r3_deserializer import decompress_and_parse_r3
from .utils import extract_messages_from_data
from ..common_utils import get_user_agent

Expand Down Expand Up @@ -106,8 +108,6 @@ def convert_trace_dict_to_evaluation_row(
router_replay = payloads.get("router_replay")
if isinstance(router_replay, dict) and router_replay.get("data"):
try:
from .r3_deserializer import decompress_and_parse_r3

matrices, r3_meta = decompress_and_parse_r3(router_replay["data"])
if execution_metadata.extra is None:
execution_metadata.extra = {}
Expand All @@ -116,6 +116,32 @@ def convert_trace_dict_to_evaluation_row(
except Exception as e:
logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e)

logprobs_payload = payloads.get("logprobs")
if isinstance(logprobs_payload, dict) and logprobs_payload.get("data"):
try:
logprobs, token_ids, lp_meta = decompress_and_parse_lp(logprobs_payload["data"])
if execution_metadata.extra is None:
execution_metadata.extra = {}
execution_metadata.extra["completion_logprobs"] = logprobs
if token_ids is not None:
execution_metadata.extra["completion_token_ids"] = token_ids
execution_metadata.extra["logprobs_metadata"] = lp_meta

for i in range(len(messages) - 1, -1, -1):
if messages[i].role == "assistant":
content_entries = [{"logprob": lp} for lp in logprobs]
if token_ids is not None:
for entry, tid in zip(content_entries, token_ids):
entry["token_id"] = tid
messages[i].logprobs = {"content": content_entries}
break
except Exception as e:
logger.warning(
"Failed to decompress logprobs payload for trace %s: %s",
trace.get("id"),
e,
)

return EvaluationRow(
messages=messages,
tools=tools,
Expand Down
109 changes: 109 additions & 0 deletions eval_protocol/adapters/lp_deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""LP/v1 binary deserializer for per-token logprobs payloads.

Implements the inverse of the tracing gateway's ``logprobs_serializer.serialize_logprobs``.
See that module for the full header specification.
"""

from __future__ import annotations

import base64
import struct
from typing import Any, Dict, List, Optional, Tuple

import zstandard as zstd

MAGIC = b"LP01"
HEADER_VERSION = 1
MISSING_TOKEN_ID = -1
ENTRY_FORMAT = "<if"
ENTRY_SIZE = struct.calcsize(ENTRY_FORMAT) # 8 bytes
HEADER_FORMAT = "<4sBBHIIQ"
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 24 bytes


def _parse_header(raw: bytes) -> Dict[str, Any]:
if len(raw) < HEADER_SIZE:
raise ValueError(f"Payload too short for lp/v1 header: {len(raw)} < {HEADER_SIZE}")

(
magic,
version,
flags,
reserved_u16,
token_count,
body_byte_length,
reserved_u64,
) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE])

if magic != MAGIC:
raise ValueError(f"Bad LP/v1 magic: {magic!r}")
if version != HEADER_VERSION:
raise ValueError(f"Unsupported lp/v1 header version: {version}")

return {
"flags": flags,
"reserved_u16": reserved_u16,
"token_count": token_count,
"body_byte_length": body_byte_length,
"reserved_u64": reserved_u64,
}


def parse_logprobs(raw: bytes) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]:
"""Parse uncompressed LP/v1 bytes into logprobs, optional token ids, and metadata."""
header = _parse_header(raw)
token_count = header["token_count"]
body_byte_length = header["body_byte_length"]

if token_count == 0:
raise ValueError("LP/v1 token_count must be > 0")
if body_byte_length != token_count * ENTRY_SIZE:
raise ValueError(
f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} "
f"({token_count * ENTRY_SIZE})"
)

expected_len = HEADER_SIZE + body_byte_length
if len(raw) != expected_len:
raise ValueError(f"LP/v1 payload length mismatch: {len(raw)} != {expected_len}")

logprobs: List[float] = []
token_ids: List[int] = []
all_token_ids_valid = True
offset = HEADER_SIZE
for _ in range(token_count):
wire_id, logprob = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE])
offset += ENTRY_SIZE
logprobs.append(logprob)
if wire_id == MISSING_TOKEN_ID:
all_token_ids_valid = False
token_ids.append(wire_id)
else:
token_ids.append(wire_id)

metadata: Dict[str, Any] = {
"scope": "completion_only",
"completion_token_count": token_count,
"all_token_ids_valid": all_token_ids_valid,
}
header.update(metadata)
ids_out: Optional[List[int]] = token_ids if all_token_ids_valid else None
return logprobs, ids_out, header


def decompress_and_parse_lp(data_b64: str) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]:
"""Decompress and unpack an LP/v1 payload into completion logprobs and token ids.

Args:
data_b64: Base64-encoded zstd-compressed LP binary blob from
``payloads.logprobs.data``.

Returns:
``(logprobs, token_ids, metadata)`` where ``logprobs`` is per-completion-token
scalars, ``token_ids`` is ``None`` if any wire id was ``MISSING_TOKEN_ID``,
and ``metadata`` includes ``all_token_ids_valid`` and ``completion_token_count``.
"""
compressed = base64.b64decode(data_b64)
decompressor = zstd.ZstdDecompressor()
raw = decompressor.decompress(compressed)
return parse_logprobs(raw)
2 changes: 1 addition & 1 deletion eval_protocol/pytest/tracing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def build_init_request(
if not completion_params_dict.get("model"):
raise ValueError("Model must be provided in completion_params")

# Extract base_url from completion_params
# Extract base_url from completion_params for tracing-gateway URL encoding
Comment thread
SunnySoldier357 marked this conversation as resolved.
completion_params_base_url: Optional[str] = completion_params_dict.get("base_url")

# Strip non-OpenAI fields from messages
Expand Down
93 changes: 93 additions & 0 deletions tests/adapters/test_fireworks_tracing_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Tests for logprobs payload handling in fireworks_tracing adapter."""

from __future__ import annotations

import base64
import struct

import pytest
import zstandard as zstd

pytest.importorskip("mcp")

from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row
from eval_protocol.adapters.lp_deserializer import (
ENTRY_FORMAT,
ENTRY_SIZE,
HEADER_FORMAT,
MAGIC,
MISSING_TOKEN_ID,
)


def _lp_b64(tokens: list[tuple[int, float]]) -> str:
token_count = len(tokens)
body_byte_length = token_count * ENTRY_SIZE
header = struct.pack(
HEADER_FORMAT,
MAGIC,
1,
0,
0,
token_count,
body_byte_length,
0,
)
body = b"".join(struct.pack(ENTRY_FORMAT, tid, lp) for tid, lp in tokens)
raw = header + body
compressed = zstd.ZstdCompressor().compress(raw)
return base64.b64encode(compressed).decode("ascii")


def _base_trace(*, with_token_ids: bool = True) -> dict:
tokens = [(10, -0.1), (11, -0.2)] if with_token_ids else [(MISSING_TOKEN_ID, -0.1), (12, -0.2)]
return {
"id": "trace-1",
"input": {
"messages": [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
],
},
"output": {"role": "assistant", "content": "hello"},
"payloads": {
"logprobs": {
"data": _lp_b64(tokens),
"manifest": {"PayloadVersion": "lp/v1"},
},
},
}


class TestConvertTraceLogprobs:
def test_attaches_completion_logprobs_and_message_logprobs(self):
row = convert_trace_dict_to_evaluation_row(_base_trace())
assert row is not None

extra = row.execution_metadata.extra
assert extra is not None
assert extra["completion_logprobs"] == pytest.approx([-0.1, -0.2])
assert extra["completion_token_ids"] == [10, 11]

assistant = row.messages[-1]
assert assistant.role == "assistant"
content = assistant.logprobs["content"]
assert len(content) == len(extra["completion_logprobs"])
assert content[0]["token_id"] == 10
assert content[1]["token_id"] == 11
assert content[0]["logprob"] == pytest.approx(-0.1)
assert content[1]["logprob"] == pytest.approx(-0.2)

def test_omits_token_id_keys_when_any_missing(self):
row = convert_trace_dict_to_evaluation_row(_base_trace(with_token_ids=False))
assert row is not None

extra = row.execution_metadata.extra
assert "completion_logprobs" in extra
assert "completion_token_ids" not in extra

content = row.messages[-1].logprobs["content"]
assert len(content) == 2
assert all("token_id" not in entry for entry in content)
assert content[0]["logprob"] == pytest.approx(-0.1)
assert content[1]["logprob"] == pytest.approx(-0.2)
78 changes: 78 additions & 0 deletions tests/adapters/test_lp_deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Tests for LP/v1 binary deserializer (gateway-compatible)."""

from __future__ import annotations

import base64
import struct

import pytest
import zstandard as zstd

from eval_protocol.adapters.lp_deserializer import (
ENTRY_FORMAT,
ENTRY_SIZE,
HEADER_FORMAT,
HEADER_SIZE,
MAGIC,
MISSING_TOKEN_ID,
decompress_and_parse_lp,
parse_logprobs,
)

# Golden raw bytes: two tokens (7, -0.25) and (8, -0.5) — must match gateway serializer.
GOLDEN_RAW_HEX = (
"4c503031010000000200000010000000000000000000000007000000000080be"
"08000000000000bf"
)


def _build_raw(tokens: list[tuple[int, float]]) -> bytes:
token_count = len(tokens)
body_byte_length = token_count * ENTRY_SIZE
header = struct.pack(
HEADER_FORMAT,
MAGIC,
1,
0,
0,
token_count,
body_byte_length,
0,
)
body = b"".join(struct.pack(ENTRY_FORMAT, tid, lp) for tid, lp in tokens)
return header + body


def _compress_b64(raw: bytes) -> str:
return base64.b64encode(zstd.ZstdCompressor().compress(raw)).decode("ascii")


class TestParseLogprobs:
def test_golden_bytes_match_gateway(self):
raw = bytes.fromhex(GOLDEN_RAW_HEX)
logprobs, token_ids, meta = parse_logprobs(raw)
assert logprobs == [-0.25, -0.5]
assert token_ids == [7, 8]
assert meta["all_token_ids_valid"] is True
assert meta["token_count"] == 2

def test_missing_token_id_omits_token_ids_list(self):
raw = _build_raw([(MISSING_TOKEN_ID, -0.3), (42, -0.4)])
logprobs, token_ids, meta = parse_logprobs(raw)
assert logprobs == pytest.approx([-0.3, -0.4])
assert token_ids is None
assert meta["all_token_ids_valid"] is False

def test_decompress_and_parse_round_trip(self):
raw = bytes.fromhex(GOLDEN_RAW_HEX)
b64 = _compress_b64(raw)
logprobs, token_ids, meta = decompress_and_parse_lp(b64)
assert logprobs == [-0.25, -0.5]
assert token_ids == [7, 8]
assert meta["scope"] == "completion_only"

def test_rejects_bad_magic(self):
raw = _build_raw([(1, -0.1)])
bad = b"XXXX" + raw[4:]
with pytest.raises(ValueError, match="Bad LP/v1 magic"):
parse_logprobs(bad)
8 changes: 6 additions & 2 deletions tests/remote_server/remote_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,12 @@ def _worker():
md = {k: v for k, v in md.items() if v is not None}
messages_payload.append(md)

# Spread all completion_params (model, temperature, max_tokens, etc.)
completion_kwargs = {"messages": messages_payload, **req.completion_params}
# Spread completion_params; omit base_url (client uses req.model_base_url; gateway
# encodes inference base_url into the tracing path via build_init_request).
completion_kwargs = {
"messages": messages_payload,
**{k: v for k, v in req.completion_params.items() if k != "base_url"},
}

if req.tools:
completion_kwargs["tools"] = req.tools
Expand Down
Loading