Skip to content

Working example with Keras #2333

@dzubo

Description

@dzubo

I have issues running Keras models with Dask when using multiple workers.

Is there any minimal working example?

I try this code:

import numpy as np

import keras
from keras.layers import Input, Dense
from keras.models import Model

import dask
from dask import compute, delayed
from dask.distributed import Client
from distributed.protocol import serialize, deserialize

@delayed
def get_model(id):
    inputs = Input(shape=(10, ))
    x = Dense(20)(inputs)
    predictions = Dense(1, activation='linear')(x)

    model = Model(inputs=inputs, outputs=predictions)
    model.compile(optimizer='RMSProp', loss='mean_absolute_error')
    return model

client = Client()

params = [{'id': 1}, {'id': 2}]

for p in params:
    p['model'] = get_model(p['id'])

print(params)
results = compute(params)
print(results)

gives error message:

Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
Using TensorFlow backend.
^CTraceback (most recent call last):
  File "dask-keras.py", line 35, in <module>
    client = Client()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 628, in __init__
    self.start(timeout=timeout)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 751, in start
    sync(self.loop, self._start, **kwargs)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/utils.py", line 275, in sync
    e.wait(10)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/threading.py", line 551, in wait
    signaled = self._cond.wait(timeout)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/threading.py", line 299, in wait
    gotit = waiter.acquire(True, timeout)
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-11, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-2, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-4, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-10, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-12, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-1, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-9, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-5, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-6, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-8, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-3, started daemon)>
distributed.process - WARNING - reaping stray process <ForkServerProcess(ForkServerProcess-7, started daemon)>
distributed.nanny - WARNING - Worker process 19972 was killed by unknown signal
distributed.nanny - WARNING - Worker process 19975 was killed by unknown signal
distributed.nanny - WARNING - Worker process 19971 was killed by unknown signal
distributed.nanny - WARNING - Worker process 19980 was killed by unknown signal
distributed.nanny - WARNING - Worker process 19977 was killed by unknown signal
distributed.nanny - WARNING - Worker process 19979 was killed by unknown signal
distributed.nanny - WARNING - Worker process 19970 was killed by unknown signal
...

If I run scheduler in the command line:

dask-scheduler
dask-worker --memory-limit 10GB --nprocs 1 --nthreads 6 --name local <scheduler-ip>:8786

and replace in the code

client = Client('<scheduler-ip>:8786')

then I get this:

Using TensorFlow backend.
[{'id': 1, 'model': Delayed('get_model-de411eed-9c7d-49d0-af7d-95da92d39d5d')}, {'id': 2, 'model': Delayed('get_model-452c16b3-85cc-4472-a1e3-35c215a32647')}]
distributed.protocol.core - CRITICAL - Failed to deserialize
Traceback (most recent call last):
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/core.py", line 131, in loads
    value = _deserialize(head, fs, deserializers=deserializers)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 179, in deserialize
    return loads(header, frames)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 75, in serialization_error_loads
    raise TypeError(msg)
TypeError: Could not serialize object of type Model.
Traceback (most recent call last):
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 139, in serialize
    header, frames = dumps(x, context=context) if wants_context else dumps(x)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 38, in dask_dumps
    header, frames = dumps(x)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/keras.py", line 22, in serialize_keras_model
    weights = model.get_weights()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/keras/engine/network.py", line 492, in get_weights
    return K.batch_get_value(weights)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2420, in batch_get_value
    return get_session().run(ops)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 877, in run
    run_metadata_ptr)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1025, in _run
    raise RuntimeError('The Session graph is empty.  Add operations to the '
RuntimeError: The Session graph is empty.  Add operations to the graph before calling run().

Traceback (most recent call last):
  File "dask-keras.py", line 47, in <module>
    results = compute(params)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/dask/base.py", line 392, in compute
    results = schedule(dsk, keys, **kwargs)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 2308, in get
    direct=direct)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 1647, in gather
    asynchronous=asynchronous)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 665, in sync
    return sync(self.loop, func, *args, **kwargs)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/utils.py", line 277, in sync
    six.reraise(*error[0])
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/six.py", line 693, in reraise
    raise value
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/utils.py", line 262, in f
    result[0] = yield future
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1055, in run
    value = future.result()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/concurrent.py", line 238, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 4, in raise_exc_info
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1063, in run
    yielded = self.gen.throw(*exc_info)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 1518, in _gather
    response = yield future
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1055, in run
    value = future.result()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/concurrent.py", line 238, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 4, in raise_exc_info
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1063, in run
    yielded = self.gen.throw(*exc_info)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/client.py", line 1567, in _gather_remote
    response = yield self.scheduler.gather(keys=keys)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1055, in run
    value = future.result()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/concurrent.py", line 238, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 4, in raise_exc_info
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1063, in run
    yielded = self.gen.throw(*exc_info)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/core.py", line 574, in send_recv_from_rpc
    result = yield send_recv(comm=comm, op=key, **kwargs)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1055, in run
    value = future.result()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/concurrent.py", line 238, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 4, in raise_exc_info
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1063, in run
    yielded = self.gen.throw(*exc_info)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/core.py", line 451, in send_recv
    response = yield comm.read(deserializers=deserializers)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1055, in run
    value = future.result()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/concurrent.py", line 238, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 4, in raise_exc_info
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1063, in run
    yielded = self.gen.throw(*exc_info)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/comm/tcp.py", line 203, in read
    deserializers=deserializers)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 1055, in run
    value = future.result()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/concurrent.py", line 238, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 4, in raise_exc_info
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tornado/gen.py", line 307, in wrapper
    yielded = next(result)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/comm/utils.py", line 79, in from_frames
    res = _from_frames()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/comm/utils.py", line 65, in _from_frames
    deserializers=deserializers)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/core.py", line 131, in loads
    value = _deserialize(head, fs, deserializers=deserializers)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 179, in deserialize
    return loads(header, frames)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 75, in serialization_error_loads
    raise TypeError(msg)
TypeError: Could not serialize object of type Model.
Traceback (most recent call last):
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 139, in serialize
    header, frames = dumps(x, context=context) if wants_context else dumps(x)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 38, in dask_dumps
    header, frames = dumps(x)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/distributed/protocol/keras.py", line 22, in serialize_keras_model
    weights = model.get_weights()
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/keras/engine/network.py", line 492, in get_weights
    return K.batch_get_value(weights)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2420, in batch_get_value
    return get_session().run(ops)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 877, in run
    run_metadata_ptr)
  File "/Users/denis.zubo/miniconda3/envs/pai/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1025, in _run
    raise RuntimeError('The Session graph is empty.  Add operations to the '
RuntimeError: The Session graph is empty.  Add operations to the graph before calling run().

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprove or add to documentation

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions