Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions graphdatascience/session/aura_api_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import sys
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, NamedTuple, Optional, Set

from pandas import Timedelta


@dataclass(repr=True, frozen=True)
class SessionDetails:
Expand All @@ -15,28 +17,34 @@ class SessionDetails:
instance_id: str
memory: str
status: str
host: Optional[str]
expiry_date: Optional[datetime]
host: str
created_at: datetime
expiry_date: Optional[datetime]
ttl: Optional[timedelta]

@classmethod
def fromJson(cls, json: Dict[str, Any]) -> SessionDetails:
expiry_date = json.get("expiry_date")
ttl = json.get("ttl")

return cls(
id=json["id"],
name=json["name"],
instance_id=json["instance_id"],
memory=json["memory"],
status=json["status"],
host=json.get("host"),
host=json["host"],
expiry_date=TimeParser.fromisoformat(expiry_date) if expiry_date else None,
created_at=TimeParser.fromisoformat(json["created_at"]),
ttl=Timedelta(ttl).to_pytimedelta() if ttl else None, # datetime has no support for parsing timedetla
)

def bolt_connection_url(self) -> str:
return f"neo4j+ssc://{self.host}" # TODO use neo4j+s

def is_expired(self) -> bool:
return self.status == "Expired"


@dataclass(repr=True, frozen=True)
class InstanceDetails:
Expand Down
10 changes: 10 additions & 0 deletions graphdatascience/session/dedicated_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import hashlib
import warnings
from datetime import datetime, timedelta, timezone
from typing import List, Optional

from graphdatascience.session.algorithm_category import AlgorithmCategory
Expand Down Expand Up @@ -50,6 +51,7 @@ def get_or_create(

# TODO configure session size (and check existing_session has same size)
if existing_session:
self._check_expiry_date(existing_session)
session_id = existing_session.id
else:
create_details = self._create_session(session_name, dbid, db_connection.uri, password, memory)
Expand Down Expand Up @@ -129,6 +131,14 @@ def _construct_client(
delete_fn=lambda: self.delete(session_name, dbid=AuraApi.extract_id(db_connection.uri)),
)

def _check_expiry_date(self, session: SessionDetails) -> None:
if session.is_expired():
raise RuntimeError(f"Session `{session.name}` is expired. Please delete it and create a new one.")
if session.expiry_date:
until_expiry: timedelta = session.expiry_date - datetime.now(timezone.utc)
if until_expiry < timedelta(days=1):
raise Warning(f"Session `{session.name}` is expiring in less than a day.")

@classmethod
def _fail_ambiguous_session(cls, session_name: str, sessions: List[SessionDetails]) -> None:
candidates = [i.id for i in sessions]
Expand Down
14 changes: 11 additions & 3 deletions graphdatascience/tests/unit/test_aura_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone

import pytest
from _pytest.logging import LogCaptureFixture
Expand Down Expand Up @@ -48,6 +48,7 @@ def test_create_session(requests_mock: Mocker) -> None:
host="1.2.3.4",
memory="4G",
expiry_date=None,
ttl=None,
)


Expand Down Expand Up @@ -82,6 +83,7 @@ def test_list_session(requests_mock: Mocker) -> None:
host="1.2.3.4",
memory="4G",
expiry_date=TimeParser.fromisoformat("1977-01-01T00:00:00Z"),
ttl=None,
)


Expand Down Expand Up @@ -109,6 +111,7 @@ def test_list_sessions(requests_mock: Mocker) -> None:
"instance_id": "dbid-3",
"created_at": "2012-01-01T00:00:00Z",
"memory": "8G",
"host": "foo.bar",
},
],
)
Expand All @@ -124,6 +127,7 @@ def test_list_sessions(requests_mock: Mocker) -> None:
host="1.2.3.4",
memory="4G",
expiry_date=TimeParser.fromisoformat("1977-01-01T00:00:00Z"),
ttl=None,
)

expected2 = SessionDetails(
Expand All @@ -133,8 +137,9 @@ def test_list_sessions(requests_mock: Mocker) -> None:
instance_id="dbid-3",
created_at=TimeParser.fromisoformat("2012-01-01T00:00:00Z"),
memory="8G",
host=None,
host="foo.bar",
expiry_date=None,
ttl=None,
)

assert result == [expected1, expected2]
Expand Down Expand Up @@ -635,6 +640,7 @@ def test_parse_session_info() -> None:
"expiry_date": "2022-01-01T00:00:00Z",
"created_at": "2021-01-01T00:00:00Z",
"host": "a.b",
"ttl": "1d8h1m2s",
}
session_info = SessionDetails.fromJson(session_details)

Expand All @@ -647,10 +653,11 @@ def test_parse_session_info() -> None:
host="a.b",
expiry_date=datetime(2022, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
created_at=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
ttl=timedelta(days=1, hours=8, minutes=1, seconds=2),
)


def test_parse_session_info_without_expiry() -> None:
def test_parse_session_info_without_optionals() -> None:
session_details = {
"id": "test_id",
"name": "test_session",
Expand All @@ -670,5 +677,6 @@ def test_parse_session_info_without_expiry() -> None:
host="a.b",
status="running",
expiry_date=None,
ttl=None,
created_at=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
)
59 changes: 57 additions & 2 deletions graphdatascience/tests/unit/test_dedicated_sessions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import re
from datetime import datetime
from typing import List, Optional
from datetime import datetime, timedelta, timezone
from typing import List, Optional, cast

import pytest
from pytest_mock import MockerFixture
Expand Down Expand Up @@ -53,13 +53,20 @@ def create_session(self, name: str, dbid: str, pwd: str, memory: str) -> Session
created_at=datetime.fromisoformat("2021-01-01T00:00:00+00:00"),
host="foo.bar",
expiry_date=None,
ttl=None,
)

self.id_counter += 1
self._sessions[details.id] = details

return details

def add_session(self, session: SessionDetails) -> None:
if session.id in self._sessions:
raise ValueError(f"Session with id {session.id} already exists.")

self._sessions[session.id] = session

def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails:
id = f"ffff{self.id_counter}"
create_details = InstanceCreateDetails(
Expand Down Expand Up @@ -236,6 +243,54 @@ def test_get_or_create_duplicate_session(aura_api: AuraApi) -> None:
sessions.get_or_create("one", SessionMemory.m_8GB, DbmsConnectionInfo(db.connection_url, "", ""))


def test_get_or_create_expired_session(aura_api: AuraApi) -> None:
db = _setup_db_instance(aura_api)

fake_aura_api = cast(FakeAuraApi, aura_api)
fake_aura_api.add_session(
SessionDetails(
id="ffff0-ffff1",
name="one",
instance_id=db.id,
memory=SessionMemory.m_8GB.value,
status="Expired",
created_at=datetime.now(),
host="foo.bar",
expiry_date=None,
ttl=None,
)
)

with pytest.raises(
RuntimeError, match=re.escape("Session `one` is expired. Please delete it and create a new one.")
):
sessions = DedicatedSessions(aura_api)
sessions.get_or_create("one", SessionMemory.m_8GB, DbmsConnectionInfo(db.connection_url, "", ""))


def test_get_or_create_soon_expired_session(aura_api: AuraApi) -> None:
db = _setup_db_instance(aura_api)

fake_aura_api = cast(FakeAuraApi, aura_api)
fake_aura_api.add_session(
SessionDetails(
id="ffff0-ffff1",
name="one",
instance_id=db.id,
memory=SessionMemory.m_8GB.value,
status="Ready",
created_at=datetime.now(),
host="foo.bar",
expiry_date=datetime.now(tz=timezone.utc) - timedelta(hours=23),
ttl=None,
)
)

with pytest.raises(Warning, match=re.escape("Session `one` is expiring in less than a day.")):
sessions = DedicatedSessions(aura_api)
sessions.get_or_create("one", SessionMemory.m_8GB, DbmsConnectionInfo(db.connection_url, "", ""))


def test_delete_session(aura_api: AuraApi) -> None:
db1 = aura_api.create_instance("db1", "1GB", "aura", "leipzig").id
db2 = aura_api.create_instance("db2", "1GB", "aura", "dresden").id
Expand Down