Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions graphdatascience/session/aura_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import math
import time
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -165,7 +166,7 @@ def wait_for_session_running(
session_id: str,
sleep_time: float = 0.2,
max_sleep_time: float = 10,
max_wait_time: float = 300,
max_wait_time: float = math.inf,
) -> WaitResult:
waited_time = 0.0
while waited_time < max_wait_time:
Expand All @@ -186,7 +187,12 @@ def wait_for_session_running(
time.sleep(sleep_time)
sleep_time = min(sleep_time * 2, max_sleep_time, max_wait_time - waited_time)

return WaitResult.from_error(f"Session `{session_id}` is not running after {waited_time} seconds")
return WaitResult.from_error(
f"Session `{session_id}` is not running after {waited_time} seconds.\n"
"\tThe session may become available at a later time.\n"
f'\tConsider running `sessions.delete(session_id="{session_id}")` '
"to avoid resource leakage."
)

def delete_session(self, session_id: str) -> bool:
response = self._request_session.delete(
Expand Down
27 changes: 22 additions & 5 deletions graphdatascience/session/dedicated_sessions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import hashlib
import math
import warnings
from datetime import datetime, timedelta, timezone
from typing import Optional
Expand All @@ -24,7 +25,10 @@ def __init__(self, aura_api: AuraApi) -> None:
self._aura_api = aura_api

def estimate(
self, node_count: int, relationship_count: int, algorithm_categories: Optional[list[AlgorithmCategory]] = None
self,
node_count: int,
relationship_count: int,
algorithm_categories: Optional[list[AlgorithmCategory]] = None,
) -> SessionMemory:
if algorithm_categories is None:
algorithm_categories = []
Expand Down Expand Up @@ -56,6 +60,7 @@ def get_or_create(
db_connection: DbmsConnectionInfo,
ttl: Optional[timedelta] = None,
cloud_location: Optional[CloudLocation] = None,
timeout: Optional[int] = None,
) -> AuraGraphDataScience:
db_runner = Neo4jQueryRunner.create_for_db(
endpoint=db_connection.uri,
Expand Down Expand Up @@ -83,7 +88,8 @@ def get_or_create(

connection_url = session_details.bolt_connection_url()
if session_details.status != "Ready":
wait_result = self._aura_api.wait_for_session_running(session_id)
max_wait_time = float(timeout) if timeout is not None else math.inf
wait_result = self._aura_api.wait_for_session_running(session_id, max_wait_time=max_wait_time)
if err := wait_result.error:
raise RuntimeError(f"Failed to get or create session `{session_name}`: {err}")

Expand All @@ -93,7 +99,11 @@ def get_or_create(
password=password,
)

return self._construct_client(session_id=session_id, session_connection=session_connection, db_runner=db_runner)
return self._construct_client(
session_id=session_id,
session_connection=session_connection,
db_runner=db_runner,
)

def delete(self, *, session_name: Optional[str] = None, session_id: Optional[str] = None) -> bool:
if not session_name and not session_id:
Expand Down Expand Up @@ -160,13 +170,20 @@ def _get_or_create_session(
# If cloud location is provided we go for self managed DBs path
if cloud_location:
return self._aura_api.get_or_create_session(
name=session_name, pwd=pwd, memory=memory, ttl=ttl, cloud_location=cloud_location
name=session_name,
pwd=pwd,
memory=memory,
ttl=ttl,
cloud_location=cloud_location,
)
else:
return self._aura_api.get_or_create_session(name=session_name, dbid=dbid, pwd=pwd, memory=memory, ttl=ttl)

def _construct_client(
self, session_id: str, session_connection: DbmsConnectionInfo, db_runner: Neo4jQueryRunner
self,
session_id: str,
session_connection: DbmsConnectionInfo,
db_runner: Neo4jQueryRunner,
) -> AuraGraphDataScience:
return AuraGraphDataScience.create(
gds_session_connection_info=session_connection,
Expand Down
13 changes: 10 additions & 3 deletions graphdatascience/session/gds_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def __init__(self, api_credentials: AuraAPICredentials) -> None:
self._impl: DedicatedSessions = DedicatedSessions(aura_api)

def estimate(
self, node_count: int, relationship_count: int, algorithm_categories: Optional[list[AlgorithmCategory]] = None
self,
node_count: int,
relationship_count: int,
algorithm_categories: Optional[list[AlgorithmCategory]] = None,
) -> SessionMemory:
"""
Estimates the memory required for a session with the given node and relationship counts.
Expand Down Expand Up @@ -86,6 +89,7 @@ def get_or_create(
db_connection: DbmsConnectionInfo,
ttl: Optional[timedelta] = None,
cloud_location: Optional[CloudLocation] = None,
timeout: Optional[int] = None,
) -> AuraGraphDataScience:
"""
Retrieves an existing session with the given session name and database connection,
Expand All @@ -98,13 +102,16 @@ def get_or_create(
session_name (str): The name of the session.
memory (SessionMemory): The size of the session specified by memory.
db_connection (DbmsConnectionInfo): The database connection information.
ttl: Optional[timedelta]: The sessions time to live after inactivity in seconds.
ttl: (Optional[timedelta]): The sessions time to live after inactivity in seconds.
cloud_location (Optional[CloudLocation]): The cloud location. Required if the GDS session is for a self-managed database.
timeout (Optional[int]): Optional timeout (in seconds) when waiting for session to become ready. If unset the method will wait forever. If set and session does not become ready an exception will be raised. It is user responsibility to ensure resource gets cleaned up in this situation.

Returns:
AuraGraphDataScience: The session.
"""
return self._impl.get_or_create(session_name, memory, db_connection, ttl=ttl, cloud_location=cloud_location)
return self._impl.get_or_create(
session_name, memory, db_connection, ttl=ttl, cloud_location=cloud_location, timeout=timeout
)

def delete(self, *, session_name: Optional[str] = None, session_id: Optional[str] = None) -> bool:
"""
Expand Down