diff --git a/scripts/test_api_empty_responses.py b/scripts/test_api_empty_responses.py new file mode 100644 index 000000000..4510b2463 --- /dev/null +++ b/scripts/test_api_empty_responses.py @@ -0,0 +1,312 @@ +""" +Example usage: +python scripts/test_api_empty_responses.py --concurrent 4 --queries 1000 --out test + +Results will be stored in test/ + +Additional requirements: +pip install plotly kaleido +""" +import argparse +import asyncio +import csv +import random +import sys +import time +from pathlib import Path + +import nltk +import openai +import pandas as pd +import plotly.graph_objects as go +from httpx import HTTPStatusError, Timeout +from nltk.corpus import words + +nltk.download("words") +word_list = words.words() + + +def approximate_tokens(text: str) -> int: + """Approximate the number of tokens from a given string. + + Approach: + - Split by whitespace to get a word count. + - Multiply by ~1.3 to get approximate amount of tokens. + """ + words_in_text = text.split() + return int(len(words_in_text) * 1.3) + + +async def make_completion( + client: openai.AsyncOpenAI, + prompt: str, + stream: bool = True, + seed: str = "1759348", +) -> dict: + """Make a completion request to the API. + + Measures: + - Time to first token. + - Time to full response. + - Approximate tokens received. + - Status code. + + Returns a dictionary with measurements. + """ + start_time = time.perf_counter() + time_to_first_token = None + total_approx_tokens = 0 + status_code = None + + try: + result = await client.chat.completions.create( + model=None, + messages=[{"role": "user", "content": prompt}], + stream=stream, + extra_body={ + "seed": seed, + "sampling_parameters": { + "temperature": 0.7, + "top_p": 0.95, + "top_k": 50, + "max_new_tokens": 256, + "do_sample": True, + "seed": None, + }, + "task": "InferenceTask", + "mixture": False, + }, + ) + + if not stream: + # TODO: Non-streaming part is not tested. + raise NotImplementedError("Implement non-streaming mode") + text = result + total_latency = time.perf_counter() - start_time + total_approx_tokens = approximate_tokens(text) + status_code = 200 + else: + # If streaming, measure time-to-first-token and accumulate approximate tokens + async for chunk in result: + if hasattr(result, "response") and result.response is not None: + status_code = result.response.status_code + + delta_content = chunk.choices[0].delta.content + if delta_content: + if time_to_first_token is None: + time_to_first_token = time.perf_counter() - start_time + total_approx_tokens += approximate_tokens(delta_content) + + total_latency = time.perf_counter() - start_time + + if time_to_first_token is None: + time_to_first_token = total_latency + + if status_code is None: + status_code = 200 + + if total_latency > 0: + tokens_per_second = total_approx_tokens / total_latency + else: + tokens_per_second = 0.0 + + return { + "time_to_first_token": time_to_first_token, + "total_latency": total_latency, + "tokens_per_second": tokens_per_second, + "status_code": status_code, + "success": True, + "total_approx_tokens": total_approx_tokens, + } + + except HTTPStatusError as e: + return { + "time_to_first_token": None, + "total_latency": None, + "tokens_per_second": 0, + "status_code": e.response.status_code, + "success": False, + "total_approx_tokens": 0, + } + except Exception as e: + print(f"Unexpected error: {e}", file=sys.stderr) + return { + "time_to_first_token": None, + "total_latency": None, + "tokens_per_second": 0, + "status_code": 0, + "success": False, + "total_approx_tokens": 0, + } + + +async def run_stress_test( + api_key: str, output_dir: str, concurrent: int = 4, queries: int = 1000, url: str = "http://0.0.0.0:8005/v1" +): + """Run a stress test by sending concurrent API requests. + + Args: + api_key (str): API key for authentication. + output_dir (str): Directory to save outputs. + concurrent (int): Number of concurrent workers. + queries (int): Number of queries per worker. + url (str): API endpoint URL. + + Measures: + - Total successes and failures. + - Success rate. + - Cumulative failures over query execution time. + Saves to result.csv and generates a Plotly fail rate chart. + + Additionally, prints the total number of empty or errored queries to the console. + """ + client = openai.AsyncOpenAI( + base_url=url, + max_retries=0, + timeout=Timeout(120, connect=120, read=120), + api_key=api_key, + ) + client._client.headers["api-key"] = api_key + word = random.choice(word_list) + prompt = f"Write a short story about {word}." + + total_queries = concurrent * queries + success_count = 0 + fail_count = 0 + fail_list = [] + + print( + f"\nStarting stress test with {concurrent} concurrent workers, each performing {queries} queries (Total: {total_queries} queries)." + ) + + semaphore = asyncio.Semaphore(concurrent) + + async def worker(worker_id: int): + nonlocal success_count, fail_count + for i in range(queries): + async with semaphore: + if i % 10 == 0: + print(f"Worker {worker_id}: {i} / {queries}") + response = await make_completion(client, prompt=prompt, stream=True) + if response["success"] and response["total_approx_tokens"] > 0: + success_count += 1 + fail_list.append(0) + else: + fail_count += 1 + fail_list.append(1) + + # Launch all workers + workers = [asyncio.create_task(worker(w_id)) for w_id in range(concurrent)] + await asyncio.gather(*workers) + + # Calculate success rate + success_rate = (success_count / total_queries) * 100 if total_queries > 0 else 0.0 + + # Save result.csv + result_data = { + "concurrent": concurrent, + "queries": queries, + "success": success_count, + "fail": fail_count, + "success_rate": round(success_rate, 2), + } + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + csv_file = output_dir / "result.csv" + with open(csv_file, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=["concurrent", "queries", "success", "fail", "success_rate"]) + writer.writeheader() + writer.writerow(result_data) + + print("\nSaved result.csv with the following data:") + print(result_data) + + print(f"Total number of empty or errored queries: {fail_count}") + + # Create DataFrame for plotting + fail_df = pd.DataFrame({"fail": fail_list}) + fail_df["cumulative_fail"] = fail_df["fail"].cumsum() + fail_df["query_number"] = fail_df.index + 1 + + # Generate the fail rate chart using Plotly with dark background + fig = go.Figure() + + fig.add_trace( + go.Scatter( + x=fail_df["query_number"], + y=fail_df["cumulative_fail"], + mode="lines", + name="Cumulative Failed/Empty Queries", + line=dict(color="red"), + hoverinfo="x+y", + ) + ) + + fig.update_layout( + template="plotly_dark", + title="Cumulative Failed/Empty Queries Over Time", + xaxis_title="Query Number", + yaxis_title="Number of Failed/Empty Queries", + autosize=False, + width=1200, + height=600, + ) + + # Save the plot to the output directory + plot_file = output_dir / "fail_rate_chart.png" + try: + fig.write_image(plot_file) + print(f"Saved fail rate chart to {plot_file}") + except Exception as e: + print(f"Failed to save fail rate chart as image: {e}") + + # Optionally display the figure + fig.show() + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Run a stress test against the specified API endpoint.") + + parser.add_argument( + "--key", + type=str, + # Specify your API key, current is left here just for local testings. + default="0566dbe21ee33bba9419549716cd6f1f", + help="API key for authentication (default: 0566dbe21ee33bba9419549716cd6f1f).", + ) + + parser.add_argument( + "--url", + type=str, + default="http://0.0.0.0:8005/v1", + help="URL of the API endpoint to test (default: http://0.0.0.0:8005/v1).", + ) + + parser.add_argument( + "--out", + type=str, + default="stress_test", + help="Output directory for storing test results (default: stress_test).", + ) + + parser.add_argument("--concurrent", type=int, default=4, help="Number of concurrent workers to query (default: 4).") + + parser.add_argument("--queries", type=int, default=1000, help="Number of queries per worker (default: 1000).") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + + asyncio.run( + run_stress_test( + api_key=args.key, + url=args.url, + output_dir=args.out, + concurrent=args.concurrent, + queries=args.queries, + ) + ) diff --git a/scripts/stress_test_api.py b/scripts/test_api_load.py similarity index 78% rename from scripts/stress_test_api.py rename to scripts/test_api_load.py index aabf0c747..350a33d3c 100644 --- a/scripts/stress_test_api.py +++ b/scripts/test_api_load.py @@ -1,8 +1,19 @@ +""" +Example usage: +python scripts/test_api_load.py --levels 4 --out stress_test --key API_KEY + +Results will be stored in stress_test/ + +Additional requirements: +pip install plotly kaleido +""" +import argparse import asyncio import csv import random import sys import time +from pathlib import Path import nltk import openai @@ -16,8 +27,7 @@ def approximate_tokens(text: str) -> int: - """ - Approximate the number of tokens from a given string. + """Approximate the number of tokens from a given string. Approach: - Split by whitespace to get a word count. @@ -29,13 +39,13 @@ def approximate_tokens(text: str) -> int: def get_color_for_code(code: int) -> str: - """ - Return a color string for a given status code. - - 200 -> green - - 4xx -> crimson - - 5xx -> darkred - - 0 -> firebrick (means unknown error in this script) - - else -> red + """Return a color string for a given status code. + + - 200 -> green + - 4xx -> crimson + - 5xx -> darkred + - 0 -> firebrick (means unknown error in this script) + - else -> red """ if code == 200: return "green" @@ -55,14 +65,15 @@ async def make_completion( stream: bool = True, seed: str = "1759348", ) -> dict: - """ - Make a completion request to the API, measuring: + """Make a completion request to the API. + + Measures: - Time to first token - Time to full response - Approximate tokens received - Status code - Returns a dictionary with these measurements. + Returns a dictionary with measurements. """ start_time = time.perf_counter() time_to_first_token = None @@ -126,7 +137,7 @@ async def make_completion( "tokens_per_second": tokens_per_second, "status_code": status_code, "success": True, - "total_approx_tokens": total_approx_tokens, # Added for empty response tracking + "total_approx_tokens": total_approx_tokens, } except HTTPStatusError as e: @@ -136,23 +147,22 @@ async def make_completion( "tokens_per_second": 0, "status_code": e.response.status_code, "success": False, - "total_approx_tokens": 0, # Assuming zero tokens on failure + "total_approx_tokens": 0, } except Exception as e: - # For other errors, we record status_code=0 - print(f"Unexpected error: {e}", file=sys.stderr) # Optional: Better error logging + print(f"Unexpected error: {e}", file=sys.stderr) return { "time_to_first_token": None, "total_latency": None, "tokens_per_second": 0, "status_code": 0, "success": False, - "total_approx_tokens": 0, # Assuming zero tokens on failure + "total_approx_tokens": 0, } -async def run_stress_test(api_key: str, url: str = "http://0.0.0.0:8005/v1"): - """Run a stress test by sending concurrent requests at levels 1, 2, 4, 8, etc. +async def run_stress_test(api_key: str, output_dir: str, levels: int = 10, url: str = "http://0.0.0.0:8005/v1"): + """Run a stress test by sending exponentially increasing amount of concurrent requests till `2**levels`. Measures: - Time to first token @@ -171,8 +181,7 @@ async def run_stress_test(api_key: str, url: str = "http://0.0.0.0:8005/v1"): client._client.headers["api-key"] = api_key word = random.choice(word_list) prompt = f"Write a short story about {word}." - # concurrency_levels = [2**i for i in range(0, 11)] - concurrency_levels = [2**i for i in range(0, 8)] + concurrency_levels = [2**i for i in range(0, levels)] results = [] for concurrency in concurrency_levels: @@ -238,7 +247,11 @@ async def run_stress_test(api_key: str, url: str = "http://0.0.0.0:8005/v1"): field_names = list(csv_rows[0].keys()) - with open("stress_test.csv", "w", newline="", encoding="utf-8") as f: + output_dir = Path(output_dir) + if not output_dir.exists(): + output_dir.mkdir(exist_ok=True, parents=True) + + with open(output_dir / "stress_test.csv", "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=field_names) writer.writeheader() writer.writerows(csv_rows) @@ -259,7 +272,7 @@ async def run_stress_test(api_key: str, url: str = "http://0.0.0.0:8005/v1"): ) fig1.update_xaxes(title_text="Concurrent Queries") fig1.update_yaxes(title_text="Avg. Tokens/Second (Approx.)") - fig1.write_image("stress_test_tokens_per_second.png") + fig1.write_image(output_dir / "stress_test_tokens_per_second.png") fig1.show() # 2. First-Token Latency vs. Concurrency. @@ -272,7 +285,7 @@ async def run_stress_test(api_key: str, url: str = "http://0.0.0.0:8005/v1"): ) fig2.update_xaxes(title_text="Concurrent Queries") fig2.update_yaxes(title_text="Avg. Latency to First Token (s)") - fig2.write_image("stress_test_first_token_latency.png") + fig2.write_image(output_dir / "stress_test_first_token_latency.png") fig2.show() # 3. Full Response Latency vs. Concurrency. @@ -285,7 +298,7 @@ async def run_stress_test(api_key: str, url: str = "http://0.0.0.0:8005/v1"): ) fig3.update_xaxes(title_text="Concurrent Queries") fig3.update_yaxes(title_text="Avg. Total Latency (s)") - fig3.write_image("stress_test_full_response_latency.png") + fig3.write_image(output_dir / "stress_test_full_response_latency.png") fig3.show() # 4. Status Code Counts vs. Concurrency. @@ -318,7 +331,7 @@ async def run_stress_test(api_key: str, url: str = "http://0.0.0.0:8005/v1"): ) fig4.update_xaxes(title_text="Concurrent Queries") fig4.update_yaxes(title_text="Count of Responses") - fig4.write_image("stress_test_status_codes.png") + fig4.write_image(output_dir / "stress_test_status_codes.png") fig4.show() # 5. Empty Responses vs. Concurrency. @@ -332,12 +345,50 @@ async def run_stress_test(api_key: str, url: str = "http://0.0.0.0:8005/v1"): ) fig5.update_xaxes(title_text="Concurrent Queries") fig5.update_yaxes(title_text="Count of Empty Responses") - fig5.write_image("stress_test_empty_responses.png") + fig5.write_image(output_dir / "stress_test_empty_responses.png") fig5.show() print("All plots saved to .png files and displayed.") +def parse_arguments(): + parser = argparse.ArgumentParser(description="Run a stress test against the specified API endpoint.") + + parser.add_argument( + "--key", + type=str, + # Specify your API key, current is left here just for local testings. + default="0566dbe21ee33bba9419549716cd6f1f", + help="API key for authentication.", + ) + + parser.add_argument( + "--url", + type=str, + default="http://0.0.0.0:8005/v1", + help="URL of the API endpoint to test (default: http://0.0.0.0:8005/v1).", + ) + + parser.add_argument( + "--out", + type=str, + default="stress_test", + help="Output directory for storing test results (default: stress_test).", + ) + + parser.add_argument("--levels", type=int, default=10, help="Number of stress test levels to execute (default: 10).") + + return parser.parse_args() + + if __name__ == "__main__": - # Replace api_key and url with appropriate values. - asyncio.run(run_stress_test(api_key="0566dbe21ee33bba9419549716cd6f1f", url="http://0.0.0.0:8005/v1")) + args = parse_arguments() + + asyncio.run( + run_stress_test( + api_key=args.key, + url=args.url, + output_dir=args.out, + levels=args.levels, + ) + ) diff --git a/validator_api/api_management.py b/validator_api/api_management.py index 109efe57e..6ed306af1 100644 --- a/validator_api/api_management.py +++ b/validator_api/api_management.py @@ -26,7 +26,6 @@ def save_api_keys(api_keys): # Use lifespan to initialize API keys _keys = load_api_keys() logger.info(f"Loaded API keys: {_keys}") -save_api_keys(_keys) # Dependency to validate the admin key