Skip to content

Commit

Permalink
Initial plugins code cleanup (#3120)
Browse files Browse the repository at this point in the history
- Introduce `PromptedLLM` class to manage the repeated process of
preparing prompt with memory and template, calling LLM, and
postprocessing output
- Separating out functions in `openapi_parser`
- General improvements to clarity/readability
  • Loading branch information
olliestanley committed May 19, 2023
1 parent 0b577aa commit 1a14d9d
Show file tree
Hide file tree
Showing 9 changed files with 357 additions and 369 deletions.
55 changes: 55 additions & 0 deletions inference/server/oasst_inference_server/plugin_utils.py
@@ -0,0 +1,55 @@
import asyncio
import json

import aiohttp
import yaml
from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError
from fastapi import HTTPException
from loguru import logger
from oasst_shared.schemas import inference


async def attempt_fetch_plugin(session: aiohttp.ClientSession, url: str, timeout: float = 5.0):
async with session.get(url, timeout=timeout) as response:
content_type = response.headers.get("Content-Type")

if response.status == 404:
raise HTTPException(status_code=404, detail="Plugin not found")
if response.status != 200:
raise HTTPException(status_code=500, detail="Failed to fetch plugin")

if "application/json" in content_type or "text/json" in content_type or url.endswith(".json"):
if "text/json" in content_type:
logger.warning(f"Plugin {url} is using text/json as its content type. This is not recommended.")
config = json.loads(await response.text())
else:
config = await response.json()
elif (
"application/yaml" in content_type
or "application/x-yaml" in content_type
or url.endswith(".yaml")
or url.endswith(".yml")
):
config = yaml.safe_load(await response.text())
else:
raise HTTPException(
status_code=400,
detail=f"Unsupported content type: {content_type}. Only JSON and YAML are supported.",
)

return inference.PluginConfig(**config)


async def fetch_plugin(url: str, retries: int = 3, timeout: float = 5.0) -> inference.PluginConfig:
async with aiohttp.ClientSession() as session:
for attempt in range(retries):
try:
plugin_config = await attempt_fetch_plugin(session, url, timeout=timeout)
return plugin_config
except (ClientConnectorError, ServerTimeoutError) as e:
if attempt == retries - 1:
raise HTTPException(status_code=500, detail=f"Request failed after {retries} retries: {e}")
await asyncio.sleep(2**attempt) # exponential backoff
except aiohttp.ClientError as e:
raise HTTPException(status_code=500, detail=f"Request failed: {e}")
raise HTTPException(status_code=500, detail="Failed to fetch plugin")
55 changes: 3 additions & 52 deletions inference/server/oasst_inference_server/routes/configs.py
@@ -1,13 +1,8 @@
import asyncio
import json

import aiohttp
import fastapi
import pydantic
import yaml
from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError
from fastapi import HTTPException
from loguru import logger
from oasst_inference_server import plugin_utils
from oasst_inference_server.settings import settings
from oasst_shared import model_configs
from oasst_shared.schemas import inference
Expand Down Expand Up @@ -114,50 +109,6 @@ class ModelConfigInfo(pydantic.BaseModel):
]


async def fetch_plugin(url: str, retries: int = 3, timeout: float = 5.0) -> inference.PluginConfig:
async with aiohttp.ClientSession() as session:
for attempt in range(retries):
try:
async with session.get(url, timeout=timeout) as response:
content_type = response.headers.get("Content-Type")

if response.status == 200:
if "application/json" in content_type or "text/json" in content_type or url.endswith(".json"):
if "text/json" in content_type:
logger.warning(
f"Plugin {url} is using text/json as its content type. This is not recommended."
)
config = json.loads(await response.text())
else:
config = await response.json()
elif (
"application/yaml" in content_type
or "application/x-yaml" in content_type
or url.endswith(".yaml")
or url.endswith(".yml")
):
config = yaml.safe_load(await response.text())
else:
raise HTTPException(
status_code=400,
detail=f"Unsupported content type: {content_type}. Only JSON and YAML are supported.",
)

return inference.PluginConfig(**config)
elif response.status == 404:
raise HTTPException(status_code=404, detail="Plugin not found")
else:
raise HTTPException(status_code=response.status, detail="Unexpected status code")
except (ClientConnectorError, ServerTimeoutError) as e:
if attempt == retries - 1: # last attempt
raise HTTPException(status_code=500, detail=f"Request failed after {retries} retries: {e}")
await asyncio.sleep(2**attempt) # exponential backoff

except aiohttp.ClientError as e:
raise HTTPException(status_code=500, detail=f"Request failed: {e}")
raise HTTPException(status_code=500, detail="Failed to fetch plugin")


@router.get("/model_configs")
async def get_model_configs() -> list[ModelConfigInfo]:
return [
Expand All @@ -173,7 +124,7 @@ async def get_model_configs() -> list[ModelConfigInfo]:
@router.post("/plugin_config")
async def get_plugin_config(plugin: inference.PluginEntry) -> inference.PluginEntry:
try:
plugin_config = await fetch_plugin(plugin.url)
plugin_config = await plugin_utils.fetch_plugin(plugin.url)
except HTTPException as e:
logger.warning(f"Failed to fetch plugin config from {plugin.url}: {e.detail}")
raise fastapi.HTTPException(status_code=e.status_code, detail=e.detail)
Expand All @@ -187,7 +138,7 @@ async def get_builtin_plugins() -> list[inference.PluginEntry]:

for plugin in OA_PLUGINS:
try:
plugin_config = await fetch_plugin(plugin.url)
plugin_config = await plugin_utils.fetch_plugin(plugin.url)
except HTTPException as e:
logger.warning(f"Failed to fetch plugin config from {plugin.url}: {e.detail}")
continue
Expand Down
6 changes: 4 additions & 2 deletions inference/worker/basic_hf_server.py
Expand Up @@ -5,13 +5,13 @@
from queue import Queue

import fastapi
import hf_stopping
import hf_streamer
import interface
import torch
import transformers
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
from hf_stopping import SequenceStoppingCriteria
from loguru import logger
from oasst_shared import model_configs
from settings import settings
Expand Down Expand Up @@ -85,7 +85,9 @@ def print_text(token_id: int):
streamer = hf_streamer.HFStreamer(input_ids=ids, printer=print_text)
ids = ids.to(model.device)
stopping_criteria = (
transformers.StoppingCriteriaList([SequenceStoppingCriteria(tokenizer, stop_sequences, prompt)])
transformers.StoppingCriteriaList(
[hf_stopping.SequenceStoppingCriteria(tokenizer, stop_sequences, prompt)]
)
if stop_sequences
else None
)
Expand Down

0 comments on commit 1a14d9d

Please sign in to comment.