diff --git a/graphdatascience/session/__init__.py b/graphdatascience/session/__init__.py index 110aa49bc..7cf3651c3 100644 --- a/graphdatascience/session/__init__.py +++ b/graphdatascience/session/__init__.py @@ -2,7 +2,7 @@ from .dbms_connection_info import DbmsConnectionInfo from .gds_sessions import AuraAPICredentials, GdsSessions from .session_info import SessionInfo -from .session_sizes import SessionMemory +from .session_sizes import SessionMemory, SessionMemoryValue __all__ = [ "GdsSessions", @@ -10,5 +10,6 @@ "DbmsConnectionInfo", "AuraAPICredentials", "SessionMemory", + "SessionMemoryValue", "AlgorithmCategory", ] diff --git a/graphdatascience/session/aura_api.py b/graphdatascience/session/aura_api.py index 4ce34d18e..3cd4d449d 100644 --- a/graphdatascience/session/aura_api.py +++ b/graphdatascience/session/aura_api.py @@ -19,6 +19,7 @@ TenantDetails, WaitResult, ) +from graphdatascience.session.session_sizes import SessionMemoryValue from graphdatascience.version import __version__ @@ -62,11 +63,11 @@ def extract_id(uri: str) -> str: return host.split(".")[0].split("-")[0] - def create_session(self, name: str, dbid: str, pwd: str, memory: str) -> SessionDetails: + def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryValue) -> SessionDetails: response = req.post( f"{self._base_uri}/v1beta5/data-science/sessions", headers=self._build_header(), - json={"name": name, "instance_id": dbid, "password": pwd, "memory": memory}, + json={"name": name, "instance_id": dbid, "password": pwd, "memory": memory.value}, ) response.raise_for_status() @@ -141,12 +142,14 @@ def delete_session(self, session_id: str, dbid: str) -> bool: return False - def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails: + def create_instance( + self, name: str, memory: SessionMemoryValue, cloud_provider: str, region: str + ) -> InstanceCreateDetails: tenant_details = self.tenant_details() data = { "name": name, - "memory": memory, + "memory": memory.value, "version": "5", "region": region, "type": tenant_details.ds_type, diff --git a/graphdatascience/session/aura_api_responses.py b/graphdatascience/session/aura_api_responses.py index cd533a192..99c1c1156 100644 --- a/graphdatascience/session/aura_api_responses.py +++ b/graphdatascience/session/aura_api_responses.py @@ -9,13 +9,15 @@ from pandas import Timedelta +from .session_sizes import SessionMemoryValue + @dataclass(repr=True, frozen=True) class SessionDetails: id: str name: str instance_id: str - memory: str + memory: SessionMemoryValue status: str host: str created_at: datetime @@ -31,7 +33,7 @@ def fromJson(cls, json: Dict[str, Any]) -> SessionDetails: id=json["id"], name=json["name"], instance_id=json["instance_id"], - memory=json["memory"], + memory=SessionMemoryValue.fromApiResponse(json["memory"]), status=json["status"], host=json["host"], expiry_date=TimeParser.fromisoformat(expiry_date) if expiry_date else None, @@ -67,7 +69,7 @@ def fromJson(cls, json: Dict[str, Any]) -> InstanceDetails: class InstanceSpecificDetails(InstanceDetails): status: str connection_url: str - memory: str + memory: SessionMemoryValue type: str region: str @@ -80,7 +82,7 @@ def fromJson(cls, json: Dict[str, Any]) -> InstanceSpecificDetails: cloud_provider=json["cloud_provider"], status=json["status"], connection_url=json.get("connection_url", ""), - memory=json.get("memory", ""), + memory=SessionMemoryValue.fromApiResponse(json.get("memory", "")), type=json["type"], region=json["region"], ) diff --git a/graphdatascience/session/aurads_sessions.py b/graphdatascience/session/aurads_sessions.py index c54ec34a8..d98e9f020 100644 --- a/graphdatascience/session/aurads_sessions.py +++ b/graphdatascience/session/aurads_sessions.py @@ -16,7 +16,7 @@ from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo from graphdatascience.session.region_suggester import closest_match from graphdatascience.session.session_info import SessionInfo -from graphdatascience.session.session_sizes import SessionMemory +from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue class AuraDsSessions: @@ -41,7 +41,7 @@ def estimate( ResourceWarning, ) - return SessionMemory(estimation.recommended_size) + return SessionMemory(SessionMemoryValue(estimation.recommended_size)) def get_or_create( self, @@ -54,13 +54,13 @@ def get_or_create( if existing_session: session_id = existing_session.id # 0MB is AuraAPI default value for memory if none can be retrieved - if existing_session.memory != "0MB" and existing_session.memory != memory.value: + if existing_session.memory.value != "0MB" and existing_session.memory != memory.value: raise ValueError( - f"Session `{session_name}` already exists with memory `{existing_session.memory}`. " + f"Session `{session_name}` already exists with memory `{existing_session.memory.value}`. " f"Requested memory `{memory.value}` does not match." ) else: - create_details = self._create_session(session_name, memory, db_connection) + create_details = self._create_session(session_name, memory.value, db_connection) session_id = create_details.id wait_result = self._aura_api.wait_for_instance_running(session_id) @@ -118,7 +118,7 @@ def _find_existing_session(self, session_name: str) -> Optional[InstanceSpecific return self._aura_api.list_instance(matched_instances[0].id) def _create_session( - self, session_name: str, memory: SessionMemory, db_connection: DbmsConnectionInfo + self, session_name: str, memory: SessionMemoryValue, db_connection: DbmsConnectionInfo ) -> InstanceCreateDetails: db_instance_id = AuraApi.extract_id(db_connection.uri) db_instance = self._aura_api.list_instance(db_instance_id) @@ -128,7 +128,7 @@ def _create_session( region = self._ds_region(db_instance.region, db_instance.cloud_provider) create_details = self._aura_api.create_instance( - SessionNameHelper.instance_name(session_name), memory.value, db_instance.cloud_provider, region + SessionNameHelper.instance_name(session_name), memory, db_instance.cloud_provider, region ) return create_details diff --git a/graphdatascience/session/dedicated_sessions.py b/graphdatascience/session/dedicated_sessions.py index 94260ac08..cff902249 100644 --- a/graphdatascience/session/dedicated_sessions.py +++ b/graphdatascience/session/dedicated_sessions.py @@ -11,7 +11,7 @@ from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo from graphdatascience.session.session_info import SessionInfo -from graphdatascience.session.session_sizes import SessionMemory +from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue class DedicatedSessions: @@ -35,7 +35,7 @@ def estimate( ResourceWarning, ) - return SessionMemory(estimation.recommended_size) + return SessionMemory(SessionMemoryValue(estimation.recommended_size)) def get_or_create( self, @@ -52,9 +52,10 @@ def get_or_create( # TODO configure session size (and check existing_session has same size) if existing_session: self._check_expiry_date(existing_session) + self._check_memory_configuration(existing_session, memory.value) session_id = existing_session.id else: - create_details = self._create_session(session_name, dbid, db_connection.uri, password, memory) + create_details = self._create_session(session_name, dbid, db_connection.uri, password, memory.value) session_id = create_details.id wait_result = self._aura_api.wait_for_session_running(session_id, dbid) @@ -108,7 +109,7 @@ def _find_existing_session(self, session_name: str, dbid: str) -> Optional[Sessi return matched_sessions[0] def _create_session( - self, session_name: str, dbid: str, dburi: str, pwd: str, memory: SessionMemory + self, session_name: str, dbid: str, dburi: str, pwd: str, memory: SessionMemoryValue ) -> SessionDetails: db_instance = self._aura_api.list_instance(dbid) if not db_instance: @@ -118,7 +119,7 @@ def _create_session( name=session_name, dbid=dbid, pwd=pwd, - memory=memory.value, + memory=memory, ) return create_details @@ -139,6 +140,15 @@ def _check_expiry_date(self, session: SessionDetails) -> None: if until_expiry < timedelta(days=1): raise Warning(f"Session `{session.name}` is expiring in less than a day.") + def _check_memory_configuration( + self, existing_session: SessionDetails, requested_memory: SessionMemoryValue + ) -> None: + if existing_session.memory != requested_memory: + raise RuntimeError( + f"Session `{existing_session.name}` exists with a different memory configuration. " + f"Current: {existing_session.memory}, Requested: {requested_memory}." + ) + @classmethod def _fail_ambiguous_session(cls, session_name: str, sessions: List[SessionDetails]) -> None: candidates = [i.id for i in sessions] diff --git a/graphdatascience/session/session_info.py b/graphdatascience/session/session_info.py index 50d675764..072cb0dab 100644 --- a/graphdatascience/session/session_info.py +++ b/graphdatascience/session/session_info.py @@ -5,6 +5,7 @@ from typing import Optional from graphdatascience.session.aura_api_responses import SessionDetails +from graphdatascience.session.session_sizes import SessionMemoryValue @dataclass(frozen=True) @@ -18,7 +19,7 @@ class SessionInfo: """ name: str - memory: str + memory: SessionMemoryValue @classmethod def from_session_details(cls, details: SessionDetails) -> ExtendedSessionInfo: diff --git a/graphdatascience/session/session_sizes.py b/graphdatascience/session/session_sizes.py index 2967eb7ad..c7bb1a4ca 100644 --- a/graphdatascience/session/session_sizes.py +++ b/graphdatascience/session/session_sizes.py @@ -1,26 +1,53 @@ +from dataclasses import dataclass from enum import Enum from typing import List +@dataclass(frozen=True) +class SessionMemoryValue: + value: str + + def __str__(self) -> str: + return self.value + + @staticmethod + def fromApiResponse(value: str) -> "SessionMemoryValue": + """ + Converts the string value from an API response to a SessionMemory enumeration value. + + Args: + value: The string value from the API response. + + Returns: + The SessionMemory enumeration value. + + """ + if value == "": + raise ValueError("memory configuration cannot be empty") + + return SessionMemoryValue(value.replace("Gi", "GB")) + + class SessionMemory(Enum): """ Enumeration representing session main memory configurations. """ - m_8GB = "8GB" - m_16GB = "16GB" - m_24GB = "24GB" - m_32GB = "32GB" - m_48GB = "48GB" - m_64GB = "64GB" - m_96GB = "96GB" - m_128GB = "128GB" - m_192GB = "192GB" - m_256GB = "256GB" - m_384GB = "384GB" + m_4GB = SessionMemoryValue("4GB") + m_8GB = SessionMemoryValue("8GB") + m_16GB = SessionMemoryValue("16GB") + m_24GB = SessionMemoryValue("24GB") + m_32GB = SessionMemoryValue("32GB") + m_48GB = SessionMemoryValue("48GB") + m_64GB = SessionMemoryValue("64GB") + m_96GB = SessionMemoryValue("96GB") + m_128GB = SessionMemoryValue("128GB") + m_192GB = SessionMemoryValue("192GB") + m_256GB = SessionMemoryValue("256GB") + m_384GB = SessionMemoryValue("384GB") @classmethod - def all_values(cls) -> List[str]: + def all_values(cls) -> List[SessionMemoryValue]: """ All supported memory configurations. diff --git a/graphdatascience/tests/unit/test_aura_api.py b/graphdatascience/tests/unit/test_aura_api.py index 5b87816ca..a20e3294a 100644 --- a/graphdatascience/tests/unit/test_aura_api.py +++ b/graphdatascience/tests/unit/test_aura_api.py @@ -6,6 +6,7 @@ from requests import HTTPError from requests_mock import Mocker +from graphdatascience.session import SessionMemory from graphdatascience.session.algorithm_category import AlgorithmCategory from graphdatascience.session.aura_api import AuraApi from graphdatascience.session.aura_api_responses import ( @@ -33,11 +34,11 @@ def test_create_session(requests_mock: Mocker) -> None: "instance_id": "dbid-1", "created_at": "1970-01-01T00:00:00Z", "host": "1.2.3.4", - "memory": "4G", + "memory": "4Gi", }, ) - result = api.create_session("name-0", "dbid-1", "pwd-2", "4G") + result = api.create_session("name-0", "dbid-1", "pwd-2", SessionMemory.m_4GB.value) assert result == SessionDetails( id="id0", @@ -46,7 +47,7 @@ def test_create_session(requests_mock: Mocker) -> None: instance_id="dbid-1", created_at=TimeParser.fromisoformat("1970-01-01T00:00:00Z"), host="1.2.3.4", - memory="4G", + memory=SessionMemory.m_4GB.value, expiry_date=None, ttl=None, ) @@ -67,7 +68,7 @@ def test_list_session(requests_mock: Mocker) -> None: "instance_id": "dbid-1", "created_at": "1970-01-01T00:00:00Z", "host": "1.2.3.4", - "memory": "4G", + "memory": "4Gi", "expiry_date": "1977-01-01T00:00:00Z", }, ) @@ -81,7 +82,7 @@ def test_list_session(requests_mock: Mocker) -> None: instance_id="dbid-1", created_at=TimeParser.fromisoformat("1970-01-01T00:00:00Z"), host="1.2.3.4", - memory="4G", + memory=SessionMemory.m_4GB.value, expiry_date=TimeParser.fromisoformat("1977-01-01T00:00:00Z"), ttl=None, ) @@ -101,7 +102,7 @@ def test_list_sessions(requests_mock: Mocker) -> None: "instance_id": "dbid-1", "created_at": "1970-01-01T00:00:00Z", "host": "1.2.3.4", - "memory": "4G", + "memory": "4Gi", "expiry_date": "1977-01-01T00:00:00Z", }, { @@ -110,7 +111,7 @@ def test_list_sessions(requests_mock: Mocker) -> None: "status": "Creating", "instance_id": "dbid-3", "created_at": "2012-01-01T00:00:00Z", - "memory": "8G", + "memory": "8Gi", "host": "foo.bar", }, ], @@ -125,7 +126,7 @@ def test_list_sessions(requests_mock: Mocker) -> None: instance_id="dbid-1", created_at=TimeParser.fromisoformat("1970-01-01T00:00:00Z"), host="1.2.3.4", - memory="4G", + memory=SessionMemory.m_4GB.value, expiry_date=TimeParser.fromisoformat("1977-01-01T00:00:00Z"), ttl=None, ) @@ -136,7 +137,7 @@ def test_list_sessions(requests_mock: Mocker) -> None: status="Creating", instance_id="dbid-3", created_at=TimeParser.fromisoformat("2012-01-01T00:00:00Z"), - memory="8G", + memory=SessionMemory.m_8GB.value, host="foo.bar", expiry_date=None, ttl=None, @@ -200,7 +201,7 @@ def test_dont_wait_forever_for_session(requests_mock: Mocker, caplog: LogCapture "instance_id": "dbid-1", "created_at": "1970-01-01T00:00:00Z", "host": "foo.bar", - "memory": "4G", + "memory": "4Gi", "expiry_date": "1977-01-01T00:00:00Z", }, ) @@ -227,7 +228,7 @@ def test_wait_for_session_running(requests_mock: Mocker) -> None: "instance_id": "dbid-1", "created_at": "1970-01-01T00:00:00Z", "host": "foo.bar", - "memory": "4G", + "memory": "4Gi", "expiry_date": "1977-01-01T00:00:00Z", }, ) @@ -253,7 +254,7 @@ def test_delete_instance(requests_mock: Mocker) -> None: "connection_url": "", "tenant_id": "", "cloud_provider": "", - "memory": "", + "memory": "4Gi", "region": "", "type": "", } @@ -269,7 +270,7 @@ def test_delete_instance(requests_mock: Mocker) -> None: cloud_provider="", status="deleting", connection_url="", - memory="", + memory=SessionMemory.m_4GB.value, region="", type="", ) @@ -336,7 +337,7 @@ def test_create_instance(requests_mock: Mocker) -> None: }, ) - api.create_instance("name", "16GB", "gcp", "leipzig-1") + api.create_instance("name", SessionMemory.m_16GB.value, "gcp", "leipzig-1") requested_data = requests_mock.request_history[-1].json() assert requested_data["name"] == "name" @@ -437,6 +438,7 @@ def test_list_instance_missing_memory_field(requests_mock: Mocker) -> None: "status": "creating", "tenant_id": "046046d1-6996-53e4-8880-5b822766e1f9", "type": "enterprise-ds", + "memory": "16Gi", } }, ) @@ -444,7 +446,7 @@ def test_list_instance_missing_memory_field(requests_mock: Mocker) -> None: result = api.list_instance("id0") assert result and result.id == "a10fb995" - assert result.memory == "" + assert result.memory == SessionMemory.m_16GB.value def test_list_missing_instance(requests_mock: Mocker) -> None: @@ -483,6 +485,7 @@ def test_dont_wait_forever(requests_mock: Mocker, caplog: LogCaptureFixture) -> "region": None, "tenant_id": None, "type": None, + "memory": "4Gi", } }, ) @@ -512,6 +515,7 @@ def test_wait_for_instance_running(requests_mock: Mocker) -> None: "region": None, "tenant_id": None, "type": None, + "memory": "4Gi", } }, ) @@ -535,6 +539,7 @@ def test_wait_for_instance_deleting(requests_mock: Mocker) -> None: "region": None, "tenant_id": None, "type": None, + "memory": "4Gi", } }, ) @@ -550,6 +555,7 @@ def test_wait_for_instance_deleting(requests_mock: Mocker) -> None: "region": None, "tenant_id": None, "type": None, + "memory": "4Gi", } }, ) @@ -634,7 +640,7 @@ def test_parse_session_info() -> None: session_details = { "id": "test_id", "name": "test_session", - "memory": "small", + "memory": "4Gi", "instance_id": "test_instance", "status": "running", "expiry_date": "2022-01-01T00:00:00Z", @@ -647,7 +653,7 @@ def test_parse_session_info() -> None: assert session_info == SessionDetails( id="test_id", name="test_session", - memory="small", + memory=SessionMemory.m_4GB.value, instance_id="test_instance", status="running", host="a.b", @@ -661,7 +667,7 @@ def test_parse_session_info_without_optionals() -> None: session_details = { "id": "test_id", "name": "test_session", - "memory": "small", + "memory": "16Gi", "instance_id": "test_instance", "status": "running", "host": "a.b", @@ -672,7 +678,7 @@ def test_parse_session_info_without_optionals() -> None: assert session_info == SessionDetails( id="test_id", name="test_session", - memory="small", + memory=SessionMemory.m_16GB.value, instance_id="test_instance", host="a.b", status="running", diff --git a/graphdatascience/tests/unit/test_aurads_sessions.py b/graphdatascience/tests/unit/test_aurads_sessions.py index aa0610fa0..40948c8ae 100644 --- a/graphdatascience/tests/unit/test_aurads_sessions.py +++ b/graphdatascience/tests/unit/test_aurads_sessions.py @@ -18,7 +18,7 @@ from graphdatascience.session.aurads_sessions import AuraDsSessions, SessionNameHelper from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo from graphdatascience.session.session_info import SessionInfo -from graphdatascience.session.session_sizes import SessionMemory +from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue class FakeAuraApi(AuraApi): @@ -37,7 +37,9 @@ def __init__( self._status_after_creating = status_after_creating self._size_estimation = size_estimation or EstimationDetails("1GB", "8GB", False) - def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails: + def create_instance( + self, name: str, memory: SessionMemoryValue, cloud_provider: str, region: str + ) -> InstanceCreateDetails: create_details = InstanceCreateDetails( id=f"ffff{self.id_counter}", username="neo4j", @@ -102,11 +104,13 @@ def aura_api() -> AuraApi: def test_list_session(aura_api: AuraApi) -> None: - aura_api.create_instance(name="gds-session-my-session-name", cloud_provider="gcp", memory="16GB", region="") + aura_api.create_instance( + name="gds-session-my-session-name", cloud_provider="gcp", memory=SessionMemory.m_16GB.value, region="" + ) sessions = AuraDsSessions(aura_api) - assert sessions.list() == [SessionInfo("my-session-name", "16GB")] + assert sessions.list() == [SessionInfo("my-session-name", SessionMemory.m_16GB.value)] def test_create_session(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -135,12 +139,12 @@ def assert_db_credentials(*args: List[Any], **kwargs: str) -> None: "gds_url": "fake-url", "session_name": "my-session", } - assert sessions.list() == [SessionInfo("my-session", "384GB")] + assert sessions.list() == [SessionInfo("my-session", SessionMemory.m_384GB.value)] instance_details: InstanceSpecificDetails = aura_api.list_instance("ffff1") # type: ignore assert instance_details.cloud_provider == "aws" assert instance_details.region == "leipzig-1" - assert instance_details.memory == "384GB" + assert instance_details.memory == SessionMemory.m_384GB.value def test_create_default_session(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -163,11 +167,11 @@ def test_create_default_session(mocker: MockerFixture, aura_api: AuraApi) -> Non instance_details: InstanceSpecificDetails = aura_api.list_instance("ffff1") # type: ignore assert instance_details.cloud_provider == "aws" assert instance_details.region == "leipzig-1" - assert instance_details.memory == "8GB" + assert instance_details.memory == SessionMemory.m_8GB.value def test_create_session_override_region(mocker: MockerFixture, aura_api: AuraApi) -> None: - aura_api.create_instance("test", "8GB", "aws", "dresden-2") + aura_api.create_instance("test", SessionMemory.m_8GB.value, "aws", "dresden-2") sessions = AuraDsSessions(aura_api) @@ -186,7 +190,7 @@ def test_create_session_override_region(mocker: MockerFixture, aura_api: AuraApi instance_details: InstanceSpecificDetails = aura_api.list_instance("ffff1") # type: ignore assert instance_details.cloud_provider == "aws" assert instance_details.region == "leipzig-1" - assert instance_details.memory == "8GB" + assert instance_details.memory == SessionMemory.m_8GB.value def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -221,7 +225,7 @@ def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None: } assert gds_args1 == gds_args2 - assert sessions.list() == [SessionInfo("my-session", "8GB")] + assert sessions.list() == [SessionInfo("my-session", SessionMemory.m_8GB.value)] def test_get_or_create_different_size(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -265,7 +269,7 @@ def test_get_or_create_duplicate_session() -> None: cloud_provider="", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -276,7 +280,7 @@ def test_get_or_create_duplicate_session() -> None: cloud_provider="", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -298,7 +302,7 @@ def test_delete_session() -> None: cloud_provider="cloud_provider", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -309,7 +313,7 @@ def test_delete_session() -> None: cloud_provider="cloud_provider", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -318,7 +322,7 @@ def test_delete_session() -> None: sessions = AuraDsSessions(FakeAuraApi(existing_instances=existing_instances)) assert sessions.delete("one") - assert sessions.list() == [SessionInfo("other", "")] + assert sessions.list() == [SessionInfo("other", SessionMemory.m_8GB.value)] def test_delete_nonexisting_session() -> None: @@ -330,7 +334,7 @@ def test_delete_nonexisting_session() -> None: cloud_provider="cloud_provider", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -339,7 +343,7 @@ def test_delete_nonexisting_session() -> None: sessions = AuraDsSessions(FakeAuraApi(existing_instances=existing_instances)) assert sessions.delete("other") is False - assert sessions.list() == [SessionInfo("one", "")] + assert sessions.list() == [SessionInfo("one", SessionMemory.m_8GB.value)] def test_delete_nonunique_session() -> None: @@ -351,7 +355,7 @@ def test_delete_nonunique_session() -> None: cloud_provider="", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -362,7 +366,7 @@ def test_delete_nonunique_session() -> None: cloud_provider="", status="RUNNING", connection_url="", - memory="", + memory=SessionMemory.m_8GB.value, type="", region="", ), @@ -378,7 +382,10 @@ def test_delete_nonunique_session() -> None: ): sessions.delete("one") - assert sessions.list() == [SessionInfo("one", ""), SessionInfo("one", "")] + assert sessions.list() == [ + SessionInfo("one", SessionMemory.m_8GB.value), + SessionInfo("one", SessionMemory.m_8GB.value), + ] def test_create_immediate_delete() -> None: @@ -435,4 +442,4 @@ def test_estimate_size_exceeds() -> None: def _setup_db_instance(aura_api: AuraApi) -> None: - aura_api.create_instance("test", "8GB", "aws", "leipzig-1") + aura_api.create_instance("test", SessionMemory.m_8GB.value, "aws", "leipzig-1") diff --git a/graphdatascience/tests/unit/test_dedicated_sessions.py b/graphdatascience/tests/unit/test_dedicated_sessions.py index 69653fdd4..7b54430a1 100644 --- a/graphdatascience/tests/unit/test_dedicated_sessions.py +++ b/graphdatascience/tests/unit/test_dedicated_sessions.py @@ -20,7 +20,7 @@ from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo from graphdatascience.session.dedicated_sessions import DedicatedSessions from graphdatascience.session.session_info import SessionInfo -from graphdatascience.session.session_sizes import SessionMemory +from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue class FakeAuraApi(AuraApi): @@ -43,7 +43,7 @@ def __init__( self._status_after_creating = status_after_creating self._size_estimation = size_estimation or EstimationDetails("1GB", "8GB", False) - def create_session(self, name: str, dbid: str, pwd: str, memory: str) -> SessionDetails: + def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryValue) -> SessionDetails: details = SessionDetails( id=f"{dbid}-ffff{self.id_counter}", name=name, @@ -67,7 +67,9 @@ def add_session(self, session: SessionDetails) -> None: self._sessions[session.id] = session - def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails: + def create_instance( + self, name: str, memory: SessionMemoryValue, cloud_provider: str, region: str + ) -> InstanceCreateDetails: id = f"ffff{self.id_counter}" create_details = InstanceCreateDetails( id=id, @@ -168,7 +170,10 @@ def aura_api() -> AuraApi: def test_list_session(aura_api: AuraApi) -> None: _setup_db_instance(aura_api) session = aura_api.create_session( - name="gds-session-my-session-name", dbid=aura_api.list_instances()[0].id, pwd="some_pwd", memory="8GB" + name="gds-session-my-session-name", + dbid=aura_api.list_instances()[0].id, + pwd="some_pwd", + memory=SessionMemory.m_8GB.value, ) sessions = DedicatedSessions(aura_api) @@ -234,8 +239,8 @@ def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None: def test_get_or_create_duplicate_session(aura_api: AuraApi) -> None: db = _setup_db_instance(aura_api) - aura_api.create_session("one", db.id, "1234", memory="1GB") - aura_api.create_session("one", db.id, "12345", memory="1GB") + aura_api.create_session("one", db.id, "1234", memory=SessionMemory.m_4GB.value) + aura_api.create_session("one", db.id, "12345", memory=SessionMemory.m_4GB.value) sessions = DedicatedSessions(aura_api) @@ -291,11 +296,37 @@ def test_get_or_create_soon_expired_session(aura_api: AuraApi) -> None: sessions.get_or_create("one", SessionMemory.m_8GB, DbmsConnectionInfo(db.connection_url, "", "")) +def test_get_or_create_with_different_memory_config(aura_api: AuraApi) -> None: + db = _setup_db_instance(aura_api) + + fake_aura_api = cast(FakeAuraApi, aura_api) + fake_aura_api.add_session( + SessionDetails( + id="ffff0-ffff1", + name="one", + instance_id=db.id, + memory=SessionMemory.m_8GB.value, + status="Ready", + created_at=datetime.now(), + host="foo.bar", + expiry_date=None, + ttl=None, + ) + ) + + with pytest.raises( + RuntimeError, + match=re.escape("Session `one` exists with a different memory configuration. Current: 8GB, Requested: 16GB."), + ): + sessions = DedicatedSessions(aura_api) + sessions.get_or_create("one", SessionMemory.m_16GB, DbmsConnectionInfo(db.connection_url, "", "")) + + def test_delete_session(aura_api: AuraApi) -> None: - db1 = aura_api.create_instance("db1", "1GB", "aura", "leipzig").id - db2 = aura_api.create_instance("db2", "1GB", "aura", "dresden").id - aura_api.create_session("one", db1, "12345", memory="8GB") - aura_api.create_session("other", db2, "123123", memory="8GB") + db1 = aura_api.create_instance("db1", SessionMemory.m_4GB.value, "aura", "leipzig").id + db2 = aura_api.create_instance("db2", SessionMemory.m_4GB.value, "aura", "dresden").id + aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB.value) + aura_api.create_session("other", db2, "123123", memory=SessionMemory.m_8GB.value) sessions = DedicatedSessions(aura_api) @@ -304,8 +335,8 @@ def test_delete_session(aura_api: AuraApi) -> None: def test_delete_nonexisting_session(aura_api: AuraApi) -> None: - db1 = aura_api.create_instance("db1", "1gb", "aura", "leipzig").id - aura_api.create_session("one", db1, "12345", memory="8GB") + db1 = aura_api.create_instance("db1", SessionMemory.m_4GB.value, "aura", "leipzig").id + aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB.value) sessions = DedicatedSessions(aura_api) assert sessions.delete("other") is False @@ -313,9 +344,9 @@ def test_delete_nonexisting_session(aura_api: AuraApi) -> None: def test_delete_nonunique_session(aura_api: AuraApi) -> None: - db1 = aura_api.create_instance("db1", "1GB", "aura", "leipzig").id - aura_api.create_session("one", db1, "12345", memory="8GB") - aura_api.create_session("one", db1, "12345", memory="8GB") + db1 = aura_api.create_instance("db1", SessionMemory.m_4GB.value, "aura", "leipzig").id + aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB.value) + aura_api.create_session("one", db1, "12345", memory=SessionMemory.m_8GB.value) sessions = DedicatedSessions(aura_api) with pytest.raises( @@ -361,7 +392,7 @@ def test_estimate_size_exceeds() -> None: def _setup_db_instance(aura_api: AuraApi) -> InstanceCreateDetails: - return aura_api.create_instance("test", "8GB", "aws", "leipzig-1") + return aura_api.create_instance("test", SessionMemory.m_8GB.value, "aws", "leipzig-1") def patch_construct_client(mocker: MockerFixture) -> None: diff --git a/graphdatascience/tests/unit/test_session_sizes.py b/graphdatascience/tests/unit/test_session_sizes.py index e89e55bcd..b736a1ea5 100644 --- a/graphdatascience/tests/unit/test_session_sizes.py +++ b/graphdatascience/tests/unit/test_session_sizes.py @@ -2,6 +2,6 @@ def test_all_values() -> None: - assert len(SessionMemory.all_values()) == 11 + assert len(SessionMemory.all_values()) == 12 for e in SessionMemory.all_values(): - assert e.endswith("GB") + assert e.value.endswith("GB")