Skip to content

Commit

Permalink
Merge pull request #217 from openforcefield/216-network-JSON-encoding
Browse files Browse the repository at this point in the history
Network representation optimizations for JSON encoding
  • Loading branch information
dotsdl committed Jan 4, 2024
2 parents c17a423 + b1dba97 commit 16b2776
Show file tree
Hide file tree
Showing 10 changed files with 374 additions and 81 deletions.
4 changes: 3 additions & 1 deletion alchemiscale/base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
oauth2_scheme,
)
from ..security.models import Token, TokenData, CredentialedEntity
from ..utils import gufe_to_keyed_dicts


def validate_scopes(scope: Scope, token: TokenData) -> None:
Expand Down Expand Up @@ -144,7 +145,8 @@ class GufeJSONResponse(JSONResponse):
media_type = "application/json"

def render(self, content: Any) -> bytes:
return json.dumps(content, cls=JSON_HANDLER.encoder).encode("utf-8")
keyed_dicts = gufe_to_keyed_dicts(content)
return json.dumps(keyed_dicts, cls=JSON_HANDLER.encoder).encode("utf-8")


class GzipRequest(Request):
Expand Down
3 changes: 1 addition & 2 deletions alchemiscale/compute/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from gufe.tokenization import GufeTokenizable, JSON_HANDLER

from ..base.api import (
GufeJSONResponse,
QueryGUFEHandler,
scope_params,
get_token_data_depends,
Expand Down Expand Up @@ -193,7 +192,7 @@ def claim_taskhub_tasks(
return [str(t) if t is not None else None for t in tasks]


@router.get("/tasks/{task_scoped_key}/transformation", response_class=GufeJSONResponse)
@router.get("/tasks/{task_scoped_key}/transformation")
def get_task_transformation(
task_scoped_key,
*,
Expand Down
50 changes: 17 additions & 33 deletions alchemiscale/interface/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..models import Scope, ScopedKey
from ..security.auth import get_token_data, oauth2_scheme
from ..security.models import Token, TokenData, CredentialedUserIdentity
from ..utils import keyed_dicts_to_gufe


app = FastAPI(title="AlchemiscaleAPI")
Expand Down Expand Up @@ -92,7 +93,7 @@ def list_scopes(
### inputs


@router.get("/exists/{scoped_key}", response_class=GufeJSONResponse)
@router.get("/exists/{scoped_key}")
def check_existence(
scoped_key,
*,
Expand All @@ -108,14 +109,14 @@ def check_existence(
@router.post("/networks", response_model=ScopedKey)
def create_network(
*,
network: Dict = Body(...),
network: List = Body(...),
scope: Scope,
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
validate_scopes(scope, token)

an = AlchemicalNetwork.from_dict(network)
an = keyed_dicts_to_gufe(network)

try:
an_sk = n4js.create_network(network=an, scope=scope)
Expand All @@ -131,7 +132,7 @@ def create_network(
return an_sk


@router.get("/networks", response_class=GufeJSONResponse)
@router.get("/networks")
def query_networks(
*,
name: str = None,
Expand Down Expand Up @@ -290,7 +291,7 @@ def get_chemicalsystem_transformations(
]


@router.get("/networks/{network_scoped_key}", response_class=GufeJSONResponse)
@router.get("/networks/{network_scoped_key}")
def get_network(
network_scoped_key,
*,
Expand All @@ -301,12 +302,10 @@ def get_network(
validate_scopes(sk.scope, token)

network = n4js.get_gufe(scoped_key=sk)
return gufe_to_json(network)
return GufeJSONResponse(network)


@router.get(
"/transformations/{transformation_scoped_key}", response_class=GufeJSONResponse
)
@router.get("/transformations/{transformation_scoped_key}")
def get_transformation(
transformation_scoped_key,
*,
Expand All @@ -317,12 +316,10 @@ def get_transformation(
validate_scopes(sk.scope, token)

transformation = n4js.get_gufe(scoped_key=sk)
return gufe_to_json(transformation)
return GufeJSONResponse(transformation)


@router.get(
"/chemicalsystems/{chemicalsystem_scoped_key}", response_class=GufeJSONResponse
)
@router.get("/chemicalsystems/{chemicalsystem_scoped_key}")
def get_chemicalsystem(
chemicalsystem_scoped_key,
*,
Expand All @@ -333,7 +330,7 @@ def get_chemicalsystem(
validate_scopes(sk.scope, token)

chemicalsystem = n4js.get_gufe(scoped_key=sk)
return gufe_to_json(chemicalsystem)
return GufeJSONResponse(chemicalsystem)


### compute
Expand Down Expand Up @@ -739,7 +736,7 @@ def get_task_status(
return status[0].value


@router.get("/tasks/{task_scoped_key}/transformation", response_class=GufeJSONResponse)
@router.get("/tasks/{task_scoped_key}/transformation")
def get_task_transformation(
task_scoped_key,
*,
Expand All @@ -762,10 +759,7 @@ def get_task_transformation(
### results


@router.get(
"/transformations/{transformation_scoped_key}/results",
response_class=GufeJSONResponse,
)
@router.get("/transformations/{transformation_scoped_key}/results")
def get_transformation_results(
transformation_scoped_key,
*,
Expand All @@ -778,10 +772,7 @@ def get_transformation_results(
return [str(sk) for sk in n4js.get_transformation_results(sk)]


@router.get(
"/transformations/{transformation_scoped_key}/failures",
response_class=GufeJSONResponse,
)
@router.get("/transformations/{transformation_scoped_key}/failures")
def get_transformation_failures(
transformation_scoped_key,
*,
Expand All @@ -795,8 +786,7 @@ def get_transformation_failures(


@router.get(
"/transformations/{transformation_scoped_key}/{route}/{protocoldagresultref_scoped_key}",
response_class=GufeJSONResponse,
"/transformations/{transformation_scoped_key}/{route}/{protocoldagresultref_scoped_key}"
)
def get_protocoldagresult(
protocoldagresultref_scoped_key,
Expand Down Expand Up @@ -844,10 +834,7 @@ def get_protocoldagresult(
return [pdr]


@router.get(
"/tasks/{task_scoped_key}/results",
response_class=GufeJSONResponse,
)
@router.get("/tasks/{task_scoped_key}/results")
def get_task_results(
task_scoped_key,
*,
Expand All @@ -860,10 +847,7 @@ def get_task_results(
return [str(sk) for sk in n4js.get_task_results(sk)]


@router.get(
"/tasks/{task_scoped_key}/failures",
response_class=GufeJSONResponse,
)
@router.get("/tasks/{task_scoped_key}/failures")
def get_task_failures(
task_scoped_key,
*,
Expand Down
57 changes: 29 additions & 28 deletions alchemiscale/interface/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..strategies import Strategy
from ..security.models import CredentialedUserIdentity
from ..validators import validate_network_nonself
from ..utils import gufe_to_keyed_dicts, keyed_dicts_to_gufe


class AlchemiscaleClientError(AlchemiscaleBaseClientError):
Expand Down Expand Up @@ -125,7 +126,8 @@ def create_network(
sk = self.get_scoped_key(network, scope)

def post():
data = dict(network=network.to_dict(), scope=scope.dict())
keyed_dicts = gufe_to_keyed_dicts(network)
data = dict(network=keyed_dicts, scope=scope.dict())
return self._post_resource("/networks", data, compress=compress)

if visualize:
Expand Down Expand Up @@ -298,6 +300,11 @@ def get_network(
The retrieved AlchemicalNetwork.
"""

def _get_network():
content = self._get_resource(f"/networks/{network}", compress=compress)
return keyed_dicts_to_gufe(content)

if visualize:
from rich.progress import (
Progress,
Expand All @@ -311,16 +318,12 @@ def get_network(
f"Retrieving [bold]'{network}'[/bold]...", total=None
)

an = json_to_gufe(
self._get_resource(f"/networks/{network}", compress=compress)
)
an = _get_network()

progress.start_task(task)
progress.update(task, total=1, completed=1)
else:
an = json_to_gufe(
self._get_resource(f"/networks/{network}", compress=compress)
)

an = _get_network()
return an

@lru_cache(maxsize=10000)
Expand Down Expand Up @@ -351,6 +354,13 @@ def get_transformation(
The retrieved Transformation.
"""

def _get_transformation():
content = self._get_resource(
f"/transformations/{transformation}", compress=compress
)
return keyed_dicts_to_gufe(content)

if visualize:
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn

Expand All @@ -359,19 +369,11 @@ def get_transformation(
f"Retrieving [bold]'{transformation}'[/bold]...", total=None
)

tf = json_to_gufe(
self._get_resource(
f"/transformations/{transformation}", compress=compress
)
)
tf = _get_transformation()
progress.start_task(task)
progress.update(task, total=1, completed=1)
else:
tf = json_to_gufe(
self._get_resource(
f"/transformations/{transformation}", compress=compress
)
)
tf = _get_transformation()

return tf

Expand Down Expand Up @@ -403,6 +405,13 @@ def get_chemicalsystem(
The retrieved ChemicalSystem.
"""

def _get_chemicalsystem():
content = self._get_resource(
f"/chemicalsystems/{chemicalsystem}", compress=compress
)
return keyed_dicts_to_gufe(content)

if visualize:
from rich.progress import Progress

Expand All @@ -411,20 +420,12 @@ def get_chemicalsystem(
f"Retrieving [bold]'{chemicalsystem}'[/bold]...", total=None
)

cs = json_to_gufe(
self._get_resource(
f"/chemicalsystems/{chemicalsystem}", compress=compress
)
)
cs = _get_chemicalsystem()

progress.start_task(task)
progress.update(task, total=1, completed=1)
else:
cs = json_to_gufe(
self._get_resource(
f"/chemicalsystems/{chemicalsystem}", compress=compress
)
)
cs = _get_chemicalsystem()

return cs

Expand Down
20 changes: 10 additions & 10 deletions alchemiscale/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def n4js_fresh(graph):
return n4js


@fixture(scope="session")
@fixture(scope="module")
def s3objectstore_settings():
os.environ["AWS_ACCESS_KEY_ID"] = "test-key-id"
os.environ["AWS_SECRET_ACCESS_KEY"] = "test-key"
Expand Down Expand Up @@ -219,7 +219,7 @@ def s3os(s3objectstore_settings):
# TODO: add in atom mapping once `gufe`#35 is settled


@fixture(scope="session")
@fixture(scope="module")
def network_tyk2():
tyk2s = tyk2.get_system()

Expand Down Expand Up @@ -266,17 +266,17 @@ def network_tyk2():
)


@fixture(scope="session")
@fixture(scope="module")
def transformation(network_tyk2):
return list(network_tyk2.edges)[0]


@fixture(scope="session")
@fixture(scope="module")
def chemicalsystem(network_tyk2):
return list(network_tyk2.nodes)[0]


@fixture(scope="session")
@fixture(scope="module")
def protocoldagresults(tmpdir_factory, transformation):
pdrs = []
for i in range(3):
Expand All @@ -300,7 +300,7 @@ def protocoldagresults(tmpdir_factory, transformation):
return pdrs


@fixture(scope="session")
@fixture(scope="module")
def network_tyk2_failure(network_tyk2):
transformation = list(network_tyk2.edges)[0]

Expand All @@ -316,12 +316,12 @@ def network_tyk2_failure(network_tyk2):
)


@fixture(scope="session")
@fixture(scope="module")
def transformation_failure(network_tyk2_failure):
return [t for t in network_tyk2_failure.edges if t.name == "broken"][0]


@fixture(scope="session")
@fixture(scope="module")
def protocoldagresults_failure(tmpdir_factory, transformation_failure):
pdrs = []
for i in range(3):
Expand All @@ -346,13 +346,13 @@ def protocoldagresults_failure(tmpdir_factory, transformation_failure):
return pdrs


@fixture(scope="session")
@fixture(scope="module")
def scope_test():
"""Primary scope for individual tests"""
return Scope(org="test_org", campaign="test_campaign", project="test_project")


@fixture(scope="session")
@fixture(scope="module")
def multiple_scopes(scope_test):
scopes = [scope_test] # Append initial test
# Augment
Expand Down
Loading

0 comments on commit 16b2776

Please sign in to comment.