In [None]:
import time
import uuid
import requests
import json
!pip install -q websocket-client
import websocket

In [None]:
URL = "localhost:7860"

In [None]:
from concurrent.futures import ThreadPoolExecutor

def run_in_parallel(func, n):
    # Ensure the callable and repetitions are valid
    if not callable(func) or not isinstance(n, int) or n < 1:
        raise ValueError("Invalid function or number of repetitions")
    
    # Define a wrapper function to execute
    def task_wrapper():
        return func()
    
    # Use ThreadPoolExecutor to run tasks in parallel
    with ThreadPoolExecutor(max_workers=n) as executor:
        futures = [executor.submit(task_wrapper) for _ in range(n)]
        
        # Wait for all futures to complete and collect results
        results = [future.result() for future in futures]
    
    return results


## Gradio 4

In [None]:
def request():
	start_time = time.time()
	session_hash = uuid.uuid4().hex
	payload = {"data": ["test"], "fn_index": 0, "session_hash": session_hash}
	url = f"http://{URL}/"
	resp = requests.post(f"{url}queue/join", json=payload, timeout=5)
	assert resp.status_code == 200

	message_count = 0
	output = ""
	with requests.get(f"{url}queue/data?session_hash={session_hash}", stream=True) as response:
		response.raise_for_status()
		for line in response.iter_lines():
			if line:
				decoded_line = line.decode('utf-8')
				if decoded_line.startswith("data:"):
					output = decoded_line.replace("data: ", "")
					message_count += 1
	
	end_time = time.time()
	duration = end_time - start_time
	return (duration, message_count, json.loads(output)["output"]["data"])

In [None]:
request()

In [None]:
output = run_in_parallel(request, 5)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 25)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 100)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)

print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 250)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


## Gradio 3.x

In [None]:
def request():
    start_time = time.time()
    url = f"ws://{URL}/"
    session_hash = uuid.uuid4().hex
    output = None
    message_count = 0
    start_time = time.time()
    try:
        # Connect to WebSocket server
        ws = websocket.create_connection(f"{url}queue/join")

        while True:
            message = ws.recv()  # Wait and receive incoming message
            message_count += 1
            message = json.loads(message)
            msg = message["msg"]

            if msg == "send_hash":
                ws.send(json.dumps({"session_hash": session_hash, "fn_index": 0}))

            if msg == "send_hash":
                ws.send(json.dumps({"data":["test"],"event_data":None,"fn_index":0,"session_hash":session_hash}	))

            if msg == "process_completed":
                output = message["output"]["data"]
                break


    finally:
        ws.close()  # Ensure the connection is closed properly
        
    duration = time.time() - start_time
    return duration, message_count, output


In [None]:
request()

In [None]:
output = run_in_parallel(request, 5)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 25)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)

In [None]:
output = run_in_parallel(request, 100)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 250)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


### Simple SSE

In [None]:
def request():
	start_time = time.time()
	url = f"http://{URL}/"
	message_count = 0
	output = ""
	with requests.get(f"{url}sse", stream=True) as response:
		response.raise_for_status()
		for line in response.iter_lines():
			if line:
				decoded_line = line.decode('utf-8')
				if decoded_line.startswith("data:"):
					output = decoded_line.replace("data: ", "")
					message_count += 1
			if message_count == 500:
				break
	
	end_time = time.time()
	duration = end_time - start_time
	return (duration, message_count, output)

In [None]:
request()

In [None]:
output = run_in_parallel(request, 5)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 25)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 100)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 250)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


### Simple Websocket

In [None]:
def request():
    start_time = time.time()
    url = f"ws://{URL}/"
    output = None
    message_count = 0
    start_time = time.time()
    try:
        # Connect to WebSocket server
        ws = websocket.create_connection(f"{url}ws")

        while True:
            message = ws.recv()  # Wait and receive incoming message
            message_count += 1
            output = message
            if message_count == 500:
                break


    finally:
        ws.close()  # Ensure the connection is closed properly
        
    duration = time.time() - start_time
    return duration, message_count, output


In [None]:
request()

In [None]:
output = run_in_parallel(request, 5)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 25)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 100)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 250)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


### SSE w/ Workers

In [None]:
def request():
	start_time = time.time()
	payload = {"data": "test"}
	url = f"http://{URL}/"
	resp = requests.post(f"{url}sse/send", json=payload, timeout=5)
	assert resp.status_code == 200
	session_id = resp.json()["session_id"]

	message_count = 0
	output = ""
	with requests.get(f"{url}sse/listen?session_id={session_id}", stream=True) as response:
		response.raise_for_status()
		for line in response.iter_lines():
			if line:
				decoded_line = line.decode('utf-8')
				if decoded_line.startswith("data:"):
					output = decoded_line.replace("data: ", "")
					message_count += 1
	
	end_time = time.time()
	duration = end_time - start_time
	return (duration, message_count, output)

In [None]:
request()

In [None]:
output = run_in_parallel(request, 5)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 25)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 100)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 250)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


### Websockets w/ Workers

In [None]:
def request():
    start_time = time.time()
    url = f"ws://{URL}/"
    output = None
    message_count = 0
    start_time = time.time()
    try:
        ws = websocket.create_connection(f"{url}ws")
        ws.send("test")

        while True:
            message = ws.recv()  # Wait and receive incoming message
            message_count += 1
            output = message
            if message_count == 500:
                break
        

    finally:
        ws.close()  # Ensure the connection is closed properly
        
    duration = time.time() - start_time
    return duration, message_count, output


In [None]:
request()

In [None]:
output = run_in_parallel(request, 5)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 25)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 100)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)


In [None]:
output = run_in_parallel(request, 250)
avg_duration = sum(o[0] for o in output) / len(output)
avg_msg = sum(o[1] for o in output) / len(output)
print(avg_duration, avg_msg)
