Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally capture more frames in computations #7656

Merged
merged 6 commits into from Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 21 additions & 7 deletions distributed/client.py
Expand Up @@ -2956,7 +2956,9 @@
)

@staticmethod
def _get_computation_code(stacklevel: int | None = None) -> str:
def _get_computation_code(
stacklevel: int | None = None, nframes: int = 1
) -> tuple[str, ...]:
"""Walk up the stack to the user code and extract the code surrounding
the compute/submit/persist call. All modules encountered which are
ignored through the option
Expand All @@ -2967,6 +2969,8 @@
``stacklevel`` may be used to explicitly indicate from which frame on
the stack to get the source code.
"""
if nframes <= 0:
return ()
ignore_modules = dask.config.get(
"distributed.diagnostics.computations.ignore-modules"
)
Expand All @@ -2985,7 +2989,10 @@
# stacklevel 0 or less - shows dask internals which likely isn't helpful
stacklevel = stacklevel if stacklevel > 0 else 1

code: list[str] = []
for i, (fr, _) in enumerate(traceback.walk_stack(sys._getframe().f_back), 1):
if len(code) >= nframes:
break
if stacklevel is not None:
if i != stacklevel:
continue
Expand All @@ -2995,7 +3002,7 @@
):
continue
try:
return inspect.getsource(fr)
code.append(inspect.getsource(fr))
except OSError:
# Try to fine the source if we are in %%time or %%timeit magic.
if (
Expand All @@ -3007,9 +3014,10 @@
ip = get_ipython()
if ip is not None:
# The current cell
return ip.history_manager._i00
code.append(ip.history_manager._i00)

Check warning on line 3017 in distributed/client.py

View check run for this annotation

Codecov / codecov/patch

distributed/client.py#L3017

Added line #L3017 was not covered by tests
break
return "<Code not available>"

return tuple(reversed(code))

def _graph_to_futures(
self,
Expand Down Expand Up @@ -3067,7 +3075,11 @@
"submitting_task": getattr(thread_state, "key", None),
"fifo_timeout": fifo_timeout,
"actors": actors,
"code": self._get_computation_code(),
"code": self._get_computation_code(
nframes=dask.config.get(
"distributed.diagnostics.computations.nframes"
)
),
}
)
return futures
Expand Down Expand Up @@ -5539,7 +5551,8 @@
async def __aexit__(self, exc_type, exc_value, traceback, code=None):
client = get_client()
if code is None:
code = client._get_computation_code(self._stacklevel + 1)
frames = client._get_computation_code(self._stacklevel + 1, nframes=1)
code = frames[0] if frames else "<Code not available>"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Magic string here doesn't seem ideal, what about using None when we can't get code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inserted into the performance report; _get_computation_code would have returned "<Code not available>" before this PR, so I wanted to just keep the same behavior.

data = await client.scheduler.performance_report(
start=self.start, last_count=self.last_count, code=code, mode=self.mode
)
Expand All @@ -5551,7 +5564,8 @@

def __exit__(self, exc_type, exc_value, traceback):
client = get_client()
code = client._get_computation_code(self._stacklevel + 1)
frames = client._get_computation_code(self._stacklevel + 1, nframes=1)
code = frames[0] if frames else "<Code not available>"

Check warning on line 5568 in distributed/client.py

View check run for this annotation

Codecov / codecov/patch

distributed/client.py#L5567-L5568

Added lines #L5567 - L5568 were not covered by tests
client.sync(self.__aexit__, exc_type, exc_value, traceback, code=code)


Expand Down
5 changes: 5 additions & 0 deletions distributed/distributed-schema.yaml
Expand Up @@ -1021,6 +1021,11 @@ properties:
minimum: 0
description: |
The maximum number of Computations to remember.
nframes:
type: integer
minimum: 0
description: |
The number of frames of code to capture, starting from the innermost frame.
ignore-modules:
type: array
description: |
Expand Down
1 change: 1 addition & 0 deletions distributed/distributed.yaml
Expand Up @@ -270,6 +270,7 @@ distributed:
nvml: True
computations:
max-history: 100
nframes: 0
ignore-modules:
- distributed
- dask
Expand Down
2 changes: 1 addition & 1 deletion distributed/scheduler.py
Expand Up @@ -4350,7 +4350,7 @@ def update_graph(
computation = Computation()
self.computations.append(computation)

if code and code not in computation.code: # add new code blocks
if code: # add new code blocks
computation.code.add(code)

n = 0
Expand Down
95 changes: 63 additions & 32 deletions distributed/tests/test_client.py
Expand Up @@ -7068,12 +7068,15 @@ def test_computation_code_walk_frames():
test_function_code = inspect.getsource(test_computation_code_walk_frames)
code = Client._get_computation_code()

assert test_function_code == code
assert code == (test_function_code,)

def nested_call():
return Client._get_computation_code()
return Client._get_computation_code(nframes=2)

assert nested_call() == inspect.getsource(nested_call)
nested = nested_call()
assert len(nested) == 2
assert nested[-1] == inspect.getsource(nested_call)
assert nested[-2] == test_function_code

with pytest.raises(TypeError, match="Ignored modules must be a list"):
with dask.config.set(
Expand All @@ -7088,15 +7091,15 @@ def nested_call():

upper_frame_code = inspect.getsource(sys._getframe(1))
code = Client._get_computation_code()
assert code == upper_frame_code
assert nested_call() == upper_frame_code
assert code == (upper_frame_code,)
assert nested_call()[-1] == upper_frame_code


def test_computation_object_code_dask_compute(client):
da = pytest.importorskip("dask.array")
x = da.ones((10, 10), chunks=(3, 3))
future = x.sum().compute()
y = future
with dask.config.set({"distributed.diagnostics.computations.nframes": 2}):
x = da.ones((10, 10), chunks=(3, 3))
x.sum().compute()

test_function_code = inspect.getsource(test_computation_object_code_dask_compute)

Expand All @@ -7109,29 +7112,44 @@ def fetch_comp_code(dask_scheduler):

code = client.run_on_scheduler(fetch_comp_code)

assert code == test_function_code
assert len(code) == 2
assert code[-1] == test_function_code
assert code[-2] == inspect.getsource(sys._getframe(1))


def test_computation_object_code_dask_compute_no_frames_default(client):
da = pytest.importorskip("dask.array")
x = da.ones((10, 10), chunks=(3, 3))
x.sum().compute()

def fetch_comp_code(dask_scheduler):
computations = list(dask_scheduler.computations)
assert len(computations) == 1
comp = computations[0]
assert not comp.code

client.run_on_scheduler(fetch_comp_code)


def test_computation_object_code_not_available(client):
np = pytest.importorskip("numpy")
pd = pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
df = pd.DataFrame({"a": range(10)})
ddf = dd.from_pandas(df, npartitions=3)
result = np.where(ddf.a > 4)
with dask.config.set({"distributed.diagnostics.computations.nframes": 2}):
df = pd.DataFrame({"a": range(10)})
ddf = dd.from_pandas(df, npartitions=3)
result = np.where(ddf.a > 4)

def fetch_comp_code(dask_scheduler):
computations = list(dask_scheduler.computations)
assert len(computations) == 1
comp = computations[0]
assert len(comp.code) == 1
return comp.code[0]
assert not comp.code

code = client.run_on_scheduler(fetch_comp_code)
assert code == "<Code not available>"
client.run_on_scheduler(fetch_comp_code)


@gen_cluster(client=True)
@gen_cluster(client=True, config={"distributed.diagnostics.computations.nframes": 2})
async def test_computation_object_code_dask_persist(c, s, a, b):
da = pytest.importorskip("dask.array")
x = da.ones((10, 10), chunks=(3, 3))
Expand All @@ -7146,10 +7164,12 @@ async def test_computation_object_code_dask_persist(c, s, a, b):
comp = computations[0]
assert len(comp.code) == 1

assert comp.code[0] == test_function_code
assert len(comp.code[0]) == 2
assert comp.code[0][-1] == test_function_code
assert comp.code[0][-2] == inspect.getsource(sys._getframe(1))


@gen_cluster(client=True)
@gen_cluster(client=True, config={"distributed.diagnostics.computations.nframes": 2})
async def test_computation_object_code_client_submit_simple(c, s, a, b):
def func(x):
return x
Expand All @@ -7167,10 +7187,12 @@ def func(x):

assert len(comp.code) == 1

assert comp.code[0] == test_function_code
assert len(comp.code[0]) == 2
assert comp.code[0][-1] == test_function_code
assert comp.code[0][-2] == inspect.getsource(sys._getframe(1))


@gen_cluster(client=True)
@gen_cluster(client=True, config={"distributed.diagnostics.computations.nframes": 2})
async def test_computation_object_code_client_submit_list_comp(c, s, a, b):
def func(x):
return x
Expand All @@ -7189,10 +7211,12 @@ def func(x):
# Code is deduplicated
assert len(comp.code) == 1

assert comp.code[0] == test_function_code
assert len(comp.code[0]) == 2
assert comp.code[0][-1] == test_function_code
assert comp.code[0][-2] == inspect.getsource(sys._getframe(1))


@gen_cluster(client=True)
@gen_cluster(client=True, config={"distributed.diagnostics.computations.nframes": 2})
async def test_computation_object_code_client_submit_dict_comp(c, s, a, b):
def func(x):
return x
Expand All @@ -7211,15 +7235,18 @@ def func(x):
# Code is deduplicated
assert len(comp.code) == 1

assert comp.code[0] == test_function_code
assert len(comp.code[0]) == 2
assert comp.code[0][-1] == test_function_code
assert comp.code[0][-2] == inspect.getsource(sys._getframe(1))


@gen_cluster(client=True)
@gen_cluster(client=True, config={"distributed.diagnostics.computations.nframes": 2})
async def test_computation_object_code_client_map(c, s, a, b):
da = pytest.importorskip("dask.array")
x = da.ones((10, 10), chunks=(3, 3))
future = c.compute(x.sum(), retries=2)
y = await future
def func(x):
return x

futs = c.map(func, list(range(5)))
await c.gather(futs)

test_function_code = inspect.getsource(
test_computation_object_code_client_map.__wrapped__
Expand All @@ -7229,10 +7256,12 @@ async def test_computation_object_code_client_map(c, s, a, b):
comp = computations[0]
assert len(comp.code) == 1

assert comp.code[0] == test_function_code
assert len(comp.code[0]) == 2
assert comp.code[0][-1] == test_function_code
assert comp.code[0][-2] == inspect.getsource(sys._getframe(1))


@gen_cluster(client=True)
@gen_cluster(client=True, config={"distributed.diagnostics.computations.nframes": 2})
async def test_computation_object_code_client_compute(c, s, a, b):
da = pytest.importorskip("dask.array")
x = da.ones((10, 10), chunks=(3, 3))
Expand All @@ -7247,7 +7276,9 @@ async def test_computation_object_code_client_compute(c, s, a, b):
comp = computations[0]
assert len(comp.code) == 1

assert comp.code[0] == test_function_code
assert len(comp.code[0]) == 2
assert comp.code[0][-1] == test_function_code
assert comp.code[0][-2] == inspect.getsource(sys._getframe(1))


@pytest.mark.slow
Expand Down
4 changes: 2 additions & 2 deletions distributed/widgets/templates/computation.html.j2
Expand Up @@ -26,9 +26,9 @@

<details>
<summary style="margin-bottom": 20px><h4 style="display:inline">Code</h4></summary>
{% for segment in code %}
{% for frames in code if frames %}
<h5>Code segment {{ loop.index }} / {{ code | length }}</h5>
<pre><code>{{ segment }}</code></pre>
<pre><code>{{ frames[-1] }}</code></pre>
{% endfor %}
</details>

Expand Down