In [None]:
import asyncio
import aiohttp
from aiohttp import ClientSession, FormData
import time
import json
from pathlib import Path

ENDPOINT = "http://localhost:80/infer"

TEST_CASES = [
    {
        "name": "sentiment-analysis",
        "model_name": "distilbert-base-uncased-finetuned-sst-2-english",
        "task": "text-classification",
        "inputs": "I love this product!",
        "use_file": False
    },

    {
        "name": "question-answering",
        "model_name": "distilbert-base-cased-distilled-squad",
        "task": "question-answering",
        "inputs": {
            "question": "Where is Hugging Face based?",
            "context": "Hugging Face is a company based in New York."
        },
        "use_file": False
    },

    {
        "name": "image-classification",
        "model_name": "facebook/deit-tiny-patch16-224",
        "task": "image-classification",
        "file_path": "test_image.jpg",
        "use_file": True
    },
  
    {
        "name": "speech-recognition",
        "model_name": "openai/whisper-small",
        "task": "automatic-speech-recognition",
        "file_path": "test_audio.wav",
        "use_file": True
    },

    {
        "name": "text-generation",
        "model_name": "gpt2",
        "task": "text-generation",
        "inputs": "The future of AI is",
        "parameters": {"max_length": 50, "temperature": 0.7},
        "use_file": False
    }
]

NUM_REPEATS = 1

async def send_request(session: ClientSession, test_case: dict, idx: int):
    try:
        data = FormData()
        data.add_field("model_name", test_case["model_name"])

        if "task" in test_case:
            data.add_field("task", test_case["task"])

        if test_case["use_file"]:
            file_path = Path(test_case["file_path"])
            if not file_path.exists():
                print(f"File not found: {file_path}. Skipping...")
                return

            f = open(file_path, "rb")
            data.add_field("file", f, filename=file_path.name)

            async with session.post(ENDPOINT, data=data) as response:
                resp = await response.json()
                print(f"\nRequest {idx} | {test_case['name']}")
                print(f"Model: {test_case['model_name']}")
                print(f"Status: {response.status}")
                print("Response:", resp)

            f.close()

        else:
            if "inputs" in test_case:
                if isinstance(test_case["inputs"], dict):
                    data.add_field("inputs", json.dumps(test_case["inputs"]))
                else:
                    data.add_field("inputs", test_case["inputs"])

            if "parameters" in test_case:
                data.add_field("parameters", json.dumps(test_case["parameters"]))
            
            async with session.post(ENDPOINT, data=data) as response:
                resp = await response.json()
                print(f"\nRequest {idx} | {test_case['name']}")
                print(f"Model: {test_case['model_name']}")
                print(f"Status: {response.status}")
                print("Response:", resp)

    except Exception as e:
        print(f"\nRequest {idx} | Error: {str(e)}")

async def main():
    valid_cases = []
    for case in TEST_CASES:
        if case["use_file"]:
            file_path = Path(case["file_path"])
            if file_path.exists():
                valid_cases.append(case)
            else:
                print(f"Skipping {case['name']} - file not found: {file_path}")
        else:
            valid_cases.append(case)

    if not valid_cases:
        print("No valid test cases available!")
        return

    total_requests = len(valid_cases) * NUM_REPEATS
    print(f"\nStarting test with {total_requests} requests...")
    print(f"Endpoint: {ENDPOINT}")
    print("=" * 60)

    start_time = time.time()

    async with ClientSession() as session:
        tasks = []
        for i in range(NUM_REPEATS):
            for case in valid_cases:
                tasks.append(send_request(session, case, len(tasks) + 1))

        await asyncio.gather(*tasks)

    duration = time.time() - start_time
    print("\n" + "=" * 60)
    print(f"Test completed")
    print(f"Total requests: {total_requests}")
    print(f"Total time: {duration:.2f} seconds")
    print(f"Requests/sec: {total_requests / duration:.2f}")
    print("=" * 60)


if __name__ == "__main__":
    import nest_asyncio
    nest_asyncio.apply()

    await main()




Starting test with 5 requests...
Endpoint: http://localhost:80/infer
