Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Network representation optimizations for JSON encoding #217

Merged
merged 24 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d6dbb30
Added optimizations for network representation
ianmkenney Dec 22, 2023
4cfe22a
Fixed API tests
ianmkenney Dec 23, 2023
667539f
Added helper functions
ianmkenney Dec 23, 2023
47f9324
Use helper function
ianmkenney Dec 23, 2023
7326948
Return GufeJSONResponse directly
ianmkenney Dec 27, 2023
53f53d3
Updated imports and added unit tests
ianmkenney Dec 28, 2023
9f3d2e0
Registry changes should be thrown away if there is an exception
ianmkenney Dec 28, 2023
6860e6c
Module isolation and network name change
ianmkenney Dec 29, 2023
1e65c19
Separated integration and unit tests in CI
ianmkenney Dec 29, 2023
d032877
Testing typo
ianmkenney Dec 29, 2023
bc41a56
Added debugging output
ianmkenney Dec 29, 2023
8d1b237
Reverted render arg type hint
ianmkenney Dec 29, 2023
70758a2
Test all object types in a network
ianmkenney Dec 29, 2023
cabf810
ensure that registry is rebuilt
ianmkenney Dec 29, 2023
3018bb3
don't pop the transformation components
ianmkenney Dec 29, 2023
0c0df63
typo in test
ianmkenney Dec 29, 2023
c21a79f
Alternative way of getting transfromation scopedkey
ianmkenney Dec 29, 2023
0bde59a
moving to module scope in tests
ianmkenney Dec 29, 2023
ca5ca9a
Revert splitting up tests
ianmkenney Dec 30, 2023
ed10fd8
updated test contents
ianmkenney Dec 30, 2023
a5b8cdb
Cleaning up PR
ianmkenney Dec 30, 2023
bd858a2
Added keep_changes test for RegistryBackup
ianmkenney Dec 30, 2023
43dbe2f
Keep pdr in test
ianmkenney Dec 30, 2023
b1dba97
Avoid removing the transformation
ianmkenney Dec 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion alchemiscale/base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
oauth2_scheme,
)
from ..security.models import Token, TokenData, CredentialedEntity
from ..utils import gufe_to_keyed_dicts


def validate_scopes(scope: Scope, token: TokenData) -> None:
Expand Down Expand Up @@ -144,7 +145,8 @@ class GufeJSONResponse(JSONResponse):
media_type = "application/json"

def render(self, content: Any) -> bytes:
dotsdl marked this conversation as resolved.
Show resolved Hide resolved
return json.dumps(content, cls=JSON_HANDLER.encoder).encode("utf-8")
keyed_dicts = gufe_to_keyed_dicts(content)
return json.dumps(keyed_dicts, cls=JSON_HANDLER.encoder).encode("utf-8")


class GzipRequest(Request):
Expand Down
3 changes: 1 addition & 2 deletions alchemiscale/compute/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from gufe.tokenization import GufeTokenizable, JSON_HANDLER

from ..base.api import (
GufeJSONResponse,
QueryGUFEHandler,
scope_params,
get_token_data_depends,
Expand Down Expand Up @@ -193,7 +192,7 @@ def claim_taskhub_tasks(
return [str(t) if t is not None else None for t in tasks]


@router.get("/tasks/{task_scoped_key}/transformation", response_class=GufeJSONResponse)
@router.get("/tasks/{task_scoped_key}/transformation")
def get_task_transformation(
task_scoped_key,
*,
Expand Down
51 changes: 18 additions & 33 deletions alchemiscale/interface/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gufe import AlchemicalNetwork, ChemicalSystem, Transformation
from gufe.protocols import ProtocolDAGResult
from gufe.tokenization import GufeTokenizable, JSON_HANDLER
import networkx as nx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we use this import in this module.


from ..base.api import (
GufeJSONResponse,
Expand All @@ -38,6 +39,7 @@
from ..models import Scope, ScopedKey
from ..security.auth import get_token_data, oauth2_scheme
from ..security.models import Token, TokenData, CredentialedUserIdentity
from ..utils import gufe_to_keyed_dicts, keyed_dicts_to_gufe


app = FastAPI(title="AlchemiscaleAPI")
Expand Down Expand Up @@ -92,7 +94,7 @@ def list_scopes(
### inputs


@router.get("/exists/{scoped_key}", response_class=GufeJSONResponse)
@router.get("/exists/{scoped_key}")
def check_existence(
scoped_key,
*,
Expand All @@ -108,14 +110,14 @@ def check_existence(
@router.post("/networks", response_model=ScopedKey)
def create_network(
*,
network: Dict = Body(...),
network: List = Body(...),
scope: Scope,
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
validate_scopes(scope, token)

an = AlchemicalNetwork.from_dict(network)
an = keyed_dicts_to_gufe(network)

try:
an_sk = n4js.create_network(network=an, scope=scope)
Expand All @@ -131,7 +133,7 @@ def create_network(
return an_sk


@router.get("/networks", response_class=GufeJSONResponse)
@router.get("/networks")
def query_networks(
*,
name: str = None,
Expand Down Expand Up @@ -290,7 +292,7 @@ def get_chemicalsystem_transformations(
]


@router.get("/networks/{network_scoped_key}", response_class=GufeJSONResponse)
@router.get("/networks/{network_scoped_key}")
def get_network(
network_scoped_key,
*,
Expand All @@ -301,12 +303,10 @@ def get_network(
validate_scopes(sk.scope, token)

network = n4js.get_gufe(scoped_key=sk)
return gufe_to_json(network)
return GufeJSONResponse(network)


@router.get(
"/transformations/{transformation_scoped_key}", response_class=GufeJSONResponse
)
@router.get("/transformations/{transformation_scoped_key}")
def get_transformation(
transformation_scoped_key,
*,
Expand All @@ -317,12 +317,10 @@ def get_transformation(
validate_scopes(sk.scope, token)

transformation = n4js.get_gufe(scoped_key=sk)
return gufe_to_json(transformation)
return GufeJSONResponse(content=transformation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My OCD: remove content= for consistency elsewhere. 😉



@router.get(
"/chemicalsystems/{chemicalsystem_scoped_key}", response_class=GufeJSONResponse
)
@router.get("/chemicalsystems/{chemicalsystem_scoped_key}")
def get_chemicalsystem(
chemicalsystem_scoped_key,
*,
Expand All @@ -333,7 +331,7 @@ def get_chemicalsystem(
validate_scopes(sk.scope, token)

chemicalsystem = n4js.get_gufe(scoped_key=sk)
return gufe_to_json(chemicalsystem)
return GufeJSONResponse(chemicalsystem)


### compute
Expand Down Expand Up @@ -739,7 +737,7 @@ def get_task_status(
return status[0].value


@router.get("/tasks/{task_scoped_key}/transformation", response_class=GufeJSONResponse)
@router.get("/tasks/{task_scoped_key}/transformation")
def get_task_transformation(
task_scoped_key,
*,
Expand All @@ -762,10 +760,7 @@ def get_task_transformation(
### results


@router.get(
"/transformations/{transformation_scoped_key}/results",
response_class=GufeJSONResponse,
)
@router.get("/transformations/{transformation_scoped_key}/results")
def get_transformation_results(
transformation_scoped_key,
*,
Expand All @@ -778,10 +773,7 @@ def get_transformation_results(
return [str(sk) for sk in n4js.get_transformation_results(sk)]


@router.get(
"/transformations/{transformation_scoped_key}/failures",
response_class=GufeJSONResponse,
)
@router.get("/transformations/{transformation_scoped_key}/failures")
def get_transformation_failures(
transformation_scoped_key,
*,
Expand All @@ -795,8 +787,7 @@ def get_transformation_failures(


@router.get(
"/transformations/{transformation_scoped_key}/{route}/{protocoldagresultref_scoped_key}",
response_class=GufeJSONResponse,
"/transformations/{transformation_scoped_key}/{route}/{protocoldagresultref_scoped_key}"
)
def get_protocoldagresult(
protocoldagresultref_scoped_key,
Expand Down Expand Up @@ -844,10 +835,7 @@ def get_protocoldagresult(
return [pdr]


@router.get(
"/tasks/{task_scoped_key}/results",
response_class=GufeJSONResponse,
)
@router.get("/tasks/{task_scoped_key}/results")
def get_task_results(
task_scoped_key,
*,
Expand All @@ -860,10 +848,7 @@ def get_task_results(
return [str(sk) for sk in n4js.get_task_results(sk)]


@router.get(
"/tasks/{task_scoped_key}/failures",
response_class=GufeJSONResponse,
)
@router.get("/tasks/{task_scoped_key}/failures")
def get_task_failures(
task_scoped_key,
*,
Expand Down
58 changes: 30 additions & 28 deletions alchemiscale/interface/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from gufe import AlchemicalNetwork, Transformation, ChemicalSystem
from gufe.tokenization import GufeTokenizable, JSON_HANDLER, GufeKey
from gufe.protocols import ProtocolResult, ProtocolDAGResult
import networkx as nx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we use this import in this module.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we do, see get_transformation_tasks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's strange; not sure how we didn't already have this import 🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we did, I added it as a repeated import because I didn't notice it 😮‍💨



from ..base.client import (
Expand All @@ -30,6 +31,7 @@
from ..strategies import Strategy
from ..security.models import CredentialedUserIdentity
from ..validators import validate_network_nonself
from ..utils import gufe_to_keyed_dicts, keyed_dicts_to_gufe


class AlchemiscaleClientError(AlchemiscaleBaseClientError):
Expand Down Expand Up @@ -125,7 +127,8 @@ def create_network(
sk = self.get_scoped_key(network, scope)

def post():
data = dict(network=network.to_dict(), scope=scope.dict())
keyed_dicts = gufe_to_keyed_dicts(network)
data = dict(network=keyed_dicts, scope=scope.dict())
return self._post_resource("/networks", data, compress=compress)

if visualize:
Expand Down Expand Up @@ -298,6 +301,11 @@ def get_network(
The retrieved AlchemicalNetwork.

"""

def _get_network():
content = self._get_resource(f"/networks/{network}", compress=compress)
return keyed_dicts_to_gufe(content)

if visualize:
from rich.progress import (
Progress,
Expand All @@ -311,16 +319,12 @@ def get_network(
f"Retrieving [bold]'{network}'[/bold]...", total=None
)

an = json_to_gufe(
self._get_resource(f"/networks/{network}", compress=compress)
)
an = _get_network()

progress.start_task(task)
progress.update(task, total=1, completed=1)
else:
an = json_to_gufe(
self._get_resource(f"/networks/{network}", compress=compress)
)

an = _get_network()
return an

@lru_cache(maxsize=10000)
Expand Down Expand Up @@ -351,6 +355,13 @@ def get_transformation(
The retrieved Transformation.

"""

def _get_transformation():
content = self._get_resource(
f"/transformations/{transformation}", compress=compress
)
return keyed_dicts_to_gufe(content)

if visualize:
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn

Expand All @@ -359,19 +370,11 @@ def get_transformation(
f"Retrieving [bold]'{transformation}'[/bold]...", total=None
)

tf = json_to_gufe(
self._get_resource(
f"/transformations/{transformation}", compress=compress
)
)
tf = _get_transformation()
progress.start_task(task)
progress.update(task, total=1, completed=1)
else:
tf = json_to_gufe(
self._get_resource(
f"/transformations/{transformation}", compress=compress
)
)
tf = _get_transformation()

return tf

Expand Down Expand Up @@ -403,6 +406,13 @@ def get_chemicalsystem(
The retrieved ChemicalSystem.

"""

def _get_chemicalsystem():
content = self._get_resource(
f"/chemicalsystems/{chemicalsystem}", compress=compress
)
return keyed_dicts_to_gufe(content)

if visualize:
from rich.progress import Progress

Expand All @@ -411,20 +421,12 @@ def get_chemicalsystem(
f"Retrieving [bold]'{chemicalsystem}'[/bold]...", total=None
)

cs = json_to_gufe(
self._get_resource(
f"/chemicalsystems/{chemicalsystem}", compress=compress
)
)
cs = _get_chemicalsystem()

progress.start_task(task)
progress.update(task, total=1, completed=1)
else:
cs = json_to_gufe(
self._get_resource(
f"/chemicalsystems/{chemicalsystem}", compress=compress
)
)
cs = _get_chemicalsystem()

return cs

Expand Down
10 changes: 9 additions & 1 deletion alchemiscale/tests/integration/interface/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from pathlib import Path

from gufe import AlchemicalNetwork, ChemicalSystem, Transformation
from gufe.tokenization import TOKENIZABLE_REGISTRY, GufeKey
from gufe.tokenization import TOKENIZABLE_REGISTRY, GufeKey, get_all_gufe_objs
from gufe.protocols.protocoldag import execute_DAG
from gufe.tests.test_protocol import BrokenProtocol
import networkx as nx

from alchemiscale.models import ScopedKey, Scope
from alchemiscale.storage.models import TaskStatusEnum
from alchemiscale.interface import client
from alchemiscale.utils import gufe_to_keyed_dicts
from alchemiscale.tests.integration.interface.utils import (
get_user_settings_override,
)
Expand Down Expand Up @@ -1152,6 +1153,13 @@ def test_get_transformation_and_network_results(
for pdr in protocoldagresults:
TOKENIZABLE_REGISTRY.pop(pdr.key, None)

# get_transformation_results constructs the GufeTokenizable objects
# needed to create the transformation. Therefore we need to clear the registry
# of these objects to be sure that the correct objects are returned from the
# database.
for gt in get_all_gufe_objs(transformation):
TOKENIZABLE_REGISTRY.pop(gt.key, None)

# user client : pull transformation results, evaluate
protocolresult = user_client.get_transformation_results(transformation_sk)

Expand Down
10 changes: 7 additions & 3 deletions alchemiscale/tests/integration/interface/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from alchemiscale.models import ScopedKey
from alchemiscale.base.client import json_to_gufe
from alchemiscale.utils import keyed_dicts_to_gufe, gufe_to_keyed_dicts

import networkx as nx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this import is used in this module.



def pre_load_payload(network, scope, name="incomplete 2"):
Expand All @@ -14,7 +17,7 @@ def pre_load_payload(network, scope, name="incomplete 2"):
edges=list(network.edges)[:-3], nodes=network.nodes, name=name
)
headers = {"Content-type": "application/json"}
data = dict(network=new_network.to_dict(), scope=scope.dict())
data = dict(network=gufe_to_keyed_dicts(new_network), scope=scope.dict())
jsondata = json.dumps(data, cls=JSON_HANDLER.encoder)

return new_network, headers, jsondata
Expand Down Expand Up @@ -106,8 +109,9 @@ def test_get_network(self, prepared_network, test_client):

assert response.status_code == 200

content = json.loads(response.text, cls=JSON_HANDLER.decoder)
network_ = json_to_gufe(content)
network_ = keyed_dicts_to_gufe(
json.loads(response.text, cls=JSON_HANDLER.decoder)
)

assert network_.key == network.key
assert network_ is network
Expand Down
Loading
Loading