From 2b19487da45ffd824edef0cd03636d753c9baaee Mon Sep 17 00:00:00 2001 From: Zach Sailer Date: Wed, 23 Mar 2022 15:20:42 -0700 Subject: [PATCH] add unit tests for pending kernels in sessions --- .../services/sessions/sessionmanager.py | 92 +++++++++-- tests/services/sessions/test_manager.py | 154 ++++++++++++++++++ 2 files changed, 236 insertions(+), 10 deletions(-) diff --git a/jupyter_server/services/sessions/sessionmanager.py b/jupyter_server/services/sessions/sessionmanager.py index 344bca416..70290d98a 100644 --- a/jupyter_server/services/sessions/sessionmanager.py +++ b/jupyter_server/services/sessions/sessionmanager.py @@ -27,9 +27,15 @@ from dataclasses import fields +class KernelRecordConflict(Exception): + """An exception raised when""" + + pass + + @dataclass class KernelRecord: - """A temporary record. + """A record object for tracking a Jupyter Server Kernel Session. Two records are equal if they share the """ @@ -39,39 +45,102 @@ class KernelRecord: def __eq__(self, other: "KernelRecord") -> bool: if isinstance(other, KernelRecord): - if any( + condition1 = self.kernel_id and self.kernel_id == other.kernel_id + condition2 = all( [ - # Check if the session_id matches - self.session_id and other.session_id and self.session_id == other.session_id, - # Check if the kernel_id matches. - self.kernel_id and other.kernel_id and self.kernel_id == other.kernel_id, + self.session_id == other.session_id, + self.kernel_id is None or other.kernel_id is None, ] - ): + ) + if any([condition1, condition2]): return True + # If two records share session_id but have different kernels, this is + # and ill-posed expression. This should never be true. Raise an exception + # to inform the user. + if all( + [ + self.session_id, + self.session_id == other.session_id, + self.kernel_id != other.kernel_id, + ] + ): + raise KernelRecordConflict( + "A single session_id can only have one kernel_id " + "associated with. These two KernelRecords share the same " + "session_id but have different kernel_ids. This should " + "not be possible and is likely an issue with the session " + "records." + ) return False def update(self, other: "KernelRecord") -> None: """Updates in-place a kernel from other (only accepts positive updates""" + if not isinstance(other, KernelRecord): + raise TypeError("'other' must be an instance of KernelRecord.") + + if other.kernel_id and self.kernel_id and other.kernel_id != self.kernel_id: + raise KernelRecordConflict( + "Could not update the record from 'other' because the two records conflict." + ) + for field in fields(self): if hasattr(other, field.name) and getattr(other, field.name): setattr(self, field.name, getattr(other, field.name)) class KernelRecordList: + """Handy object for storing and managing a list of KernelRecords. - _records = [] + When adding a record to the list, first checks if the record + already exists. If it does, the record will be updated with + the new information. + """ + + def __init__(self, *records): + self._records = [] + for record in records: + self.update(record) def __str__(self): return str(self._records) + def __contains__(self, record: Union[KernelRecord, str]): + """Search for records by kernel_id and session_id""" + if isinstance(record, KernelRecord) and record in self._records: + return True + + if isinstance(record, str): + for r in self._records: + if record in [r.session_id, r.kernel_id]: + return True + return False + + def __len__(self): + return len(self._records) + + def get(self, record: Union[KernelRecord, str]) -> KernelRecord: + if isinstance(record, str): + for r in self._records: + if record == r.kernel_id or record == r.session_id: + return r + elif isinstance(record, KernelRecord): + for r in self._records: + if record == r: + return record + raise ValueError(f"{record} not found in KernelRecordList.") + def update(self, record: KernelRecord) -> None: + """Update a record in-place or append it if not in the list.""" try: idx = self._records.index(record) self._records[idx].update(record) except ValueError: - self.append(record) + self._records.append(record) def remove(self, record: KernelRecord) -> None: + """Remove a record if its found in the list. If it's not found, + do nothing. + """ if record in self._records: self._records.remove(record) @@ -116,7 +185,9 @@ def _validate_database_filepath(self, proposal): ] ) - _pending_kernels = KernelRecordList() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._pending_kernels = KernelRecordList() # Session database initialized below _cursor = None @@ -186,6 +257,7 @@ async def create_session( kernel_id = await self.start_kernel_for_session( session_id, path, name, type, kernel_name ) + record.kernel_id = kernel_id self._pending_kernels.update(record) result = await self.save_session( session_id, path=path, name=name, type=type, kernel_id=kernel_id diff --git a/tests/services/sessions/test_manager.py b/tests/services/sessions/test_manager.py index f0142be42..bc0625305 100644 --- a/tests/services/sessions/test_manager.py +++ b/tests/services/sessions/test_manager.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from tornado import web from traitlets import TraitError @@ -6,6 +8,9 @@ from jupyter_server._tz import utcnow from jupyter_server.services.contents.manager import ContentsManager from jupyter_server.services.kernels.kernelmanager import MappingKernelManager +from jupyter_server.services.sessions.sessionmanager import KernelRecord +from jupyter_server.services.sessions.sessionmanager import KernelRecordConflict +from jupyter_server.services.sessions.sessionmanager import KernelRecordList from jupyter_server.services.sessions.sessionmanager import SessionManager @@ -40,11 +45,113 @@ async def shutdown_kernel(self, kernel_id, now=False): del self._kernels[kernel_id] +class SlowDummyMKM(DummyMKM): + async def start_kernel(self, kernel_id=None, path=None, kernel_name="python", **kwargs): + await asyncio.sleep(1.0) + return await super().start_kernel( + kernel_id=kernel_id, path=path, kernel_name=kernel_name, **kwargs + ) + + async def shutdown_kernel(self, kernel_id, now=False): + await asyncio.sleep(1.0) + await super().shutdown_kernel(kernel_id, now=now) + + @pytest.fixture def session_manager(): return SessionManager(kernel_manager=DummyMKM(), contents_manager=ContentsManager()) +def test_kernel_record_equals(): + record1 = KernelRecord(session_id="session1") + record2 = KernelRecord(session_id="session1", kernel_id="kernel1") + record3 = KernelRecord(session_id="session2", kernel_id="kernel1") + record4 = KernelRecord(session_id="session1", kernel_id="kernel2") + + assert record1 == record2 + assert record2 == record3 + assert record3 != record4 + assert record1 != record3 + assert record3 != record4 + + with pytest.raises(KernelRecordConflict): + assert record2 == record4 + + +def test_kernel_record_update(): + record1 = KernelRecord(session_id="session1") + record2 = KernelRecord(session_id="session1", kernel_id="kernel1") + record1.update(record2) + assert record1.kernel_id == "kernel1" + + record1 = KernelRecord(session_id="session1") + record2 = KernelRecord(kernel_id="kernel1") + record1.update(record2) + assert record1.kernel_id == "kernel1" + + record1 = KernelRecord(kernel_id="kernel1") + record2 = KernelRecord(session_id="session1") + record1.update(record2) + assert record1.session_id == "session1" + + record1 = KernelRecord(kernel_id="kernel1") + record2 = KernelRecord(session_id="session1", kernel_id="kernel1") + record1.update(record2) + assert record1.session_id == "session1" + + record1 = KernelRecord(kernel_id="kernel1") + record2 = KernelRecord(session_id="session1", kernel_id="kernel2") + with pytest.raises(KernelRecordConflict): + record1.update(record2) + + record1 = KernelRecord(kernel_id="kernel1", session_id="session1") + record2 = KernelRecord(kernel_id="kernel2") + with pytest.raises(KernelRecordConflict): + record1.update(record2) + + record1 = KernelRecord(kernel_id="kernel1", session_id="session1") + record2 = KernelRecord(kernel_id="kernel2", session_id="session1") + with pytest.raises(KernelRecordConflict): + record1.update(record2) + + record1 = KernelRecord(session_id="session1", kernel_id="kernel1") + record2 = KernelRecord(session_id="session2", kernel_id="kernel1") + record1.update(record2) + assert record1.session_id == "session2" + + +def test_kernel_record_list(): + records = KernelRecordList() + r = KernelRecord(kernel_id="kernel1") + records.update(r) + assert r in records + assert "kernel1" in records + assert len(records) == 1 + + # Test .get() + r_ = records.get(r) + assert r == r_ + r_ = records.get(r.kernel_id) + assert r == r_ + + with pytest.raises(ValueError): + records.get("badkernel") + + r_update = KernelRecord(kernel_id="kernel1", session_id="session1") + records.update(r_update) + assert len(records) == 1 + assert "session1" in records + + r2 = KernelRecord(kernel_id="kernel2") + records.update(r2) + assert r2 in records + assert len(records) == 2 + + records.remove(r2) + assert r2 not in records + assert len(records) == 1 + + async def create_multiple_sessions(session_manager, *kwargs_list): sessions = [] for kwargs in kwargs_list: @@ -363,3 +470,50 @@ async def test_session_persistence(jp_runtime_dir): # Assert that the session database persists. session = await session_manager.get_session(session_id=session["id"]) + + +async def test_pending_kernel(): + session_manager = SessionManager( + kernel_manager=SlowDummyMKM(), contents_manager=ContentsManager() + ) + # Create a session with a slow starting kernel + fut = session_manager.create_session( + path="/path/to/test.ipynb", kernel_name="python", type="notebook" + ) + task = asyncio.create_task(fut) + await asyncio.sleep(0.1) + assert len(session_manager._pending_kernels) == 1 + # Get a handle on the record + record = session_manager._pending_kernels._records[0] + session = await task + # Check that record is cleared after the task has completed. + assert record not in session_manager._pending_kernels + + # Check pending kernel list when sessions are + fut = session_manager.delete_session(session_id=session["id"]) + task = asyncio.create_task(fut) + await asyncio.sleep(0.1) + assert len(session_manager._pending_kernels) == 1 + # Get a handle on the record + record = session_manager._pending_kernels._records[0] + session = await task + # Check that record is cleared after the task has completed. + assert record not in session_manager._pending_kernels + + # Test multiple, parallel pending kernels + fut1 = session_manager.create_session( + path="/path/to/test.ipynb", kernel_name="python", type="notebook" + ) + fut2 = session_manager.create_session( + path="/path/to/test.ipynb", kernel_name="python", type="notebook" + ) + task1 = asyncio.create_task(fut1) + await asyncio.sleep(0.1) + task2 = asyncio.create_task(fut2) + await asyncio.sleep(0.1) + assert len(session_manager._pending_kernels) == 2 + + await task1 + await task2 + session1, session2 = await asyncio.gather(task1, task2) + assert len(session_manager._pending_kernels) == 0