Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion graphdatascience/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from .dbms_connection_info import DbmsConnectionInfo
from .gds_sessions import AuraAPICredentials, GdsSessions
from .session_info import SessionInfo
from .session_sizes import SessionMemory
from .session_sizes import SessionMemory, SessionMemoryValue

__all__ = [
"GdsSessions",
"SessionInfo",
"DbmsConnectionInfo",
"AuraAPICredentials",
"SessionMemory",
"SessionMemoryValue",
"AlgorithmCategory",
]
11 changes: 7 additions & 4 deletions graphdatascience/session/aura_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TenantDetails,
WaitResult,
)
from graphdatascience.session.session_sizes import SessionMemoryValue
from graphdatascience.version import __version__


Expand Down Expand Up @@ -62,11 +63,11 @@ def extract_id(uri: str) -> str:

return host.split(".")[0].split("-")[0]

def create_session(self, name: str, dbid: str, pwd: str, memory: str) -> SessionDetails:
def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryValue) -> SessionDetails:
response = req.post(
f"{self._base_uri}/v1beta5/data-science/sessions",
headers=self._build_header(),
json={"name": name, "instance_id": dbid, "password": pwd, "memory": memory},
json={"name": name, "instance_id": dbid, "password": pwd, "memory": memory.value},
)

response.raise_for_status()
Expand Down Expand Up @@ -141,12 +142,14 @@ def delete_session(self, session_id: str, dbid: str) -> bool:

return False

def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails:
def create_instance(
self, name: str, memory: SessionMemoryValue, cloud_provider: str, region: str
) -> InstanceCreateDetails:
tenant_details = self.tenant_details()

data = {
"name": name,
"memory": memory,
"memory": memory.value,
"version": "5",
"region": region,
"type": tenant_details.ds_type,
Expand Down
10 changes: 6 additions & 4 deletions graphdatascience/session/aura_api_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

from pandas import Timedelta

from .session_sizes import SessionMemoryValue


@dataclass(repr=True, frozen=True)
class SessionDetails:
id: str
name: str
instance_id: str
memory: str
memory: SessionMemoryValue
status: str
host: str
created_at: datetime
Expand All @@ -31,7 +33,7 @@ def fromJson(cls, json: Dict[str, Any]) -> SessionDetails:
id=json["id"],
name=json["name"],
instance_id=json["instance_id"],
memory=json["memory"],
memory=SessionMemoryValue.fromApiResponse(json["memory"]),
status=json["status"],
host=json["host"],
expiry_date=TimeParser.fromisoformat(expiry_date) if expiry_date else None,
Expand Down Expand Up @@ -67,7 +69,7 @@ def fromJson(cls, json: Dict[str, Any]) -> InstanceDetails:
class InstanceSpecificDetails(InstanceDetails):
status: str
connection_url: str
memory: str
memory: SessionMemoryValue
type: str
region: str

Expand All @@ -80,7 +82,7 @@ def fromJson(cls, json: Dict[str, Any]) -> InstanceSpecificDetails:
cloud_provider=json["cloud_provider"],
status=json["status"],
connection_url=json.get("connection_url", ""),
memory=json.get("memory", ""),
memory=SessionMemoryValue.fromApiResponse(json.get("memory", "")),
type=json["type"],
region=json["region"],
)
Expand Down
14 changes: 7 additions & 7 deletions graphdatascience/session/aurads_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
from graphdatascience.session.region_suggester import closest_match
from graphdatascience.session.session_info import SessionInfo
from graphdatascience.session.session_sizes import SessionMemory
from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue


class AuraDsSessions:
Expand All @@ -41,7 +41,7 @@ def estimate(
ResourceWarning,
)

return SessionMemory(estimation.recommended_size)
return SessionMemory(SessionMemoryValue(estimation.recommended_size))

def get_or_create(
self,
Expand All @@ -54,13 +54,13 @@ def get_or_create(
if existing_session:
session_id = existing_session.id
# 0MB is AuraAPI default value for memory if none can be retrieved
if existing_session.memory != "0MB" and existing_session.memory != memory.value:
if existing_session.memory.value != "0MB" and existing_session.memory != memory.value:
raise ValueError(
f"Session `{session_name}` already exists with memory `{existing_session.memory}`. "
f"Session `{session_name}` already exists with memory `{existing_session.memory.value}`. "
f"Requested memory `{memory.value}` does not match."
)
else:
create_details = self._create_session(session_name, memory, db_connection)
create_details = self._create_session(session_name, memory.value, db_connection)
session_id = create_details.id

wait_result = self._aura_api.wait_for_instance_running(session_id)
Expand Down Expand Up @@ -118,7 +118,7 @@ def _find_existing_session(self, session_name: str) -> Optional[InstanceSpecific
return self._aura_api.list_instance(matched_instances[0].id)

def _create_session(
self, session_name: str, memory: SessionMemory, db_connection: DbmsConnectionInfo
self, session_name: str, memory: SessionMemoryValue, db_connection: DbmsConnectionInfo
) -> InstanceCreateDetails:
db_instance_id = AuraApi.extract_id(db_connection.uri)
db_instance = self._aura_api.list_instance(db_instance_id)
Expand All @@ -128,7 +128,7 @@ def _create_session(
region = self._ds_region(db_instance.region, db_instance.cloud_provider)

create_details = self._aura_api.create_instance(
SessionNameHelper.instance_name(session_name), memory.value, db_instance.cloud_provider, region
SessionNameHelper.instance_name(session_name), memory, db_instance.cloud_provider, region
)
return create_details

Expand Down
20 changes: 15 additions & 5 deletions graphdatascience/session/dedicated_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
from graphdatascience.session.session_info import SessionInfo
from graphdatascience.session.session_sizes import SessionMemory
from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue


class DedicatedSessions:
Expand All @@ -35,7 +35,7 @@ def estimate(
ResourceWarning,
)

return SessionMemory(estimation.recommended_size)
return SessionMemory(SessionMemoryValue(estimation.recommended_size))

def get_or_create(
self,
Expand All @@ -52,9 +52,10 @@ def get_or_create(
# TODO configure session size (and check existing_session has same size)
if existing_session:
self._check_expiry_date(existing_session)
self._check_memory_configuration(existing_session, memory.value)
session_id = existing_session.id
else:
create_details = self._create_session(session_name, dbid, db_connection.uri, password, memory)
create_details = self._create_session(session_name, dbid, db_connection.uri, password, memory.value)
session_id = create_details.id

wait_result = self._aura_api.wait_for_session_running(session_id, dbid)
Expand Down Expand Up @@ -108,7 +109,7 @@ def _find_existing_session(self, session_name: str, dbid: str) -> Optional[Sessi
return matched_sessions[0]

def _create_session(
self, session_name: str, dbid: str, dburi: str, pwd: str, memory: SessionMemory
self, session_name: str, dbid: str, dburi: str, pwd: str, memory: SessionMemoryValue
) -> SessionDetails:
db_instance = self._aura_api.list_instance(dbid)
if not db_instance:
Expand All @@ -118,7 +119,7 @@ def _create_session(
name=session_name,
dbid=dbid,
pwd=pwd,
memory=memory.value,
memory=memory,
)
return create_details

Expand All @@ -139,6 +140,15 @@ def _check_expiry_date(self, session: SessionDetails) -> None:
if until_expiry < timedelta(days=1):
raise Warning(f"Session `{session.name}` is expiring in less than a day.")

def _check_memory_configuration(
self, existing_session: SessionDetails, requested_memory: SessionMemoryValue
) -> None:
if existing_session.memory != requested_memory:
raise RuntimeError(
f"Session `{existing_session.name}` exists with a different memory configuration. "
f"Current: {existing_session.memory}, Requested: {requested_memory}."
)

@classmethod
def _fail_ambiguous_session(cls, session_name: str, sessions: List[SessionDetails]) -> None:
candidates = [i.id for i in sessions]
Expand Down
3 changes: 2 additions & 1 deletion graphdatascience/session/session_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional

from graphdatascience.session.aura_api_responses import SessionDetails
from graphdatascience.session.session_sizes import SessionMemoryValue


@dataclass(frozen=True)
Expand All @@ -18,7 +19,7 @@ class SessionInfo:
"""

name: str
memory: str
memory: SessionMemoryValue

@classmethod
def from_session_details(cls, details: SessionDetails) -> ExtendedSessionInfo:
Expand Down
51 changes: 39 additions & 12 deletions graphdatascience/session/session_sizes.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,53 @@
from dataclasses import dataclass
from enum import Enum
from typing import List


@dataclass(frozen=True)
class SessionMemoryValue:
value: str

def __str__(self) -> str:
return self.value

@staticmethod
def fromApiResponse(value: str) -> "SessionMemoryValue":
"""
Converts the string value from an API response to a SessionMemory enumeration value.
Args:
value: The string value from the API response.
Returns:
The SessionMemory enumeration value.
"""
if value == "":
raise ValueError("memory configuration cannot be empty")

return SessionMemoryValue(value.replace("Gi", "GB"))


class SessionMemory(Enum):
"""
Enumeration representing session main memory configurations.
"""

m_8GB = "8GB"
m_16GB = "16GB"
m_24GB = "24GB"
m_32GB = "32GB"
m_48GB = "48GB"
m_64GB = "64GB"
m_96GB = "96GB"
m_128GB = "128GB"
m_192GB = "192GB"
m_256GB = "256GB"
m_384GB = "384GB"
m_4GB = SessionMemoryValue("4GB")
m_8GB = SessionMemoryValue("8GB")
m_16GB = SessionMemoryValue("16GB")
m_24GB = SessionMemoryValue("24GB")
m_32GB = SessionMemoryValue("32GB")
m_48GB = SessionMemoryValue("48GB")
m_64GB = SessionMemoryValue("64GB")
m_96GB = SessionMemoryValue("96GB")
m_128GB = SessionMemoryValue("128GB")
m_192GB = SessionMemoryValue("192GB")
m_256GB = SessionMemoryValue("256GB")
m_384GB = SessionMemoryValue("384GB")

@classmethod
def all_values(cls) -> List[str]:
def all_values(cls) -> List[SessionMemoryValue]:
"""
All supported memory configurations.
Expand Down
Loading