Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 215 additions & 18 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import asyncio
import configparser
import copy
import functools
import inspect
import json
import math
import os
import pathlib
import re
import requests
import statistics
import time
from dataclasses import replace
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from collections import defaultdict
from pathlib import Path
import hashlib
import ast
from mcp.types import Completion
import pytest


from eval_protocol.dataset_logger import default_logger
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
from eval_protocol.human_id import generate_id, num_combinations
Expand Down Expand Up @@ -66,6 +71,42 @@

from ..common_utils import load_jsonl

from pytest import StashKey
from typing_extensions import Literal


EXPERIMENT_LINKS_STASH_KEY = StashKey[list]()


def _store_experiment_link(experiment_id: str, job_link: str, status: Literal["success", "failure"]):
"""Store experiment link in pytest session stash."""
try:
import sys

# Walk up the call stack to find the pytest session
session = None
frame = sys._getframe()
while frame:
if "session" in frame.f_locals and hasattr(frame.f_locals["session"], "stash"):
session = frame.f_locals["session"]
break
frame = frame.f_back

if session is not None:
global EXPERIMENT_LINKS_STASH_KEY

if EXPERIMENT_LINKS_STASH_KEY not in session.stash:
session.stash[EXPERIMENT_LINKS_STASH_KEY] = []

session.stash[EXPERIMENT_LINKS_STASH_KEY].append(
{"experiment_id": experiment_id, "job_link": job_link, "status": status}
)
else:
pass

except Exception as e:
pass


def postprocess(
all_results: List[List[EvaluationRow]],
Expand Down Expand Up @@ -214,22 +255,180 @@ def postprocess(
# Do not fail evaluation if summary writing fails
pass

# # Write all rows from active_logger.read() to a JSONL file in the same directory as the summary
# try:
# if active_logger is not None:
# rows = active_logger.read()
# # Write to a .jsonl file alongside the summary file
# jsonl_path = "logs.jsonl"
# import json

# with open(jsonl_path, "w", encoding="utf-8") as f_jsonl:
# for row in rows:
# json.dump(row.model_dump(exclude_none=True, mode="json"), f_jsonl)
# f_jsonl.write("\n")
# except Exception as e:
# # Do not fail evaluation if log writing fails
# print(e)
# pass
try:
# Default is to save and upload experiment JSONL files, unless explicitly disabled
should_save_and_upload = os.getenv("EP_NO_UPLOAD") != "1"

if should_save_and_upload:
current_run_rows = [item for sublist in all_results for item in sublist]
if current_run_rows:
experiments: Dict[str, List[EvaluationRow]] = defaultdict(list)
for row in current_run_rows:
if row.execution_metadata and row.execution_metadata.experiment_id:
experiments[row.execution_metadata.experiment_id].append(row)

exp_dir = pathlib.Path("experiment_results")
exp_dir.mkdir(parents=True, exist_ok=True)

# Create one JSONL file per experiment_id
for experiment_id, exp_rows in experiments.items():
if not experiment_id or not exp_rows:
continue

# Generate dataset name (sanitize for Fireworks API compatibility)
# API requires: lowercase a-z, 0-9, and hyphen (-) only
safe_experiment_id = re.sub(r"[^a-zA-Z0-9-]", "-", experiment_id).lower()
safe_test_func_name = re.sub(r"[^a-zA-Z0-9-]", "-", test_func_name).lower()
dataset_name = f"{safe_test_func_name}-{safe_experiment_id}"

if len(dataset_name) > 63:
dataset_name = dataset_name[:63]

exp_file = exp_dir / f"{experiment_id}.jsonl"
with open(exp_file, "w", encoding="utf-8") as f:
for row in exp_rows:
row_data = row.model_dump(exclude_none=True, mode="json")

if row.evaluation_result:
row_data["evals"] = {"score": row.evaluation_result.score}

row_data["eval_details"] = {
"score": row.evaluation_result.score,
"is_score_valid": row.evaluation_result.is_score_valid,
"reason": row.evaluation_result.reason or "",
"metrics": {
name: metric.model_dump() if metric else {}
for name, metric in (row.evaluation_result.metrics or {}).items()
},
}
else:
# Default values if no evaluation result
row_data["evals"] = {"score": 0}
row_data["eval_details"] = {
"score": 0,
"is_score_valid": True,
"reason": "No evaluation result",
"metrics": {},
}

json.dump(row_data, f, ensure_ascii=False)
f.write("\n")

def get_auth_value(key):
"""Get auth value from config file or environment."""
try:
config_path = Path.home() / ".fireworks" / "auth.ini"
if config_path.exists():
config = configparser.ConfigParser()
config.read(config_path)
for section in ["DEFAULT", "auth"]:
if config.has_section(section) and config.has_option(section, key):
return config.get(section, key)
except Exception:
pass
return os.getenv(key)

fireworks_api_key = get_auth_value("FIREWORKS_API_KEY")
fireworks_account_id = get_auth_value("FIREWORKS_ACCOUNT_ID")

if fireworks_api_key and fireworks_account_id:
headers = {"Authorization": f"Bearer {fireworks_api_key}", "Content-Type": "application/json"}

# Make dataset first
dataset_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets"

dataset_payload = {
"dataset": {
"displayName": dataset_name,
"evalProtocol": {},
"format": "FORMAT_UNSPECIFIED",
"exampleCount": f"{len(exp_rows)}",
},
"datasetId": dataset_name,
}

dataset_response = requests.post(dataset_url, json=dataset_payload, headers=headers)

# Skip if dataset creation failed
if dataset_response.status_code not in [200, 201]:
_store_experiment_link(
experiment_id,
f"Dataset creation failed: {dataset_response.status_code} {dataset_response.text}",
"failure",
)
continue

dataset_data = dataset_response.json()
dataset_id = dataset_data.get("datasetId", dataset_name)

# Upload the JSONL file content
upload_url = (
f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload"
)
upload_headers = {"Authorization": f"Bearer {fireworks_api_key}"}

with open(exp_file, "rb") as f:
files = {"file": f}
upload_response = requests.post(upload_url, files=files, headers=upload_headers)

# Skip if upload failed
if upload_response.status_code not in [200, 201]:
_store_experiment_link(
experiment_id,
f"File upload failed: {upload_response.status_code} {upload_response.text}",
"failure",
)
continue

# Create evaluation job (optional - don't skip experiment if this fails)
eval_job_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/evaluationJobs"
# Truncate job ID to fit 63 character limit
job_id_base = f"{dataset_name}-job"
if len(job_id_base) > 63:
# Keep the "-job" suffix and truncate the dataset_name part
max_dataset_name_len = 63 - 4 # 4 = len("-job")
truncated_dataset_name = dataset_name[:max_dataset_name_len]
job_id_base = f"{truncated_dataset_name}-job"

eval_job_payload = {
"evaluationJobId": job_id_base,
"evaluationJob": {
"evaluator": f"accounts/{fireworks_account_id}/evaluators/dummy",
"inputDataset": f"accounts/{fireworks_account_id}/datasets/dummy",
"outputDataset": f"accounts/{fireworks_account_id}/datasets/{dataset_id}",
},
}

eval_response = requests.post(eval_job_url, json=eval_job_payload, headers=headers)

if eval_response.status_code in [200, 201]:
eval_job_data = eval_response.json()
job_id = eval_job_data.get("evaluationJobId", job_id_base)

_store_experiment_link(
experiment_id,
f"https://app.fireworks.ai/dashboard/evaluation-jobs/{job_id}",
"success",
)
else:
_store_experiment_link(
experiment_id,
f"Job creation failed: {eval_response.status_code} {eval_response.text}",
"failure",
)

else:
# Store failure for missing credentials for all experiments
for experiment_id, exp_rows in experiments.items():
if experiment_id and exp_rows:
_store_experiment_link(
experiment_id, "No Fireworks API key or account ID found", "failure"
)

except Exception as e:
# Do not fail evaluation if experiment JSONL writing fails
print(f"Warning: Failed to persist results: {e}")
pass

# Check threshold after logging
if threshold is not None and not passed:
Expand Down Expand Up @@ -816,7 +1015,6 @@ def create_dual_mode_wrapper() -> Callable:
Returns:
A callable that can handle both pytest test execution and direct function calls
"""
import asyncio

# Check if the test function is async
is_async = asyncio.iscoroutinefunction(test_func)
Expand Down Expand Up @@ -863,7 +1061,6 @@ async def dual_mode_wrapper(*args, **kwargs):
}

# Copy all attributes from the pytest wrapper to our dual mode wrapper
import functools

functools.update_wrapper(dual_mode_wrapper, pytest_wrapper)

Expand Down
42 changes: 42 additions & 0 deletions eval_protocol/pytest/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from typing import Optional
import json
import pathlib
import sys
from pytest import StashKey


def pytest_addoption(parser) -> None:
Expand Down Expand Up @@ -104,6 +106,15 @@ def pytest_addoption(parser) -> None:
"Pass a float >= 0.0 (e.g., 0.05). If only this is set, success threshold defaults to 0.0."
),
)
group.addoption(
"--ep-no-upload",
action="store_true",
default=False,
help=(
"Disable saving and uploading of detailed experiment JSON files to Fireworks. "
"Default: false (experiment JSONs are saved and uploaded by default)."
),
)


def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
Expand Down Expand Up @@ -229,6 +240,9 @@ def pytest_configure(config) -> None:
if threshold_env is not None:
os.environ["EP_PASSED_THRESHOLD"] = threshold_env

if config.getoption("--ep-no-upload"):
os.environ["EP_NO_UPLOAD"] = "1"

# Allow ad-hoc overrides of input params via CLI flags
try:
merged: dict = {}
Expand Down Expand Up @@ -263,3 +277,31 @@ def pytest_configure(config) -> None:
except Exception:
# best effort, do not crash pytest session
pass


def pytest_sessionfinish(session, exitstatus):
"""Print all collected Fireworks experiment links from pytest stash."""
try:
from .evaluation_test import EXPERIMENT_LINKS_STASH_KEY

# Get links from pytest stash using shared key
links = []

if EXPERIMENT_LINKS_STASH_KEY in session.stash:
links = session.stash[EXPERIMENT_LINKS_STASH_KEY]

if links:
print("\n" + "=" * 80, file=sys.__stderr__)
print("🔥 FIREWORKS EXPERIMENT LINKS", file=sys.__stderr__)
print("=" * 80, file=sys.__stderr__)

for link in links:
if link["status"] == "success":
print(f"🔗 Experiment {link['experiment_id']}: {link['job_link']}", file=sys.__stderr__)
else:
print(f"❌ Experiment {link['experiment_id']}: {link['job_link']}", file=sys.__stderr__)

print("=" * 80, file=sys.__stderr__)
sys.__stderr__.flush()
except Exception as e:
pass
Loading