Skip to content

Commit

Permalink
Allow downstream of streaming nodes (#346)
Browse files Browse the repository at this point in the history
* add handling logics to downstream of a streaming node

* loose limit

* fix name

* fix decode, add debug

* fix

* fix

* fix chunking logics

* json dump input

* fix re, en/decode

* remove debug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add another UT for streaming cases

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* further increase coverage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Spycsh and pre-commit-ci[bot] authored Jul 29, 2024
1 parent d39fee9 commit 90e367e
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 22 deletions.
7 changes: 1 addition & 6 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,7 @@ async def handle_request(self, request: Request):
initial_inputs={"text": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LLM
):
if isinstance(response, StreamingResponse):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
Expand Down
49 changes: 47 additions & 2 deletions comps/cores/mega/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,42 @@ async def execute(
response = requests.post(
url=endpoint, data=json.dumps(inputs), proxies={"http": None}, stream=True, timeout=1000
)
downstream = runtime_graph.downstream(cur_node)
if downstream:
assert len(downstream) == 1, "Not supported multiple streaming downstreams yet!"
cur_node = downstream[0]
hitted_ends = [".", "?", "!", "。", ",", "!"]
endpoint = self.services[downstream[0]].endpoint_path

def generate():
if response:
buffered_chunk_str = ""
for chunk in response.iter_content(chunk_size=None):
if chunk:
yield chunk
if downstream:
chunk = chunk.decode("utf-8")
buffered_chunk_str += self.extract_chunk_str(chunk)
is_last = chunk.endswith("[DONE]\n\n")
if (buffered_chunk_str and buffered_chunk_str[-1] in hitted_ends) or is_last:
res = requests.post(
url=endpoint,
data=json.dumps({"text": buffered_chunk_str}),
proxies={"http": None},
)
res_json = res.json()
if "text" in res_json:
res_txt = res_json["text"]
else:
raise Exception("Other response types not supported yet!")
buffered_chunk_str = "" # clear
yield from self.token_generator(res_txt, is_last=is_last)
else:
yield chunk

return StreamingResponse(generate(), media_type="text/event-stream"), cur_node
else:
async with session.post(endpoint, json=inputs) as response:
print(response.status)
print(f"{cur_node}: {response.status}")
return await response.json(), cur_node

def dump_outputs(self, node, response, result_dict):
Expand All @@ -143,3 +168,23 @@ def get_all_final_outputs(self, result_dict, runtime_graph):
for leaf in runtime_graph.all_leaves():
final_output_dict[leaf] = result_dict[leaf]
return final_output_dict

def extract_chunk_str(self, chunk_str):
if chunk_str == "data: [DONE]\n\n":
return ""
prefix = "data: b'"
suffix = "'\n\n"
if chunk_str.startswith(prefix):
chunk_str = chunk_str[len(prefix) :]
if chunk_str.endswith(suffix):
chunk_str = chunk_str[: -len(suffix)]
return chunk_str

def token_generator(self, sentence, is_last=False):
prefix = "data: "
suffix = "\n\n"
tokens = re.findall(r"\S+\s?", sentence, re.UNICODE)
for token in tokens:
yield prefix + repr(token.encode("utf-8")) + suffix
if is_last:
yield "data: [DONE]\n\n"
30 changes: 16 additions & 14 deletions tests/cores/mega/test_service_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,22 @@ async def s2_add(request: TextDoc) -> TextDoc:


class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.s1 = opea_microservices["s1"]
self.s2 = opea_microservices["s2"]
self.s1.start()
self.s2.start()

self.service_builder = ServiceOrchestrator()

self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"])
self.service_builder.flow_to(self.s1, self.s2)

def tearDown(self):
self.s1.stop()
self.s2.stop()
@classmethod
def setUpClass(cls):
cls.s1 = opea_microservices["s1"]
cls.s2 = opea_microservices["s2"]
cls.s1.start()
cls.s2.start()

cls.service_builder = ServiceOrchestrator()

cls.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"])
cls.service_builder.flow_to(cls.s1, cls.s2)

@classmethod
def tearDownClass(cls):
cls.s1.stop()
cls.s2.stop()

async def test_schedule(self):
result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "})
Expand Down
78 changes: 78 additions & 0 deletions tests/cores/mega/test_service_orchestrator_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import unittest

from fastapi.responses import StreamingResponse

from comps import ServiceOrchestrator, ServiceType, TextDoc, opea_microservices, register_microservice


@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add")
async def s1_add(request: TextDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
text += " ~~~"
return {"text": text}


@register_microservice(name="s0", host="0.0.0.0", port=8085, endpoint="/v1/add", service_type=ServiceType.LLM)
async def s0_add(request: TextDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]

async def token_generator():
for i in [" OPEA", " is", " great.", " I", " think ", " so."]:
yield i

text += "project!"
return StreamingResponse(token_generator(), media_type="text/event-stream")


class TestServiceOrchestratorStreaming(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
cls.s0 = opea_microservices["s0"]
cls.s1 = opea_microservices["s1"]
cls.s0.start()
cls.s1.start()

cls.service_builder = ServiceOrchestrator()

cls.service_builder.add(opea_microservices["s0"]).add(opea_microservices["s1"])
cls.service_builder.flow_to(cls.s0, cls.s1)

@classmethod
def tearDownClass(cls):
cls.s0.stop()
cls.s1.stop()

async def test_schedule(self):
result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "})
response = result_dict["s1/MicroService"]
idx = 0
res_expected = ["OPEA", "is", "great.", "~~~", "I", "think", "so.", "~~~"]
async for k in response.__reduce__()[2]["body_iterator"]:
self.assertEqual(self.service_builder.extract_chunk_str(k).strip(), res_expected[idx])
idx += 1

def test_extract_chunk_str(self):
res = self.service_builder.extract_chunk_str("data: [DONE]\n\n")
self.assertEqual(res, "")
res = self.service_builder.extract_chunk_str("data: b'example test.'\n\n")
self.assertEqual(res, "example test.")

def test_token_generator(self):
sentence = "I write an example test.</s>"
for i in self.service_builder.token_generator(sentence=sentence, is_last=False):
self.assertTrue(i.startswith("data: b'"))

for i in self.service_builder.token_generator(sentence=sentence, is_last=True):
self.assertTrue(i.startswith("data: "))


if __name__ == "__main__":
unittest.main()

0 comments on commit 90e367e

Please sign in to comment.