Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: preloading of machine learning models #7540

Merged
merged 12 commits into from
Mar 4, 2024
22 changes: 12 additions & 10 deletions docs/docs/install/environment-variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,18 @@ Redis (Sentinel) URL example JSON before encoding:

## Machine Learning

| Variable | Description | Default | Services |
| :----------------------------------------------- | :----------------------------------------------------------------- | :-----------------: | :--------------- |
| `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if \<= 0) | `300` | machine learning |
| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if \<= 0) | `10` | machine learning |
| `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning |
| `MACHINE_LEARNING_REQUEST_THREADS`<sup>\*1</sup> | Thread count of the request thread pool (disabled if \<= 0) | number of CPU cores | machine learning |
| `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning |
| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning |
| `MACHINE_LEARNING_WORKERS`<sup>\*2</sup> | Number of worker processes to spawn | `1` | machine learning |
| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning |
| Variable | Description | Default | Services |
| :----------------------------------------------- | :------------------------------------------------------------------- | :-----------------: | :--------------- |
| `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if \<= 0) | `300` | machine learning |
| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if \<= 0) | `10` | machine learning |
| `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning |
| `MACHINE_LEARNING_REQUEST_THREADS`<sup>\*1</sup> | Thread count of the request thread pool (disabled if \<= 0) | number of CPU cores | machine learning |
| `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning |
| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning |
| `MACHINE_LEARNING_WORKERS`<sup>\*2</sup> | Number of worker processes to spawn | `1` | machine learning |
| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning |
| `MACHINE_LEARNING_PRELOAD__CLIP` | Name of a CLIP model to be preloaded and kept in cache | | machine learning |
| `MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION` | Name of a facial recognition model to be preloaded and kept in cache | | machine learning |

\*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones.

Expand Down
2 changes: 2 additions & 0 deletions machine-learning/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# VS Code
.vscode

*.onnx
*.zip
9 changes: 8 additions & 1 deletion machine-learning/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from socket import socket

from gunicorn.arbiter import Arbiter
from pydantic import BaseSettings
from pydantic import BaseModel, BaseSettings
from rich.console import Console
from rich.logging import RichHandler
from uvicorn import Server
Expand All @@ -15,6 +15,11 @@
from .schemas import ModelType


class PreloadModelData(BaseModel):
clip: str | None
facial_recognition: str | None


class Settings(BaseSettings):
cache_folder: str = "/cache"
model_ttl: int = 300
Expand All @@ -27,10 +32,12 @@ class Settings(BaseSettings):
model_inter_op_threads: int = 0
model_intra_op_threads: int = 0
ann: bool = True
preload: PreloadModelData | None = None

class Config:
env_prefix = "MACHINE_LEARNING_"
case_sensitive = False
env_nested_delimiter = "__"


class LogSettings(BaseSettings):
Expand Down
16 changes: 13 additions & 3 deletions machine-learning/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from app.models.base import InferenceModel

from .config import log, settings
from .config import PreloadModelData, log, settings
from .models.cache import ModelCache
from .schemas import (
MessageResponse,
Expand All @@ -27,7 +27,7 @@

MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger

model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
model_cache = ModelCache(revalidate=settings.model_ttl > 0)
thread_pool: ThreadPoolExecutor | None = None
lock = threading.Lock()
active_requests = 0
Expand All @@ -51,6 +51,8 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task())
if settings.preload is not None:
await preload_models(settings.preload)
yield
finally:
log.handlers.clear()
Expand All @@ -61,6 +63,14 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
gc.collect()


async def preload_models(preload_models: PreloadModelData) -> None:
log.info(f"Preloading models: {preload_models}")
if preload_models.clip is not None:
await load(await model_cache.get(preload_models.clip, ModelType.CLIP))
if preload_models.facial_recognition is not None:
await load(await model_cache.get(preload_models.facial_recognition, ModelType.FACIAL_RECOGNITION))


def update_state() -> Iterator[None]:
global active_requests, last_called
active_requests += 1
Expand Down Expand Up @@ -103,7 +113,7 @@ async def predict(
except orjson.JSONDecodeError:
raise HTTPException(400, f"Invalid options JSON: {options}")

model = await load(await model_cache.get(model_name, model_type, **kwargs))
model = await load(await model_cache.get(model_name, model_type, ttl=settings.model_ttl, **kwargs))
model.configure(**kwargs)
outputs = await run(model.predict, inputs)
return ORJSONResponse(outputs)
Expand Down
37 changes: 11 additions & 26 deletions machine-learning/app/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from aiocache.backends.memory import SimpleMemoryCache
from aiocache.lock import OptimisticLock
from aiocache.plugins import BasePlugin, TimingPlugin
from aiocache.plugins import TimingPlugin

from app.models import from_model_type

Expand All @@ -15,28 +15,25 @@ class ModelCache:

def __init__(
self,
ttl: float | None = None,
revalidate: bool = False,
timeout: int | None = None,
profiling: bool = False,
) -> None:
"""
Args:
ttl: Unloads model after this duration. Disabled if None. Defaults to None.
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
"""

self.ttl = ttl
plugins = []

if revalidate:
plugins.append(RevalidationPlugin())
if profiling:
plugins.append(TimingPlugin())

self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
self.revalidate_enable = revalidate

self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)

async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
"""
Expand All @@ -49,11 +46,14 @@ async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any)
"""

key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"

async with OptimisticLock(self.cache, key) as lock:
model: InferenceModel | None = await self.cache.get(key)
if model is None:
model = from_model_type(model_type, model_name, **model_kwargs)
await lock.cas(model, ttl=self.ttl)
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
elif self.revalidate_enable:
await self.revalidate(key, model_kwargs.get("ttl", None))
return model

async def get_profiling(self) -> dict[str, float] | None:
Expand All @@ -62,21 +62,6 @@ async def get_profiling(self) -> dict[str, float] | None:

return self.cache.profiling


class RevalidationPlugin(BasePlugin): # type: ignore[misc]
"""Revalidates cache item's TTL after cache hit."""

async def post_get(
self,
client: SimpleMemoryCache,
key: str,
ret: Any | None = None,
namespace: str | None = None,
**kwargs: Any,
) -> None:
if ret is None:
return
if namespace is not None:
key = client.build_key(key, namespace)
if key in client._handlers:
await client.expire(key, client.ttl)
async def revalidate(self, key: str, ttl: int | None) -> None:
if ttl is not None and key in self.cache._handlers:
await self.cache.expire(key, ttl)
38 changes: 29 additions & 9 deletions machine-learning/app/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import pytest
from fastapi.testclient import TestClient
from PIL import Image
from pytest import MonkeyPatch
from pytest_mock import MockerFixture

from app.main import load
from app.main import load, preload_models

from .config import log, settings
from .config import Settings, log, settings
from .models.base import InferenceModel
from .models.cache import ModelCache
from .models.clip import MCLIPEncoder, OpenCLIPEncoder
Expand Down Expand Up @@ -509,20 +510,20 @@ async def test_different_clip(self, mock_get_model: mock.Mock) -> None:

@mock.patch("app.models.cache.OptimisticLock", autospec=True)
async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None:
model_cache = ModelCache(ttl=100)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
model_cache = ModelCache()
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)

@mock.patch("app.models.cache.SimpleMemoryCache.expire")
async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
model_cache = ModelCache(ttl=100, revalidate=True)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
model_cache = ModelCache(revalidate=True)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
mock_cache_expire.assert_called_once_with(mock.ANY, 100)

async def test_profiling(self, mock_get_model: mock.Mock) -> None:
model_cache = ModelCache(ttl=100, profiling=True)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
model_cache = ModelCache(profiling=True)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
profiling = await model_cache.get_profiling()
assert isinstance(profiling, dict)
assert profiling == model_cache.cache.profiling
Expand All @@ -548,6 +549,25 @@ async def test_raises_exception_if_unknown_model_name(self) -> None:
with pytest.raises(ValueError):
await model_cache.get("test_model_name", ModelType.CLIP, mode="text")

async def test_preloads_models(self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock) -> None:
os.environ["MACHINE_LEARNING_PRELOAD__CLIP"] = "ViT-B-32__openai"
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"] = "buffalo_s"

settings = Settings()
assert settings.preload is not None
assert settings.preload.clip == "ViT-B-32__openai"
assert settings.preload.facial_recognition == "buffalo_s"

model_cache = ModelCache()
monkeypatch.setattr("app.main.model_cache", model_cache)

await preload_models(settings.preload)
assert len(model_cache.cache._cache) == 2
assert mock_get_model.call_count == 2
await model_cache.get("ViT-B-32__openai", ModelType.CLIP, ttl=100)
await model_cache.get("buffalo_s", ModelType.FACIAL_RECOGNITION, ttl=100)
assert mock_get_model.call_count == 2


@pytest.mark.asyncio
class TestLoad:
Expand Down
Loading