Skip to content

Commit

Permalink
Track and delete gathered data
Browse files Browse the repository at this point in the history
We are now more careful to record when workers copy data from each other
and to delete that data when appropriate.  Previously we could end up in
situations where data was copied, and then forgotten so that it filled
up memory needlessly.
  • Loading branch information
mrocklin committed Aug 14, 2017
1 parent 8e9eaf8 commit 4f1db8b
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 39 deletions.
6 changes: 4 additions & 2 deletions distributed/scheduler.py
Expand Up @@ -1935,8 +1935,10 @@ def add_keys(self, comm=None, worker=None, keys=()):
DEFAULT_DATA_SIZE)
self.has_what[worker].add(key)
self.who_has[key].add(worker)
# else:
# TODO: delete key from worker
else:
self.worker_comms[worker].send({'op': 'delete-data',
'keys': [key],
'report': False})
return 'OK'

def update_data(self, comm=None, who_has=None, nbytes=None, client=None):
Expand Down
88 changes: 61 additions & 27 deletions distributed/tests/test_steal.py
Expand Up @@ -6,21 +6,21 @@
import random
import sys
from time import sleep
import weakref

import pytest
from toolz import sliding_window, concat
from tornado import gen

import dask
from dask import delayed
from distributed import Worker, Nanny, worker_client
from distributed import Worker, Nanny, worker_client, Client, wait
from distributed.config import config
from distributed.client import Client, _wait, wait
from distributed.metrics import time
from distributed.scheduler import BANDWIDTH, key_split
from distributed.utils_test import (cluster, slowinc, slowadd, randominc,
loop, inc, dec, div, throws, gen_cluster, gen_test, double, deep,
slowidentity)
slowidentity, slowdouble)

import pytest

Expand All @@ -33,7 +33,7 @@ def test_work_stealing(c, s, a, b):
[x] = yield c._scatter([1], workers=a.address)
futures = c.map(slowadd, range(50), [x] * 50)
yield gen.sleep(0.1)
yield _wait(futures)
yield wait(futures)
assert len(a.data) > 10
assert len(b.data) > 10
assert len(a.data) > len(b.data) - 5
Expand All @@ -43,25 +43,26 @@ def test_work_stealing(c, s, a, b):
def test_dont_steal_expensive_data_fast_computation(c, s, a, b):
np = pytest.importorskip('numpy')
x = c.submit(np.arange, 1000000, workers=a.address)
yield _wait([x])
yield wait([x])
future = c.submit(np.sum, [1], workers=a.address) # learn that sum is fast
yield _wait([future])
yield wait([future])

cheap = [c.submit(np.sum, x, pure=False, workers=a.address,
allow_other_workers=True) for i in range(10)]
yield _wait(cheap)
yield wait(cheap)
assert len(s.who_has[x.key]) == 1
assert len(b.data) == 0
assert len(a.data) == 12


@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2)
def test_steal_cheap_data_slow_computation(c, s, a, b):
x = c.submit(slowinc, 100, delay=0.1) # learn that slowinc is slow
yield _wait([x])
yield wait(x)

futures = c.map(slowinc, range(10), delay=0.1, workers=a.address,
allow_other_workers=True)
yield _wait(futures)
yield wait(futures)
assert abs(len(a.data) - len(b.data)) <= 5


Expand All @@ -78,6 +79,7 @@ def test_steal_expensive_data_slow_computation(c, s, a, b):

slow = [c.submit(slowinc, x, delay=0.1, pure=False) for i in range(20)]
yield wait(slow)
assert len(s.who_has[x.key]) > 1

assert b.data # not empty

Expand All @@ -94,13 +96,14 @@ def test_worksteal_many_thieves(c, s, *workers):
for w, keys in s.has_what.items():
assert 2 < len(keys) < 30

assert len(s.who_has[x.key]) > 1
assert sum(map(len, s.has_what.values())) < 150


@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2)
def test_dont_steal_unknown_functions(c, s, a, b):
futures = c.map(inc, [1, 2], workers=a.address, allow_other_workers=True)
yield _wait(futures)
yield wait(futures)
assert len(a.data) == 2
assert len(b.data) == 0

Expand All @@ -109,7 +112,7 @@ def test_dont_steal_unknown_functions(c, s, a, b):
def test_eventually_steal_unknown_functions(c, s, a, b):
futures = c.map(slowinc, range(10), delay=0.1, workers=a.address,
allow_other_workers=True)
yield _wait(futures)
yield wait(futures)
assert len(a.data) >= 3
assert len(b.data) >= 3

Expand All @@ -120,7 +123,7 @@ def test_steal_related_tasks(e, s, a, b, c):
futures = e.map(slowinc, range(20), delay=0.05, workers=a.address,
allow_other_workers=True)

yield _wait(futures)
yield wait(futures)

nearby = 0
for f1, f2 in sliding_window(2, futures):
Expand All @@ -138,18 +141,19 @@ def test_dont_steal_fast_tasks(c, s, *workers):
def do_nothing(x, y=None):
pass

yield _wait(c.submit(do_nothing, 1))
yield wait(c.submit(do_nothing, 1))

futures = c.map(do_nothing, range(1000), y=x)

yield _wait(futures)
yield wait(futures)

assert len(s.who_has[x.key]) == 1
assert len(s.has_what[workers[0].address]) == 1001


@gen_cluster(client=True, ncores=[('127.0.0.1', 1)], timeout=20)
def test_new_worker_steals(c, s, a):
yield _wait(c.submit(slowinc, 1, delay=0.01))
yield wait(c.submit(slowinc, 1, delay=0.01))

futures = c.map(slowinc, range(100), delay=0.05)
total = c.submit(sum, futures)
Expand All @@ -172,12 +176,12 @@ def test_new_worker_steals(c, s, a):

@gen_cluster(client=True, timeout=20)
def test_work_steal_no_kwargs(c, s, a, b):
yield _wait(c.submit(slowinc, 1, delay=0.05))
yield wait(c.submit(slowinc, 1, delay=0.05))

futures = c.map(slowinc, range(100), workers=a.address,
allow_other_workers=True, delay=0.05)

yield _wait(futures)
yield wait(futures)

assert 20 < len(a.data) < 80
assert 20 < len(b.data) < 80
Expand Down Expand Up @@ -280,7 +284,7 @@ def slow(x):
sleep(y)
return y
futures = c.map(slow, range(100))
yield _wait(futures)
yield wait(futures)

durations = [sum(w.data.values()) for w in workers]
assert max(durations) / min(durations) < 3
Expand All @@ -291,7 +295,7 @@ def test_dont_steal_executing_tasks(c, s, a, b):
futures = c.map(slowinc, range(4), delay=0.1, workers=a.address,
allow_other_workers=True)

yield _wait(futures)
yield wait(futures)
assert len(a.data) == 4
assert len(b.data) == 0

Expand All @@ -300,23 +304,22 @@ def test_dont_steal_executing_tasks(c, s, a, b):
def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest):
s.extensions['stealing']._pc.callback_time = 20
x = c.submit(mul, b'0', 100000000, workers=a.address) # 100 MB
yield _wait(x)
yield wait(x)
s.task_duration['slowidentity'] = 0.2

futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(2)]

yield _wait(futures)
yield wait(futures)

assert len(a.data) == 3
assert not any(w.task_state for w in rest)


@pytest.mark.skip(reason='leaks large amount of memory')
@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10)
def test_steal_when_more_tasks(c, s, a, *rest):
s.extensions['stealing']._pc.callback_time = 20
x = c.submit(mul, b'0', 100000000, workers=a.address) # 100 MB
yield _wait(x)
yield wait(x)
s.task_duration['slowidentity'] = 0.2

futures = [c.submit(slowidentity, x, pure=False, delay=0.2)
Expand All @@ -334,7 +337,7 @@ def slow2(x):
return x
s.extensions['stealing']._pc.callback_time = 20
x = c.submit(mul, b'0', 100000000, workers=a.address) # 100 MB
yield _wait(x)
yield wait(x)
s.task_duration['slowidentity'] = 0.2
s.task_duration['slow2'] = 1

Expand Down Expand Up @@ -497,7 +500,7 @@ def test_steal_communication_heavy_tasks(c, s, a, b):
@gen_cluster(client=True)
def test_steal_twice(c, s, a, b):
x = c.submit(inc, 1, workers=a.address)
yield _wait(x)
yield wait(x)

futures = [c.submit(slowadd, x, i, delay=0.2) for i in range(100)]

Expand All @@ -507,7 +510,7 @@ def test_steal_twice(c, s, a, b):
workers = [Worker(s.ip, s.port, loop=s.loop) for _ in range(30)]
yield [w._start() for w in workers] # army of new workers arrives to help

yield _wait(futures)
yield wait(futures)

assert all(s.has_what.values())
assert max(map(len, s.has_what.values())) < 20
Expand Down Expand Up @@ -561,4 +564,35 @@ def long(delay):

assert sum(1 for k in s.processing[b.address] if k.startswith('long')) <= nb

yield _wait(long_tasks)
yield wait(long_tasks)


@gen_cluster(client=True, ncores=[('127.0.0.1', 5)] * 2)
def test_cleanup_repeated_tasks(c, s, a, b):
class Foo(object):
pass

s.extensions['stealing']._pc.callback_time = 20
yield c.submit(slowidentity, -1, delay=0.1)
objects = [c.submit(Foo, pure=False, workers=a.address) for _ in range(50)]

x = c.map(slowidentity, objects, workers=a.address, allow_other_workers=True,
delay=0.05)
del objects
yield wait(x)
assert a.data and b.data
assert len(a.data) + len(b.data) > 10
ws = weakref.WeakSet()
ws.update(a.data.values())
ws.update(b.data.values())
del x

start = time()
while a.data or b.data:
yield gen.sleep(0.01)
assert time() < start + 1

assert not s.who_has
assert not any(s.has_what.values())

assert not list(ws)
8 changes: 3 additions & 5 deletions distributed/tests/test_utils_comm.py
Expand Up @@ -15,11 +15,9 @@ def test_pack_data():
assert pack_data({'a': ['x'], 'b': 'y'}, data) == {'a': [1], 'b': 'y'}


@gen_cluster()
def test_gather_from_workers_permissive(s, a, b):
while not a.batched_stream:
yield gen.sleep(0.01)
a.update_data(data={'x': 1})
@gen_cluster(client=True)
def test_gather_from_workers_permissive(c, s, a, b):
x = yield c.scatter({'x': 1})

data, missing, bad_workers = yield gather_from_workers(
{'x': [a.address], 'y': [b.address]}, rpc=rpc)
Expand Down
10 changes: 9 additions & 1 deletion distributed/utils_test.py
Expand Up @@ -166,6 +166,11 @@ def slowdec(x, delay=0.02):
return x - 1


def slowdouble(x, delay=0.02):
sleep(delay)
return 2 * x


def randominc(x, scale=1):
from random import random
sleep(random() * scale)
Expand All @@ -185,7 +190,10 @@ def slowsum(seq, delay=0.02):
def slowidentity(*args, **kwargs):
delay = kwargs.get('delay', 0.02)
sleep(delay)
return args
if len(args) == 1:
return args[0]
else:
return args


@gen.coroutine
Expand Down
14 changes: 10 additions & 4 deletions distributed/worker.py
Expand Up @@ -1088,7 +1088,7 @@ def on_closed():
elif op == 'compute-task':
self.add_task(**msg)
elif op == 'release-task':
self.log.append((msg['key'], 'release-task'))
self.log.append((msg['key'], 'release-task', msg.get('reason')))
self.release_key(report=False, **msg)
elif op == 'delete-data':
self.delete_data(**msg)
Expand Down Expand Up @@ -1278,8 +1278,14 @@ def transition_dep_flight_memory(self, dep, value=None):
assert dep in self.in_flight_tasks

del self.in_flight_tasks[dep]
self.dep_state[dep] = 'memory'
self.put_key_in_memory(dep, value)
if self.dependents[dep]:
self.dep_state[dep] = 'memory'
self.put_key_in_memory(dep, value)
self.batched_stream.send({'op': 'add-keys',
'keys': [dep]})
else:
self.release_dep(dep)

except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down Expand Up @@ -1788,7 +1794,7 @@ def release_key(self, key, cause=None, reason=None, report=True):

for dep in self.dependencies.pop(key, ()):
self.dependents[dep].remove(key)
if not self.dependents[dep] and self.dep_state[dep] == 'waiting':
if not self.dependents[dep] and self.dep_state[dep] in ('waiting', 'flight'):
self.release_dep(dep)

if key in self.threads:
Expand Down

0 comments on commit 4f1db8b

Please sign in to comment.