Skip to content

Commit

Permalink
Add debug auth between Website & Inference (#1893)
Browse files Browse the repository at this point in the history
Refs #1880 #1881 

This PR does not add any new features to the UI, but enables debug
authorization when talking to the inference api.
  • Loading branch information
AbdBarho committed Feb 26, 2023
1 parent 54070f7 commit abdeb25
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 124 deletions.
33 changes: 32 additions & 1 deletion website/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions website/package.json
Expand Up @@ -47,6 +47,7 @@
"boolean": "^3.2.0",
"chart.js": "^4.2.1",
"clsx": "^1.2.1",
"cookies": "^0.8.0",
"date-fns": "^2.29.3",
"eslint": "8.29.0",
"eslint-config-next": "13.0.6",
Expand Down
102 changes: 102 additions & 0 deletions website/src/components/Chat/ChatConversation.tsx
@@ -0,0 +1,102 @@
import { Box, Button, Flex, Textarea, useColorModeValue } from "@chakra-ui/react";
import { useTranslation } from "next-i18next";
import { useCallback, useRef, useState } from "react";
import { InferenceMessage, InferenceResponse } from "src/types/Chat";

interface ChatConversationProps {
chatId: string;
}

export const ChatConversation = ({ chatId }: ChatConversationProps) => {
const { t } = useTranslation("common");
const inputRef = useRef<HTMLTextAreaElement>();
const [messages, setMessages] = useState<InferenceMessage[]>([]);
const [streamedResponse, setResponse] = useState("");

const isLoading = Boolean(streamedResponse);

const send = useCallback(async () => {
const content = inputRef.current.value.trim();

if (!content || !chatId) {
return;
}

setResponse("...");

const parent_id = messages[messages.length - 1]?.id ?? null;
// we have to do this manually since we want to stream the chunks
// there is also EventSource, but it only works with get requests.
const { body } = await fetch("/api/chat/message", {
method: "POST",
headers: { "content-type": "application/json" },
body: JSON.stringify({ chat_id: chatId, content, parent_id }),
});

// first chunk is message information
const stream = iteratorSSE(body);
const { value } = await stream.next();
const response: InferenceResponse = JSON.parse(value.data);

setMessages((messages) => [...messages, response.prompter_message]);

// remaining messages are the tokens
let responseMessage = "";
for await (const { data } of stream) {
const text = JSON.parse(data).token.text;
responseMessage += text;
setResponse(responseMessage);
// wait for re-render
await new Promise(requestAnimationFrame);
}

setMessages((old) => [...old, { ...response.assistant_message, content: responseMessage }]);
setResponse(null);
}, [chatId, messages]);

return (
<Flex flexDir="column" gap={4} overflowY="auto">
{messages.map((message) => (
<Entry key={message.id} isAssistant={message.role === "assistant"}>
{message.content}
</Entry>
))}
{streamedResponse ? <Entry isAssistant>{streamedResponse}</Entry> : <Textarea ref={inputRef} autoFocus />}
<Button onClick={send} isDisabled={isLoading}>
{t("submit")}
</Button>
</Flex>
);
};

const Entry = ({ children, isAssistant }) => {
const bgUser = useColorModeValue("gray.100", "gray.700");
const bgAssistant = useColorModeValue("#DFE8F1", "#42536B");
return (
<Box bg={isAssistant ? bgAssistant : bgUser} borderRadius="lg" p="4" whiteSpace="pre-line">
{children}
</Box>
);
};

async function* iteratorSSE(stream: ReadableStream<Uint8Array>) {
const reader = stream.pipeThrough(new TextDecoderStream()).getReader();

let done = false,
value = "";
while (!done) {
({ value, done } = await reader.read());
if (done) {
break;
}

const fields = value
.split(/\r?\n/)
.filter(Boolean)
.map((line) => {
const colonIdx = line.indexOf(":");
return [line.slice(0, colonIdx), line.slice(colonIdx + 1).trimStart()];
});
yield Object.fromEntries(fields);
}
}
2 changes: 1 addition & 1 deletion website/src/flags.ts
@@ -1,6 +1,6 @@
const flags = [
{ name: "flagTest", isActive: false },
{ name: "chatEnabled", isActive: false },
{ name: "chat", isActive: process.env.NODE_ENV === "development" },
];

export default flags;
61 changes: 61 additions & 0 deletions website/src/lib/oasst_inference_client.ts
@@ -0,0 +1,61 @@
import axios, { AxiosRequestConfig } from "axios";
import Cookies from "cookies";
import type { NextApiRequest, NextApiResponse } from "next";
import { JWT } from "next-auth/jwt";
import { InferenceCreateChatResponse, InferenceDebugTokenResponse } from "src/types/Chat";

// TODO: this class could be structured better
export class OasstInferenceClient {
private readonly cookies: Cookies;
private inferenceToken: string;
private readonly userTokenSub: string;

constructor(req: NextApiRequest, res: NextApiResponse, token: JWT) {
this.cookies = new Cookies(req, res);
this.inferenceToken = this.cookies.get("inference_token");
this.userTokenSub = token.sub;
}

async request(method: "GET" | "POST" | "PUT" | "DELETE", path: string, init?: AxiosRequestConfig) {
const token = await this.get_token();
const { data } = await axios(process.env.INFERENCE_SERVER_HOST + path, {
method,
...init,
headers: {
...init?.headers,
Authorization: `Bearer ${token}`,
"Content-Type": "application/json",
},
});
return data;
}

async get_token() {
// TODO: handle the case where the token is outdated and requires a refresh.
if (this.inferenceToken) {
return this.inferenceToken;
}
console.log("fetching new token");
// we might want to include the inference token in our JWT, but this won't be trivial.
// or we might have to force log-in the user every time a new JWT is created

// TODO: we have not decided on a format for the user yet, this is here for debug only
const res = await fetch(process.env.INFERENCE_SERVER_HOST + `/auth/login/debug?username=${this.userTokenSub}`);
const inferenceResponse: InferenceDebugTokenResponse = await res.json();
this.inferenceToken = inferenceResponse.access_token;
this.cookies.set("inference_token", this.inferenceToken);
// console.dir(this.inferenceToken);
return this.inferenceToken;
}

create_chat(): Promise<InferenceCreateChatResponse> {
return this.request("POST", "/chat", { data: "" });
}

post_prompt({ chat_id, parent_id, content }: { chat_id: string; parent_id: string | null; content: string }) {
return this.request("POST", `/chat/${chat_id}/message`, {
data: { parent_id, content },
responseType: "stream",
});
}
}
9 changes: 4 additions & 5 deletions website/src/pages/api/chat/index.ts
@@ -1,11 +1,10 @@
import { post } from "src/lib/api";
import { withoutRole } from "src/lib/auth";

export const INFERENCE_HOST = process.env.INFERENCE_SERVER_HOST;
import { OasstInferenceClient } from "src/lib/oasst_inference_client";

const handler = withoutRole("banned", async (req, res, token) => {
const chat = await post(INFERENCE_HOST + "/chat", { arg: {} });
return res.status(200).json(chat);
const client = new OasstInferenceClient(req, res, token);
const data = await client.create_chat();
return res.status(200).json(data);
});

export default handler;
15 changes: 4 additions & 11 deletions website/src/pages/api/chat/message.ts
@@ -1,19 +1,12 @@
import axios from "axios";
import { IncomingMessage } from "http";
import { withoutRole } from "src/lib/auth";

import { INFERENCE_HOST } from ".";
import { OasstInferenceClient } from "src/lib/oasst_inference_client";

const handler = withoutRole("banned", async (req, res, token) => {
const { chat_id, parent_id, content } = req.body;

const { data } = await axios.post<IncomingMessage>(
INFERENCE_HOST + `/chat/${chat_id}/message`,
{ parent_id, content },
{ responseType: "stream" }
);
const client = new OasstInferenceClient(req, res, token);
const responseStream = await client.post_prompt({ chat_id, parent_id, content });
res.status(200);
data.pipe(res);
responseStream.pipe(res);
});

export default handler;

0 comments on commit abdeb25

Please sign in to comment.