Skip to content

Commit

Permalink
Merge pull request #2 from c0sogi/dev
Browse files Browse the repository at this point in the history
Huggingface downloader & Simpler log message & InterruptMixin
  • Loading branch information
c0sogi committed Aug 2, 2023
2 parents efb6956 + 14ad9fd commit 344ab12
Show file tree
Hide file tree
Showing 18 changed files with 1,007 additions and 304 deletions.
Binary file added contents/auto-download-model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions install_packages.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
set VENV_DIR=.venv

if not exist %VENV_DIR% (
echo Creating virtual environment
python -m venv %VENV_DIR%
)
call %VENV_DIR%\Scripts\activate.bat
python -m llama_api.server.app_settings --install-pkgs
9 changes: 9 additions & 0 deletions install_packages.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash
VENV_DIR=.venv

if [ ! -d "$VENV_DIR" ]; then
echo "Creating virtual environment"
python3 -m venv $VENV_DIR
fi
source $VENV_DIR/bin/activate
python3 -m llama_api.server.app_settings --install-pkgs
32 changes: 32 additions & 0 deletions llama_api/mixins/interrupt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from threading import Event
from typing import Optional


class InterruptMixin:
"""A mixin class for interrupting(aborting) a job."""

_interrupt_signal: Optional[Event] = None

@property
def is_interrupted(self) -> bool:
"""Check whether the job is interrupted or not."""
return (
self.interrupt_signal is not None
and self.interrupt_signal.is_set()
)

@property
def raise_for_interruption(self) -> None:
"""Raise an InterruptedError if the job is interrupted."""
if self.is_interrupted:
raise InterruptedError

@property
def interrupt_signal(self) -> Optional[Event]:
"""Get the interrupt signal."""
return self._interrupt_signal

@interrupt_signal.setter
def interrupt_signal(self, value: Optional[Event]) -> None:
"""Set the interrupt signal."""
self._interrupt_signal = value
30 changes: 0 additions & 30 deletions llama_api/mixins/waiter.py

This file was deleted.

3 changes: 2 additions & 1 deletion llama_api/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Iterator, List, TypeVar

from ..mixins.prompt_utils import PromptUtilsMixin
from ..mixins.interrupt import InterruptMixin
from ..schemas.api import (
APIChatMessage,
ChatCompletion,
Expand All @@ -23,7 +24,7 @@ class BaseLLMModel:
max_total_tokens: int = 2048


class BaseCompletionGenerator(ABC, PromptUtilsMixin):
class BaseCompletionGenerator(ABC, PromptUtilsMixin, InterruptMixin):
"""Base class for all completion generators."""

@abstractmethod
Expand Down
8 changes: 7 additions & 1 deletion llama_api/modules/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def _generator_context_manager(
tokenizer=self.tokenizer,
cache=self.cache,
)
generator.settings.temperature = settings.temperature
# Temperature cannot be 0.0, so we use a very small value instead.
# 0.0 will cause a division by zero error.
generator.settings.temperature = settings.temperature or 0.01
generator.settings.top_p = settings.top_p
generator.settings.top_k = settings.top_k
generator.settings.typical = settings.typical_p
Expand Down Expand Up @@ -219,7 +221,11 @@ def _generate_text_with_streaming(
n_completion_tokens: int = 0

for n_completion_tokens in range(1, settings.max_tokens + 1):
if self.is_interrupted:
return # the generator was interrupted
token = generator.gen_single_token()
if self.is_interrupted:
return # the generator was interrupted
if token.item() == generator.tokenizer.eos_token_id:
return
if (
Expand Down
12 changes: 10 additions & 2 deletions llama_api/modules/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,11 @@ def generate_completion_with_streaming(
)
assert isinstance(completion_chunk_generator, Iterator)
self.generator = completion_chunk_generator
yield from completion_chunk_generator
for chunk in completion_chunk_generator:
if self.is_interrupted:
yield chunk
return # the generator was interrupted
yield chunk

def generate_chat_completion(
self, messages: List[APIChatMessage], settings: TextGenerationSettings
Expand All @@ -284,7 +288,11 @@ def generate_chat_completion_with_streaming(
)
assert isinstance(chat_completion_chunk_generator, Iterator)
self.generator = chat_completion_chunk_generator
yield from chat_completion_chunk_generator
for chunk in chat_completion_chunk_generator:
if self.is_interrupted:
yield chunk
return # the generator was interrupted
yield chunk

def encode(self, text: str, add_bos: bool = True) -> List[int]:
assert self.client is not None, "Client is not initialized"
Expand Down
9 changes: 0 additions & 9 deletions llama_api/schemas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,6 @@ class LlamaCppModel(BaseLLMModel):
class ExllamaModel(BaseLLMModel):
"""Exllama model that can be loaded from local path."""

model_path: str = field(
default="YOUR_GPTQ_FOLDER_NAME",
metadata={
"description": "The GPTQ model path to the model."
"e.g. If you have a model folder in 'models/gptq/your_model',"
"then you should set this to 'your_model'."
},
)

compress_pos_emb: float = field(
default=1.0,
metadata={
Expand Down
72 changes: 37 additions & 35 deletions llama_api/server/app_settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import platform
from contextlib import asynccontextmanager
from os import environ
from os import environ, getpid
from pathlib import Path
from typing import Dict, Optional, Union

Expand All @@ -21,39 +21,44 @@
logger = ApiLogger(__name__)


def set_priority(pid: Optional[int] = None, priority: str = "high"):
import platform
from os import getpid

import psutil

def set_priority(priority: str = "high", pid: Optional[int] = None) -> bool:
"""Set The Priority of a Process. Priority is a string which can be
'low', 'below_normal', 'normal', 'above_normal', 'high', 'realtime'.
'normal' is the default."""

if platform.system() == "Windows":
priorities = {
"low": psutil.IDLE_PRIORITY_CLASS,
"below_normal": psutil.BELOW_NORMAL_PRIORITY_CLASS,
"normal": psutil.NORMAL_PRIORITY_CLASS,
"above_normal": psutil.ABOVE_NORMAL_PRIORITY_CLASS,
"high": psutil.HIGH_PRIORITY_CLASS,
"realtime": psutil.REALTIME_PRIORITY_CLASS,
}
else: # Linux and other Unix systems
priorities = {
"low": 19,
"below_normal": 10,
"normal": 0,
"above_normal": -5,
"high": -11,
"realtime": -20,
}

'normal' is the default.
Returns True if successful, False if not."""
if pid is None:
pid = getpid()
p = psutil.Process(pid)
p.nice(priorities[priority])
try:
import psutil

if platform.system() == "Windows":
priorities = {
"low": psutil.IDLE_PRIORITY_CLASS,
"below_normal": psutil.BELOW_NORMAL_PRIORITY_CLASS,
"normal": psutil.NORMAL_PRIORITY_CLASS,
"above_normal": psutil.ABOVE_NORMAL_PRIORITY_CLASS,
"high": psutil.HIGH_PRIORITY_CLASS,
"realtime": psutil.REALTIME_PRIORITY_CLASS,
}
else: # Linux and other Unix systems
priorities = {
"low": 19,
"below_normal": 10,
"normal": 0,
"above_normal": -5,
"high": -11,
"realtime": -20,
}
if priority not in priorities:
logger.warning(f"⚠️ Invalid priority [{priority}]")
return False

p = psutil.Process(pid)
p.nice(priorities[priority])
return True
except Exception as e:
logger.warning(f"⚠️ Failed to set priority of process [{pid}]: {e}")
return False


def initialize_before_launch(
Expand Down Expand Up @@ -99,11 +104,8 @@ def initialize_before_launch(
"If any packages are missing, "
"use `--install-pkgs` option to install them."
)
try:
# Set the priority of the process
set_priority(priority="high")
except Exception:
pass
# Set the priority of the process
set_priority(priority="high")


@asynccontextmanager
Expand Down
27 changes: 17 additions & 10 deletions llama_api/server/pools/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
EmbeddingUsage,
)
from ...schemas.models import ExllamaModel, LlamaCppModel
from ...utils.concurrency import queue_event_manager
from ...utils.concurrency import queue_manager
from ...utils.lazy_imports import LazyImports
from ...utils.logger import ApiLogger
from ...utils.system import free_memory_of_first_item_from_container
Expand All @@ -51,10 +51,13 @@ def completion_generator_manager(
CreateChatCompletionRequest,
CreateEmbeddingRequest,
],
interrupt_signal: Event,
):
"""Context manager for completion generators."""
completion_generator = get_completion_generator(body)
completion_generator.interrupt_signal = interrupt_signal
yield completion_generator
completion_generator.interrupt_signal = None


def get_model_names() -> List[str]:
Expand Down Expand Up @@ -206,10 +209,12 @@ def get_embedding_generator(
def generate_completion_chunks(
body: Union[CreateChatCompletionRequest, CreateCompletionRequest],
queue: Queue,
event: Event,
interrupt_signal: Event,
) -> None:
with queue_event_manager(queue=queue, event=event):
with completion_generator_manager(body=body) as cg:
with queue_manager(queue=queue):
with completion_generator_manager(
body=body, interrupt_signal=interrupt_signal
) as cg:
if isinstance(body, CreateChatCompletionRequest):
_iterator: Iterator[
Union[ChatCompletionChunk, CompletionChunk]
Expand All @@ -234,7 +239,7 @@ def iterator() -> (
yield chunk

for chunk in iterator():
if event.is_set():
if interrupt_signal.is_set():
# If the event is set, it means the client has disconnected
return
queue.put(chunk)
Expand All @@ -243,10 +248,12 @@ def iterator() -> (
def generate_completion(
body: Union[CreateChatCompletionRequest, CreateCompletionRequest],
queue: Queue,
event: Event,
interrupt_signal: Event,
) -> None:
with queue_event_manager(queue=queue, event=event):
with completion_generator_manager(body=body) as cg:
with queue_manager(queue=queue):
with completion_generator_manager(
body=body, interrupt_signal=interrupt_signal
) as cg:
if isinstance(body, CreateChatCompletionRequest):
completion: Union[
ChatCompletion, Completion
Expand All @@ -263,9 +270,9 @@ def generate_completion(


def generate_embeddings(
body: CreateEmbeddingRequest, queue: Queue, event: Event
body: CreateEmbeddingRequest, queue: Queue, interrupt_signal: Event
) -> None:
with queue_event_manager(queue=queue, event=event):
with queue_manager(queue=queue):
try:
llm_model = get_model(body.model)
if not isinstance(llm_model, LlamaCppModel):
Expand Down
Loading

0 comments on commit 344ab12

Please sign in to comment.