Skip to content

Commit

Permalink
Add on_start function & add env module for determining if we are runn…
Browse files Browse the repository at this point in the history
…ing locally or remotely (#160)

- Adds an on_start lifecycle method to taskqueues/endpoints
- Adds an is_local method
- Adds a local entrypoint decorator which can be used like this:

```import time

from beta9 import function
from beta9.env import local_entrypoint

def test_func():
    return True # this value will be available in some_func

@task_queue(cpu=0.1, on_start=test_func)
def some_func(context):
    print(context.task_id)
    print(context.on_start_value)
    for i in range(1000):
        print(i)
        time.sleep(0.01)
    return "hi"


# This will be executed automatically when running locally (like if __name__ == '__main__')
@local_entrypoint
def main():
    print("hi there")
```

---------

Co-authored-by: Luke Lombardi <luke@beam.cloud>
  • Loading branch information
luke-lombardi and Luke Lombardi authored Apr 26, 2024
1 parent 02eaa79 commit 8f4748b
Show file tree
Hide file tree
Showing 24 changed files with 270 additions and 127 deletions.
1 change: 1 addition & 0 deletions internal/abstractions/endpoint/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func registerEndpointRoutes(g *echo.Group, es *HttpEndpointService) *endpointGro
group := &endpointGroup{routeGroup: g, es: es}

g.POST("/id/:stubId/", group.endpointRequest)
g.POST("/id/:stubId", group.endpointRequest)
g.POST("/:deploymentName/v:version", group.endpointRequest)

return group
Expand Down
1 change: 1 addition & 0 deletions internal/abstractions/endpoint/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func (i *endpointInstance) startContainers(containersToRun int) error {
Env: []string{
fmt.Sprintf("BETA9_TOKEN=%s", i.Token.Key),
fmt.Sprintf("HANDLER=%s", i.StubConfig.Handler),
fmt.Sprintf("ON_START=%s", i.StubConfig.OnStart),
fmt.Sprintf("STUB_ID=%s", i.Stub.ExternalId),
fmt.Sprintf("STUB_TYPE=%s", i.Stub.Type),
fmt.Sprintf("CONCURRENCY=%d", i.StubConfig.Concurrency),
Expand Down
1 change: 1 addition & 0 deletions internal/abstractions/function/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func registerFunctionRoutes(g *echo.Group, fs *RunCFunctionService) *functionGro
group := &functionGroup{routerGroup: g, fs: fs}

g.POST("/id/:stubId", group.FunctionInvoke)
g.POST("/id/:stubId/", group.FunctionInvoke)
g.POST("/:deploymentName/v:version", group.FunctionInvoke)

return group
Expand Down
1 change: 1 addition & 0 deletions internal/abstractions/taskqueue/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func registerTaskQueueRoutes(g *echo.Group, tq *RedisTaskQueue) *taskQueueGroup
group := &taskQueueGroup{routeGroup: g, tq: tq}

g.POST("/id/:stubId", group.TaskQueuePut)
g.POST("/id/:stubId/", group.TaskQueuePut)
g.POST("/:deploymentName/v:version", group.TaskQueuePut)

return group
Expand Down
1 change: 1 addition & 0 deletions internal/abstractions/taskqueue/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func (i *taskQueueInstance) startContainers(containersToRun int) error {
Env: []string{
fmt.Sprintf("BETA9_TOKEN=%s", i.Token.Key),
fmt.Sprintf("HANDLER=%s", i.StubConfig.Handler),
fmt.Sprintf("ON_START=%s", i.StubConfig.OnStart),
fmt.Sprintf("STUB_ID=%s", i.Stub.ExternalId),
fmt.Sprintf("STUB_TYPE=%s", i.Stub.Type),
fmt.Sprintf("CONCURRENCY=%d", i.StubConfig.Concurrency),
Expand Down
1 change: 1 addition & 0 deletions internal/gateway/gateway.proto
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ message GetOrCreateStubRequest {
uint32 max_pending_tasks = 15;
repeated Volume volumes = 16;
bool force_create = 17;
string on_start = 18;
}

message GetOrCreateStubResponse {
Expand Down
1 change: 1 addition & 0 deletions internal/gateway/services/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func (gws *GatewayService) GetOrCreateStub(ctx context.Context, in *pb.GetOrCrea
ImageId: in.ImageId,
},
Handler: in.Handler,
OnStart: in.OnStart,
PythonVersion: in.PythonVersion,
TaskPolicy: types.TaskPolicy{
MaxRetries: uint(in.Retries),
Expand Down
1 change: 1 addition & 0 deletions internal/types/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ type TaskCountByTime struct {
type StubConfigV1 struct {
Runtime Runtime `json:"runtime"`
Handler string `json:"handler"`
OnStart string `json:"on_start"`
PythonVersion string `json:"python_version"`
KeepWarmSeconds uint `json:"keep_warm_seconds"`
MaxPendingTasks uint `json:"max_pending_tasks"`
Expand Down
149 changes: 79 additions & 70 deletions proto/gateway.pb.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sdk/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "beta9"
version = "0.1.1"
version = "0.1.2"
description = ""
authors = ["beam.cloud <support@beam.cloud>"]
packages = [
Expand Down
2 changes: 2 additions & 0 deletions sdk/src/beta9/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import env
from .abstractions.container import Container
from .abstractions.endpoint import Endpoint as endpoint
from .abstractions.function import Function as function
Expand All @@ -18,4 +19,5 @@
"function",
"endpoint",
"Container",
"env",
]
19 changes: 14 additions & 5 deletions sdk/src/beta9/abstractions/base/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
CONTAINER_STUB_TYPE = "container"
FUNCTION_STUB_TYPE = "function"
TASKQUEUE_STUB_TYPE = "taskqueue"
WEBSERVER_STUB_TYPE = "endpoint"
ENDPOINT_STUB_TYPE = "endpoint"
TASKQUEUE_DEPLOYMENT_STUB_TYPE = "taskqueue/deployment"
ENDPOINT_DEPLOYMENT_STUB_TYPE = "endpoint/deployment"
FUNCTION_DEPLOYMENT_STUB_TYPE = "function/deployment"
Expand All @@ -45,12 +45,16 @@ def __init__(
retries: int = 3,
timeout: int = 3600,
volumes: Optional[List[Volume]] = None,
on_start: Optional[Callable] = None,
) -> None:
super().__init__()

if image is None:
image = Image()

if on_start is not None:
self._map_callable_to_attr(attr="on_start", func=on_start)

self.image: Image = image
self.image_available: bool = False
self.files_synced: bool = False
Expand Down Expand Up @@ -112,8 +116,12 @@ def _parse_cpu_to_millicores(self, cpu: Union[float, str]) -> int:
else:
raise TypeError("CPU must be a float or a string.")

def _load_handler(self, func: Callable) -> None:
if self.handler or func is None:
def _map_callable_to_attr(self, *, attr: str, func: Callable):
"""
Determine the module and function name of a callable function, and cache on the class.
This is used for passing callables to stub config.
"""
if getattr(self, attr):
return

module = inspect.getmodule(func) # Determine module / function name
Expand All @@ -124,7 +132,7 @@ def _load_handler(self, func: Callable) -> None:
module_name = "__main__"

function_name = func.__name__
self.handler = f"{module_name}:{function_name}"
setattr(self, attr, f"{module_name}:{function_name}")

async def _object_iterator(self, *, dir: str, object_id: str, file_update_queue: Queue):
while True:
Expand Down Expand Up @@ -175,7 +183,7 @@ def prepare_runtime(
name: Optional[str] = None,
) -> bool:
if func is not None:
self._load_handler(func)
self._map_callable_to_attr(attr="handler", func=func)

stub_name = f"{stub_type}/{self.handler}" if self.handler else stub_type

Expand Down Expand Up @@ -222,6 +230,7 @@ def prepare_runtime(
memory=self.memory,
gpu=self.gpu,
handler=self.handler,
on_start=self.on_start,
retries=self.retries,
timeout=self.timeout,
keep_warm_seconds=self.keep_warm_seconds,
Expand Down
11 changes: 8 additions & 3 deletions sdk/src/beta9/abstractions/endpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Callable, Union
from typing import Any, Callable, Optional, Union

from .. import terminal
from ..abstractions.base.runner import (
Expand All @@ -15,6 +15,7 @@
)
from ..clients.gateway import DeployStubRequest, DeployStubResponse
from ..config import GatewayConfig, get_gateway_config
from ..env import is_local


class Endpoint(RunnerAbstraction):
Expand All @@ -29,6 +30,7 @@ def __init__(
max_containers: int = 1,
keep_warm_seconds: int = 300,
max_pending_tasks: int = 100,
on_start: Optional[Callable] = None,
):
super().__init__(
cpu=cpu,
Expand All @@ -41,6 +43,7 @@ def __init__(
retries=0,
keep_warm_seconds=keep_warm_seconds,
max_pending_tasks=max_pending_tasks,
on_start=on_start,
)

self.endpoint_stub: EndpointServiceStub = EndpointServiceStub(self.channel)
Expand All @@ -55,12 +58,14 @@ def __init__(self, func: Callable, parent: Endpoint):
self.parent: Endpoint = parent

def __call__(self, *args, **kwargs) -> Any:
container_id = os.getenv("CONTAINER_ID")
if container_id is not None:
if not is_local():
return self.local(*args, **kwargs)

raise NotImplementedError("Direct calls to Endpoints are not supported.")

def local(self, *args, **kwargs) -> Any:
return self.func(*args, **kwargs)

def deploy(self, name: str) -> bool:
if not self.parent.prepare_runtime(
func=self.func, stub_type=ENDPOINT_DEPLOYMENT_STUB_TYPE, force_create_stub=True
Expand Down
5 changes: 2 additions & 3 deletions sdk/src/beta9/abstractions/function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import os
from typing import Any, Callable, Iterator, List, Optional, Sequence, Union

import cloudpickle
Expand All @@ -19,6 +18,7 @@
)
from ..clients.gateway import DeployStubRequest, DeployStubResponse
from ..config import GatewayConfig, get_gateway_config
from ..env import is_local
from ..sync import FileSyncer


Expand Down Expand Up @@ -91,8 +91,7 @@ def __init__(self, func: Callable, parent: Function) -> None:
self.parent: Function = parent

def __call__(self, *args, **kwargs) -> Any:
container_id = os.getenv("CONTAINER_ID")
if container_id:
if not is_local():
return self.local(*args, **kwargs)

if not self.parent.prepare_runtime(
Expand Down
11 changes: 8 additions & 3 deletions sdk/src/beta9/abstractions/taskqueue.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Callable, Union
from typing import Any, Callable, Optional, Union

from .. import terminal
from ..abstractions.base.runner import (
Expand All @@ -19,6 +19,7 @@
TaskQueueServiceStub,
)
from ..config import GatewayConfig, get_gateway_config
from ..env import is_local


class TaskQueue(RunnerAbstraction):
Expand Down Expand Up @@ -62,6 +63,9 @@ class TaskQueue(RunnerAbstraction):
The maximum number of tasks that can be pending in the queue. If the number of
pending tasks exceeds this value, the task queue will stop accepting new tasks.
Default is 100.
on_start (Optional[Callable]):
An optional function to run once (per process) when the container starts. Can be used for downloading data,
loading models, or anything else computationally expensive.
Example:
```python
from beta9 import task_queue, Image
Expand All @@ -88,6 +92,7 @@ def __init__(
max_containers: int = 1,
keep_warm_seconds: int = 10,
max_pending_tasks: int = 100,
on_start: Optional[Callable] = None,
) -> None:
super().__init__(
cpu=cpu,
Expand All @@ -100,6 +105,7 @@ def __init__(
retries=retries,
keep_warm_seconds=keep_warm_seconds,
max_pending_tasks=max_pending_tasks,
on_start=on_start,
)

self.taskqueue_stub: TaskQueueServiceStub = TaskQueueServiceStub(self.channel)
Expand All @@ -114,8 +120,7 @@ def __init__(self, func: Callable, parent: TaskQueue):
self.parent: TaskQueue = parent

def __call__(self, *args, **kwargs) -> Any:
container_id = os.getenv("CONTAINER_ID")
if container_id is not None:
if not is_local():
return self.local(*args, **kwargs)

raise NotImplementedError(
Expand Down
1 change: 1 addition & 0 deletions sdk/src/beta9/clients/gateway/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions sdk/src/beta9/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
from functools import wraps
from typing import Callable


def is_local() -> bool:
"""Check if we are currently running in a remote container"""
return os.getenv("CONTAINER_ID", "") == ""


def local_entrypoint(func: Callable) -> None:
"""Decorator that executes the decorated function if the environment is local (i.e. not a remote container)"""

@wraps(func)
def wrapper(*args, **kwargs):
if is_local():
func(*args, **kwargs)

wrapper()
4 changes: 2 additions & 2 deletions sdk/src/beta9/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import sys
import json
import sys


class StdoutJsonInterceptor(io.TextIOBase):
def __init__(self, stream=sys.__stdout__, **ctx):
Expand Down Expand Up @@ -34,4 +35,3 @@ def flush(self):

def fileno(self) -> int:
return -1

Loading

0 comments on commit 8f4748b

Please sign in to comment.