Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify implementation with new pyllamacpp apis #7

Merged
merged 1 commit into from
Apr 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

[![PyPI version](https://img.shields.io/pypi/v/llama-server)](https://pypi.org/project/llama-server/) [![Unit test](https://github.com/nuance1979/llama-server/actions/workflows/test.yml/badge.svg?branch=main&&event=push)](https://github.com/nuance1979/llama-server/actions) [![GitHub stars](https://img.shields.io/github/stars/nuance1979/llama-server)](https://star-history.com/#nuance1979/llama-server&Date) [![GitHub license](https://img.shields.io/github/license/nuance1979/llama-server)](https://github.com/nuance1979/llama-server/blob/master/LICENSE)

LLaMA Server combines the power of [LLaMA C++](https://github.com/ggerganov/llama.cpp) (via [PyLLaMACpp](https://github.com/nomic-ai/pyllamacpp)) with the beauty of [Chatbot UI](https://github.com/mckaywrigley/chatbot-ui).
LLaMA Server combines the power of [LLaMA C++](https://github.com/ggerganov/llama.cpp) (via [PyLLaMACpp](https://github.com/abdeladim-s/pyllamacpp)) with the beauty of [Chatbot UI](https://github.com/mckaywrigley/chatbot-ui).

🦙LLaMA C++ (via 🐍PyLLaMACpp) ➕ 🤖Chatbot UI ➕ 🔗LLaMA Server 🟰 😊

**UPDATE**: Now supports better streaming through [PyLLaMACpp](https://github.com/nomic-ai/pyllamacpp)!
**UPDATE**: Greatly simplified implementation thanks to the [awesome Pythonic APIs](https://github.com/abdeladim-s/pyllamacpp#different-persona) of PyLLaMACpp 2.0.0!

**UPDATE**: Now supports better streaming through [PyLLaMACpp](https://github.com/abdeladim-s/pyllamacpp)!

**UPDATE**: Now supports streaming!

Expand Down Expand Up @@ -57,11 +59,6 @@ conda activate llama
python -m pip install git+https://github.com/nuance1979/llama-server.git
```

- Install a patched version of PyLLaMACpp: (*Note:* this step will not be needed **after** PyLLaMACpp makes a new release.)
```bash
python -m pip install git+https://github.com/nuance1979/pyllamacpp.git@dev --upgrade
```

- Start LLaMA Server with your `models.yml` file:
```bash
llama-server --models-yml models.yml --model-id llama-7b
Expand Down Expand Up @@ -97,11 +94,6 @@ export LLAMA_STREAM_MODE=0 # 1 to enable streaming
npm run dev
```

## Limitations

- "Regenerate response" is currently not working;
- IMHO, the prompt/reverse-prompt machanism of LLaMA C++'s interactive mode needs an overhaul. I tried very hard to dance around it but the whole thing is still a hack.

## Fun facts

I am not fluent in JavaScript at all but I was able to make the changes in Chatbot UI by chatting with [ChatGPT](https://chat.openai.com); no more StackOverflow.
2 changes: 1 addition & 1 deletion llama_server/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "0.2.0"
153 changes: 26 additions & 127 deletions llama_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@
"""Serve llama models with llama.cpp ."""
import logging
import os
import time
from collections import deque
from contextlib import asynccontextmanager
from multiprocessing import Process
from multiprocessing import Queue
from pathlib import Path
from typing import Any
from typing import Dict
Expand All @@ -27,10 +22,9 @@


PROMPT_PATH = Path(__file__).parent / "prompts" / "chat-with-bob.txt"
PROMPT = PROMPT_PATH.read_text(encoding="utf-8").strip()
PROMPT_SIZE = len(PROMPT)
REVERSE_PROMPT = "User:"
REPLY_PREFIX = "Bob: "
PROMPT = PROMPT_PATH.read_text(encoding="utf-8")
REVERSE_PROMPT = "\n User:"
REPLY_PREFIX = "\n Bob:"


class Message(BaseModel):
Expand Down Expand Up @@ -66,137 +60,27 @@ class ModelList(BaseModel):
object: str = "list"


class Buffer:
def __init__(self, prompt_size: int, reverse_prompt: str) -> None:
self._q = deque(maxlen=10)
self._c = 0
self._prompt_size = prompt_size
self._reverse_prompt = reverse_prompt
self._is_first = True
self._c_first = 0

def __len__(self) -> int:
return self._c

def prompt_consumed(self) -> bool:
return self._c_first >= self._prompt_size

def clear(self) -> None:
self._q.clear()
self._c = 0

def append(self, data: str) -> None:
if self._is_first:
self._c_first += len(data)
if self._c_first < self._prompt_size:
return
else:
self._is_first = False
diff = self._c_first - self._prompt_size
if diff > 0:
self.append(data[-diff:])
self._c_first = self._prompt_size
else:
self._c += len(data)
self._q.append(data)

def popleft(self) -> str:
if self._c < len(self._reverse_prompt):
return ""
data = self._q.popleft()
self._c -= len(data)
if self._c < len(self._reverse_prompt):
diff = self._c - len(self._reverse_prompt)
self._q.appendleft(data[diff:])
self._c = len(self._reverse_prompt)
data = data[:diff]
return data

def turnends(self) -> bool:
return "".join(self._q).endswith(self._reverse_prompt)


logging.basicConfig(
format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s", level=logging.INFO
)
logger = logging.getLogger(name=__name__)

logger = None
model_id = None
model_path = None
output_q = Queue()
input_q = Queue()
buffer = Buffer(PROMPT_SIZE, REVERSE_PROMPT)


def generate(model_path: str, input_q: Queue, output_q: Queue) -> None:
def output_callback(text: str) -> None:
output_q.put(text)

def input_callback() -> str:
return input_q.get()

model = Model(ggml_model=model_path, n_ctx=512)
model.generate(
PROMPT,
new_text_callback=output_callback,
grab_text_callback=input_callback,
n_predict=256,
n_batch=1024,
n_keep=48,
repeat_penalty=1.0,
n_threads=8,
interactive=True,
antiprompt=[REVERSE_PROMPT],
)
model = None

app = FastAPI()

@asynccontextmanager
async def lifespan(app: FastAPI):
p = Process(target=generate, args=(model_path, input_q, output_q))
p.start()
# Skip system prompt
while not buffer.prompt_consumed():
buffer.append(output_q.get())
logger.info("ready to serve...")
yield
input_q.close()
output_q.close()
if p.is_alive():
p.terminate()
time.sleep(5)
p.kill()


app = FastAPI(lifespan=lifespan)
def _chat(user_utt: str) -> Generator[str, None, None]:
return model.generate(user_utt, n_predict=256, repeat_penalty=1.0, n_threads=8)


def chat_stream(user_utt: str) -> Generator[Dict[str, Any], None, None]:
for text in _chat(user_utt):
logger.debug("text: %s", text)
payload = Completion(
choices=[Choice(delta=Message(role="assistant", content=text))]
)
yield {"event": "event", "data": payload.json()}
yield {"event": "event", "data": "[DONE]"}


def _chat(user_utt: str) -> Generator[str, None, None]:
input_q.put(user_utt)
counter = 0
while not buffer.turnends():
text = output_q.get()
counter += len(text)
if counter <= len(REPLY_PREFIX):
continue
buffer.append(text)
yield buffer.popleft()
while True:
text = buffer.popleft()
if not text:
break
yield text
buffer.clear()


def chat_nonstream(user_utt: str) -> Completion:
assistant_utt = "".join(_chat(user_utt))
logger.info("assistant: %s", assistant_utt)
Expand All @@ -212,7 +96,7 @@ def chat(conv: Conversation):
if not conv.stream:
return chat_nonstream(user_utt)
else:
return EventSourceResponse(chat_stream(user_utt))
return EventSourceResponse(chat_stream(user_utt), ping_message_factory=None)


@app.get("/v1/models")
Expand Down Expand Up @@ -247,13 +131,20 @@ class KnownModels(BaseModel):
)
@click.option("--model-id", type=click.STRING, default="llama-7b", help="Model id.")
@click.option("--model-path", type=click.Path(exists=True), help="Model path.")
@click.option(
"--log-level",
type=click.Choice(["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]),
default="INFO",
help="Log level.",
)
def main(
models_yml: Path,
host: str,
port: int,
reload: bool,
model_id: Optional[str] = None,
model_path: Optional[Path] = None,
log_level: Optional[str] = None,
):
with open(models_yml, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
Expand All @@ -266,7 +157,15 @@ def main(
if not model_path.is_absolute():
model_path = Path(KNOWN_MODELS.model_home) / model_path
globals()["model_id"] = model_id
globals()["model_path"] = str(model_path)
globals()["model"] = Model(
ggml_model=str(model_path),
n_ctx=512,
prompt_context=PROMPT,
prompt_prefix=REVERSE_PROMPT,
prompt_suffix=REPLY_PREFIX,
)
globals()["logger"] = logging.getLogger(name=__name__)
globals()["logger"].setLevel(log_level)

uvicorn.run("llama_server.server:app", host=host, port=port, reload=reload)

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_install_requires() -> str:
def get_extras_require() -> str:
req = {
"dev": [
"httpx",
"pre-commit",
"pytest",
"pytest-cov",
Expand Down
111 changes: 66 additions & 45 deletions test/llama_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,54 +3,75 @@
# Copyright (c) 2023 Yi Su. All rights reserved.
#
"""Test llama server."""
import logging
from unittest import TestCase
from unittest.mock import patch

from llama_server.server import Buffer
from fastapi.testclient import TestClient

from llama_server.server import app
from llama_server.server import Conversation
from llama_server.server import Message

class TestBuffer(TestCase):
"""Test Buffer class."""

class MockModel:
tokens = ["44th", " presi", "dent", " of", " USA"]

def generate(self, *args, **kwargs):
for tok in MockModel.tokens:
yield tok


class Testapp(TestCase):
"""Test the server."""

def setUp(self):
self._buffer = Buffer(10, "user:")

def testInit(self):
self.assertTrue(isinstance(self._buffer, Buffer))

def testAppendPopleft(self):
self.assertEqual(0, len(self._buffer))
self._buffer.append("0123456789")
self.assertEqual(0, len(self._buffer)) # consume prompt
self._buffer.append("test")
self.assertEqual(4, len(self._buffer))
self._buffer.append("abc")
self.assertEqual(7, len(self._buffer))
self.assertEqual("te", self._buffer.popleft()) # because len("user:") == 5
self.assertEqual("", self._buffer.popleft())
self._buffer.append("this is test")
self.assertEqual("st", self._buffer.popleft())
self.assertEqual("abc", self._buffer.popleft())

def testPromptConsumed(self):
self._buffer.append("abcdefgh")
self.assertFalse(self._buffer.prompt_consumed())
self._buffer.append("123")
self.assertTrue(self._buffer.prompt_consumed())

def testTurnends(self):
self.assertFalse(self._buffer.turnends())
self._buffer.append("0123456789")
self._buffer.append("user")
self.assertFalse(self._buffer.turnends())
self._buffer.append(":")
self.assertTrue(self._buffer.turnends())

def testClear(self):
self._buffer.append("0123456789")
self.assertTrue(self._buffer.prompt_consumed())
self._buffer.append("abc")
self._buffer.append("xyz")
self.assertEqual(6, len(self._buffer))
self._buffer.clear()
self.assertEqual(0, len(self._buffer))
self.assertTrue(self._buffer.prompt_consumed())
self._client = TestClient(app)

@patch("llama_server.server.model_id", "myModelId")
def testGetModels(self):
response = self._client.get("/v1/models")
self.assertEqual(200, response.status_code)
json = response.json()
self.assertEqual(1, len(json["data"]))
self.assertEqual("myModelId", json["data"][0]["id"])

@patch("llama_server.server.logger", logging)
@patch("llama_server.server.model_id", "myModelId")
@patch("llama_server.server.model", MockModel())
def testPostChat(self):
conv = Conversation(
model="myModelId",
messages=[Message(role="user", content="who is barack obama?")],
max_tokens=256,
temperature=0.8,
stream=False,
)
response = self._client.post("/v1/chat/completions", data=conv.json())
json = response.json()
self.assertEqual(1, len(json["choices"]))
self.assertEqual("assistant", json["choices"][0]["message"]["role"])
self.assertEqual(
"".join(MockModel.tokens), json["choices"][0]["message"]["content"]
)

@patch("llama_server.server.logger", logging)
@patch("llama_server.server.model_id", "myModelId")
@patch("llama_server.server.model", MockModel())
def testPostChatStreaming(self):
conv = Conversation(
model="myModelId",
messages=[Message(role="user", content="who is barack obama?")],
max_tokens=256,
temperature=0.8,
stream=True,
)
response = self._client.post("/v1/chat/completions", data=conv.json())
from json import loads

datalines = [line for line in response.iter_lines() if line.startswith("data")]
for line, tok in zip(datalines, MockModel.tokens):
json = loads(line[6:])
self.assertEqual(1, len(json["choices"]))
self.assertEqual("assistant", json["choices"][0]["delta"]["role"])
self.assertEqual(tok, json["choices"][0]["delta"]["content"])