Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add debug auth between Website & Inference (#1893)
Refs #1880 #1881 This PR does not add any new features to the UI, but enables debug authorization when talking to the inference api.
- Loading branch information
Showing
9 changed files
with
243 additions
and
124 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import { Box, Button, Flex, Textarea, useColorModeValue } from "@chakra-ui/react"; | ||
import { useTranslation } from "next-i18next"; | ||
import { useCallback, useRef, useState } from "react"; | ||
import { InferenceMessage, InferenceResponse } from "src/types/Chat"; | ||
|
||
interface ChatConversationProps { | ||
chatId: string; | ||
} | ||
|
||
export const ChatConversation = ({ chatId }: ChatConversationProps) => { | ||
const { t } = useTranslation("common"); | ||
const inputRef = useRef<HTMLTextAreaElement>(); | ||
const [messages, setMessages] = useState<InferenceMessage[]>([]); | ||
const [streamedResponse, setResponse] = useState(""); | ||
|
||
const isLoading = Boolean(streamedResponse); | ||
|
||
const send = useCallback(async () => { | ||
const content = inputRef.current.value.trim(); | ||
|
||
if (!content || !chatId) { | ||
return; | ||
} | ||
|
||
setResponse("..."); | ||
|
||
const parent_id = messages[messages.length - 1]?.id ?? null; | ||
// we have to do this manually since we want to stream the chunks | ||
// there is also EventSource, but it only works with get requests. | ||
const { body } = await fetch("/api/chat/message", { | ||
method: "POST", | ||
headers: { "content-type": "application/json" }, | ||
body: JSON.stringify({ chat_id: chatId, content, parent_id }), | ||
}); | ||
|
||
// first chunk is message information | ||
const stream = iteratorSSE(body); | ||
const { value } = await stream.next(); | ||
const response: InferenceResponse = JSON.parse(value.data); | ||
|
||
setMessages((messages) => [...messages, response.prompter_message]); | ||
|
||
// remaining messages are the tokens | ||
let responseMessage = ""; | ||
for await (const { data } of stream) { | ||
const text = JSON.parse(data).token.text; | ||
responseMessage += text; | ||
setResponse(responseMessage); | ||
// wait for re-render | ||
await new Promise(requestAnimationFrame); | ||
} | ||
|
||
setMessages((old) => [...old, { ...response.assistant_message, content: responseMessage }]); | ||
setResponse(null); | ||
}, [chatId, messages]); | ||
|
||
return ( | ||
<Flex flexDir="column" gap={4} overflowY="auto"> | ||
{messages.map((message) => ( | ||
<Entry key={message.id} isAssistant={message.role === "assistant"}> | ||
{message.content} | ||
</Entry> | ||
))} | ||
{streamedResponse ? <Entry isAssistant>{streamedResponse}</Entry> : <Textarea ref={inputRef} autoFocus />} | ||
<Button onClick={send} isDisabled={isLoading}> | ||
{t("submit")} | ||
</Button> | ||
</Flex> | ||
); | ||
}; | ||
|
||
const Entry = ({ children, isAssistant }) => { | ||
const bgUser = useColorModeValue("gray.100", "gray.700"); | ||
const bgAssistant = useColorModeValue("#DFE8F1", "#42536B"); | ||
return ( | ||
<Box bg={isAssistant ? bgAssistant : bgUser} borderRadius="lg" p="4" whiteSpace="pre-line"> | ||
{children} | ||
</Box> | ||
); | ||
}; | ||
|
||
async function* iteratorSSE(stream: ReadableStream<Uint8Array>) { | ||
const reader = stream.pipeThrough(new TextDecoderStream()).getReader(); | ||
|
||
let done = false, | ||
value = ""; | ||
while (!done) { | ||
({ value, done } = await reader.read()); | ||
if (done) { | ||
break; | ||
} | ||
|
||
const fields = value | ||
.split(/\r?\n/) | ||
.filter(Boolean) | ||
.map((line) => { | ||
const colonIdx = line.indexOf(":"); | ||
return [line.slice(0, colonIdx), line.slice(colonIdx + 1).trimStart()]; | ||
}); | ||
yield Object.fromEntries(fields); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
const flags = [ | ||
{ name: "flagTest", isActive: false }, | ||
{ name: "chatEnabled", isActive: false }, | ||
{ name: "chat", isActive: process.env.NODE_ENV === "development" }, | ||
]; | ||
|
||
export default flags; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import axios, { AxiosRequestConfig } from "axios"; | ||
import Cookies from "cookies"; | ||
import type { NextApiRequest, NextApiResponse } from "next"; | ||
import { JWT } from "next-auth/jwt"; | ||
import { InferenceCreateChatResponse, InferenceDebugTokenResponse } from "src/types/Chat"; | ||
|
||
// TODO: this class could be structured better | ||
export class OasstInferenceClient { | ||
private readonly cookies: Cookies; | ||
private inferenceToken: string; | ||
private readonly userTokenSub: string; | ||
|
||
constructor(req: NextApiRequest, res: NextApiResponse, token: JWT) { | ||
this.cookies = new Cookies(req, res); | ||
this.inferenceToken = this.cookies.get("inference_token"); | ||
this.userTokenSub = token.sub; | ||
} | ||
|
||
async request(method: "GET" | "POST" | "PUT" | "DELETE", path: string, init?: AxiosRequestConfig) { | ||
const token = await this.get_token(); | ||
const { data } = await axios(process.env.INFERENCE_SERVER_HOST + path, { | ||
method, | ||
...init, | ||
headers: { | ||
...init?.headers, | ||
Authorization: `Bearer ${token}`, | ||
"Content-Type": "application/json", | ||
}, | ||
}); | ||
return data; | ||
} | ||
|
||
async get_token() { | ||
// TODO: handle the case where the token is outdated and requires a refresh. | ||
if (this.inferenceToken) { | ||
return this.inferenceToken; | ||
} | ||
console.log("fetching new token"); | ||
// we might want to include the inference token in our JWT, but this won't be trivial. | ||
// or we might have to force log-in the user every time a new JWT is created | ||
|
||
// TODO: we have not decided on a format for the user yet, this is here for debug only | ||
const res = await fetch(process.env.INFERENCE_SERVER_HOST + `/auth/login/debug?username=${this.userTokenSub}`); | ||
const inferenceResponse: InferenceDebugTokenResponse = await res.json(); | ||
this.inferenceToken = inferenceResponse.access_token; | ||
this.cookies.set("inference_token", this.inferenceToken); | ||
// console.dir(this.inferenceToken); | ||
return this.inferenceToken; | ||
} | ||
|
||
create_chat(): Promise<InferenceCreateChatResponse> { | ||
return this.request("POST", "/chat", { data: "" }); | ||
} | ||
|
||
post_prompt({ chat_id, parent_id, content }: { chat_id: string; parent_id: string | null; content: string }) { | ||
return this.request("POST", `/chat/${chat_id}/message`, { | ||
data: { parent_id, content }, | ||
responseType: "stream", | ||
}); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,10 @@ | ||
import { post } from "src/lib/api"; | ||
import { withoutRole } from "src/lib/auth"; | ||
|
||
export const INFERENCE_HOST = process.env.INFERENCE_SERVER_HOST; | ||
import { OasstInferenceClient } from "src/lib/oasst_inference_client"; | ||
|
||
const handler = withoutRole("banned", async (req, res, token) => { | ||
const chat = await post(INFERENCE_HOST + "/chat", { arg: {} }); | ||
return res.status(200).json(chat); | ||
const client = new OasstInferenceClient(req, res, token); | ||
const data = await client.create_chat(); | ||
return res.status(200).json(data); | ||
}); | ||
|
||
export default handler; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,12 @@ | ||
import axios from "axios"; | ||
import { IncomingMessage } from "http"; | ||
import { withoutRole } from "src/lib/auth"; | ||
|
||
import { INFERENCE_HOST } from "."; | ||
import { OasstInferenceClient } from "src/lib/oasst_inference_client"; | ||
|
||
const handler = withoutRole("banned", async (req, res, token) => { | ||
const { chat_id, parent_id, content } = req.body; | ||
|
||
const { data } = await axios.post<IncomingMessage>( | ||
INFERENCE_HOST + `/chat/${chat_id}/message`, | ||
{ parent_id, content }, | ||
{ responseType: "stream" } | ||
); | ||
const client = new OasstInferenceClient(req, res, token); | ||
const responseStream = await client.post_prompt({ chat_id, parent_id, content }); | ||
res.status(200); | ||
data.pipe(res); | ||
responseStream.pipe(res); | ||
}); | ||
|
||
export default handler; |
Oops, something went wrong.