Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: preloading of machine learning models #7540

Merged
merged 12 commits into from
Mar 4, 2024
1 change: 1 addition & 0 deletions docs/docs/install/environment-variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ Redis (Sentinel) URL example JSON before encoding:
| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning |
| `MACHINE_LEARNING_WORKERS`<sup>\*2</sup> | Number of worker processes to spawn | `1` | machine learning |
| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning |
| `MACHINE_LEARNING_PRELOAD` | Comma seprated list of "model_type:model_name" pairs | `` | machine learning |

\*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones.

Expand Down
2 changes: 2 additions & 0 deletions machine-learning/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# VS Code
.vscode

*.onnx
*.zip
1 change: 1 addition & 0 deletions machine-learning/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Settings(BaseSettings):
model_inter_op_threads: int = 0
model_intra_op_threads: int = 0
ann: bool = True
preload: list[tuple[ModelType, str]] | None = None

class Config:
env_prefix = "MACHINE_LEARNING_"
Expand Down
14 changes: 12 additions & 2 deletions machine-learning/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger

model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
model_cache = ModelCache(revalidate=settings.model_ttl > 0)
thread_pool: ThreadPoolExecutor | None = None
lock = threading.Lock()
active_requests = 0
Expand All @@ -51,6 +51,8 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task())
if settings.preload is not None:
await preload_models(model_cache, settings.preload)
yield
finally:
log.handlers.clear()
Expand All @@ -61,6 +63,12 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
gc.collect()


async def preload_models(cache: ModelCache, preload_models: list[tuple[ModelType, str]]):
mertalev marked this conversation as resolved.
Show resolved Hide resolved
log.info(f"Preloading models: {preload_models}")
for pair in preload_models:
await load(await cache.get(pair[1], pair[0]))


def update_state() -> Iterator[None]:
global active_requests, last_called
active_requests += 1
Expand Down Expand Up @@ -103,7 +111,9 @@ async def predict(
except orjson.JSONDecodeError:
raise HTTPException(400, f"Invalid options JSON: {options}")

model = await load(await model_cache.get(model_name, model_type, **kwargs))
key = f"{model_name}{model_type.value}{kwargs.get('mode', '')}"
ttl = settings.model_ttl if key in model_cache.cache._handlers else None
mertalev marked this conversation as resolved.
Show resolved Hide resolved
model = await load(await model_cache.get(model_name, model_type, ttl=ttl, **kwargs))
model.configure(**kwargs)
outputs = await run(model.predict, inputs)
return ORJSONResponse(outputs)
Expand Down
37 changes: 11 additions & 26 deletions machine-learning/app/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from aiocache.backends.memory import SimpleMemoryCache
from aiocache.lock import OptimisticLock
from aiocache.plugins import BasePlugin, TimingPlugin
from aiocache.plugins import TimingPlugin

from app.models import from_model_type

Expand All @@ -15,28 +15,25 @@ class ModelCache:

def __init__(
self,
ttl: float | None = None,
revalidate: bool = False,
timeout: int | None = None,
profiling: bool = False,
) -> None:
"""
Args:
ttl: Unloads model after this duration. Disabled if None. Defaults to None.
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
"""

self.ttl = ttl
plugins = []

if revalidate:
plugins.append(RevalidationPlugin())
if profiling:
plugins.append(TimingPlugin())

self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
self.revalidate_enable = revalidate

self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)

async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
"""
Expand All @@ -49,11 +46,14 @@ async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any)
"""

key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"

async with OptimisticLock(self.cache, key) as lock:
model: InferenceModel | None = await self.cache.get(key)
if model is None:
model = from_model_type(model_type, model_name, **model_kwargs)
await lock.cas(model, ttl=self.ttl)
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
elif self.revalidate_enable:
await self.revalidate(key, model_kwargs.get("ttl", None))
return model

async def get_profiling(self) -> dict[str, float] | None:
Expand All @@ -62,21 +62,6 @@ async def get_profiling(self) -> dict[str, float] | None:

return self.cache.profiling


class RevalidationPlugin(BasePlugin): # type: ignore[misc]
"""Revalidates cache item's TTL after cache hit."""

async def post_get(
self,
client: SimpleMemoryCache,
key: str,
ret: Any | None = None,
namespace: str | None = None,
**kwargs: Any,
) -> None:
if ret is None:
return
if namespace is not None:
key = client.build_key(key, namespace)
if key in client._handlers:
await client.expire(key, client.ttl)
async def revalidate(self, key: str, ttl: int | None) -> None:
if ttl is not None and key in self.cache._handlers:
await self.cache.expire(key, ttl)
32 changes: 23 additions & 9 deletions machine-learning/app/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from PIL import Image
from pytest_mock import MockerFixture

from app.main import load
from app.main import load, preload_models

from .config import log, settings
from .config import Settings, log, settings
from .models.base import InferenceModel
from .models.cache import ModelCache
from .models.clip import MCLIPEncoder, OpenCLIPEncoder
Expand Down Expand Up @@ -509,20 +509,20 @@ async def test_different_clip(self, mock_get_model: mock.Mock) -> None:

@mock.patch("app.models.cache.OptimisticLock", autospec=True)
async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None:
model_cache = ModelCache(ttl=100)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
model_cache = ModelCache()
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)

@mock.patch("app.models.cache.SimpleMemoryCache.expire")
async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
model_cache = ModelCache(ttl=100, revalidate=True)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
model_cache = ModelCache(revalidate=True)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
mock_cache_expire.assert_called_once_with(mock.ANY, 100)

async def test_profiling(self, mock_get_model: mock.Mock) -> None:
model_cache = ModelCache(ttl=100, profiling=True)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
model_cache = ModelCache(profiling=True)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
profiling = await model_cache.get_profiling()
assert isinstance(profiling, dict)
assert profiling == model_cache.cache.profiling
Expand All @@ -548,6 +548,20 @@ async def test_raises_exception_if_unknown_model_name(self) -> None:
with pytest.raises(ValueError):
await model_cache.get("test_model_name", ModelType.CLIP, mode="text")

async def test_preloads_models(self, mock_get_model: mock.Mock) -> None:
os.environ['MACHINE_LEARNING_PRELOAD'] = '[["clip", "ViT-B-32__openai"], ["facial-recognition", "buffalo_s"]]'
settings = Settings()
assert settings.preload == [(ModelType.CLIP, "ViT-B-32__openai"), (ModelType.FACIAL_RECOGNITION, "buffalo_s")]

model_cache = ModelCache()

await preload_models(model_cache, settings.preload)
assert len(model_cache.cache._cache) == 2
assert mock_get_model.call_count == 2
await model_cache.get("ViT-B-32__openai", ModelType.CLIP, ttl=100)
await model_cache.get("buffalo_s", ModelType.FACIAL_RECOGNITION, ttl=100)
assert mock_get_model.call_count == 2


@pytest.mark.asyncio
class TestLoad:
Expand Down