Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add modal queue CLI #1772

Merged
merged 6 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 7 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ We appreciate your patience while we speedily work towards a stable release of t

### 0.62.116 (2024-04-26)

* Added a command-line interface for interacting with modal.Dict objects. Run modal dict --help in your terminal to see what is available.
* Added a command-line interface for interacting with `modal.Dict` objects. Run `modal dict --help` in your terminal to see what is available.



### 0.62.114 (2024-04-25)

* `Secret.from_dotenv` now accepts an optional filename keyword argument:
```python
@app.function(secrets=[modal.Secret.from_dotenv(filename=".env-dev")])
def run():
...
* `Secret.from_dotenv` now accepts an optional filename keyword argument:

```python
@app.function(secrets=[modal.Secret.from_dotenv(filename=".env-dev")])
def run():
...
```


Expand Down
4 changes: 2 additions & 2 deletions modal/cli/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def create(name: str, *, env: Optional[str] = ENV_OPTION):
@dict_cli.command(name="list")
@synchronizer.create_blocking
async def list(*, json: bool = False, env: Optional[str] = ENV_OPTION):
"""List all named Dict objects."""
"""List all named Dicts."""
env = ensure_env(env)
client = await _Client.from_env()
request = api_pb2.DictListRequest(environment_name=env)
Expand All @@ -64,7 +64,7 @@ async def clear(name: str, *, yes: bool = YES_OPTION, env: Optional[str] = ENV_O
@dict_cli.command(name="delete")
@synchronizer.create_blocking
async def delete(name: str, *, yes: bool = YES_OPTION, env: Optional[str] = ENV_OPTION):
"""Delete a named Dict object and all of its data."""
"""Delete a named Dict and all of its data."""
# Lookup first to validate the name, even though delete is a staticmethod
await _Dict.lookup(name, environment_name=env)
if not yes:
Expand Down
2 changes: 2 additions & 0 deletions modal/cli/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .launch import launch_cli
from .network_file_system import nfs_cli
from .profile import profile_cli
from .queues import queue_cli
from .secret import secret_cli
from .token import _new_token, token_cli
from .volume import volume_cli
Expand Down Expand Up @@ -92,6 +93,7 @@ async def setup(profile: Optional[str] = None):
entrypoint_cli_typer.add_typer(profile_cli)
entrypoint_cli_typer.add_typer(secret_cli)
entrypoint_cli_typer.add_typer(token_cli)
entrypoint_cli_typer.add_typer(queue_cli)
entrypoint_cli_typer.add_typer(volume_cli)

entrypoint_cli_typer.command("deploy", help="Deploy a Modal stub as an application.", no_args_is_help=True)(run.deploy)
Expand Down
131 changes: 131 additions & 0 deletions modal/cli/queues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright Modal Labs 2024
from typing import Optional

import typer
from rich.console import Console
from typer import Argument, Option, Typer

from modal._resolver import Resolver
from modal._utils.async_utils import synchronizer
from modal._utils.grpc_utils import retry_transient_errors
from modal.cli.utils import ENV_OPTION, YES_OPTION, display_table, timestamp_to_local
from modal.client import _Client
from modal.environments import ensure_env
from modal.queue import _Queue
from modal_proto import api_pb2

queue_cli = Typer(
name="queue",
no_args_is_help=True,
help="Manage `modal.Queue` objects and inspect their contents.",
)

PARTITION_OPTION = Option(
None,
"-p",
"--partition",
help="Name of the partition to use, otherwise use the default (anonymous) partition.",
)


@queue_cli.command(name="create")
@synchronizer.create_blocking
async def create(name: str, *, env: Optional[str] = ENV_OPTION):
"""Create a named Queue.

Note: This is a no-op when the Queue already exists.
"""
q = _Queue.from_name(name, environment_name=env, create_if_missing=True)
client = await _Client.from_env()
resolver = Resolver(client=client)
await resolver.load(q)


@queue_cli.command(name="delete")
@synchronizer.create_blocking
async def delete(name: str, *, yes: bool = YES_OPTION, env: Optional[str] = ENV_OPTION):
"""Delete a named Queue and all of its data."""
# Lookup first to validate the name, even though delete is a staticmethod
await _Queue.lookup(name, environment_name=env)
if not yes:
typer.confirm(
f"Are you sure you want to irrevocably delete the modal.Queue '{name}'?",
default=False,
abort=True,
)
await _Queue.delete(name, environment_name=env)


@queue_cli.command(name="list")
@synchronizer.create_blocking
async def list(*, json: bool = False, env: Optional[str] = ENV_OPTION):
"""List all named Queues."""
env = ensure_env(env)

max_total_size = 100_000
client = await _Client.from_env()
request = api_pb2.QueueListRequest(environment_name=env, total_size_limit=max_total_size + 1)
response = await retry_transient_errors(client.stub.QueueList, request)

rows = [
(
q.name,
timestamp_to_local(q.created_at, json),
str(q.num_partitions),
str(q.total_size) if q.total_size <= max_total_size else f">{max_total_size}",
)
for q in response.queues
]
display_table(["Name", "Created at", "Partitions", "Total size"], rows, json)


@queue_cli.command(name="clear")
@synchronizer.create_blocking
async def clear(
name: str,
partition: Optional[str] = PARTITION_OPTION,
all: bool = Option(False, "-a", "--all", help="Clear the contents of all partitions."),
yes: bool = YES_OPTION,
*,
env: Optional[str] = ENV_OPTION,
):
"""Clear the contents of a queue by removing all of its data."""
q = await _Queue.lookup(name, environment_name=env)
if not yes:
typer.confirm(
f"Are you sure you want to irrevocably delete the contents of modal.Queue '{name}'?",
default=False,
abort=True,
)
await q.clear(partition=partition, all=all)


@queue_cli.command(name="peek")
@synchronizer.create_blocking
async def peek(
name: str, n: int = Argument(1), partition: Optional[str] = PARTITION_OPTION, *, env: Optional[str] = ENV_OPTION
):
"""Print the next N items in the queue or queue partition (without removal)."""
q = await _Queue.lookup(name, environment_name=env)
console = Console()
i = 0
async for item in q.iterate(partition=partition):
console.print(item)
i += 1
if i >= n:
break


@queue_cli.command(name="len")
@synchronizer.create_blocking
async def len(
name: str,
partition: Optional[str] = PARTITION_OPTION,
total: bool = Option(False, "-t", "--total", help="Compute the sum of the queue lengths across all partitions"),
*,
env: Optional[str] = ENV_OPTION,
):
"""Print the length of a queue partition or the total length of all partitions."""
q = await _Queue.lookup(name, environment_name=env)
console = Console()
console.print(await q.len(partition=partition, total=total))
17 changes: 16 additions & 1 deletion modal/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,18 @@ async def _get_blocking(self, partition: Optional[str], timeout: Optional[float]

raise queue.Empty()

@live_method
async def clear(self, *, partition: Optional[str] = None, all: bool = False) -> None:
"""Clear the contents of a single partition or all partitions."""
if partition and all:
raise InvalidError("Partition must be null when requesting to clear all.")
request = api_pb2.QueueClearRequest(
queue_id=self.object_id,
partition_key=self.validate_partition_key(partition),
all_partitions=all,
)
await retry_transient_errors(self._client.stub.QueueClear, request)

@live_method
async def get(
self, block: bool = True, timeout: Optional[float] = None, *, partition: Optional[str] = None
Expand Down Expand Up @@ -391,11 +403,14 @@ async def _put_many_nonblocking(self, partition: Optional[str], partition_ttl: i
raise queue.Full(exc.message) if exc.status == Status.RESOURCE_EXHAUSTED else exc

@live_method
async def len(self, *, partition: Optional[str] = None) -> int:
async def len(self, *, partition: Optional[str] = None, total: bool = False) -> int:
"""Return the number of objects in the queue partition."""
if partition and total:
raise InvalidError("Partition must be null when requesting total length.")
request = api_pb2.QueueLenRequest(
queue_id=self.object_id,
partition_key=self.validate_partition_key(partition),
total=total,
)
response = await retry_transient_errors(self._client.stub.QueueLen, request)
return response.len
Expand Down
40 changes: 40 additions & 0 deletions test/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,3 +725,43 @@ def test_dict_show_get_clear(servicer, server_url_env, set_env_client):

res = _run(["dict", "clear", "baz-dict", "--yes"])
assert servicer.dicts[dict_id] == {}


def test_queue_create_list_delete(servicer, server_url_env, set_env_client):
_run(["queue", "create", "foo-queue"])
_run(["queue", "create", "bar-queue"])
res = _run(["queue", "list"])
assert "foo-queue" in res.stdout
assert "bar-queue" in res.stdout

_run(["queue", "delete", "bar-queue", "--yes"])

res = _run(["queue", "list"])
assert "foo-queue" in res.stdout
assert "bar-queue" not in res.stdout


def test_queue_peek_len_clear(servicer, server_url_env, set_env_client):
# Kind of hacky to be modifying the attributes on the servicer like this
name = "queue-who"
key = (name, api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, os.environ.get("MODAL_ENVIRONMENT", "main"))
queue_id = "qu-abc123"
servicer.deployed_queues[key] = queue_id
servicer.queue = {b"": [dumps("a"), dumps("b"), dumps("c")], b"alt": [dumps("x"), dumps("y")]}

assert _run(["queue", "peek", name]).stdout == "a\n"
assert _run(["queue", "peek", name, "-p", "alt"]).stdout == "x\n"
assert _run(["queue", "peek", name, "3"]).stdout == "a\nb\nc\n"
assert _run(["queue", "peek", name, "3", "--partition", "alt"]).stdout == "x\ny\n"

assert _run(["queue", "len", name]).stdout == "3\n"
assert _run(["queue", "len", name, "--partition", "alt"]).stdout == "2\n"
assert _run(["queue", "len", name, "--total"]).stdout == "5\n"

_run(["queue", "clear", name, "--yes"])
assert _run(["queue", "len", name]).stdout == "0\n"
assert _run(["queue", "peek", name, "--partition", "alt"]).stdout == "x\n"

_run(["queue", "clear", name, "--all", "--yes"])
assert _run(["queue", "len", name, "--total"]).stdout == "0\n"
assert _run(["queue", "peek", name, "--partition", "alt"]).stdout == ""
52 changes: 40 additions & 12 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import traceback
from collections import defaultdict
from pathlib import Path
from typing import Dict, Iterator, Optional, get_args
from typing import Dict, Iterator, List, Optional, get_args

import aiohttp.web
import aiohttp.web_runner
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(self, blob_host, blobs):
self.container_outputs = []
self.fc_data_in = defaultdict(lambda: asyncio.Queue()) # unbounded
self.fc_data_out = defaultdict(lambda: asyncio.Queue()) # unbounded
self.queue = []
self.queue: Dict[bytes, List[bytes]] = {b"": []}
self.deployed_apps = {
client_mount_name(): "ap-x",
}
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(self, blob_host, blobs):
self.sandbox_result: Optional[api_pb2.GenericResult] = None

self.token_flow_localhost_port = None
self.queue_max_len = 1_00
self.queue_max_len = 100

self.container_heartbeat_response = None
self.container_heartbeat_abort = threading.Event()
Expand Down Expand Up @@ -816,6 +816,15 @@ async def ProxyGetOrCreate(self, stream):

### Queue

async def QueueClear(self, stream):
request: api_pb2.QueueClearRequest = await stream.recv_message()
if request.all_partitions:
self.queue = {b"": []}
else:
if request.partition_key in self.queue:
self.queue[request.partition_key] = []
await stream.send_message(Empty())

async def QueueCreate(self, stream):
request: api_pb2.QueueCreateRequest = await stream.recv_message()
if request.existing_queue_id:
Expand All @@ -834,6 +843,7 @@ async def QueueGetOrCreate(self, stream):
self.n_queues += 1
queue_id = f"qu-{self.n_queues}"
self.deployed_queues[k] = queue_id
self.deployed_apps[request.deployment_name] = f"ap-{queue_id}"
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
self.n_queues += 1
queue_id = f"qu-{self.n_queues}"
Expand All @@ -853,28 +863,46 @@ async def QueueHeartbeat(self, stream):

async def QueuePut(self, stream):
request: api_pb2.QueuePutRequest = await stream.recv_message()
if len(self.queue) >= self.queue_max_len:
if sum(map(len, self.queue.values())) >= self.queue_max_len:
raise GRPCError(Status.RESOURCE_EXHAUSTED, f"Hit servicer's max len for Queues: {self.queue_max_len}")
self.queue += request.values
q = self.queue.setdefault(request.partition_key, [])
q += request.values
await stream.send_message(Empty())

async def QueueGet(self, stream):
await stream.recv_message()
if len(self.queue) > 0:
values = [self.queue.pop(0)]
request: api_pb2.QueueGetRequest = await stream.recv_message()
q = self.queue.get(request.partition_key, [])
if len(q) > 0:
values = [q.pop(0)]
else:
values = []
await stream.send_message(api_pb2.QueueGetResponse(values=values))

async def QueueLen(self, stream):
await stream.recv_message()
await stream.send_message(api_pb2.QueueLenResponse(len=len(self.queue)))
request = await stream.recv_message()
if request.total:
value = sum(map(len, self.queue.values()))
else:
q = self.queue.get(request.partition_key, [])
value = len(q)
await stream.send_message(api_pb2.QueueLenResponse(len=value))

async def QueueList(self, stream):
# TODO Note that the actual self.queue holding the data assumes we have a single queue
# So there is a mismatch and I am not implementing a mock for the num_partitions / total_size
queues = [
api_pb2.QueueListResponse.QueueInfo(name=name, created_at=1)
for name, _, _ in self.deployed_queues
if name in self.deployed_apps
]
await stream.send_message(api_pb2.QueueListResponse(queues=queues))

async def QueueNextItems(self, stream):
request: api_pb2.QueueNextItemsRequest = await stream.recv_message()
next_item_idx = int(request.last_entry_id) + 1 if request.last_entry_id else 0
if next_item_idx < len(self.queue):
item = api_pb2.QueueItem(value=self.queue[next_item_idx], entry_id=f"{next_item_idx}")
q = self.queue.get(request.partition_key, [])
if next_item_idx < len(q):
item = api_pb2.QueueItem(value=q[next_item_idx], entry_id=f"{next_item_idx}")
await stream.send_message(api_pb2.QueueNextItemsResponse(items=[item]))
else:
if request.item_poll_timeout > 0:
Expand Down