Skip to content

Commit

Permalink
Plugins intermediate steps reporting (#3238)
Browse files Browse the repository at this point in the history
This PR is extracted from my OA Browsing experimental branch, and it
should improve the visual/functional usability of plugins in a way.

Currently, when the plugin is enabled, on the front-end there will be
rendered QueueInfo with a message that you are in position 0, for the
whole time while the plugin system is executing its steps.

And this PR will change that so that when the plugin is enabled, it will
be sending intermediate steps while executing each call to the LLM, so
QueueInfo will be discarded, as soon as the first step in the plugin
system is executed.

Also, users now will have "real-time" tracking of LLM inner thoughts,
plans etc...

Demo:


https://github.com/LAION-AI/Open-Assistant/assets/13547364/30252e22-d5d9-4200-9ccf-3a209d8581a7

---------

Co-authored-by: Oliver Stanley <oliver.stanley@kainos.com>
  • Loading branch information
draganjovanovich and olliestanley committed May 29, 2023
1 parent 54f4408 commit 21d9ebf
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 6 deletions.
11 changes: 11 additions & 0 deletions inference/server/oasst_inference_server/routes/chats.py
Expand Up @@ -244,6 +244,17 @@ async def event_generator(chat_id: str, message_id: str, worker_compat_hash: str
).json(),
}

if response_packet.response_type == "plugin_intermediate":
logger.info(f"Received plugin intermediate response {chat_id}")
yield {
"data": chat_schema.PluginIntermediateResponseEvent(
current_plugin_thought=response_packet.current_plugin_thought,
current_plugin_action_taken=response_packet.current_plugin_action_taken,
current_plugin_action_input=response_packet.current_plugin_action_input,
current_plugin_action_response=response_packet.current_plugin_action_response,
).json(),
}

if response_packet.response_type == "internal_error":
yield {
"data": chat_schema.ErrorResponseEvent(
Expand Down
19 changes: 19 additions & 0 deletions inference/server/oasst_inference_server/routes/workers.py
Expand Up @@ -218,6 +218,12 @@ def _add_receive(ftrs: set):
response=worker_response,
work_request_map=work_request_map,
)
case "plugin_intermediate":
worker_response = cast(inference.PluginIntermediateResponse, worker_response)
await handle_plugin_intermediate_response(
work_request_map=work_request_map,
response=worker_response,
)
case _:
raise RuntimeError(f"Unknown response type: {worker_response.response_type}")
finally:
Expand Down Expand Up @@ -338,6 +344,19 @@ async def handle_token_response(
work_response_container.num_responses += 1


async def handle_plugin_intermediate_response(
response: inference.PluginIntermediateResponse,
work_request_map: WorkRequestContainerMap,
):
work_response_container = get_work_request_container(work_request_map, response.request_id)
message_queue = queueing.message_queue(
deps.redis_client,
message_id=work_response_container.message_id,
)
await message_queue.enqueue(response.json())
work_response_container.num_responses += 1


async def handle_generated_text_response(
response: inference.GeneratedTextResponse,
work_request_map: WorkRequestContainerMap,
Expand Down
17 changes: 16 additions & 1 deletion inference/server/oasst_inference_server/schemas/chat.py
Expand Up @@ -46,8 +46,23 @@ class SafePromptResponseEvent(pydantic.BaseModel):
message: inference.MessageRead


class PluginIntermediateResponseEvent(pydantic.BaseModel):
event_type: Literal["plugin_intermediate"] = "plugin_intermediate"
current_plugin_thought: str
current_plugin_action_taken: str
current_plugin_action_input: str
current_plugin_action_response: str
message: inference.MessageRead | None = None


ResponseEvent = Annotated[
Union[TokenResponseEvent, ErrorResponseEvent, MessageResponseEvent, SafePromptResponseEvent],
Union[
TokenResponseEvent,
ErrorResponseEvent,
MessageResponseEvent,
SafePromptResponseEvent,
PluginIntermediateResponseEvent,
],
pydantic.Field(discriminator="event_type"),
]

Expand Down
51 changes: 51 additions & 0 deletions inference/worker/chat_chain.py
Expand Up @@ -3,6 +3,7 @@
import interface
import transformers
import utils
import websocket
from chat_chain_prompts import (
ASSISTANT_PREFIX,
HUMAN_PREFIX,
Expand Down Expand Up @@ -108,6 +109,8 @@ def handle_plugin_usage(
tools: list[Tool],
plugin: inference.PluginEntry | None,
plugin_max_depth: int,
ws: websocket.WebSocket,
work_request_id: str,
) -> tuple[str, inference.PluginUsed]:
execution_details = inference.PluginExecutionDetails(
inner_monologue=[],
Expand Down Expand Up @@ -142,17 +145,46 @@ def handle_plugin_usage(
tokenizer, worker_config, parameters, prompt_template, memory, tool_names, language, action_input_format
)

# send "thinking..." intermediate step to UI (This will discard queue position 0) immediately
utils.send_response(
ws,
inference.PluginIntermediateResponse(
request_id=work_request_id,
current_plugin_thought="thinking...",
current_plugin_action_taken="",
current_plugin_action_input="",
current_plugin_action_response="",
),
)

init_prompt = f"{input_prompt}{eos_token}{V2_ASST_PREFIX}"
init_prompt, chain_response = chain.call(init_prompt)

inner_monologue.append("In: " + str(init_prompt))
inner_monologue.append("Out: " + str(chain_response))

current_action_thought = ""
if THOUGHT_SEQ in chain_response:
current_action_thought = chain_response.split(THOUGHT_SEQ)[1].split("\n")[0]

# Tool name/assistant prefix, Tool input/assistant response
prefix, response = extract_tool_and_input(llm_output=chain_response, ai_prefix=ASSISTANT_PREFIX)
assisted = False if ASSISTANT_PREFIX in prefix else True
chain_finished = not assisted

if assisted:
# model decided to use a tool, so send that thought to the client
utils.send_response(
ws,
inference.PluginIntermediateResponse(
request_id=work_request_id,
current_plugin_thought=current_action_thought,
current_plugin_action_taken=prefix,
current_plugin_action_input=chain_response,
current_plugin_action_response=response,
),
)

while not chain_finished and assisted and achieved_depth < plugin_max_depth:
tool_response = use_tool(prefix, response, tools)

Expand All @@ -165,6 +197,22 @@ def handle_plugin_usage(
inner_monologue.append("In: " + str(new_prompt))
inner_monologue.append("Out: " + str(chain_response))

current_action_thought = ""
if THOUGHT_SEQ in chain_response:
current_action_thought = chain_response.split(THOUGHT_SEQ)[1].split("\n")[0]

# Send deep plugin intermediate steps to UI
utils.send_response(
ws,
inference.PluginIntermediateResponse(
request_id=work_request_id,
current_plugin_thought=current_action_thought,
current_plugin_action_taken=prefix,
current_plugin_action_input=chain_response,
current_plugin_action_response=response,
),
)

prefix, response = extract_tool_and_input(llm_output=chain_response, ai_prefix=ASSISTANT_PREFIX)
assisted = False if ASSISTANT_PREFIX in prefix else True

Expand Down Expand Up @@ -286,6 +334,7 @@ def handle_conversation(
worker_config: inference.WorkerConfig,
parameters: interface.GenerateStreamParameters,
tokenizer: transformers.PreTrainedTokenizer,
ws: websocket.WebSocket,
) -> tuple[str, inference.PluginUsed | None]:
try:
original_prompt = work_request.thread.messages[-1].content
Expand Down Expand Up @@ -323,6 +372,8 @@ def handle_conversation(
tools,
plugin,
work_request.parameters.plugin_max_depth,
ws,
work_request.id,
)

return handle_standard_usage(original_prompt, prompt_template, language, memory, worker_config, tokenizer)
Expand Down
2 changes: 1 addition & 1 deletion inference/worker/work.py
Expand Up @@ -84,7 +84,7 @@ def handle_work_request(

for plugin in parameters.plugins:
if plugin.enabled:
prompt, used_plugin = chat_chain.handle_conversation(work_request, worker_config, parameters, tokenizer)
prompt, used_plugin = chat_chain.handle_conversation(work_request, worker_config, parameters, tokenizer, ws)
# When using plugins and final prompt is truncated due to length limit
# LLaMA has tendency to leak internal prompts and generate bad continuations
# So we add keywords/sequences to the stop sequences to reduce this
Expand Down
10 changes: 10 additions & 0 deletions oasst-shared/oasst_shared/schemas/inference.py
Expand Up @@ -329,6 +329,15 @@ class SafePromptResponse(WorkerResponseBase):
safety_rots: str


class PluginIntermediateResponse(WorkerResponseBase):
response_type: Literal["plugin_intermediate"] = "plugin_intermediate"
text: str = ""
current_plugin_thought: str
current_plugin_action_taken: str
current_plugin_action_input: str
current_plugin_action_response: str


class TokenResponse(WorkerResponseBase):
response_type: Literal["token"] = "token"
text: str
Expand Down Expand Up @@ -389,6 +398,7 @@ class GeneralErrorResponse(WorkerResponseBase):
InternalFinishedMessageResponse,
InternalErrorResponse,
SafePromptResponse,
PluginIntermediateResponse,
],
pydantic.Field(discriminator="response_type"),
]
39 changes: 37 additions & 2 deletions website/src/components/Chat/ChatConversation.tsx
Expand Up @@ -7,7 +7,7 @@ import { UseFormGetValues } from "react-hook-form";
import SimpleBar from "simplebar-react";
import { useMessageVote } from "src/hooks/chat/useMessageVote";
import { get, post } from "src/lib/api";
import { handleChatEventStream, QueueInfo } from "src/lib/chat_stream";
import { handleChatEventStream, QueueInfo, PluginIntermediateResponse } from "src/lib/chat_stream";
import { OasstError } from "src/lib/oasst_api_client";
import { API_ROUTES } from "src/lib/routes";
import {
Expand Down Expand Up @@ -38,6 +38,7 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf

const [streamedResponse, setResponse] = useState<string | null>(null);
const [queueInfo, setQueueInfo] = useState<QueueInfo | null>(null);
const [pluginIntermediateResponse, setPluginIntermediateResponse] = useState<PluginIntermediateResponse | null>(null);
const [isSending, setIsSending] = useBoolean();
const [showEncourageMessage, setShowEncourageMessage] = useBoolean(false);

Expand Down Expand Up @@ -98,9 +99,16 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
message = await handleChatEventStream({
stream: body!,
onError: console.error,
onPending: setQueueInfo,
onPending: (data) => {
setQueueInfo(data);
setPluginIntermediateResponse(null);
},
onPluginIntermediateResponse: setPluginIntermediateResponse,
onToken: async (text) => {
setQueueInfo(null);
if (text != "") {
setPluginIntermediateResponse(null);
}
setResponse(text);
await new Promise(requestAnimationFrame);
},
Expand Down Expand Up @@ -305,6 +313,33 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
{t("queue_info", queueInfo)}
</Badge>
)}
{pluginIntermediateResponse && pluginIntermediateResponse.currentPluginThought && (
<Box
position="absolute"
bottom="0"
left="50%"
transform="translate(-50%)"
display="flex"
flexDirection="row"
gap="1"
justifyContent="center"
alignItems="center"
>
<Box
bg="purple.700"
color="white"
px="2"
py="2.5px"
borderRadius="8px"
maxWidth="50vw"
fontSize="11"
fontWeight="bold"
isTruncated
>
{pluginIntermediateResponse.currentPluginThought}
</Box>
</Box>
)}
</Box>
<ChatForm ref={inputRef} isSending={isSending} onSubmit={sendPrompterMessage}></ChatForm>
<ChatWarning />
Expand Down
18 changes: 17 additions & 1 deletion website/src/lib/chat_stream.ts
Expand Up @@ -5,18 +5,27 @@ export interface QueueInfo {
queueSize: number;
}

export interface PluginIntermediateResponse {
currentPluginThought: string;
currentPluginAction: string;
currentPluginActionResponse: string;
currentPluginActionInput: string;
}

export interface ChatStreamHandlerOptions {
stream: ReadableStream<Uint8Array>;
onError: (err: unknown) => unknown;
onPending: (info: QueueInfo) => unknown;
onToken: (partialMessage: string) => unknown;
onPluginIntermediateResponse: (pluginIntermediateResponse: PluginIntermediateResponse) => unknown;
}

export async function handleChatEventStream({
stream,
onError,
onPending,
onToken,
onPluginIntermediateResponse,
}: ChatStreamHandlerOptions): Promise<InferenceMessage | null> {
let tokens = "";
for await (const { event, data } of iteratorSSE(stream)) {
Expand All @@ -39,8 +48,15 @@ export async function handleChatEventStream({
// handle error
await onError(chunk.error);
return chunk.message;
} else if (chunk.event_type === "plugin_intermediate") {
await onPluginIntermediateResponse({
currentPluginThought: chunk.current_plugin_thought,
currentPluginAction: chunk.current_plugin_action_taken,
currentPluginActionResponse: chunk.current_plugin_action_response,
currentPluginActionInput: chunk.current_plugin_action_input,
});
} else {
console.error("Unexpected event", chunk);
console.log("Unexpected event", chunk);
}
} catch (e) {
console.error(`Error parsing data: ${data}, error: ${e}`);
Expand Down
15 changes: 14 additions & 1 deletion website/src/types/Chat.ts
Expand Up @@ -80,7 +80,20 @@ interface InferenceEventPending {
queue_size: number;
}

export type InferenceEvent = InferenceEventMessage | InferenceEventError | InferenceEventToken | InferenceEventPending;
interface InferenceEventPluginIntermediateStep {
event_type: "plugin_intermediate";
current_plugin_thought: string;
current_plugin_action_taken: string;
current_plugin_action_response: string;
current_plugin_action_input: string;
}

export type InferenceEvent =
| InferenceEventMessage
| InferenceEventError
| InferenceEventToken
| InferenceEventPending
| InferenceEventPluginIntermediateStep;

export type ModelInfo = {
name: string;
Expand Down

0 comments on commit 21d9ebf

Please sign in to comment.