Skip to content

Commit

Permalink
switched to HF text-generation-inference
Browse files Browse the repository at this point in the history
  • Loading branch information
yk committed Feb 8, 2023
1 parent af6885d commit bab056a
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 12 deletions.
2 changes: 1 addition & 1 deletion inference/README.md
Expand Up @@ -86,7 +86,7 @@ For the worker, you'll also want to have the text-generation-inference server
running:

```bash
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference
```

Run the client:
Expand Down
2 changes: 1 addition & 1 deletion inference/full-dev-setup.sh
Expand Up @@ -5,7 +5,7 @@
tmux new-session -d -s "inference-dev-setup"
tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m
tmux split-window -h
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference" C-m
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference" C-m
tmux split-window -h
tmux send-keys "cd server" C-m
tmux send-keys "uvicorn main:app --reload" C-m
Expand Down
4 changes: 2 additions & 2 deletions inference/server/main.py
Expand Up @@ -57,7 +57,7 @@ def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:


class TokenResponseEvent(pydantic.BaseModel):
token: str
token: inference.TokenResponse


class MessageRequestState(str, enum.Enum):
Expand Down Expand Up @@ -143,7 +143,7 @@ async def event_generator():

chat.conversation.messages.append(
protocol.ConversationMessage(
text="".join([d.token for d in result_data[:-1]]),
text=response_packet.generated_text.text,
is_assistant=True,
)
)
Expand Down
2 changes: 1 addition & 1 deletion inference/text-client/__main__.py
Expand Up @@ -32,7 +32,7 @@ def main(backend_url: str = "http://127.0.0.1:8000"):
print("Assistant: ", end="", flush=True)
for event in client.events():
data = json.loads(event.data)
print(data["token"], end="", flush=True)
print(data["token"]["text"], end="", flush=True)
print()


Expand Down
36 changes: 30 additions & 6 deletions inference/worker/__main__.py
Expand Up @@ -54,24 +54,48 @@ def _prepare_message(message: protocol.ConversationMessage) -> str:
"top_p": work_request.top_p,
"temperature": work_request.temperature,
"seed": work_request.seed,
# "stop": ["User:", "Assistant:"], # TODO: this doesn't work... why?
},
},
stream=True,
headers={"Accept": "text/event-stream"},
)
response.raise_for_status()
try:
response.raise_for_status()
except requests.HTTPError:
logger.exception("Failed to get response from inference server")
return

client = sseclient.SSEClient(response)
for event in client.events():
logger.debug(f"Received event: {event}")
data = json.loads(event.data)
if data["is_end"]:
if data["generated_text"]:
break
intermediate = data["event"]
ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json())
ws.send(inference.WorkResponsePacket(is_end=True).json())
token = data["token"]
ws.send(
inference.WorkResponsePacket(
token=inference.TokenResponse(
text=token["text"],
log_prob=token["logprob"],
token_id=token["id"],
)
).json()
)
ws.send(
inference.WorkResponsePacket(
is_end=True,
generated_text=inference.GeneratedTextResponse(
text=data["generated_text"],
),
).json()
)

def on_error(ws: websocket.WebSocket, error: Exception):
logger.error(f"Connection error: {error}")
try:
raise error
except Exception:
logger.exception("Error in websocket")

def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str):
logger.warning(f"Connection closed: {close_status_code=} {close_msg=}")
Expand Down
13 changes: 12 additions & 1 deletion oasst-shared/oasst_shared/schemas/inference.py
Expand Up @@ -20,6 +20,17 @@ class WorkRequest(pydantic.BaseModel):
temperature: float = 1.0


class TokenResponse(pydantic.BaseModel):
text: str
log_prob: float
token_id: int


class GeneratedTextResponse(pydantic.BaseModel):
text: str


class WorkResponsePacket(pydantic.BaseModel):
token: str | None = None
token: TokenResponse | None = None
generated_text: GeneratedTextResponse | None = None
is_end: bool = False

0 comments on commit bab056a

Please sign in to comment.