Skip to content

Commit

Permalink
Merge 403e9e5 into 1cb1a9c
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Sep 27, 2021
2 parents 1cb1a9c + 403e9e5 commit 3f53a52
Show file tree
Hide file tree
Showing 13 changed files with 260 additions and 195 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/test_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ jobs:
activate-environment: afar
- name: Install dependencies
run: |
conda install -y -c conda-forge distributed pytest
pip install innerscope
conda install -y -c conda-forge distributed pytest innerscope
pip install -e .
- name: PyTest
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
run: |
pip install black flake8
flake8 .
black afar *.py --check --diff
black . --check --diff
- name: Coverage
env:
GITHUB_TOKEN: ${{ secrets.github_token }}
Expand Down
15 changes: 15 additions & 0 deletions afar/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
"""afar runs code within a context manager or IPython magic on a Dask cluster.
>>> with afar.run, remotely:
... import dask_cudf
... df = dask_cudf.read_parquet("s3://...")
... result = df.sum().compute()
or to use an IPython magic:
>>> %load_ext afar
>>> %afar z = x + y
Read the documentation at https://github.com/eriknw/afar
"""

from ._core import get, run # noqa
from ._version import get_versions
from ._where import later, locally, remotely # noqa
Expand Down
9 changes: 7 additions & 2 deletions afar/_abra.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
"""Perform a magic trick: given lines of code, create a function to run remotely.
This callable object is able to provide the values of the requested argument
names and return the final expression so it can be displayed.
"""
import dis
from types import FunctionType

from dask.distributed import Future
from innerscope import scoped_function

from ._reprs import get_repr_methods
from ._utils import code_replace, is_kernel
from ._utils import code_replace, is_ipython


def endswith_expr(func):
Expand Down Expand Up @@ -85,7 +90,7 @@ def cadabra(context_body, where, names, data, global_ns, local_ns):
# Create a new function from the code block of the context.
# For now, we require that the source code is available.
source = "def _afar_magic_():\n" + "".join(context_body)
func, display_expr = create_func(source, global_ns, is_kernel())
func, display_expr = create_func(source, global_ns, is_ipython())

# If no variable names were given, only get the last assignment
if not names:
Expand Down
253 changes: 114 additions & 139 deletions afar/_core.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,26 @@
"""Define the user-facing `run` object; this is where it all comes together."""
import dis
from functools import partial
from inspect import currentframe, findsource
import sys
from inspect import currentframe
from uuid import uuid4
from weakref import WeakKeyDictionary, WeakSet

from dask import distributed
from dask.distributed import get_worker

from ._abra import cadabra
from ._printing import PrintRecorder, print_outputs, print_outputs_async
from ._reprs import repr_afar
from ._utils import is_kernel, supports_async_output
from ._inspect import get_body, get_body_start, get_lines
from ._printing import PrintRecorder
from ._reprs import display_repr, repr_afar
from ._utils import supports_async_output
from ._where import find_where


def get_body_start(lines, with_start):
line = lines[with_start]
stripped = line.lstrip()
body = line[: len(line) - len(stripped)] + " pass\n"
body *= 2
with_lines = [stripped]
try:
code = compile(stripped, "<exec>", "exec")
except Exception:
pass
else:
raise RuntimeError(
"Failed to analyze the context! When using afar, "
"please put the context body on a new line."
)
for i, line in enumerate(lines[with_start:]):
if i > 0:
with_lines.append(line)
if ":" in line:
source = "".join(with_lines) + body
try:
code = compile(source, "<exec>", "exec")
except Exception:
pass
else:
num_with = code.co_code.count(dis.opmap["SETUP_WITH"])
body_start = with_start + i + 1
return num_with, body_start
raise RuntimeError("Failed to analyze the context!")


def get_body(lines):
head = "def f():\n with x:\n "
tail = " pass\n pass\n"
while lines:
source = head + " ".join(lines) + tail
try:
compile(source, "<exec>", "exec")
except Exception:
lines.pop()
else:
return lines
raise RuntimeError("Failed to analyze the context body!")


class Run:
_gather_data = False
# Used to update outputs asynchronously
_outputs = {}
_channel = "afar-" + uuid4().hex

def __init__(self, *names, client=None, data=None):
self.names = names
Expand Down Expand Up @@ -94,36 +56,8 @@ def __enter__(self):
if self.data:
raise RuntimeError("uh oh!")
self.data = {}
try:
lines, offset = findsource(self._frame)
except OSError:
# Try to fine the source if we are in %%time or %%timeit magic
if self._frame.f_code.co_filename in {"<timed exec>", "<magic-timeit>"} and is_kernel():
from IPython import get_ipython

ip = get_ipython()
if ip is None:
raise
cell = ip.history_manager._i00 # The current cell!
lines = cell.splitlines(keepends=True)
# strip the magic
for i, line in enumerate(lines):
if line.strip().startswith("%%time"):
lines = lines[i + 1 :]
break
else:
raise
# strip blank lines
for i, line in enumerate(lines):
if line.strip():
if i:
lines = lines[i:]
lines[-1] += "\n"
break
else:
raise
else:
raise

lines = get_lines(self._frame)

while not lines[with_lineno].lstrip().startswith("with"):
with_lineno -= 1
Expand Down Expand Up @@ -236,14 +170,6 @@ def _run(
else:
weak_futures = self._client_to_futures[client]

has_print = "print" in self._magic_func._scoped.builtin_names
capture_print = (
self._gather_data # we're blocking anyway to gather data
or display_expr # we need to display an expression (sync or async)
or has_print # print is in the context body
or supports_async_output() # no need to block, so why not?
)

to_scatter = data.keys() & self._magic_func._scoped.outer_scope.keys()
if to_scatter:
# Scatter value in `data` that we need in this calculation.
Expand All @@ -261,34 +187,36 @@ def _run(
data.update(scattered)
for key in to_scatter:
del self._magic_func._scoped.outer_scope[key]

capture_print = True
if capture_print and self._channel not in client._event_handlers:
client.subscribe_topic(self._channel, self._handle_print)
# When would be a good time to unsubscribe?
async_print = capture_print and supports_async_output()
if capture_print:
unique_key = uuid4().hex
self._setup_print(unique_key, async_print)
else:
unique_key = None

# Scatter magic_func to avoid "Large object" UserWarning
magic_func = client.scatter(self._magic_func)
magic_func = client.scatter(self._magic_func, hash=False)
weak_futures.add(magic_func)

remote_dict = client.submit(
run_afar, magic_func, names, futures, capture_print, pure=False, **submit_kwargs
run_afar,
magic_func,
names,
futures,
capture_print,
self._channel,
unique_key,
pure=False,
**submit_kwargs,
)
weak_futures.add(remote_dict)
magic_func.release() # Let go ASAP
if display_expr:
return_future = client.submit(get_afar, remote_dict, "_afar_return_value_")
repr_future = client.submit(
repr_afar,
return_future,
self._magic_func._repr_methods,
)
weak_futures.add(repr_future)
if return_expr:
weak_futures.add(return_future)
else:
return_future.release() # Let go ASAP
return_future = None
else:
repr_future = None
if capture_print:
stdout_future = client.submit(get_afar, remote_dict, "_afar_stdout_")
weak_futures.add(stdout_future)
stderr_future = client.submit(get_afar, remote_dict, "_afar_stderr_")
weak_futures.add(stderr_future)

if self._gather_data:
futures_to_name = {
client.submit(get_afar, remote_dict, name, **submit_kwargs): name
Expand All @@ -304,21 +232,6 @@ def _run(
weak_futures.add(future)
data[name] = future
remote_dict.release() # Let go ASAP

if capture_print and supports_async_output():
# Display in `out` cell when data is ready: non-blocking
from IPython.display import display
from ipywidgets import Output

out = Output()
display(out)
out.append_stdout("\N{SPARKLES} Running afar... \N{SPARKLES}")
stdout_future.add_done_callback(
partial(print_outputs_async, out, stderr_future, repr_future)
)
elif capture_print:
# blocks!
print_outputs(stdout_future, stderr_future, repr_future)
elif where == "locally":
# Run locally. This is handy for testing and debugging.
results = self._magic_func()
Expand Down Expand Up @@ -352,32 +265,94 @@ def cancel(self, *, client=None, force=False):
)
weak_futures.clear()

def _setup_print(self, key, async_print):
if async_print:
from IPython.display import display
from ipywidgets import Output

out = Output()
display(out)
out.append_stdout("\N{SPARKLES} Running afar... \N{SPARKLES}")
else:
out = None
self._outputs[key] = [out, False] # False means has not been updated

@classmethod
def _handle_print(cls, event):
# XXX: can we assume all messages from a single task arrive in FIFO order?
_, msg = event
key, action, payload = msg
if key not in cls._outputs:
return
out, is_updated = cls._outputs[key]
if out is not None:
if action == "begin":
if is_updated:
out.outputs = type(out.outputs)()
out.append_stdout("\N{SPARKLES} Running afar... (restarted) \N{SPARKLES}")
cls._outputs[key][1] = False # is not updated
else:
if not is_updated:
# Clear the "Running afar..." message
out.outputs = type(out.outputs)()
cls._outputs[key][1] = True # is updated
# ipywidgets.Output is pretty slow if there are lots of messages
if action == "stdout":
out.append_stdout(payload)
elif action == "stderr":
out.append_stderr(payload)
elif action == "stdout":
print(payload, end="")
elif action == "stderr":
print(payload, end="", file=sys.stderr)
if action == "display_expr":
display_repr(payload, out=out)
del cls._outputs[key]
elif action == "finish":
del cls._outputs[key]


class Get(Run):
"""Unlike ``run``, ``get`` automatically gathers the data locally"""

_gather_data = True


def run_afar(magic_func, names, futures, capture_print):
def run_afar(magic_func, names, futures, capture_print, channel, unique_key):
if capture_print:
rec = PrintRecorder()
if "print" in magic_func._scoped.builtin_names and "print" not in futures:
sfunc = magic_func._scoped.bind(futures, print=rec)
try:
worker = get_worker()
send_finish = True
except ValueError:
worker = None
try:
if capture_print and worker is not None:
worker.log_event(channel, (unique_key, "begin", None))
rec = PrintRecorder(channel, unique_key)
if "print" in magic_func._scoped.builtin_names and "print" not in futures:
sfunc = magic_func._scoped.bind(futures, print=rec)
else:
sfunc = magic_func._scoped.bind(futures)
with rec:
results = sfunc()
else:
sfunc = magic_func._scoped.bind(futures)
with rec:
results = sfunc()
else:
sfunc = magic_func._scoped.bind(futures)
results = sfunc()

rv = {key: results[key] for key in names}
if magic_func._display_expr:
rv["_afar_return_value_"] = results.return_value
if capture_print:
rv["_afar_stdout_"] = rec.stdout.getvalue()
rv["_afar_stderr_"] = rec.stderr.getvalue()
rv = {key: results[key] for key in names}

if magic_func._display_expr and worker is not None:
# Hopefully computing the repr is fast. If it is slow, perhaps it would be
# better to add the return value to rv and call repr_afar as a separate task.
# Also, pretty_repr must be msgpack serializable if done via events. Hence,
# custom _ipython_display_ doesn't work, and we resort to using a basic repr.
pretty_repr = repr_afar(results.return_value, magic_func._repr_methods)
if pretty_repr is not None:
worker.log_event(channel, (unique_key, "display_expr", pretty_repr))
send_finish = False
finally:
if capture_print and worker is not None and send_finish:
worker.log_event(channel, (unique_key, "finish", None))
return rv


Expand Down

0 comments on commit 3f53a52

Please sign in to comment.