Usage:
1. Run `python api.py --init-quota 200 --max-generated-tokens 50` to start the server.
2. Run this script to test the server.

You MUST keep the logs of this file in your submission.

In [1]:
import asyncio
import random
import time
import aiohttp
from typing import Dict

In [2]:


class AsyncLLMServiceTester:
    def __init__(self, base_url: str = "http://localhost:8000"):
        self.base_url = base_url.rstrip("/")
        self.session = None

    async def __aenter__(self):
        self.session = aiohttp.ClientSession()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self.session:
            await self.session.close()

    async def generate(self, prompt: str) -> Dict:
        """generate response"""
        # TODO: generate function call
        # you need to post a request to the server
        # and parse the response in async way
        # the response is a json with the following format:
        # {
        #     "status": "success" or "error",
        #     "text": "the generated text"
        # }
        # ==== start your code here ====
        data = {
            "prompt": prompt
        }
        async with self.session as session:
            async with session.post(self.base_url + "/generate", data=data) as response:
                status = response.status
                text = await response.text()
                result = {
                    "status": status,
                    "text": text
                }
        return result
        # ==== end of your code ====

    async def process_test_case(self, test_case: Dict) -> Dict:
        """Process a single test case"""

        await asyncio.sleep(random.uniform(0, 2))  # random delay
        print(f"Submitting test case: {test_case['prompt']}")

        start_time = time.time()
        result = await self.generate(test_case["prompt"])
        end_time = time.time()

        return {
            "test_case": test_case,
            "result": result,
            "time_taken": end_time - start_time,
        }

In [3]:

async def main():
    test_cases = [
        {"prompt": "Hello, how are you?"},
        {"prompt": "What is the capital of France? And what is the capital of Canada?"},
        {"prompt": "Write a poem about spring."},
        {"prompt": "Explain quantum computing in 50 words."},
        {"prompt": "Write a recipe for chocolate cake."},
    ]

    print(f"Starting test with {len(test_cases)} test cases...")
    start_time = time.time()

    async with AsyncLLMServiceTester() as tester:
        # create submit coroutines
        submit_coroutines = [
            tester.process_test_case(test_case) for test_case in test_cases
        ]

        # run all submit coroutines
        results = await asyncio.gather(*submit_coroutines)

    total_time = time.time() - start_time

    # print results
    for i, result in enumerate(results, 1):
        print(f"\n=== Test Case {i} ===")
        print(f"Prompt: {result['test_case']['prompt']}")
        print(f"Time taken: {result['time_taken']:.2f} seconds")
        print(f"Status: {result['result']['status']}")
        print(f"Response: {result['result']['text']}")
        print("=" * 50)

    print(f"\nTest Summary:")
    print(f"Total time: {total_time:.2f} seconds")
    print(f"Average time per request: {total_time/len(test_cases):.2f} seconds")
    print(f"{'='*50}\n")


await main()

Starting test with 5 test cases...
Submitting test case: Write a recipe for chocolate cake.


ClientConnectorError: Cannot connect to host localhost:8000 ssl:default [Multiple exceptions: [Errno 61] Connect call failed ('::1', 8000, 0, 0), [Errno 61] Connect call failed ('127.0.0.1', 8000)]

Submitting test case: What is the capital of France? And what is the capital of Canada?
Submitting test case: Explain quantum computing in 50 words.
Submitting test case: Write a poem about spring.
Submitting test case: Hello, how are you?
