-
Notifications
You must be signed in to change notification settings - Fork 17
Add LP/v1 logprobs trace payload deserialization (FIR-21499) #452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
ac3db7b
Add LP/v1 logprobs deserialization for tracing gateway payloads (FIR-…
SunnySoldier357 b8f2d81
Hoist r3 and lp deserializer imports to module level in fireworks_tra…
SunnySoldier357 e87ba81
Keep base_url in completion_params; strip at OpenAI call site.
SunnySoldier357 e901c69
remove redudant funciton
SunnySoldier357 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| """LP/v1 binary deserializer for per-token logprobs payloads. | ||
|
|
||
| Implements the inverse of the tracing gateway's ``logprobs_serializer.serialize_logprobs``. | ||
| See that module for the full header specification. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import base64 | ||
| import struct | ||
| from typing import Any, Dict, List, Optional, Tuple | ||
|
|
||
| import zstandard as zstd | ||
|
|
||
| MAGIC = b"LP01" | ||
| HEADER_VERSION = 1 | ||
| MISSING_TOKEN_ID = -1 | ||
| ENTRY_FORMAT = "<if" | ||
| ENTRY_SIZE = struct.calcsize(ENTRY_FORMAT) # 8 bytes | ||
| HEADER_FORMAT = "<4sBBHIIQ" | ||
| HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 24 bytes | ||
|
|
||
|
|
||
| def _parse_header(raw: bytes) -> Dict[str, Any]: | ||
| if len(raw) < HEADER_SIZE: | ||
| raise ValueError(f"Payload too short for lp/v1 header: {len(raw)} < {HEADER_SIZE}") | ||
|
|
||
| ( | ||
| magic, | ||
| version, | ||
| flags, | ||
| reserved_u16, | ||
| token_count, | ||
| body_byte_length, | ||
| reserved_u64, | ||
| ) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE]) | ||
|
|
||
| if magic != MAGIC: | ||
| raise ValueError(f"Bad LP/v1 magic: {magic!r}") | ||
| if version != HEADER_VERSION: | ||
| raise ValueError(f"Unsupported lp/v1 header version: {version}") | ||
|
|
||
| return { | ||
| "flags": flags, | ||
| "reserved_u16": reserved_u16, | ||
| "token_count": token_count, | ||
| "body_byte_length": body_byte_length, | ||
| "reserved_u64": reserved_u64, | ||
| } | ||
|
|
||
|
|
||
| def parse_logprobs(raw: bytes) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]: | ||
| """Parse uncompressed LP/v1 bytes into logprobs, optional token ids, and metadata.""" | ||
| header = _parse_header(raw) | ||
| token_count = header["token_count"] | ||
| body_byte_length = header["body_byte_length"] | ||
|
|
||
| if token_count == 0: | ||
| raise ValueError("LP/v1 token_count must be > 0") | ||
| if body_byte_length != token_count * ENTRY_SIZE: | ||
| raise ValueError( | ||
| f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} " | ||
| f"({token_count * ENTRY_SIZE})" | ||
| ) | ||
|
|
||
| expected_len = HEADER_SIZE + body_byte_length | ||
| if len(raw) != expected_len: | ||
| raise ValueError(f"LP/v1 payload length mismatch: {len(raw)} != {expected_len}") | ||
|
|
||
| logprobs: List[float] = [] | ||
| token_ids: List[int] = [] | ||
| all_token_ids_valid = True | ||
| offset = HEADER_SIZE | ||
| for _ in range(token_count): | ||
| wire_id, logprob = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE]) | ||
| offset += ENTRY_SIZE | ||
| logprobs.append(logprob) | ||
| if wire_id == MISSING_TOKEN_ID: | ||
| all_token_ids_valid = False | ||
| token_ids.append(wire_id) | ||
| else: | ||
| token_ids.append(wire_id) | ||
|
|
||
| metadata: Dict[str, Any] = { | ||
| "scope": "completion_only", | ||
| "completion_token_count": token_count, | ||
| "all_token_ids_valid": all_token_ids_valid, | ||
| } | ||
| header.update(metadata) | ||
| ids_out: Optional[List[int]] = token_ids if all_token_ids_valid else None | ||
| return logprobs, ids_out, header | ||
|
|
||
|
|
||
| def decompress_and_parse_lp(data_b64: str) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]: | ||
| """Decompress and unpack an LP/v1 payload into completion logprobs and token ids. | ||
|
|
||
| Args: | ||
| data_b64: Base64-encoded zstd-compressed LP binary blob from | ||
| ``payloads.logprobs.data``. | ||
|
|
||
| Returns: | ||
| ``(logprobs, token_ids, metadata)`` where ``logprobs`` is per-completion-token | ||
| scalars, ``token_ids`` is ``None`` if any wire id was ``MISSING_TOKEN_ID``, | ||
| and ``metadata`` includes ``all_token_ids_valid`` and ``completion_token_count``. | ||
| """ | ||
| compressed = base64.b64decode(data_b64) | ||
| decompressor = zstd.ZstdDecompressor() | ||
| raw = decompressor.decompress(compressed) | ||
| return parse_logprobs(raw) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| """Tests for logprobs payload handling in fireworks_tracing adapter.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import base64 | ||
| import struct | ||
|
|
||
| import pytest | ||
| import zstandard as zstd | ||
|
|
||
| pytest.importorskip("mcp") | ||
|
|
||
| from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row | ||
| from eval_protocol.adapters.lp_deserializer import ( | ||
| ENTRY_FORMAT, | ||
| ENTRY_SIZE, | ||
| HEADER_FORMAT, | ||
| MAGIC, | ||
| MISSING_TOKEN_ID, | ||
| ) | ||
|
|
||
|
|
||
| def _lp_b64(tokens: list[tuple[int, float]]) -> str: | ||
| token_count = len(tokens) | ||
| body_byte_length = token_count * ENTRY_SIZE | ||
| header = struct.pack( | ||
| HEADER_FORMAT, | ||
| MAGIC, | ||
| 1, | ||
| 0, | ||
| 0, | ||
| token_count, | ||
| body_byte_length, | ||
| 0, | ||
| ) | ||
| body = b"".join(struct.pack(ENTRY_FORMAT, tid, lp) for tid, lp in tokens) | ||
| raw = header + body | ||
| compressed = zstd.ZstdCompressor().compress(raw) | ||
| return base64.b64encode(compressed).decode("ascii") | ||
|
|
||
|
|
||
| def _base_trace(*, with_token_ids: bool = True) -> dict: | ||
| tokens = [(10, -0.1), (11, -0.2)] if with_token_ids else [(MISSING_TOKEN_ID, -0.1), (12, -0.2)] | ||
| return { | ||
| "id": "trace-1", | ||
| "input": { | ||
| "messages": [ | ||
| {"role": "user", "content": "hi"}, | ||
| {"role": "assistant", "content": "hello"}, | ||
| ], | ||
| }, | ||
| "output": {"role": "assistant", "content": "hello"}, | ||
| "payloads": { | ||
| "logprobs": { | ||
| "data": _lp_b64(tokens), | ||
| "manifest": {"PayloadVersion": "lp/v1"}, | ||
| }, | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| class TestConvertTraceLogprobs: | ||
| def test_attaches_completion_logprobs_and_message_logprobs(self): | ||
| row = convert_trace_dict_to_evaluation_row(_base_trace()) | ||
| assert row is not None | ||
|
|
||
| extra = row.execution_metadata.extra | ||
| assert extra is not None | ||
| assert extra["completion_logprobs"] == pytest.approx([-0.1, -0.2]) | ||
| assert extra["completion_token_ids"] == [10, 11] | ||
|
|
||
| assistant = row.messages[-1] | ||
| assert assistant.role == "assistant" | ||
| content = assistant.logprobs["content"] | ||
| assert len(content) == len(extra["completion_logprobs"]) | ||
| assert content[0]["token_id"] == 10 | ||
| assert content[1]["token_id"] == 11 | ||
| assert content[0]["logprob"] == pytest.approx(-0.1) | ||
| assert content[1]["logprob"] == pytest.approx(-0.2) | ||
|
|
||
| def test_omits_token_id_keys_when_any_missing(self): | ||
| row = convert_trace_dict_to_evaluation_row(_base_trace(with_token_ids=False)) | ||
| assert row is not None | ||
|
|
||
| extra = row.execution_metadata.extra | ||
| assert "completion_logprobs" in extra | ||
| assert "completion_token_ids" not in extra | ||
|
|
||
| content = row.messages[-1].logprobs["content"] | ||
| assert len(content) == 2 | ||
| assert all("token_id" not in entry for entry in content) | ||
| assert content[0]["logprob"] == pytest.approx(-0.1) | ||
| assert content[1]["logprob"] == pytest.approx(-0.2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| """Tests for LP/v1 binary deserializer (gateway-compatible).""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import base64 | ||
| import struct | ||
|
|
||
| import pytest | ||
| import zstandard as zstd | ||
|
|
||
| from eval_protocol.adapters.lp_deserializer import ( | ||
| ENTRY_FORMAT, | ||
| ENTRY_SIZE, | ||
| HEADER_FORMAT, | ||
| HEADER_SIZE, | ||
| MAGIC, | ||
| MISSING_TOKEN_ID, | ||
| decompress_and_parse_lp, | ||
| parse_logprobs, | ||
| ) | ||
|
|
||
| # Golden raw bytes: two tokens (7, -0.25) and (8, -0.5) — must match gateway serializer. | ||
| GOLDEN_RAW_HEX = ( | ||
| "4c503031010000000200000010000000000000000000000007000000000080be" | ||
| "08000000000000bf" | ||
| ) | ||
|
|
||
|
|
||
| def _build_raw(tokens: list[tuple[int, float]]) -> bytes: | ||
| token_count = len(tokens) | ||
| body_byte_length = token_count * ENTRY_SIZE | ||
| header = struct.pack( | ||
| HEADER_FORMAT, | ||
| MAGIC, | ||
| 1, | ||
| 0, | ||
| 0, | ||
| token_count, | ||
| body_byte_length, | ||
| 0, | ||
| ) | ||
| body = b"".join(struct.pack(ENTRY_FORMAT, tid, lp) for tid, lp in tokens) | ||
| return header + body | ||
|
|
||
|
|
||
| def _compress_b64(raw: bytes) -> str: | ||
| return base64.b64encode(zstd.ZstdCompressor().compress(raw)).decode("ascii") | ||
|
|
||
|
|
||
| class TestParseLogprobs: | ||
| def test_golden_bytes_match_gateway(self): | ||
| raw = bytes.fromhex(GOLDEN_RAW_HEX) | ||
| logprobs, token_ids, meta = parse_logprobs(raw) | ||
| assert logprobs == [-0.25, -0.5] | ||
| assert token_ids == [7, 8] | ||
| assert meta["all_token_ids_valid"] is True | ||
| assert meta["token_count"] == 2 | ||
|
|
||
| def test_missing_token_id_omits_token_ids_list(self): | ||
| raw = _build_raw([(MISSING_TOKEN_ID, -0.3), (42, -0.4)]) | ||
| logprobs, token_ids, meta = parse_logprobs(raw) | ||
| assert logprobs == pytest.approx([-0.3, -0.4]) | ||
| assert token_ids is None | ||
| assert meta["all_token_ids_valid"] is False | ||
|
|
||
| def test_decompress_and_parse_round_trip(self): | ||
| raw = bytes.fromhex(GOLDEN_RAW_HEX) | ||
| b64 = _compress_b64(raw) | ||
| logprobs, token_ids, meta = decompress_and_parse_lp(b64) | ||
| assert logprobs == [-0.25, -0.5] | ||
| assert token_ids == [7, 8] | ||
| assert meta["scope"] == "completion_only" | ||
|
|
||
| def test_rejects_bad_magic(self): | ||
| raw = _build_raw([(1, -0.1)]) | ||
| bad = b"XXXX" + raw[4:] | ||
| with pytest.raises(ValueError, match="Bad LP/v1 magic"): | ||
| parse_logprobs(bad) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.