Skip to content

Commit

Permalink
Add modal queue CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Apr 29, 2024
1 parent ad5f987 commit 28b62e3
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 13 deletions.
2 changes: 2 additions & 0 deletions modal/cli/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .launch import launch_cli
from .network_file_system import nfs_cli
from .profile import profile_cli
from .queues import queue_cli
from .secret import secret_cli
from .token import _new_token, token_cli
from .volume import volume_cli
Expand Down Expand Up @@ -92,6 +93,7 @@ async def setup(profile: Optional[str] = None):
entrypoint_cli_typer.add_typer(profile_cli)
entrypoint_cli_typer.add_typer(secret_cli)
entrypoint_cli_typer.add_typer(token_cli)
entrypoint_cli_typer.add_typer(queue_cli)
entrypoint_cli_typer.add_typer(volume_cli)

entrypoint_cli_typer.command("deploy", help="Deploy a Modal stub as an application.", no_args_is_help=True)(run.deploy)
Expand Down
124 changes: 124 additions & 0 deletions modal/cli/queues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright Modal Labs 2024
from typing import Optional

import typer
from rich.console import Console
from typer import Argument, Option, Typer

from modal._resolver import Resolver
from modal._utils.async_utils import synchronizer
from modal._utils.grpc_utils import retry_transient_errors
from modal.cli.utils import ENV_OPTION, YES_OPTION, display_table, timestamp_to_local
from modal.client import _Client
from modal.environments import ensure_env
from modal.queue import _Queue
from modal_proto import api_pb2

queue_cli = Typer(
name="queue",
no_args_is_help=True,
help="Manage `modal.Queue` objects and inspect their contents.",
)

PARTITION_OPTION = Option(
None,
"-p",
"--partition",
help="Name of the partition to use, otherwise use the default (anonymous) partition.",
)


@queue_cli.command(name="create")
@synchronizer.create_blocking
async def create(name: str, *, env: Optional[str] = ENV_OPTION):
"""Create a named Queue object.
Note: This is a no-op when the Queue already exists.
"""
q = _Queue.from_name(name, environment_name=env, create_if_missing=True)
client = await _Client.from_env()
resolver = Resolver(client=client)
await resolver.load(q)


@queue_cli.command(name="delete")
@synchronizer.create_blocking
async def delete(name: str, *, yes: bool = YES_OPTION, env: Optional[str] = ENV_OPTION):
"""Delete a named Dict object and all of its data."""
# Lookup first to validate the name, even though delete is a staticmethod
await _Queue.lookup(name, environment_name=env)
if not yes:
typer.confirm(
f"Are you sure you want to irrevocably delete the modal.Queue '{name}'?",
default=False,
abort=True,
)
await _Queue.delete(name, environment_name=env)


@queue_cli.command(name="list")
@synchronizer.create_blocking
async def list(*, json: bool = False, env: Optional[str] = ENV_OPTION):
"""List all named Queue objects."""
env = ensure_env(env)

max_total_size = 100_000
client = await _Client.from_env()
request = api_pb2.QueueListRequest(environment_name=env, total_size_limit=max_total_size + 1)
response = await retry_transient_errors(client.stub.QueueList, request)

rows = [
(
q.name,
timestamp_to_local(q.created_at, json),
str(q.num_partitions),
str(q.total_size) if q.total_size <= max_total_size else f">{max_total_size}",
)
for q in response.queues
]
display_table(["Name", "Created at", "Partitions", "Total size"], rows, json)


@queue_cli.command(name="clear")
@synchronizer.create_blocking
async def clear(
name: str,
partition: Optional[str] = PARTITION_OPTION,
all: bool = Option(False, "-a", "--all", help="Clear the contents of all partitions."),
*,
env: Optional[str] = ENV_OPTION,
):
"""Clear the contents of a queue by removing all of its data."""
q = await _Queue.lookup(name, environment_name=env)
await q.clear(partition=partition, all=all)


@queue_cli.command(name="next")
@synchronizer.create_blocking
async def next(
name: str, n: int = Argument(1), partition: Optional[str] = PARTITION_OPTION, *, env: Optional[str] = ENV_OPTION
):
"""Print the next N items in the queue or queue partition (without removal)."""
q = await _Queue.lookup(name, environment_name=env)
console = Console()
i = 0
async for item in q.iterate(partition=partition):
console.print(item)
i += 1
if i >= n:
break


@queue_cli.command(name="len")
@synchronizer.create_blocking
async def len(
name: str,
partition: Optional[str] = PARTITION_OPTION,
total: bool = Option(False, "-t", "--total", help="Compute the sum of the queue lengths across all partitions"),
*,
env: Optional[str] = ENV_OPTION,
):
"""Print the length of a queue partition or the total length of all partitions."""
q = await _Queue.lookup(name, environment_name=env)
console = Console()
console.print(await q.len(partition=partition, total=total))
17 changes: 16 additions & 1 deletion modal/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,18 @@ async def _get_blocking(self, partition: Optional[str], timeout: Optional[float]

raise queue.Empty()

@live_method
async def clear(self, *, partition: Optional[str] = None, all: bool = False) -> None:
"""Clear the contents of a single partition or all partitions."""
if partition and all:
raise InvalidError("Partition must be null when requesting to clear all.")
request = api_pb2.QueueClearRequest(
queue_id=self.object_id,
partition_key=self.validate_partition_key(partition),
all_partitions=all,
)
await retry_transient_errors(self._client.stub.QueueClear, request)

@live_method
async def get(
self, block: bool = True, timeout: Optional[float] = None, *, partition: Optional[str] = None
Expand Down Expand Up @@ -391,11 +403,14 @@ async def _put_many_nonblocking(self, partition: Optional[str], partition_ttl: i
raise queue.Full(exc.message) if exc.status == Status.RESOURCE_EXHAUSTED else exc

@live_method
async def len(self, *, partition: Optional[str] = None) -> int:
async def len(self, *, partition: Optional[str] = None, total: bool = False) -> int:
"""Return the number of objects in the queue partition."""
if partition and total:
raise InvalidError("Partition must be null when requesting total length.")
request = api_pb2.QueueLenRequest(
queue_id=self.object_id,
partition_key=self.validate_partition_key(partition),
total=total,
)
response = await retry_transient_errors(self._client.stub.QueueLen, request)
return response.len
Expand Down
40 changes: 40 additions & 0 deletions test/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,3 +725,43 @@ def test_dict_show_get_clear(servicer, server_url_env, set_env_client):

res = _run(["dict", "clear", "baz-dict", "--yes"])
assert servicer.dicts[dict_id] == {}


def test_queue_create_list_delete(servicer, server_url_env, set_env_client):
_run(["queue", "create", "foo-queue"])
_run(["queue", "create", "bar-queue"])
res = _run(["queue", "list"])
assert "foo-queue" in res.stdout
assert "bar-queue" in res.stdout

_run(["queue", "delete", "bar-queue", "--yes"])

res = _run(["queue", "list"])
assert "foo-queue" in res.stdout
assert "bar-queue" not in res.stdout


def test_queue_next_len_clear(servicer, server_url_env, set_env_client):
# Kind of hacky to be modifying the attributes on the servicer like this
name = "queue-who"
key = (name, api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, os.environ.get("MODAL_ENVIRONMENT", "main"))
queue_id = "qu-abc123"
servicer.deployed_queues[key] = queue_id
servicer.queue = {b"": [dumps("a"), dumps("b"), dumps("c")], b"alt": [dumps("x"), dumps("y")]}

assert _run(["queue", "next", name]).stdout == "a\n"
assert _run(["queue", "next", name, "-p", "alt"]).stdout == "x\n"
assert _run(["queue", "next", name, "3"]).stdout == "a\nb\nc\n"
assert _run(["queue", "next", name, "3", "--partition", "alt"]).stdout == "x\ny\n"

assert _run(["queue", "len", name]).stdout == "3\n"
assert _run(["queue", "len", name, "--partition", "alt"]).stdout == "2\n"
assert _run(["queue", "len", name, "--total"]).stdout == "5\n"

_run(["queue", "clear", name])
assert _run(["queue", "len", name]).stdout == "0\n"
assert _run(["queue", "next", name, "--partition", "alt"]).stdout == "x\n"

_run(["queue", "clear", name, "--all"])
assert _run(["queue", "len", name, "--total"]).stdout == "0\n"
assert _run(["queue", "next", name, "--partition", "alt"]).stdout == ""
52 changes: 40 additions & 12 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import traceback
from collections import defaultdict
from pathlib import Path
from typing import Dict, Iterator, Optional, get_args
from typing import Dict, Iterator, List, Optional, get_args

import aiohttp.web
import aiohttp.web_runner
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(self, blob_host, blobs):
self.container_outputs = []
self.fc_data_in = defaultdict(lambda: asyncio.Queue()) # unbounded
self.fc_data_out = defaultdict(lambda: asyncio.Queue()) # unbounded
self.queue = []
self.queue: Dict[bytes, List[bytes]] = {b"": []}
self.deployed_apps = {
client_mount_name(): "ap-x",
}
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(self, blob_host, blobs):
self.sandbox_result: Optional[api_pb2.GenericResult] = None

self.token_flow_localhost_port = None
self.queue_max_len = 1_00
self.queue_max_len = 100

self.container_heartbeat_response = None
self.container_heartbeat_abort = threading.Event()
Expand Down Expand Up @@ -816,6 +816,15 @@ async def ProxyGetOrCreate(self, stream):

### Queue

async def QueueClear(self, stream):
request: api_pb2.QueueClearRequest = await stream.recv_message()
if request.all_partitions:
self.queue = {b"": []}
else:
if request.partition_key in self.queue:
self.queue[request.partition_key] = []
await stream.send_message(Empty())

async def QueueCreate(self, stream):
request: api_pb2.QueueCreateRequest = await stream.recv_message()
if request.existing_queue_id:
Expand All @@ -834,6 +843,7 @@ async def QueueGetOrCreate(self, stream):
self.n_queues += 1
queue_id = f"qu-{self.n_queues}"
self.deployed_queues[k] = queue_id
self.deployed_apps[request.deployment_name] = f"ap-{queue_id}"
elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_EPHEMERAL:
self.n_queues += 1
queue_id = f"qu-{self.n_queues}"
Expand All @@ -853,28 +863,46 @@ async def QueueHeartbeat(self, stream):

async def QueuePut(self, stream):
request: api_pb2.QueuePutRequest = await stream.recv_message()
if len(self.queue) >= self.queue_max_len:
if sum(map(len, self.queue.values())) >= self.queue_max_len:
raise GRPCError(Status.RESOURCE_EXHAUSTED, f"Hit servicer's max len for Queues: {self.queue_max_len}")
self.queue += request.values
q = self.queue.setdefault(request.partition_key, [])
q += request.values
await stream.send_message(Empty())

async def QueueGet(self, stream):
await stream.recv_message()
if len(self.queue) > 0:
values = [self.queue.pop(0)]
request: api_pb2.QueueGetRequest = await stream.recv_message()
q = self.queue.get(request.partition_key, [])
if len(q) > 0:
values = [q.pop(0)]
else:
values = []
await stream.send_message(api_pb2.QueueGetResponse(values=values))

async def QueueLen(self, stream):
await stream.recv_message()
await stream.send_message(api_pb2.QueueLenResponse(len=len(self.queue)))
request = await stream.recv_message()
if request.total:
value = sum(map(len, self.queue.values()))
else:
q = self.queue.get(request.partition_key, [])
value = len(q)
await stream.send_message(api_pb2.QueueLenResponse(len=value))

async def QueueList(self, stream):
# TODO Note that the actual self.queue holding the data assumes we have a single queue
# So there is a mismatch and I am not implementing a mock for the num_partitions / total_size
queues = [
api_pb2.QueueListResponse.QueueInfo(name=name, created_at=1)
for name, _, _ in self.deployed_queues
if name in self.deployed_apps
]
await stream.send_message(api_pb2.QueueListResponse(queues=queues))

async def QueueNextItems(self, stream):
request: api_pb2.QueueNextItemsRequest = await stream.recv_message()
next_item_idx = int(request.last_entry_id) + 1 if request.last_entry_id else 0
if next_item_idx < len(self.queue):
item = api_pb2.QueueItem(value=self.queue[next_item_idx], entry_id=f"{next_item_idx}")
q = self.queue.get(request.partition_key, [])
if next_item_idx < len(q):
item = api_pb2.QueueItem(value=q[next_item_idx], entry_id=f"{next_item_idx}")
await stream.send_message(api_pb2.QueueNextItemsResponse(items=[item]))
else:
if request.item_poll_timeout > 0:
Expand Down

0 comments on commit 28b62e3

Please sign in to comment.