From b72a7fcf37dc4bcf04d891bb6faf52bd628fa1a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Thu, 13 Jun 2024 16:32:50 +0200 Subject: [PATCH 1/4] Use actual types for instance memory specifications --- graphdatascience/session/aura_api.py | 11 +-- .../session/aura_api_responses.py | 10 +-- graphdatascience/session/aurads_sessions.py | 6 +- graphdatascience/session/session_info.py | 3 +- graphdatascience/session/session_sizes.py | 18 +++++ graphdatascience/tests/unit/test_aura_api.py | 44 ++++++------ .../tests/unit/test_aurads_sessions.py | 44 ++++++------ .../tests/unit/test_dedicated_sessions.py | 67 ++++++++++++++----- .../tests/unit/test_session_sizes.py | 2 +- 9 files changed, 136 insertions(+), 69 deletions(-) diff --git a/graphdatascience/session/aura_api.py b/graphdatascience/session/aura_api.py index 4ce34d18e..cb621ed13 100644 --- a/graphdatascience/session/aura_api.py +++ b/graphdatascience/session/aura_api.py @@ -9,6 +9,7 @@ import requests as req from requests import HTTPError +from graphdatascience.session.session_sizes import SessionMemory from graphdatascience.session.algorithm_category import AlgorithmCategory from graphdatascience.session.aura_api_responses import ( EstimationDetails, @@ -62,11 +63,11 @@ def extract_id(uri: str) -> str: return host.split(".")[0].split("-")[0] - def create_session(self, name: str, dbid: str, pwd: str, memory: str) -> SessionDetails: + def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemory) -> SessionDetails: response = req.post( f"{self._base_uri}/v1beta5/data-science/sessions", headers=self._build_header(), - json={"name": name, "instance_id": dbid, "password": pwd, "memory": memory}, + json={"name": name, "instance_id": dbid, "password": pwd, "memory": memory.value}, ) response.raise_for_status() @@ -141,12 +142,14 @@ def delete_session(self, session_id: str, dbid: str) -> bool: return False - def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails: + def create_instance( + self, name: str, memory: SessionMemory, cloud_provider: str, region: str + ) -> InstanceCreateDetails: tenant_details = self.tenant_details() data = { "name": name, - "memory": memory, + "memory": memory.value, "version": "5", "region": region, "type": tenant_details.ds_type, diff --git a/graphdatascience/session/aura_api_responses.py b/graphdatascience/session/aura_api_responses.py index cd533a192..77bc785f0 100644 --- a/graphdatascience/session/aura_api_responses.py +++ b/graphdatascience/session/aura_api_responses.py @@ -9,13 +9,15 @@ from pandas import Timedelta +from .session_sizes import SessionMemory + @dataclass(repr=True, frozen=True) class SessionDetails: id: str name: str instance_id: str - memory: str + memory: SessionMemory status: str host: str created_at: datetime @@ -31,7 +33,7 @@ def fromJson(cls, json: Dict[str, Any]) -> SessionDetails: id=json["id"], name=json["name"], instance_id=json["instance_id"], - memory=json["memory"], + memory=SessionMemory.fromApiResponse(json["memory"]), status=json["status"], host=json["host"], expiry_date=TimeParser.fromisoformat(expiry_date) if expiry_date else None, @@ -67,7 +69,7 @@ def fromJson(cls, json: Dict[str, Any]) -> InstanceDetails: class InstanceSpecificDetails(InstanceDetails): status: str connection_url: str - memory: str + memory: SessionMemory type: str region: str @@ -80,7 +82,7 @@ def fromJson(cls, json: Dict[str, Any]) -> InstanceSpecificDetails: cloud_provider=json["cloud_provider"], status=json["status"], connection_url=json.get("connection_url", ""), - memory=json.get("memory", ""), + memory=SessionMemory.fromApiResponse(json.get("memory", "")), type=json["type"], region=json["region"], ) diff --git a/graphdatascience/session/aurads_sessions.py b/graphdatascience/session/aurads_sessions.py index c54ec34a8..fa3b79166 100644 --- a/graphdatascience/session/aurads_sessions.py +++ b/graphdatascience/session/aurads_sessions.py @@ -54,9 +54,9 @@ def get_or_create( if existing_session: session_id = existing_session.id # 0MB is AuraAPI default value for memory if none can be retrieved - if existing_session.memory != "0MB" and existing_session.memory != memory.value: + if existing_session.memory != memory: raise ValueError( - f"Session `{session_name}` already exists with memory `{existing_session.memory}`. " + f"Session `{session_name}` already exists with memory `{existing_session.memory.value}`. " f"Requested memory `{memory.value}` does not match." ) else: @@ -128,7 +128,7 @@ def _create_session( region = self._ds_region(db_instance.region, db_instance.cloud_provider) create_details = self._aura_api.create_instance( - SessionNameHelper.instance_name(session_name), memory.value, db_instance.cloud_provider, region + SessionNameHelper.instance_name(session_name), memory, db_instance.cloud_provider, region ) return create_details diff --git a/graphdatascience/session/session_info.py b/graphdatascience/session/session_info.py index 50d675764..23fae9165 100644 --- a/graphdatascience/session/session_info.py +++ b/graphdatascience/session/session_info.py @@ -4,6 +4,7 @@ from datetime import datetime from typing import Optional +from graphdatascience.session.session_sizes import SessionMemory from graphdatascience.session.aura_api_responses import SessionDetails @@ -18,7 +19,7 @@ class SessionInfo: """ name: str - memory: str + memory: SessionMemory @classmethod def from_session_details(cls, details: SessionDetails) -> ExtendedSessionInfo: diff --git a/graphdatascience/session/session_sizes.py b/graphdatascience/session/session_sizes.py index 2967eb7ad..f372873f3 100644 --- a/graphdatascience/session/session_sizes.py +++ b/graphdatascience/session/session_sizes.py @@ -7,6 +7,7 @@ class SessionMemory(Enum): Enumeration representing session main memory configurations. """ + m_4GB = "4GB" m_8GB = "8GB" m_16GB = "16GB" m_24GB = "24GB" @@ -29,3 +30,20 @@ def all_values(cls) -> List[str]: """ return [e.value for e in cls] + + @staticmethod + def fromApiResponse(value: str) -> "SessionMemory": + """ + Converts the string value from an API response to a SessionMemory enumeration value. + + Args: + value: The string value from the API response. + + Returns: + The SessionMemory enumeration value. + + """ + try: + return SessionMemory(value.replace("Gi", "GB")) + except ValueError: + raise ValueError(f"Unsupported memory configuration: {value}") diff --git a/graphdatascience/tests/unit/test_aura_api.py b/graphdatascience/tests/unit/test_aura_api.py index 5b87816ca..87383117c 100644 --- a/graphdatascience/tests/unit/test_aura_api.py +++ b/graphdatascience/tests/unit/test_aura_api.py @@ -6,6 +6,7 @@ from requests import HTTPError from requests_mock import Mocker +from graphdatascience.session import SessionMemory from graphdatascience.session.algorithm_category import AlgorithmCategory from graphdatascience.session.aura_api import AuraApi from graphdatascience.session.aura_api_responses import ( @@ -33,11 +34,11 @@ def test_create_session(requests_mock: Mocker) -> None: "instance_id": "dbid-1", "created_at": "1970-01-01T00:00:00Z", "host": "1.2.3.4", - "memory": "4G", + "memory": "4Gi", }, ) - result = api.create_session("name-0", "dbid-1", "pwd-2", "4G") + result = api.create_session("name-0", "dbid-1", "pwd-2", SessionMemory.m_4GB) assert result == SessionDetails( id="id0", @@ -46,7 +47,7 @@ def test_create_session(requests_mock: Mocker) -> None: instance_id="dbid-1", created_at=TimeParser.fromisoformat("1970-01-01T00:00:00Z"), host="1.2.3.4", - memory="4G", + memory=SessionMemory.m_4GB, expiry_date=None, ttl=None, ) @@ -67,7 +68,7 @@ def test_list_session(requests_mock: Mocker) -> None: "instance_id": "dbid-1", "created_at": "1970-01-01T00:00:00Z", "host": "1.2.3.4", - "memory": "4G", + "memory": "4Gi", "expiry_date": "1977-01-01T00:00:00Z", }, ) @@ -81,7 +82,7 @@ def test_list_session(requests_mock: Mocker) -> None: instance_id="dbid-1", created_at=TimeParser.fromisoformat("1970-01-01T00:00:00Z"), host="1.2.3.4", - memory="4G", + memory=SessionMemory.m_4GB, expiry_date=TimeParser.fromisoformat("1977-01-01T00:00:00Z"), ttl=None, ) @@ -101,7 +102,7 @@ def test_list_sessions(requests_mock: Mocker) -> None: "instance_id": "dbid-1", "created_at": "1970-01-01T00:00:00Z", "host": "1.2.3.4", - "memory": "4G", + "memory": "4Gi", "expiry_date": "1977-01-01T00:00:00Z", }, { @@ -110,7 +111,7 @@ def test_list_sessions(requests_mock: Mocker) -> None: "status": "Creating", "instance_id": "dbid-3", "created_at": "2012-01-01T00:00:00Z", - "memory": "8G", + "memory": "8Gi", "host": "foo.bar", }, ], @@ -125,7 +126,7 @@ def test_list_sessions(requests_mock: Mocker) -> None: instance_id="dbid-1", created_at=TimeParser.fromisoformat("1970-01-01T00:00:00Z"), host="1.2.3.4", - memory="4G", + memory=SessionMemory.m_4GB, expiry_date=TimeParser.fromisoformat("1977-01-01T00:00:00Z"), ttl=None, ) @@ -136,7 +137,7 @@ def test_list_sessions(requests_mock: Mocker) -> None: status="Creating", instance_id="dbid-3", created_at=TimeParser.fromisoformat("2012-01-01T00:00:00Z"), - memory="8G", + memory=SessionMemory.m_8GB, host="foo.bar", expiry_date=None, ttl=None, @@ -200,7 +201,7 @@ def test_dont_wait_forever_for_session(requests_mock: Mocker, caplog: LogCapture "instance_id": "dbid-1", "created_at": "1970-01-01T00:00:00Z", "host": "foo.bar", - "memory": "4G", + "memory": "4Gi", "expiry_date": "1977-01-01T00:00:00Z", }, ) @@ -227,7 +228,7 @@ def test_wait_for_session_running(requests_mock: Mocker) -> None: "instance_id": "dbid-1", "created_at": "1970-01-01T00:00:00Z", "host": "foo.bar", - "memory": "4G", + "memory": "4Gi", "expiry_date": "1977-01-01T00:00:00Z", }, ) @@ -253,7 +254,7 @@ def test_delete_instance(requests_mock: Mocker) -> None: "connection_url": "", "tenant_id": "", "cloud_provider": "", - "memory": "", + "memory": "4Gi", "region": "", "type": "", } @@ -269,7 +270,7 @@ def test_delete_instance(requests_mock: Mocker) -> None: cloud_provider="", status="deleting", connection_url="", - memory="", + memory=SessionMemory.m_4GB, region="", type="", ) @@ -336,7 +337,7 @@ def test_create_instance(requests_mock: Mocker) -> None: }, ) - api.create_instance("name", "16GB", "gcp", "leipzig-1") + api.create_instance("name", SessionMemory.m_16GB, "gcp", "leipzig-1") requested_data = requests_mock.request_history[-1].json() assert requested_data["name"] == "name" @@ -437,6 +438,7 @@ def test_list_instance_missing_memory_field(requests_mock: Mocker) -> None: "status": "creating", "tenant_id": "046046d1-6996-53e4-8880-5b822766e1f9", "type": "enterprise-ds", + "memory": "16Gi", } }, ) @@ -444,7 +446,7 @@ def test_list_instance_missing_memory_field(requests_mock: Mocker) -> None: result = api.list_instance("id0") assert result and result.id == "a10fb995" - assert result.memory == "" + assert result.memory == SessionMemory.m_16GB def test_list_missing_instance(requests_mock: Mocker) -> None: @@ -483,6 +485,7 @@ def test_dont_wait_forever(requests_mock: Mocker, caplog: LogCaptureFixture) -> "region": None, "tenant_id": None, "type": None, + "memory": "4Gi", } }, ) @@ -512,6 +515,7 @@ def test_wait_for_instance_running(requests_mock: Mocker) -> None: "region": None, "tenant_id": None, "type": None, + "memory": "4Gi", } }, ) @@ -535,6 +539,7 @@ def test_wait_for_instance_deleting(requests_mock: Mocker) -> None: "region": None, "tenant_id": None, "type": None, + "memory": "4Gi", } }, ) @@ -550,6 +555,7 @@ def test_wait_for_instance_deleting(requests_mock: Mocker) -> None: "region": None, "tenant_id": None, "type": None, + "memory": "4Gi", } }, ) @@ -634,7 +640,7 @@ def test_parse_session_info() -> None: session_details = { "id": "test_id", "name": "test_session", - "memory": "small", + "memory": "4Gi", "instance_id": "test_instance", "status": "running", "expiry_date": "2022-01-01T00:00:00Z", @@ -647,7 +653,7 @@ def test_parse_session_info() -> None: assert session_info == SessionDetails( id="test_id", name="test_session", - memory="small", + memory=SessionMemory.m_4GB, instance_id="test_instance", status="running", host="a.b", @@ -661,7 +667,7 @@ def test_parse_session_info_without_optionals() -> None: session_details = { "id": "test_id", "name": "test_session", - "memory": "small", + "memory": "16Gi", "instance_id": "test_instance", "status": "running", "host": "a.b", @@ -672,7 +678,7 @@ def test_parse_session_info_without_optionals() -> None: assert session_info == SessionDetails( id="test_id", name="test_session", - memory="small", + memory=SessionMemory.m_16GB, instance_id="test_instance", host="a.b", status="running", diff --git a/graphdatascience/tests/unit/test_aurads_sessions.py b/graphdatascience/tests/unit/test_aurads_sessions.py index aa0610fa0..487a6dbcd 100644 --- a/graphdatascience/tests/unit/test_aurads_sessions.py +++ b/graphdatascience/tests/unit/test_aurads_sessions.py @@ -37,7 +37,9 @@ def __init__( self._status_after_creating = status_after_creating self._size_estimation = size_estimation or EstimationDetails("1GB", "8GB", False) - def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails: + def create_instance( + self, name: str, memory: SessionMemory, cloud_provider: str, region: str + ) -> InstanceCreateDetails: create_details = InstanceCreateDetails( id=f"ffff{self.id_counter}", username="neo4j", @@ -102,11 +104,13 @@ def aura_api() -> AuraApi: def test_list_session(aura_api: AuraApi) -> None: - aura_api.create_instance(name="gds-session-my-session-name", cloud_provider="gcp", memory="16GB", region="") + aura_api.create_instance( + name="gds-session-my-session-name", cloud_provider="gcp", memory=SessionMemory.m_16GB, region="" + ) sessions = AuraDsSessions(aura_api) - assert sessions.list() == [SessionInfo("my-session-name", "16GB")] + assert sessions.list() == [SessionInfo("my-session-name", SessionMemory.m_16GB)] def test_create_session(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -135,12 +139,12 @@ def assert_db_credentials(*args: List[Any], **kwargs: str) -> None: "gds_url": "fake-url", "session_name": "my-session", } - assert sessions.list() == [SessionInfo("my-session", "384GB")] + assert sessions.list() == [SessionInfo("my-session", SessionMemory.m_384GB)] instance_details: InstanceSpecificDetails = aura_api.list_instance("ffff1") # type: ignore assert instance_details.cloud_provider == "aws" assert instance_details.region == "leipzig-1" - assert instance_details.memory == "384GB" + assert instance_details.memory == SessionMemory.m_384GB def test_create_default_session(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -163,11 +167,11 @@ def test_create_default_session(mocker: MockerFixture, aura_api: AuraApi) -> Non instance_details: InstanceSpecificDetails = aura_api.list_instance("ffff1") # type: ignore assert instance_details.cloud_provider == "aws" assert instance_details.region == "leipzig-1" - assert instance_details.memory == "8GB" + assert instance_details.memory == SessionMemory.m_8GB def test_create_session_override_region(mocker: MockerFixture, aura_api: AuraApi) -> None: - aura_api.create_instance("test", "8GB", "aws", "dresden-2") + aura_api.create_instance("test", SessionMemory.m_8GB, "aws", "dresden-2") sessions = AuraDsSessions(aura_api) @@ -186,7 +190,7 @@ def test_create_session_override_region(mocker: MockerFixture, aura_api: AuraApi instance_details: InstanceSpecificDetails = aura_api.list_instance("ffff1") # type: ignore assert instance_details.cloud_provider == "aws" assert instance_details.region == "leipzig-1" - assert instance_details.memory == "8GB" + assert instance_details.memory == SessionMemory.m_8GB def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -221,7 +225,7 @@ def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None: } assert gds_args1 == gds_args2 - assert sessions.list() == [SessionInfo("my-session", "8GB")] + assert sessions.list() == [SessionInfo("my-session", SessionMemory.m_8GB)] def test_get_or_create_different_size(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -265,7 +269,7 @@ def test_get_or_create_duplicate_session() -> None: cloud_provider="", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB, type="", region="", ), @@ -276,7 +280,7 @@ def test_get_or_create_duplicate_session() -> None: cloud_provider="", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB, type="", region="", ), @@ -298,7 +302,7 @@ def test_delete_session() -> None: cloud_provider="cloud_provider", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB, type="", region="", ), @@ -309,7 +313,7 @@ def test_delete_session() -> None: cloud_provider="cloud_provider", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB, type="", region="", ), @@ -318,7 +322,7 @@ def test_delete_session() -> None: sessions = AuraDsSessions(FakeAuraApi(existing_instances=existing_instances)) assert sessions.delete("one") - assert sessions.list() == [SessionInfo("other", "")] + assert sessions.list() == [SessionInfo("other", SessionMemory.m_8GB)] def test_delete_nonexisting_session() -> None: @@ -330,7 +334,7 @@ def test_delete_nonexisting_session() -> None: cloud_provider="cloud_provider", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB, type="", region="", ), @@ -339,7 +343,7 @@ def test_delete_nonexisting_session() -> None: sessions = AuraDsSessions(FakeAuraApi(existing_instances=existing_instances)) assert sessions.delete("other") is False - assert sessions.list() == [SessionInfo("one", "")] + assert sessions.list() == [SessionInfo("one", SessionMemory.m_8GB)] def test_delete_nonunique_session() -> None: @@ -351,7 +355,7 @@ def test_delete_nonunique_session() -> None: cloud_provider="", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB, type="", region="", ), @@ -362,7 +366,7 @@ def test_delete_nonunique_session() -> None: cloud_provider="", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB, type="", region="", ), @@ -378,7 +382,7 @@ def test_delete_nonunique_session() -> None: ): sessions.delete("one") - assert sessions.list() == [SessionInfo("one", ""), SessionInfo("one", "")] + assert sessions.list() == [SessionInfo("one", SessionMemory.m_8GB), SessionInfo("one", SessionMemory.m_8GB)] def test_create_immediate_delete() -> None: @@ -435,4 +439,4 @@ def test_estimate_size_exceeds() -> None: def _setup_db_instance(aura_api: AuraApi) -> None: - aura_api.create_instance("test", "8GB", "aws", "leipzig-1") + aura_api.create_instance("test", SessionMemory.m_8GB, "aws", "leipzig-1") diff --git a/graphdatascience/tests/unit/test_dedicated_sessions.py b/graphdatascience/tests/unit/test_dedicated_sessions.py index 69653fdd4..4ff5e6af5 100644 --- a/graphdatascience/tests/unit/test_dedicated_sessions.py +++ b/graphdatascience/tests/unit/test_dedicated_sessions.py @@ -43,7 +43,7 @@ def __init__( self._status_after_creating = status_after_creating self._size_estimation = size_estimation or EstimationDetails("1GB", "8GB", False) - def create_session(self, name: str, dbid: str, pwd: str, memory: str) -> SessionDetails: + def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemory) -> SessionDetails: details = SessionDetails( id=f"{dbid}-ffff{self.id_counter}", name=name, @@ -67,7 +67,9 @@ def add_session(self, session: SessionDetails) -> None: self._sessions[session.id] = session - def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails: + def create_instance( + self, name: str, memory: SessionMemory, cloud_provider: str, region: str + ) -> InstanceCreateDetails: id = f"ffff{self.id_counter}" create_details = InstanceCreateDetails( id=id, @@ -168,7 +170,10 @@ def aura_api() -> AuraApi: def test_list_session(aura_api: AuraApi) -> None: _setup_db_instance(aura_api) session = aura_api.create_session( - name="gds-session-my-session-name", dbid=aura_api.list_instances()[0].id, pwd="some_pwd", memory="8GB" + name="gds-session-my-session-name", + dbid=aura_api.list_instances()[0].id, + pwd="some_pwd", + memory=SessionMemory.m_8GB, ) sessions = DedicatedSessions(aura_api) @@ -234,8 +239,8 @@ def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None: def test_get_or_create_duplicate_session(aura_api: AuraApi) -> None: db = _setup_db_instance(aura_api) - aura_api.create_session("one", db.id, "1234", memory="1GB") - aura_api.create_session("one", db.id, "12345", memory="1GB") + aura_api.create_session("one", db.id, "1234", memory=SessionMemory.m_4GB) + aura_api.create_session("one", db.id, "12345", memory=SessionMemory.m_4GB) sessions = DedicatedSessions(aura_api) @@ -252,7 +257,7 @@ def test_get_or_create_expired_session(aura_api: AuraApi) -> None: id="ffff0-ffff1", name="one", instance_id=db.id, - memory=SessionMemory.m_8GB.value, + memory=SessionMemory.m_8GB, status="Expired", created_at=datetime.now(), host="foo.bar", @@ -277,7 +282,7 @@ def test_get_or_create_soon_expired_session(aura_api: AuraApi) -> None: id="ffff0-ffff1", name="one", instance_id=db.id, - memory=SessionMemory.m_8GB.value, + memory=SessionMemory.m_8GB, status="Ready", created_at=datetime.now(), host="foo.bar", @@ -291,11 +296,39 @@ def test_get_or_create_soon_expired_session(aura_api: AuraApi) -> None: sessions.get_or_create("one", SessionMemory.m_8GB, DbmsConnectionInfo(db.connection_url, "", "")) +def test_get_or_create_with_different_memory_config(aura_api: AuraApi) -> None: + db = _setup_db_instance(aura_api) + + fake_aura_api = cast(FakeAuraApi, aura_api) + fake_aura_api.add_session( + SessionDetails( + id="ffff0-ffff1", + name="one", + instance_id=db.id, + memory=SessionMemory.m_8GB, + status="Ready", + created_at=datetime.now(), + host="foo.bar", + expiry_date=None, + ttl=None, + ) + ) + + with pytest.raises( + RuntimeError, + match=re.escape( + "Session `one` exists with a different memory configuration. Current: 8GB, Requested: 16GB." + ), + ): + sessions = DedicatedSessions(aura_api) + sessions.get_or_create("one", SessionMemory.m_16GB, DbmsConnectionInfo(db.connection_url, "", "")) + + def test_delete_session(aura_api: AuraApi) -> None: - db1 = aura_api.create_instance("db1", "1GB", "aura", "leipzig").id - db2 = aura_api.create_instance("db2", "1GB", "aura", "dresden").id - aura_api.create_session("one", db1, "12345", memory="8GB") - aura_api.create_session("other", db2, "123123", memory="8GB") + db1 = aura_api.create_instance("db1", SessionMemory.m_4GB, "aura", "leipzig").id + db2 = aura_api.create_instance("db2", SessionMemory.m_4GB, "aura", "dresden").id + aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB) + aura_api.create_session("other", db2, "123123", memory=SessionMemory.m_8GB) sessions = DedicatedSessions(aura_api) @@ -304,8 +337,8 @@ def test_delete_session(aura_api: AuraApi) -> None: def test_delete_nonexisting_session(aura_api: AuraApi) -> None: - db1 = aura_api.create_instance("db1", "1gb", "aura", "leipzig").id - aura_api.create_session("one", db1, "12345", memory="8GB") + db1 = aura_api.create_instance("db1", SessionMemory.m_4GB, "aura", "leipzig").id + aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB) sessions = DedicatedSessions(aura_api) assert sessions.delete("other") is False @@ -313,9 +346,9 @@ def test_delete_nonexisting_session(aura_api: AuraApi) -> None: def test_delete_nonunique_session(aura_api: AuraApi) -> None: - db1 = aura_api.create_instance("db1", "1GB", "aura", "leipzig").id - aura_api.create_session("one", db1, "12345", memory="8GB") - aura_api.create_session("one", db1, "12345", memory="8GB") + db1 = aura_api.create_instance("db1", SessionMemory.m_4GB, "aura", "leipzig").id + aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB) + aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB) sessions = DedicatedSessions(aura_api) with pytest.raises( @@ -361,7 +394,7 @@ def test_estimate_size_exceeds() -> None: def _setup_db_instance(aura_api: AuraApi) -> InstanceCreateDetails: - return aura_api.create_instance("test", "8GB", "aws", "leipzig-1") + return aura_api.create_instance("test", SessionMemory.m_8GB, "aws", "leipzig-1") def patch_construct_client(mocker: MockerFixture) -> None: diff --git a/graphdatascience/tests/unit/test_session_sizes.py b/graphdatascience/tests/unit/test_session_sizes.py index e89e55bcd..416fa8e52 100644 --- a/graphdatascience/tests/unit/test_session_sizes.py +++ b/graphdatascience/tests/unit/test_session_sizes.py @@ -2,6 +2,6 @@ def test_all_values() -> None: - assert len(SessionMemory.all_values()) == 11 + assert len(SessionMemory.all_values()) == 12 for e in SessionMemory.all_values(): assert e.endswith("GB") From 3938b69868b62a1244903733ddd3da4040064d98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Thu, 13 Jun 2024 16:33:09 +0200 Subject: [PATCH 2/4] Check if a dedicated session exists with a different memory configuration --- graphdatascience/session/dedicated_sessions.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/graphdatascience/session/dedicated_sessions.py b/graphdatascience/session/dedicated_sessions.py index 94260ac08..e29335ddb 100644 --- a/graphdatascience/session/dedicated_sessions.py +++ b/graphdatascience/session/dedicated_sessions.py @@ -52,6 +52,7 @@ def get_or_create( # TODO configure session size (and check existing_session has same size) if existing_session: self._check_expiry_date(existing_session) + self._check_memory_configuration(existing_session, memory) session_id = existing_session.id else: create_details = self._create_session(session_name, dbid, db_connection.uri, password, memory) @@ -118,7 +119,7 @@ def _create_session( name=session_name, dbid=dbid, pwd=pwd, - memory=memory.value, + memory=memory, ) return create_details @@ -139,6 +140,13 @@ def _check_expiry_date(self, session: SessionDetails) -> None: if until_expiry < timedelta(days=1): raise Warning(f"Session `{session.name}` is expiring in less than a day.") + def _check_memory_configuration(self, existing_session: SessionDetails, requested_memory: SessionMemory) -> None: + if existing_session.memory != requested_memory: + raise RuntimeError( + f"Session `{existing_session.name}` exists with a different memory configuration. " + f"Current: {existing_session.memory.value}, Requested: {requested_memory.value}." + ) + @classmethod def _fail_ambiguous_session(cls, session_name: str, sessions: List[SessionDetails]) -> None: candidates = [i.id for i in sessions] From c2c4e484355f717456088a174d41a51b54b15eb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Thu, 13 Jun 2024 17:29:33 +0200 Subject: [PATCH 3/4] Use a data class instead of an enum to represent session memory * We will keep the enum to make the selection easier for users --- graphdatascience/session/aura_api.py | 6 +- .../session/aura_api_responses.py | 10 +-- graphdatascience/session/aurads_sessions.py | 10 +-- .../session/dedicated_sessions.py | 16 +++-- graphdatascience/session/session_info.py | 4 +- graphdatascience/session/session_sizes.py | 69 +++++++++++-------- graphdatascience/tests/unit/test_aura_api.py | 20 +++--- .../tests/unit/test_aurads_sessions.py | 45 ++++++------ .../tests/unit/test_dedicated_sessions.py | 42 ++++++----- .../tests/unit/test_session_sizes.py | 2 +- 10 files changed, 118 insertions(+), 106 deletions(-) diff --git a/graphdatascience/session/aura_api.py b/graphdatascience/session/aura_api.py index cb621ed13..3cd4d449d 100644 --- a/graphdatascience/session/aura_api.py +++ b/graphdatascience/session/aura_api.py @@ -9,7 +9,6 @@ import requests as req from requests import HTTPError -from graphdatascience.session.session_sizes import SessionMemory from graphdatascience.session.algorithm_category import AlgorithmCategory from graphdatascience.session.aura_api_responses import ( EstimationDetails, @@ -20,6 +19,7 @@ TenantDetails, WaitResult, ) +from graphdatascience.session.session_sizes import SessionMemoryValue from graphdatascience.version import __version__ @@ -63,7 +63,7 @@ def extract_id(uri: str) -> str: return host.split(".")[0].split("-")[0] - def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemory) -> SessionDetails: + def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryValue) -> SessionDetails: response = req.post( f"{self._base_uri}/v1beta5/data-science/sessions", headers=self._build_header(), @@ -143,7 +143,7 @@ def delete_session(self, session_id: str, dbid: str) -> bool: return False def create_instance( - self, name: str, memory: SessionMemory, cloud_provider: str, region: str + self, name: str, memory: SessionMemoryValue, cloud_provider: str, region: str ) -> InstanceCreateDetails: tenant_details = self.tenant_details() diff --git a/graphdatascience/session/aura_api_responses.py b/graphdatascience/session/aura_api_responses.py index 77bc785f0..99c1c1156 100644 --- a/graphdatascience/session/aura_api_responses.py +++ b/graphdatascience/session/aura_api_responses.py @@ -9,7 +9,7 @@ from pandas import Timedelta -from .session_sizes import SessionMemory +from .session_sizes import SessionMemoryValue @dataclass(repr=True, frozen=True) @@ -17,7 +17,7 @@ class SessionDetails: id: str name: str instance_id: str - memory: SessionMemory + memory: SessionMemoryValue status: str host: str created_at: datetime @@ -33,7 +33,7 @@ def fromJson(cls, json: Dict[str, Any]) -> SessionDetails: id=json["id"], name=json["name"], instance_id=json["instance_id"], - memory=SessionMemory.fromApiResponse(json["memory"]), + memory=SessionMemoryValue.fromApiResponse(json["memory"]), status=json["status"], host=json["host"], expiry_date=TimeParser.fromisoformat(expiry_date) if expiry_date else None, @@ -69,7 +69,7 @@ def fromJson(cls, json: Dict[str, Any]) -> InstanceDetails: class InstanceSpecificDetails(InstanceDetails): status: str connection_url: str - memory: SessionMemory + memory: SessionMemoryValue type: str region: str @@ -82,7 +82,7 @@ def fromJson(cls, json: Dict[str, Any]) -> InstanceSpecificDetails: cloud_provider=json["cloud_provider"], status=json["status"], connection_url=json.get("connection_url", ""), - memory=SessionMemory.fromApiResponse(json.get("memory", "")), + memory=SessionMemoryValue.fromApiResponse(json.get("memory", "")), type=json["type"], region=json["region"], ) diff --git a/graphdatascience/session/aurads_sessions.py b/graphdatascience/session/aurads_sessions.py index fa3b79166..d98e9f020 100644 --- a/graphdatascience/session/aurads_sessions.py +++ b/graphdatascience/session/aurads_sessions.py @@ -16,7 +16,7 @@ from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo from graphdatascience.session.region_suggester import closest_match from graphdatascience.session.session_info import SessionInfo -from graphdatascience.session.session_sizes import SessionMemory +from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue class AuraDsSessions: @@ -41,7 +41,7 @@ def estimate( ResourceWarning, ) - return SessionMemory(estimation.recommended_size) + return SessionMemory(SessionMemoryValue(estimation.recommended_size)) def get_or_create( self, @@ -54,13 +54,13 @@ def get_or_create( if existing_session: session_id = existing_session.id # 0MB is AuraAPI default value for memory if none can be retrieved - if existing_session.memory != memory: + if existing_session.memory.value != "0MB" and existing_session.memory != memory.value: raise ValueError( f"Session `{session_name}` already exists with memory `{existing_session.memory.value}`. " f"Requested memory `{memory.value}` does not match." ) else: - create_details = self._create_session(session_name, memory, db_connection) + create_details = self._create_session(session_name, memory.value, db_connection) session_id = create_details.id wait_result = self._aura_api.wait_for_instance_running(session_id) @@ -118,7 +118,7 @@ def _find_existing_session(self, session_name: str) -> Optional[InstanceSpecific return self._aura_api.list_instance(matched_instances[0].id) def _create_session( - self, session_name: str, memory: SessionMemory, db_connection: DbmsConnectionInfo + self, session_name: str, memory: SessionMemoryValue, db_connection: DbmsConnectionInfo ) -> InstanceCreateDetails: db_instance_id = AuraApi.extract_id(db_connection.uri) db_instance = self._aura_api.list_instance(db_instance_id) diff --git a/graphdatascience/session/dedicated_sessions.py b/graphdatascience/session/dedicated_sessions.py index e29335ddb..cff902249 100644 --- a/graphdatascience/session/dedicated_sessions.py +++ b/graphdatascience/session/dedicated_sessions.py @@ -11,7 +11,7 @@ from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo from graphdatascience.session.session_info import SessionInfo -from graphdatascience.session.session_sizes import SessionMemory +from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue class DedicatedSessions: @@ -35,7 +35,7 @@ def estimate( ResourceWarning, ) - return SessionMemory(estimation.recommended_size) + return SessionMemory(SessionMemoryValue(estimation.recommended_size)) def get_or_create( self, @@ -52,10 +52,10 @@ def get_or_create( # TODO configure session size (and check existing_session has same size) if existing_session: self._check_expiry_date(existing_session) - self._check_memory_configuration(existing_session, memory) + self._check_memory_configuration(existing_session, memory.value) session_id = existing_session.id else: - create_details = self._create_session(session_name, dbid, db_connection.uri, password, memory) + create_details = self._create_session(session_name, dbid, db_connection.uri, password, memory.value) session_id = create_details.id wait_result = self._aura_api.wait_for_session_running(session_id, dbid) @@ -109,7 +109,7 @@ def _find_existing_session(self, session_name: str, dbid: str) -> Optional[Sessi return matched_sessions[0] def _create_session( - self, session_name: str, dbid: str, dburi: str, pwd: str, memory: SessionMemory + self, session_name: str, dbid: str, dburi: str, pwd: str, memory: SessionMemoryValue ) -> SessionDetails: db_instance = self._aura_api.list_instance(dbid) if not db_instance: @@ -140,11 +140,13 @@ def _check_expiry_date(self, session: SessionDetails) -> None: if until_expiry < timedelta(days=1): raise Warning(f"Session `{session.name}` is expiring in less than a day.") - def _check_memory_configuration(self, existing_session: SessionDetails, requested_memory: SessionMemory) -> None: + def _check_memory_configuration( + self, existing_session: SessionDetails, requested_memory: SessionMemoryValue + ) -> None: if existing_session.memory != requested_memory: raise RuntimeError( f"Session `{existing_session.name}` exists with a different memory configuration. " - f"Current: {existing_session.memory.value}, Requested: {requested_memory.value}." + f"Current: {existing_session.memory}, Requested: {requested_memory}." ) @classmethod diff --git a/graphdatascience/session/session_info.py b/graphdatascience/session/session_info.py index 23fae9165..072cb0dab 100644 --- a/graphdatascience/session/session_info.py +++ b/graphdatascience/session/session_info.py @@ -4,8 +4,8 @@ from datetime import datetime from typing import Optional -from graphdatascience.session.session_sizes import SessionMemory from graphdatascience.session.aura_api_responses import SessionDetails +from graphdatascience.session.session_sizes import SessionMemoryValue @dataclass(frozen=True) @@ -19,7 +19,7 @@ class SessionInfo: """ name: str - memory: SessionMemory + memory: SessionMemoryValue @classmethod def from_session_details(cls, details: SessionDetails) -> ExtendedSessionInfo: diff --git a/graphdatascience/session/session_sizes.py b/graphdatascience/session/session_sizes.py index f372873f3..c7bb1a4ca 100644 --- a/graphdatascience/session/session_sizes.py +++ b/graphdatascience/session/session_sizes.py @@ -1,27 +1,53 @@ +from dataclasses import dataclass from enum import Enum from typing import List +@dataclass(frozen=True) +class SessionMemoryValue: + value: str + + def __str__(self) -> str: + return self.value + + @staticmethod + def fromApiResponse(value: str) -> "SessionMemoryValue": + """ + Converts the string value from an API response to a SessionMemory enumeration value. + + Args: + value: The string value from the API response. + + Returns: + The SessionMemory enumeration value. + + """ + if value == "": + raise ValueError("memory configuration cannot be empty") + + return SessionMemoryValue(value.replace("Gi", "GB")) + + class SessionMemory(Enum): """ Enumeration representing session main memory configurations. """ - m_4GB = "4GB" - m_8GB = "8GB" - m_16GB = "16GB" - m_24GB = "24GB" - m_32GB = "32GB" - m_48GB = "48GB" - m_64GB = "64GB" - m_96GB = "96GB" - m_128GB = "128GB" - m_192GB = "192GB" - m_256GB = "256GB" - m_384GB = "384GB" + m_4GB = SessionMemoryValue("4GB") + m_8GB = SessionMemoryValue("8GB") + m_16GB = SessionMemoryValue("16GB") + m_24GB = SessionMemoryValue("24GB") + m_32GB = SessionMemoryValue("32GB") + m_48GB = SessionMemoryValue("48GB") + m_64GB = SessionMemoryValue("64GB") + m_96GB = SessionMemoryValue("96GB") + m_128GB = SessionMemoryValue("128GB") + m_192GB = SessionMemoryValue("192GB") + m_256GB = SessionMemoryValue("256GB") + m_384GB = SessionMemoryValue("384GB") @classmethod - def all_values(cls) -> List[str]: + def all_values(cls) -> List[SessionMemoryValue]: """ All supported memory configurations. @@ -30,20 +56,3 @@ def all_values(cls) -> List[str]: """ return [e.value for e in cls] - - @staticmethod - def fromApiResponse(value: str) -> "SessionMemory": - """ - Converts the string value from an API response to a SessionMemory enumeration value. - - Args: - value: The string value from the API response. - - Returns: - The SessionMemory enumeration value. - - """ - try: - return SessionMemory(value.replace("Gi", "GB")) - except ValueError: - raise ValueError(f"Unsupported memory configuration: {value}") diff --git a/graphdatascience/tests/unit/test_aura_api.py b/graphdatascience/tests/unit/test_aura_api.py index 87383117c..a20e3294a 100644 --- a/graphdatascience/tests/unit/test_aura_api.py +++ b/graphdatascience/tests/unit/test_aura_api.py @@ -38,7 +38,7 @@ def test_create_session(requests_mock: Mocker) -> None: }, ) - result = api.create_session("name-0", "dbid-1", "pwd-2", SessionMemory.m_4GB) + result = api.create_session("name-0", "dbid-1", "pwd-2", SessionMemory.m_4GB.value) assert result == SessionDetails( id="id0", @@ -47,7 +47,7 @@ def test_create_session(requests_mock: Mocker) -> None: instance_id="dbid-1", created_at=TimeParser.fromisoformat("1970-01-01T00:00:00Z"), host="1.2.3.4", - memory=SessionMemory.m_4GB, + memory=SessionMemory.m_4GB.value, expiry_date=None, ttl=None, ) @@ -82,7 +82,7 @@ def test_list_session(requests_mock: Mocker) -> None: instance_id="dbid-1", created_at=TimeParser.fromisoformat("1970-01-01T00:00:00Z"), host="1.2.3.4", - memory=SessionMemory.m_4GB, + memory=SessionMemory.m_4GB.value, expiry_date=TimeParser.fromisoformat("1977-01-01T00:00:00Z"), ttl=None, ) @@ -126,7 +126,7 @@ def test_list_sessions(requests_mock: Mocker) -> None: instance_id="dbid-1", created_at=TimeParser.fromisoformat("1970-01-01T00:00:00Z"), host="1.2.3.4", - memory=SessionMemory.m_4GB, + memory=SessionMemory.m_4GB.value, expiry_date=TimeParser.fromisoformat("1977-01-01T00:00:00Z"), ttl=None, ) @@ -137,7 +137,7 @@ def test_list_sessions(requests_mock: Mocker) -> None: status="Creating", instance_id="dbid-3", created_at=TimeParser.fromisoformat("2012-01-01T00:00:00Z"), - memory=SessionMemory.m_8GB, + memory=SessionMemory.m_8GB.value, host="foo.bar", expiry_date=None, ttl=None, @@ -270,7 +270,7 @@ def test_delete_instance(requests_mock: Mocker) -> None: cloud_provider="", status="deleting", connection_url="", - memory=SessionMemory.m_4GB, + memory=SessionMemory.m_4GB.value, region="", type="", ) @@ -337,7 +337,7 @@ def test_create_instance(requests_mock: Mocker) -> None: }, ) - api.create_instance("name", SessionMemory.m_16GB, "gcp", "leipzig-1") + api.create_instance("name", SessionMemory.m_16GB.value, "gcp", "leipzig-1") requested_data = requests_mock.request_history[-1].json() assert requested_data["name"] == "name" @@ -446,7 +446,7 @@ def test_list_instance_missing_memory_field(requests_mock: Mocker) -> None: result = api.list_instance("id0") assert result and result.id == "a10fb995" - assert result.memory == SessionMemory.m_16GB + assert result.memory == SessionMemory.m_16GB.value def test_list_missing_instance(requests_mock: Mocker) -> None: @@ -653,7 +653,7 @@ def test_parse_session_info() -> None: assert session_info == SessionDetails( id="test_id", name="test_session", - memory=SessionMemory.m_4GB, + memory=SessionMemory.m_4GB.value, instance_id="test_instance", status="running", host="a.b", @@ -678,7 +678,7 @@ def test_parse_session_info_without_optionals() -> None: assert session_info == SessionDetails( id="test_id", name="test_session", - memory=SessionMemory.m_16GB, + memory=SessionMemory.m_16GB.value, instance_id="test_instance", host="a.b", status="running", diff --git a/graphdatascience/tests/unit/test_aurads_sessions.py b/graphdatascience/tests/unit/test_aurads_sessions.py index 487a6dbcd..40948c8ae 100644 --- a/graphdatascience/tests/unit/test_aurads_sessions.py +++ b/graphdatascience/tests/unit/test_aurads_sessions.py @@ -18,7 +18,7 @@ from graphdatascience.session.aurads_sessions import AuraDsSessions, SessionNameHelper from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo from graphdatascience.session.session_info import SessionInfo -from graphdatascience.session.session_sizes import SessionMemory +from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue class FakeAuraApi(AuraApi): @@ -38,7 +38,7 @@ def __init__( self._size_estimation = size_estimation or EstimationDetails("1GB", "8GB", False) def create_instance( - self, name: str, memory: SessionMemory, cloud_provider: str, region: str + self, name: str, memory: SessionMemoryValue, cloud_provider: str, region: str ) -> InstanceCreateDetails: create_details = InstanceCreateDetails( id=f"ffff{self.id_counter}", @@ -105,12 +105,12 @@ def aura_api() -> AuraApi: def test_list_session(aura_api: AuraApi) -> None: aura_api.create_instance( - name="gds-session-my-session-name", cloud_provider="gcp", memory=SessionMemory.m_16GB, region="" + name="gds-session-my-session-name", cloud_provider="gcp", memory=SessionMemory.m_16GB.value, region="" ) sessions = AuraDsSessions(aura_api) - assert sessions.list() == [SessionInfo("my-session-name", SessionMemory.m_16GB)] + assert sessions.list() == [SessionInfo("my-session-name", SessionMemory.m_16GB.value)] def test_create_session(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -139,12 +139,12 @@ def assert_db_credentials(*args: List[Any], **kwargs: str) -> None: "gds_url": "fake-url", "session_name": "my-session", } - assert sessions.list() == [SessionInfo("my-session", SessionMemory.m_384GB)] + assert sessions.list() == [SessionInfo("my-session", SessionMemory.m_384GB.value)] instance_details: InstanceSpecificDetails = aura_api.list_instance("ffff1") # type: ignore assert instance_details.cloud_provider == "aws" assert instance_details.region == "leipzig-1" - assert instance_details.memory == SessionMemory.m_384GB + assert instance_details.memory == SessionMemory.m_384GB.value def test_create_default_session(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -167,11 +167,11 @@ def test_create_default_session(mocker: MockerFixture, aura_api: AuraApi) -> Non instance_details: InstanceSpecificDetails = aura_api.list_instance("ffff1") # type: ignore assert instance_details.cloud_provider == "aws" assert instance_details.region == "leipzig-1" - assert instance_details.memory == SessionMemory.m_8GB + assert instance_details.memory == SessionMemory.m_8GB.value def test_create_session_override_region(mocker: MockerFixture, aura_api: AuraApi) -> None: - aura_api.create_instance("test", SessionMemory.m_8GB, "aws", "dresden-2") + aura_api.create_instance("test", SessionMemory.m_8GB.value, "aws", "dresden-2") sessions = AuraDsSessions(aura_api) @@ -190,7 +190,7 @@ def test_create_session_override_region(mocker: MockerFixture, aura_api: AuraApi instance_details: InstanceSpecificDetails = aura_api.list_instance("ffff1") # type: ignore assert instance_details.cloud_provider == "aws" assert instance_details.region == "leipzig-1" - assert instance_details.memory == SessionMemory.m_8GB + assert instance_details.memory == SessionMemory.m_8GB.value def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -225,7 +225,7 @@ def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None: } assert gds_args1 == gds_args2 - assert sessions.list() == [SessionInfo("my-session", SessionMemory.m_8GB)] + assert sessions.list() == [SessionInfo("my-session", SessionMemory.m_8GB.value)] def test_get_or_create_different_size(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -269,7 +269,7 @@ def test_get_or_create_duplicate_session() -> None: cloud_provider="", status="RUNNING", connection_url="", - memory=SessionMemory.m_8GB, + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -280,7 +280,7 @@ def test_get_or_create_duplicate_session() -> None: cloud_provider="", status="RUNNING", connection_url="", - memory=SessionMemory.m_8GB, + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -302,7 +302,7 @@ def test_delete_session() -> None: cloud_provider="cloud_provider", status="RUNNING", connection_url="", - memory=SessionMemory.m_8GB, + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -313,7 +313,7 @@ def test_delete_session() -> None: cloud_provider="cloud_provider", status="RUNNING", connection_url="", - memory=SessionMemory.m_8GB, + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -322,7 +322,7 @@ def test_delete_session() -> None: sessions = AuraDsSessions(FakeAuraApi(existing_instances=existing_instances)) assert sessions.delete("one") - assert sessions.list() == [SessionInfo("other", SessionMemory.m_8GB)] + assert sessions.list() == [SessionInfo("other", SessionMemory.m_8GB.value)] def test_delete_nonexisting_session() -> None: @@ -334,7 +334,7 @@ def test_delete_nonexisting_session() -> None: cloud_provider="cloud_provider", status="RUNNING", connection_url="", - memory=SessionMemory.m_8GB, + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -343,7 +343,7 @@ def test_delete_nonexisting_session() -> None: sessions = AuraDsSessions(FakeAuraApi(existing_instances=existing_instances)) assert sessions.delete("other") is False - assert sessions.list() == [SessionInfo("one", SessionMemory.m_8GB)] + assert sessions.list() == [SessionInfo("one", SessionMemory.m_8GB.value)] def test_delete_nonunique_session() -> None: @@ -355,7 +355,7 @@ def test_delete_nonunique_session() -> None: cloud_provider="", status="RUNNING", connection_url="", - memory=SessionMemory.m_8GB, + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -366,7 +366,7 @@ def test_delete_nonunique_session() -> None: cloud_provider="", status="RUNNING", connection_url="", - memory=SessionMemory.m_8GB, + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -382,7 +382,10 @@ def test_delete_nonunique_session() -> None: ): sessions.delete("one") - assert sessions.list() == [SessionInfo("one", SessionMemory.m_8GB), SessionInfo("one", SessionMemory.m_8GB)] + assert sessions.list() == [ + SessionInfo("one", SessionMemory.m_8GB.value), + SessionInfo("one", SessionMemory.m_8GB.value), + ] def test_create_immediate_delete() -> None: @@ -439,4 +442,4 @@ def test_estimate_size_exceeds() -> None: def _setup_db_instance(aura_api: AuraApi) -> None: - aura_api.create_instance("test", SessionMemory.m_8GB, "aws", "leipzig-1") + aura_api.create_instance("test", SessionMemory.m_8GB.value, "aws", "leipzig-1") diff --git a/graphdatascience/tests/unit/test_dedicated_sessions.py b/graphdatascience/tests/unit/test_dedicated_sessions.py index 4ff5e6af5..7b54430a1 100644 --- a/graphdatascience/tests/unit/test_dedicated_sessions.py +++ b/graphdatascience/tests/unit/test_dedicated_sessions.py @@ -20,7 +20,7 @@ from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo from graphdatascience.session.dedicated_sessions import DedicatedSessions from graphdatascience.session.session_info import SessionInfo -from graphdatascience.session.session_sizes import SessionMemory +from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue class FakeAuraApi(AuraApi): @@ -43,7 +43,7 @@ def __init__( self._status_after_creating = status_after_creating self._size_estimation = size_estimation or EstimationDetails("1GB", "8GB", False) - def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemory) -> SessionDetails: + def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryValue) -> SessionDetails: details = SessionDetails( id=f"{dbid}-ffff{self.id_counter}", name=name, @@ -68,7 +68,7 @@ def add_session(self, session: SessionDetails) -> None: self._sessions[session.id] = session def create_instance( - self, name: str, memory: SessionMemory, cloud_provider: str, region: str + self, name: str, memory: SessionMemoryValue, cloud_provider: str, region: str ) -> InstanceCreateDetails: id = f"ffff{self.id_counter}" create_details = InstanceCreateDetails( @@ -173,7 +173,7 @@ def test_list_session(aura_api: AuraApi) -> None: name="gds-session-my-session-name", dbid=aura_api.list_instances()[0].id, pwd="some_pwd", - memory=SessionMemory.m_8GB, + memory=SessionMemory.m_8GB.value, ) sessions = DedicatedSessions(aura_api) @@ -239,8 +239,8 @@ def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None: def test_get_or_create_duplicate_session(aura_api: AuraApi) -> None: db = _setup_db_instance(aura_api) - aura_api.create_session("one", db.id, "1234", memory=SessionMemory.m_4GB) - aura_api.create_session("one", db.id, "12345", memory=SessionMemory.m_4GB) + aura_api.create_session("one", db.id, "1234", memory=SessionMemory.m_4GB.value) + aura_api.create_session("one", db.id, "12345", memory=SessionMemory.m_4GB.value) sessions = DedicatedSessions(aura_api) @@ -257,7 +257,7 @@ def test_get_or_create_expired_session(aura_api: AuraApi) -> None: id="ffff0-ffff1", name="one", instance_id=db.id, - memory=SessionMemory.m_8GB, + memory=SessionMemory.m_8GB.value, status="Expired", created_at=datetime.now(), host="foo.bar", @@ -282,7 +282,7 @@ def test_get_or_create_soon_expired_session(aura_api: AuraApi) -> None: id="ffff0-ffff1", name="one", instance_id=db.id, - memory=SessionMemory.m_8GB, + memory=SessionMemory.m_8GB.value, status="Ready", created_at=datetime.now(), host="foo.bar", @@ -305,7 +305,7 @@ def test_get_or_create_with_different_memory_config(aura_api: AuraApi) -> None: id="ffff0-ffff1", name="one", instance_id=db.id, - memory=SessionMemory.m_8GB, + memory=SessionMemory.m_8GB.value, status="Ready", created_at=datetime.now(), host="foo.bar", @@ -316,19 +316,17 @@ def test_get_or_create_with_different_memory_config(aura_api: AuraApi) -> None: with pytest.raises( RuntimeError, - match=re.escape( - "Session `one` exists with a different memory configuration. Current: 8GB, Requested: 16GB." - ), + match=re.escape("Session `one` exists with a different memory configuration. Current: 8GB, Requested: 16GB."), ): sessions = DedicatedSessions(aura_api) sessions.get_or_create("one", SessionMemory.m_16GB, DbmsConnectionInfo(db.connection_url, "", "")) def test_delete_session(aura_api: AuraApi) -> None: - db1 = aura_api.create_instance("db1", SessionMemory.m_4GB, "aura", "leipzig").id - db2 = aura_api.create_instance("db2", SessionMemory.m_4GB, "aura", "dresden").id - aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB) - aura_api.create_session("other", db2, "123123", memory=SessionMemory.m_8GB) + db1 = aura_api.create_instance("db1", SessionMemory.m_4GB.value, "aura", "leipzig").id + db2 = aura_api.create_instance("db2", SessionMemory.m_4GB.value, "aura", "dresden").id + aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB.value) + aura_api.create_session("other", db2, "123123", memory=SessionMemory.m_8GB.value) sessions = DedicatedSessions(aura_api) @@ -337,8 +335,8 @@ def test_delete_session(aura_api: AuraApi) -> None: def test_delete_nonexisting_session(aura_api: AuraApi) -> None: - db1 = aura_api.create_instance("db1", SessionMemory.m_4GB, "aura", "leipzig").id - aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB) + db1 = aura_api.create_instance("db1", SessionMemory.m_4GB.value, "aura", "leipzig").id + aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB.value) sessions = DedicatedSessions(aura_api) assert sessions.delete("other") is False @@ -346,9 +344,9 @@ def test_delete_nonexisting_session(aura_api: AuraApi) -> None: def test_delete_nonunique_session(aura_api: AuraApi) -> None: - db1 = aura_api.create_instance("db1", SessionMemory.m_4GB, "aura", "leipzig").id - aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB) - aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB) + db1 = aura_api.create_instance("db1", SessionMemory.m_4GB.value, "aura", "leipzig").id + aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB.value) + aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB.value) sessions = DedicatedSessions(aura_api) with pytest.raises( @@ -394,7 +392,7 @@ def test_estimate_size_exceeds() -> None: def _setup_db_instance(aura_api: AuraApi) -> InstanceCreateDetails: - return aura_api.create_instance("test", SessionMemory.m_8GB, "aws", "leipzig-1") + return aura_api.create_instance("test", SessionMemory.m_8GB.value, "aws", "leipzig-1") def patch_construct_client(mocker: MockerFixture) -> None: diff --git a/graphdatascience/tests/unit/test_session_sizes.py b/graphdatascience/tests/unit/test_session_sizes.py index 416fa8e52..b736a1ea5 100644 --- a/graphdatascience/tests/unit/test_session_sizes.py +++ b/graphdatascience/tests/unit/test_session_sizes.py @@ -4,4 +4,4 @@ def test_all_values() -> None: assert len(SessionMemory.all_values()) == 12 for e in SessionMemory.all_values(): - assert e.endswith("GB") + assert e.value.endswith("GB") From 1e0e3c887a91c7f6f8aa27a637627b6da4b97be5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Tue, 18 Jun 2024 10:33:58 +0200 Subject: [PATCH 4/4] Expose public facing type --- graphdatascience/session/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphdatascience/session/__init__.py b/graphdatascience/session/__init__.py index 110aa49bc..7cf3651c3 100644 --- a/graphdatascience/session/__init__.py +++ b/graphdatascience/session/__init__.py @@ -2,7 +2,7 @@ from .dbms_connection_info import DbmsConnectionInfo from .gds_sessions import AuraAPICredentials, GdsSessions from .session_info import SessionInfo -from .session_sizes import SessionMemory +from .session_sizes import SessionMemory, SessionMemoryValue __all__ = [ "GdsSessions", @@ -10,5 +10,6 @@ "DbmsConnectionInfo", "AuraAPICredentials", "SessionMemory", + "SessionMemoryValue", "AlgorithmCategory", ]