Skip to content

Commit

Permalink
Boilerplate OAuth website <=> inference (#2127)
Browse files Browse the repository at this point in the history
Refs #2101

Use the website's backend as a callback url for login to discord and
github, so that the website also knows the token.


To use this, you need to configure your discord oauth provider to use
the url: `http://localhost:3000/api/inference_auth/discord`, I used the
same provider I use for logging in to the website and it worked like a
charm.

github also: `http://localhost:3000/api/inference_auth` or
`http://localhost:3000/api/inference_auth/gihtub`, both should work
since github allows sub paths.

you need to set these 4 env variables for the inference server:

```
AUTH_DISCORD_CLIENT_ID
AUTH_DISCORD_CLIENT_SECRET

AUTH_GITHUB_CLIENT_ID
AUTH_GITHUB_CLIENT_SECRET
```

then you can navigate to 

```
localhost:8000/auth/login/github
localhost:8000/auth/login/discord
```
  • Loading branch information
AbdBarho committed Mar 20, 2023
1 parent 42c8c3d commit b1f37f9
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 14 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -19,3 +19,5 @@ backend/openapi.json
# edit docs using obsidian.md, these files should not appear in the repo
.obsidian/
.pytest_cache/

/docker-compose.override.yml
1 change: 1 addition & 0 deletions docker-compose.yaml
Expand Up @@ -223,6 +223,7 @@ services:
POSTGRES_DB: oasst_inference
DEBUG_API_KEYS: "0000"
ALLOW_DEBUG_AUTH: "True"
AUTH_CALLBACK_ROOT: "http://localhost:3000/api/inference_auth"
volumes:
- "./oasst-shared:/opt/inference/lib/oasst-shared"
- "./inference/server:/opt/inference/server"
Expand Down
21 changes: 13 additions & 8 deletions inference/server/oasst_inference_server/routes/auth.py
Expand Up @@ -13,10 +13,15 @@
)


@router.get("/check")
async def check_user_auth(user_id: str = Depends(auth.get_current_user_id)):
return user_id


@router.get("/login/discord")
async def login_discord():
redirect_uri = f"{settings.api_root}/auth/callback/discord"
auth_url = f"https://discord.com/api/oauth2/authorize?client_id={settings.auth_discord_client_id}&redirect_uri={redirect_uri}&response_type=code&scope=identify"
async def login_discord(state: str = r"{}"):
redirect_uri = f"{settings.auth_callback_root}/discord"
auth_url = f"https://discord.com/api/oauth2/authorize?client_id={settings.auth_discord_client_id}&redirect_uri={redirect_uri}&response_type=code&scope=identify&state={state}"
raise HTTPException(status_code=302, headers={"location": auth_url})


Expand All @@ -25,7 +30,7 @@ async def callback_discord(
code: str,
db: database.AsyncSession = Depends(deps.create_session),
):
redirect_uri = f"{settings.api_root}/auth/callback/discord"
redirect_uri = f"{settings.auth_callback_root}/discord"

async with aiohttp.ClientSession(raise_for_status=True) as session:
# Exchange the auth code for a Discord access token
Expand Down Expand Up @@ -77,9 +82,9 @@ async def callback_discord(


@router.get("/login/github")
async def login_github():
redirect_uri = f"{settings.api_root}/auth/callback/github"
auth_url = f"https://github.com/login/oauth/authorize?client_id={settings.auth_github_client_id}&redirect_uri={redirect_uri}"
async def login_github(state: str = r"{}"):
redirect_uri = f"{settings.auth_callback_root}/github"
auth_url = f"https://github.com/login/oauth/authorize?client_id={settings.auth_github_client_id}&redirect_uri={redirect_uri}&state={state}"
raise HTTPException(status_code=302, headers={"location": auth_url})


Expand All @@ -88,7 +93,7 @@ async def callback_github(
code: str,
db: database.AsyncSession = Depends(deps.create_session),
):
redirect_uri = f"{settings.api_root}/auth/callback/github"
redirect_uri = f"{settings.auth_callback_root}/github"

async with aiohttp.ClientSession(raise_for_status=True) as session:
# Exchange the auth code for a GitHub access token
Expand Down
6 changes: 5 additions & 1 deletion inference/server/oasst_inference_server/settings.py
Expand Up @@ -58,7 +58,11 @@ def debug_api_keys_list(self) -> list[str]:
compliance_check_interval: int = 60
compliance_check_timeout: int = 60

api_root: str = "https://inference.prod.open-assistant.io"
# this is the URL which will be redirected to when authenticating with oauth2
# we decided on letting the nextjs / website backend handle the token at first
# and then proxy this information back to the inference server
# in short: this should refer to the website, not to this server
auth_callback_root: str = "https://open-assistant.io/api/inference_auth"

allow_debug_auth: bool = False

Expand Down
4 changes: 2 additions & 2 deletions website/src/lib/oasst_inference_client.ts
Expand Up @@ -3,7 +3,7 @@ import axios, { AxiosRequestConfig } from "axios";
import Cookies from "cookies";
import type { NextApiRequest, NextApiResponse } from "next";
import { JWT } from "next-auth/jwt";
import { ChatItem, InferenceDebugTokenResponse, InferenceMessage, InferencePostMessageResponse } from "src/types/Chat";
import { ChatItem, InferenceTokenResponse, InferenceMessage, InferencePostMessageResponse } from "src/types/Chat";

// TODO: this class could be structured better
export class OasstInferenceClient {
Expand Down Expand Up @@ -41,7 +41,7 @@ export class OasstInferenceClient {

// TODO: we have not decided on a format for the user yet, this is here for debug only
const res = await fetch(process.env.INFERENCE_SERVER_HOST + `/auth/login/debug?username=${this.userTokenSub}`);
const inferenceResponse: InferenceDebugTokenResponse = await res.json();
const inferenceResponse: InferenceTokenResponse = await res.json();
this.inferenceToken = inferenceResponse.access_token;
this.cookies.set("inference_token", this.inferenceToken, {
maxAge: 1000 * 60 * 5, // 5 minutes
Expand Down
16 changes: 16 additions & 0 deletions website/src/pages/api/inference_auth/[...parts].ts
@@ -0,0 +1,16 @@
import axios from "axios";
import type { NextApiRequest, NextApiResponse } from "next";
import { InferenceTokenResponse } from "src/types/Chat";

export default async function inferenceAuthCallback(req: NextApiRequest, res: NextApiResponse) {
const { code, parts } = req.query;
console.log(req.query);
if (!Array.isArray(parts) || parts.length !== 1) {
return res.status(400).end();
}
const [provider] = parts as string[];
const url = process.env.INFERENCE_SERVER_HOST + `/auth/callback/${provider}?code=${code}`;
const { data } = await axios<InferenceTokenResponse>(url);
console.log(data);
return res.send(data);
}
4 changes: 2 additions & 2 deletions website/src/pages/team.tsx
Expand Up @@ -2,8 +2,8 @@ export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_
import { Avatar, Badge, Box, Card, CardBody, Flex, Grid, Heading, Text } from "@chakra-ui/react";
import { Github } from "lucide-react";
import Head from "next/head";
import { useTranslation } from "next-i18next";
import Link from "next/link";
import { useTranslation } from "next-i18next";
import React from "react";
import { getTransparentHeaderLayout } from "src/components/Layout";

Expand All @@ -16,7 +16,7 @@ const Team = () => {
<>
<Head>
<title>{t("who_are_we")} - Open Assistant</title>
<meta name="description" content="The team begind Open Assistant" />
<meta name="description" content="The team behind Open Assistant" />
</Head>
<Box fontFamily="Inter" p="6" className="oa-basic-theme">
<Box className="max-w-6xl mx-auto">
Expand Down
2 changes: 1 addition & 1 deletion website/src/types/Chat.ts
@@ -1,4 +1,4 @@
export interface InferenceDebugTokenResponse {
export interface InferenceTokenResponse {
access_token: string;
token_type: string;
}
Expand Down

0 comments on commit b1f37f9

Please sign in to comment.