Skip to content

Commit

Permalink
Merge pull request #30 from davidbrochart/kernel_ycell
Browse files Browse the repository at this point in the history
Pass ycell to kernel.execute()
  • Loading branch information
davidbrochart committed Jan 21, 2023
2 parents daf0968 + 83bfc00 commit 73f499d
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 49 deletions.
1 change: 1 addition & 0 deletions plugins/local_kernels/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"txl",
"python-dateutil >=2.8.2",
"pyzmq >=24.0.1",
"y_py >=0.5.5",
]
dynamic = ["version"]

Expand Down
7 changes: 3 additions & 4 deletions plugins/local_kernels/txl_local_kernels/components.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Any, Dict

import y_py as Y
from asphalt.core import Component, Context

from txl.base import Kernels
Expand All @@ -12,8 +11,8 @@ class LocalKernels(Kernels):
def __init__(self, kernel_name: str | None = None):
self.kernel = KernelDriver(kernel_name)

async def execute(self, cell: Dict[str, Any]):
await self.kernel.execute(cell)
async def execute(self, ydoc: Y.YDoc, ycell: Y.YMap):
await self.kernel.execute(ydoc, ycell)


class LocalKernelsComponent(Component):
Expand Down
28 changes: 18 additions & 10 deletions plugins/local_kernels/txl_local_kernels/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import uuid
from typing import Any, Dict, List, Optional, cast

import y_py as Y

from .connect import cfg_t, connect_channel, launch_kernel, read_connection_file
from .connect import write_connection_file as _write_connection_file
from .kernelspec import find_kernelspec
Expand Down Expand Up @@ -113,18 +115,17 @@ async def listen_shell(self):

async def execute(
self,
cell: Dict[str, Any],
ydoc: Y.YDoc,
ycell: Dict[str, Any],
timeout: float = float("inf"),
msg_id: str = "",
wait_for_executed: bool = True,
) -> None:
await self.started
if cell["cell_type"] != "code":
if ycell["cell_type"] != "code":
return
cell_source = cell["source"]
if isinstance(cell_source, list):
cell_source = "".join(cell_source)
content = {"code": cell_source, "silent": False}
code = str(ycell["source"])
content = {"code": code, "silent": False}
msg = create_message(
"execute_request",
content,
Expand All @@ -143,6 +144,8 @@ async def execute(
"iopub_msg": asyncio.Future(),
"shell_msg": asyncio.Future(),
}
with ydoc.begin_transaction() as txn:
ycell.set(txn, "outputs", [])
while True:
try:
await asyncio.wait_for(
Expand All @@ -153,7 +156,7 @@ async def execute(
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
msg = self.execute_requests[msg_id]["iopub_msg"].result()
self._handle_outputs(cell["outputs"], msg)
self._handle_outputs(ydoc, ycell, msg)
if (
msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle"
Expand All @@ -169,7 +172,8 @@ async def execute(
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
msg = self.execute_requests[msg_id]["shell_msg"].result()
cell["execution_count"] = msg["content"]["execution_count"]
with ydoc.begin_transaction() as txn:
ycell.set(txn, "execution_count", msg["content"]["execution_count"])
del self.execute_requests[msg_id]

async def _wait_for_ready(self, timeout):
Expand Down Expand Up @@ -199,11 +203,12 @@ async def _wait_for_ready(self, timeout):
break
new_timeout = deadline_to_timeout(deadline)

def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
def _handle_outputs(self, ydoc: Y.YDoc, ycell: Y.YMap, msg: Dict[str, Any]):
msg_type = msg["header"]["msg_type"]
content = msg["content"]
outputs = list(ycell["outputs"])
if msg_type == "stream":
if (not outputs) or (outputs[-1]["name"] != content["name"]):
if (len(outputs) == 0) or (outputs[-1]["name"] != content["name"]):
outputs.append(
{"name": content["name"], "output_type": msg_type, "text": []}
)
Expand All @@ -228,3 +233,6 @@ def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
)
else:
return

with ydoc.begin_transaction() as txn:
ycell.set(txn, "outputs", outputs)
29 changes: 12 additions & 17 deletions plugins/notebook_viewer/txl_notebook_viewer/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from rich.text import Text
from textual import events
from textual.widgets import DataTable
from textual.widgets._data_table import Coordinate

from txl.base import Contents, Editor, Editors, FileOpenEvent, Kernels, NotebookFactory
from txl.hooks import register_component
Expand Down Expand Up @@ -42,8 +41,8 @@ def __init__(
self.notebook = notebook
self.kernels = kernels
self.kernel = None
self._row_to_cell = []
self._selected_cell = None
self._row_to_cell_idx = []
self._selected_cell_idx = None

async def on_open(self, event: FileOpenEvent) -> None:
await self.open(event.path)
Expand All @@ -61,6 +60,8 @@ async def open(self, path: str) -> None:
self.ynb.observe(self.on_change)

def update_viewer(self):
self.clear()

self.add_column("", width=10)
self.add_column("", width=100)

Expand All @@ -71,7 +72,7 @@ def update_viewer(self):
if self.ynb.cell_number == 0:
return

self._row_to_cell = []
self._row_to_cell_idx = []
for i_cell in range(self.ynb.cell_number):
cell = self.ynb.get_cell(i_cell)
execution_count = (
Expand Down Expand Up @@ -106,7 +107,7 @@ def update_viewer(self):
renderable = Text(source)

self.add_row(execution_count, renderable, height=num_lines)
self._row_to_cell.append(cell)
self._row_to_cell_idx.append(i_cell)

for output in cell.get("outputs", []):
output_type = output["output_type"]
Expand Down Expand Up @@ -135,29 +136,23 @@ def update_viewer(self):

num_lines = len(text.splitlines())
self.add_row(execution_count, renderable, height=num_lines)
self._row_to_cell.append(cell)
self._row_to_cell_idx.append(i_cell)

def on_click(self, event: events.Click) -> None:
self._set_hover_cursor(True)
DataTable.on_click(self, event)
if self.show_cursor and self.cursor_type != "none":
# Only emit selection events if there is a visible row/col/cell cursor.
self._emit_selected_message()
meta = self.get_style_at(event.x, event.y).meta
if meta:
self._selected_cell = self._row_to_cell[meta["row"]]
self.cursor_cell = Coordinate(meta["row"], meta["column"])
self._scroll_cursor_into_view(animate=True)
event.stop()
self._selected_cell_idx = self._row_to_cell_idx[meta["row"]]

def on_change(self, target, event):
self.clear()
self.update_viewer()

async def key_e(self) -> None:
if self.kernel:
print(f"Executing: {self._selected_cell}")
await self.kernel.execute(self._selected_cell)
print(f"Executed: {self._selected_cell}")
ycell = self.ynb._ycells[self._selected_cell_idx]
await self.kernel.execute(self.ynb.ydoc, ycell)
print(f"Executed {ycell}")


class NotebookViewerComponent(Component):
Expand Down
5 changes: 3 additions & 2 deletions plugins/remote_kernels/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ classifiers = [
dependencies = [
"txl",
"python-dateutil >=2.8.2",
"httpx>=0.23.1",
"httpx-ws>=0.2.6",
"httpx >=0.23.1",
"httpx-ws >=0.2.6",
"y_py >=0.5.5",
]
dynamic = ["version"]

Expand Down
6 changes: 3 additions & 3 deletions plugins/remote_kernels/txl_remote_kernels/components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
from typing import Any, Dict

import y_py as Y
from asphalt.core import Component, Context

from txl.base import Kernels
Expand All @@ -17,8 +17,8 @@ def __init__(
):
self.kernel = KernelDriver(url, kernel_name)

async def execute(self, cell: Dict[str, Any]):
await self.kernel.execute(cell)
async def execute(self, ydoc: Y.YDoc, ycell: Y.YMap):
await self.kernel.execute(ydoc, ycell)


class RemoteKernelsComponent(Component):
Expand Down
33 changes: 20 additions & 13 deletions plugins/remote_kernels/txl_remote_kernels/driver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import time
import uuid
from typing import Any, Dict, List
from typing import Dict, List
from urllib import parse

import httpx
import y_py as Y
from httpx_ws import WebSocketNetworkError, aconnect_ws

from .message import create_message, send_message, str_to_date
Expand Down Expand Up @@ -133,18 +134,17 @@ async def _wait_for_ready(self, timeout: float = float("inf")):

async def execute(
self,
cell: Dict[str, Any],
ydoc: Y.YDoc,
ycell: Y.YMap,
timeout: float = float("inf"),
msg_id: str = "",
wait_for_executed: bool = True,
) -> None:
await self.started.wait()
if cell["cell_type"] != "code":
if ycell["cell_type"] != "code":
return
cell_source = cell["source"]
if isinstance(cell_source, list):
cell_source = "".join(cell_source)
content = {"code": cell_source, "silent": False}
code = str(ycell["source"])
content = {"code": code, "silent": False}
msg = create_message(
"execute_request",
content,
Expand Down Expand Up @@ -172,9 +172,8 @@ async def execute(
error_message = f"Kernel didn't respond in {timeout} seconds"
del self.execute_requests[msg_id]
raise RuntimeError(error_message)
cell["outputs"] = []
await self._handle_outputs(
cell["outputs"], self.execute_requests[msg_id]["iopub"]
ydoc, ycell, self.execute_requests[msg_id]["iopub"]
)
try:
await asyncio.wait_for(
Expand All @@ -186,12 +185,16 @@ async def execute(
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
msg = self.execute_requests[msg_id]["shell"][0].result()
cell["execution_count"] = msg["content"]["execution_count"]
with ydoc.begin_transaction() as txn:
ycell.set(txn, "execution_count", msg["content"]["execution_count"])
del self.execute_requests[msg_id]

async def _handle_outputs(
self, outputs: List[Dict[str, Any]], future_messages: List[asyncio.Future]
self, ydoc: Y.YDoc, ycell: Y.YMap, future_messages: List[asyncio.Future]
):
with ydoc.begin_transaction() as txn:
ycell.set(txn, "outputs", [])

while True:
if not future_messages:
future_messages.append(asyncio.Future())
Expand All @@ -202,8 +205,9 @@ async def _handle_outputs(
msg = fut.result()
msg_type = msg["header"]["msg_type"]
content = msg["content"]
outputs = list(ycell["outputs"])
if msg_type == "stream":
if (not outputs) or (outputs[-1]["name"] != content["name"]):
if (len(outputs) == 0) or (outputs[-1]["name"] != content["name"]):
outputs.append(
{"name": content["name"], "output_type": msg_type, "text": []}
)
Expand All @@ -227,4 +231,7 @@ async def _handle_outputs(
}
)
elif msg_type == "status" and msg["content"]["execution_state"] == "idle":
return
break

with ydoc.begin_transaction() as txn:
ycell.set(txn, "outputs", outputs)

0 comments on commit 73f499d

Please sign in to comment.