Skip to content

Commit

Permalink
Merge pull request #213 from openforcefield/200-Task-priority-setter-…
Browse files Browse the repository at this point in the history
…getter

Add getters and setters for Task priority in AlchemiscaleClient
  • Loading branch information
dotsdl committed Dec 21, 2023
2 parents 8e98af3 + 9e22c8c commit 81d977b
Show file tree
Hide file tree
Showing 6 changed files with 439 additions and 77 deletions.
44 changes: 44 additions & 0 deletions alchemiscale/interface/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,50 @@ def cancel_tasks(
return [str(sk) if sk is not None else None for sk in canceled_sks]


@router.post("/bulk/tasks/priority/get")
def tasks_priority_get(
*,
tasks: List[ScopedKey] = Body(embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
) -> List[int]:
valid_tasks = []
for task_sk in tasks:
try:
validate_scopes(task_sk.scope, token)
valid_tasks.append(task_sk)
except HTTPException:
valid_tasks.append(None)

priorities = n4js.get_task_priority(valid_tasks)

return priorities


@router.post("/bulk/tasks/priority/set")
def tasks_priority_set(
*,
tasks: List[ScopedKey] = Body(embed=True),
priority: int = Body(embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
) -> List[Union[str, None]]:
valid_tasks = []
for task_sk in tasks:
try:
validate_scopes(task_sk.scope, token)
valid_tasks.append(task_sk)
except HTTPException:
valid_tasks.append(None)

try:
tasks_updated = n4js.set_task_priority(valid_tasks, priority)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))

return [str(t) if t is not None else None for t in tasks_updated]


@router.post("/bulk/tasks/status/get")
def tasks_status_get(
*,
Expand Down
198 changes: 135 additions & 63 deletions alchemiscale/interface/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import asyncio
from typing import Union, List, Dict, Optional, Tuple
from typing import Union, List, Dict, Optional, Tuple, Any
import json
from itertools import chain
from collections import Counter
Expand Down Expand Up @@ -724,12 +724,67 @@ def cancel_tasks(

return [ScopedKey.from_str(i) if i is not None else None for i in canceled_sks]

def _set_task_status(
self, task: ScopedKey, status: TaskStatusEnum
) -> Optional[ScopedKey]:
"""Set the status of a `Task`."""
task_sk = self._post_resource(f"/tasks/{task}/status", status.value)
return ScopedKey.from_str(task_sk) if task_sk is not None else None
def _task_attribute_getter(
self, tasks: List[ScopedKey], getter_function, batch_size
) -> List[Any]:
tasks = [
ScopedKey.from_str(task) if isinstance(task, str) else task
for task in tasks
]

@use_session
async def async_request(self):
values = await asyncio.gather(
*[
getter_function(task_batch)
for task_batch in self._batched(tasks, batch_size)
]
)

return list(chain.from_iterable(values))

coro = async_request(self)

try:
return asyncio.run(coro)
except RuntimeError:
# we use nest_asyncio to support environments where an event loop
# is already running, such as in a Jupyter notebook
import nest_asyncio

nest_asyncio.apply()
return asyncio.run(coro)

def _task_attribute_setter(
self, tasks: List[ScopedKey], setter_function, setter_args, batch_size
) -> List[Optional[ScopedKey]]:
tasks = [
ScopedKey.from_str(task) if isinstance(task, str) else task
for task in tasks
]

@use_session
async def async_request(self):
scoped_keys = await asyncio.gather(
*[
setter_function(task_batch, *setter_args)
for task_batch in self._batched(tasks, batch_size)
]
)

return list(chain.from_iterable(scoped_keys))

coro = async_request(self)

try:
return asyncio.run(coro)
except RuntimeError:
# we use nest_asyncio to support environments where an event loop
# is already running, such as in a Jupyter notebook
import nest_asyncio

nest_asyncio.apply()
return asyncio.run(coro)

async def _set_task_status(
self, tasks: List[ScopedKey], status: TaskStatusEnum
Expand All @@ -744,12 +799,6 @@ async def _set_task_status(
for task_sk in tasks_updated
]

async def _get_task_status(self, tasks: List[ScopedKey]) -> List[TaskStatusEnum]:
"""Get the statuses for many Tasks"""
data = dict(tasks=[t.dict() for t in tasks])
statuses = await self._post_resource_async(f"/bulk/tasks/status/get", data=data)
return statuses

def set_tasks_status(
self,
tasks: List[ScopedKey],
Expand Down Expand Up @@ -779,33 +828,15 @@ def set_tasks_status(
"""
status = TaskStatusEnum(status)

tasks = [
ScopedKey.from_str(task) if isinstance(task, str) else task
for task in tasks
]

@use_session
async def async_request(self):
scoped_keys = await asyncio.gather(
*[
self._set_task_status(task_batch, status)
for task_batch in self._batched(tasks, batch_size)
]
)

return list(chain.from_iterable(scoped_keys))

coro = async_request(self)

try:
return asyncio.run(coro)
except RuntimeError:
# we use nest_asyncio to support environments where an event loop
# is already running, such as in a Jupyter notebook
import nest_asyncio
return self._task_attribute_setter(
tasks, self._set_task_status, (status,), batch_size
)

nest_asyncio.apply()
return asyncio.run(coro)
async def _get_task_status(self, tasks: List[ScopedKey]) -> List[TaskStatusEnum]:
"""Get the statuses for many Tasks"""
data = dict(tasks=[t.dict() for t in tasks])
statuses = await self._post_resource_async(f"/bulk/tasks/status/get", data=data)
return statuses

def get_tasks_status(
self, tasks: List[ScopedKey], batch_size: int = 1000
Expand All @@ -827,40 +858,81 @@ def get_tasks_status(
given Task doesn't exist, ``None`` will be returned in its place.
"""
tasks = [
ScopedKey.from_str(task) if isinstance(task, str) else task
for task in tasks
return self._task_attribute_getter(tasks, self._get_task_status, batch_size)

async def _set_task_priority(
self, tasks: List[ScopedKey], priority: int
) -> List[Optional[ScopedKey]]:
data = dict(tasks=[t.dict() for t in tasks], priority=priority)
tasks_updated = await self._post_resource_async(
f"/bulk/tasks/priority/set", data=data
)
return [
ScopedKey.from_str(task_sk) if task_sk is not None else None
for task_sk in tasks_updated
]

@use_session
async def async_request(self):
statuses = await asyncio.gather(
*[
self._get_task_status(task_batch)
for task_batch in self._batched(tasks, batch_size)
]
)
def set_tasks_priority(
self,
tasks: List[ScopedKey],
priority: int,
batch_size: int = 1000,
) -> List[Optional[ScopedKey]]:
"""Set the priority of multiple Tasks.
return list(chain.from_iterable(statuses))
Parameters
----------
tasks
The Tasks to set the priority of.
priority
The priority to set for the Task. This value must be between 1 and
2**63 - 1, with lower values indicating an increased priority.
batch_size
The number of Tasks to include in a single request; use to tune
method call speed when requesting many priorities at once.
try:
return asyncio.run(async_request(self))
except RuntimeError:
# we use nest_asyncio to support environments where an event loop
# is already running, such as in a Jupyter notebook
import nest_asyncio
Returns
-------
updated
The ScopedKeys of the Tasks that were updated, in the same order
as given in `tasks`. If a given Task doesn't exist, ``None`` will
be returned in its place.
"""
return self._task_attribute_setter(
tasks, self._set_task_priority, (priority,), batch_size
)

nest_asyncio.apply()
return asyncio.run(async_request(self))
async def _get_task_priority(self, tasks: List[ScopedKey]) -> List[int]:
"""Get the priority for many Tasks"""
data = dict(tasks=[t.dict() for t in tasks])
priorities = await self._post_resource_async(
f"/bulk/tasks/priority/get", data=data
)
return priorities

def get_tasks_priority(
self,
tasks: List[ScopedKey],
):
raise NotImplementedError
batch_size: int = 1000,
) -> List[int]:
"""Get the priority of multiple Tasks.
def set_tasks_priority(self, tasks: List[ScopedKey], priority: int):
raise NotImplementedError
Parameters
----------
tasks
The Tasks to get the priority of.
batch_size
The number of Tasks to include in a single request; use to tune
method call speed when requesting many priorities at once.
Returns
-------
priorities
The priority of each Task in the same order as given in `tasks`. If a
given Task doesn't exist, ``None`` will be returned in its place.
"""
return self._task_attribute_getter(tasks, self._get_task_priority, batch_size)

### results

Expand Down
75 changes: 69 additions & 6 deletions alchemiscale/storage/statestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,14 +1710,77 @@ def set_tasks(
transformation, Transformation, scope
)

def set_task_priority(self, task: ScopedKey, priority: int):
q = f"""
MATCH (t:Task {{_scoped_key: "{task}"}})
SET t.priority = {priority}
RETURN t
def set_task_priority(
self, tasks: List[ScopedKey], priority: int
) -> List[Optional[ScopedKey]]:
"""Set the priority of a list of Tasks.
Parameters
----------
tasks
The list of Tasks to set the priority of.
priority
The priority to set the Tasks to.
Returns
-------
List[Optional[ScopedKey]]
A list of the Task ScopedKeys for which priority was changed; `None`
is given for any Tasks for which the priority could not be changed.
"""
if not (1 <= priority <= 2**63 - 1):
raise ValueError("priority must be between 1 and 2**63 - 1, inclusive")

with self.transaction() as tx:
tx.run(q)
q = """
WITH $scoped_keys AS batch
UNWIND batch AS scoped_key
OPTIONAL MATCH (t:Task {_scoped_key: scoped_key})
SET t.priority = $priority
RETURN scoped_key, t
"""
res = tx.run(q, scoped_keys=[str(t) for t in tasks], priority=priority)

task_results = []
for record in res:
task_i = record["t"]
scoped_key = record["scoped_key"]

# catch missing tasks
if task_i is None:
task_results.append(None)
else:
task_results.append(ScopedKey.from_str(scoped_key))
return task_results

def get_task_priority(self, tasks: List[ScopedKey]) -> List[Optional[int]]:
"""Get the priority of a list of Tasks.
Parameters
----------
tasks
The list of Tasks to get the priority for.
Returns
-------
List[Optional[int]]
A list of priorities in the same order as the provided Tasks.
If an element is ``None``, the Task could not be found.
"""
with self.transaction() as tx:
q = """
WITH $scoped_keys AS batch
UNWIND batch AS scoped_key
OPTIONAL MATCH (t:Task)
WHERE t._scoped_key = scoped_key
RETURN t.priority as priority
"""
res = tx.run(q, scoped_keys=[str(t) for t in tasks])
priorities = [rec["priority"] for rec in res]

return priorities

def delete_task(
self,
Expand Down
2 changes: 1 addition & 1 deletion alchemiscale/tests/integration/compute/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def n4js_preloaded(

# set task priority higher the first transformation
# used for claim determinism in some tests
n4js.set_task_priority(task_sks[transformations[0]][0], 1)
n4js.set_task_priority([task_sks[transformations[0]][0]], 1)

# add tasks from each transformation selected to each task hubs
n4js.action_tasks(
Expand Down
Loading

0 comments on commit 81d977b

Please sign in to comment.