New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use case with deep-learning frameworks: prediction in parallel #281

Closed
lesteve opened this Issue Jul 4, 2018 · 19 comments

Comments

Projects
None yet
2 participants
@lesteve
Collaborator

lesteve commented Jul 4, 2018

Here is the use case that some people around me are trying to tackle with dask: object detection on many videos in parallel.

  • we have a pretrained model (using TensorFlow or PyTorch typically) that does object detection
  • we have plenty of videos we want to run your model on, and we want to parallelize over the videos
  • we are using dask-jobqueue on a SGE cluster

Here are the problems/questions we have so far:

  • is there a way to make sure that the model is created on each worker, rather than created on one worker and then serialized and sent across the network (a lot slower 2-3 minutes vs 20s in our early attempts) ? Our attempts look like this:
# this creates the model on one worker and then serializes it and sends it across the network
model = client.submit(create_model)
client.map(run_object_detection, [model] * n_videos, videos)
# this adds a dummy argument to create_model, but that does not seem very clean
models = client.map(create_model, range(4))
client.map(run_object_detection, models, videos)
  • in the log files from SGE (i.e. dask-worker.o<job_id> files), we get Event loop was unresponsive in Worker for 41.37s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.. Should we worry about that? Googling about this I found this SO answer and this github comment both by @mrocklin. IIUC I should not worry about this too much until this starts becoming a problem.
  • we have some print statement in our functions (create_model, run_object_detection) but they did not seem to appear anywhere in the SGE logs. Not sure what the advice is to do lightweight debugging like this.
@mrocklin

This comment has been minimized.

Show comment
Hide comment
@mrocklin

mrocklin Jul 4, 2018

Member

For the first problem, you might first try the following:

model = client.submit(create_model)
dask.distributed.wait(model)
client.replicate(model)
client.map(run_object_detection, [model] * n_videos, videos)

This still moves the model around, but avoids all of your workers hammering the one that has it right away. It might improve things (this is also fixed in dask/distributed#2092)

I'm curious, how large is your model? Can you investigate how long it takes to serialize, and how large the result is?

%time len(pickle.dumps(model))

41 seconds is a long time to wait for something. During this time the worker is unable to respond to requests from other workers. Something is holding onto the GIL. It might be interesting to find out what.

I would expect print statements to show up in SGE's logs. An alternative would be to use distributed.worker.logger, whose outputs will also be available in the dashboard in the "info" pages (look for the the "logs" links on the right of the main info page)

Member

mrocklin commented Jul 4, 2018

For the first problem, you might first try the following:

model = client.submit(create_model)
dask.distributed.wait(model)
client.replicate(model)
client.map(run_object_detection, [model] * n_videos, videos)

This still moves the model around, but avoids all of your workers hammering the one that has it right away. It might improve things (this is also fixed in dask/distributed#2092)

I'm curious, how large is your model? Can you investigate how long it takes to serialize, and how large the result is?

%time len(pickle.dumps(model))

41 seconds is a long time to wait for something. During this time the worker is unable to respond to requests from other workers. Something is holding onto the GIL. It might be interesting to find out what.

I would expect print statements to show up in SGE's logs. An alternative would be to use distributed.worker.logger, whose outputs will also be available in the dashboard in the "info" pages (look for the the "logs" links on the right of the main info page)

@lesteve

This comment has been minimized.

Show comment
Hide comment
@lesteve

lesteve Jul 4, 2018

Collaborator

With the wait + replicate approach I reproduce similar timings to the ones I mentioned originally, i.e. 30s to create the model on a worker and ~3 minutes to serialize it (I have 4 workers total in this example and I naively assume that the serialization happens serially, which may not be the case)

%%time
model = client.submit(my_torch_create_model)
from dask.distributed import wait
wait(model)
CPU times: user 467 ms, sys: 79 ms, total: 546 ms
Wall time: 28.3 s
%%time
client.replicate(model)
CPU times: user 16.4 s, sys: 1.87 s, total: 18.3 s
Wall time: 9min 50s

If I trust the dashboard, it looks like creating the model takes about 2GB.

Another worry that we have is that we want to make sure that the model is on the GPU (with PyTorch model we can call .cuda() explictly to make sure of that) but we were not quite sure what happens when the model is serialized, sent to another worker and then unserialized (quick testing seems to indicate that the model is on the GPU).

Collaborator

lesteve commented Jul 4, 2018

With the wait + replicate approach I reproduce similar timings to the ones I mentioned originally, i.e. 30s to create the model on a worker and ~3 minutes to serialize it (I have 4 workers total in this example and I naively assume that the serialization happens serially, which may not be the case)

%%time
model = client.submit(my_torch_create_model)
from dask.distributed import wait
wait(model)
CPU times: user 467 ms, sys: 79 ms, total: 546 ms
Wall time: 28.3 s
%%time
client.replicate(model)
CPU times: user 16.4 s, sys: 1.87 s, total: 18.3 s
Wall time: 9min 50s

If I trust the dashboard, it looks like creating the model takes about 2GB.

Another worry that we have is that we want to make sure that the model is on the GPU (with PyTorch model we can call .cuda() explictly to make sure of that) but we were not quite sure what happens when the model is serialized, sent to another worker and then unserialized (quick testing seems to indicate that the model is on the GPU).

@mrocklin

This comment has been minimized.

Show comment
Hide comment
@mrocklin

mrocklin Jul 4, 2018

Member

Can I ask for %time len(pickle.dumps(model)) ?

I wonder what is taking so long. Is it serializing? Is it a really big blob to move around? It might also be interesting to look at profile output with something like the %prun magic or snakeviz.

Member

mrocklin commented Jul 4, 2018

Can I ask for %time len(pickle.dumps(model)) ?

I wonder what is taking so long. Is it serializing? Is it a really big blob to move around? It might also be interesting to look at profile output with something like the %prun magic or snakeviz.

@lesteve

This comment has been minimized.

Show comment
Hide comment
@lesteve

lesteve Jul 4, 2018

Collaborator

Not entirely sure but I think it is serializing when the model is on the GPU. It feels like there is something to do by combining http://distributed.readthedocs.io/en/latest/serialization.html#dask-serialization-family and https://pytorch.org/docs/stable/notes/serialization.html?highlight=saving%20models.

A simple snippet that reproduces a similar behaviour:

from torchvision.models.resnet import resnet18
model = resnet18(pretrained=True)
model.cuda()
%time len(pickle.dumps(model)) / 1e6
Out[7]: 105.332536
CPU times: user 1min 18s, sys: 19 s, total: 1min 37s
Wall time: 1min 37s
Collaborator

lesteve commented Jul 4, 2018

Not entirely sure but I think it is serializing when the model is on the GPU. It feels like there is something to do by combining http://distributed.readthedocs.io/en/latest/serialization.html#dask-serialization-family and https://pytorch.org/docs/stable/notes/serialization.html?highlight=saving%20models.

A simple snippet that reproduces a similar behaviour:

from torchvision.models.resnet import resnet18
model = resnet18(pretrained=True)
model.cuda()
%time len(pickle.dumps(model)) / 1e6
Out[7]: 105.332536
CPU times: user 1min 18s, sys: 19 s, total: 1min 37s
Wall time: 1min 37s
@mrocklin

This comment has been minimized.

Show comment
Hide comment
@mrocklin

mrocklin Jul 4, 2018

Member
Member

mrocklin commented Jul 4, 2018

@lesteve

This comment has been minimized.

Show comment
Hide comment
@lesteve

lesteve Jul 4, 2018

Collaborator

I'm inclined to first try to solve the problem there. We're serializing at 1MB/s. No part of the pipeline needs to be that slow.

Not entirely sure what you mean by "there", do you mean it should be fixed inside PyTorch? Do you mean it should be fixed through dask serialization families?

Collaborator

lesteve commented Jul 4, 2018

I'm inclined to first try to solve the problem there. We're serializing at 1MB/s. No part of the pipeline needs to be that slow.

Not entirely sure what you mean by "there", do you mean it should be fixed inside PyTorch? Do you mean it should be fixed through dask serialization families?

@mrocklin

This comment has been minimized.

Show comment
Hide comment
@mrocklin

mrocklin Jul 4, 2018

Member
Member

mrocklin commented Jul 4, 2018

@mrocklin

This comment has been minimized.

Show comment
Hide comment
@mrocklin

mrocklin Jul 4, 2018

Member

See pytorch/pytorch#9168

Short term the right thing to do is probably to special-case PyTorch with the class-based dask serialization family: http://distributed.readthedocs.io/en/latest/serialization.html#id2

This comment has sufficient information for both PyTorch models and tensors: pytorch/pytorch#9168 (comment)

Member

mrocklin commented Jul 4, 2018

See pytorch/pytorch#9168

Short term the right thing to do is probably to special-case PyTorch with the class-based dask serialization family: http://distributed.readthedocs.io/en/latest/serialization.html#id2

This comment has sufficient information for both PyTorch models and tensors: pytorch/pytorch#9168 (comment)

@mrocklin

This comment has been minimized.

Show comment
Hide comment
@mrocklin

mrocklin Jul 4, 2018

Member

A naive implementation is probably as follows:

from distributed.serialize import register_serialization

def serialize(model):
    import torch, io
    bio = io.BytesIO()
    torch.save(obj, bio)
    header = {}
    frames = [bio.getvalue()]
    return header, frames

def deserialize(header, frames):
    import torch, io
    [frame] = frames
    bio = io.BytesIO(frame)
    return torch.load(bio)

from distributed.protocol.serialize import register_serialization
register_serialization('torch.Tensor', serialize, deserialize)
register_serialization('...', serialize, deserialize)
...

In the future it would be nice to avoid the torch.save process and instead just pass buffers around to avoid copies, but it looks like this will get us up to about 1GB/s, which is probably enough for standard cases.

Member

mrocklin commented Jul 4, 2018

A naive implementation is probably as follows:

from distributed.serialize import register_serialization

def serialize(model):
    import torch, io
    bio = io.BytesIO()
    torch.save(obj, bio)
    header = {}
    frames = [bio.getvalue()]
    return header, frames

def deserialize(header, frames):
    import torch, io
    [frame] = frames
    bio = io.BytesIO(frame)
    return torch.load(bio)

from distributed.protocol.serialize import register_serialization
register_serialization('torch.Tensor', serialize, deserialize)
register_serialization('...', serialize, deserialize)
...

In the future it would be nice to avoid the torch.save process and instead just pass buffers around to avoid copies, but it looks like this will get us up to about 1GB/s, which is probably enough for standard cases.

@mrocklin

This comment has been minimized.

Show comment
Hide comment
@mrocklin

mrocklin Jul 4, 2018

Member

It looks like your model object though is a special subclass, which is a bit unpleasant. The dask class based serialization doesn't do any kind of inheritance checking.

In [1]: from torchvision.models.resnet import resnet18
   ...: model = resnet18(pretrained=True)
   ...: type(model).mro()
   ...: 
   ...: 
Out[1]: [torchvision.models.resnet.ResNet, torch.nn.modules.module.Module, object]
Member

mrocklin commented Jul 4, 2018

It looks like your model object though is a special subclass, which is a bit unpleasant. The dask class based serialization doesn't do any kind of inheritance checking.

In [1]: from torchvision.models.resnet import resnet18
   ...: model = resnet18(pretrained=True)
   ...: type(model).mro()
   ...: 
   ...: 
Out[1]: [torchvision.models.resnet.ResNet, torch.nn.modules.module.Module, object]
@lesteve

This comment has been minimized.

Show comment
Hide comment
@lesteve

lesteve Jul 4, 2018

Collaborator

Thanks a lot for the additional info!

So IIUC I can not register a custom serialization for torch.nn.modules.module.Module for example and expect that it works for the resnet18 instance?

I am guessing that it does not magically work on attributes either (a bit of a joblib bias here for sure ...). So I can not register a custom serialization for torch.Tensor and expect to have an effect on the resnet18 instance?

Out of curiosity, do I need to register the custom serialization on all the workers, e.g. by using client.run?

Collaborator

lesteve commented Jul 4, 2018

Thanks a lot for the additional info!

So IIUC I can not register a custom serialization for torch.nn.modules.module.Module for example and expect that it works for the resnet18 instance?

I am guessing that it does not magically work on attributes either (a bit of a joblib bias here for sure ...). So I can not register a custom serialization for torch.Tensor and expect to have an effect on the resnet18 instance?

Out of curiosity, do I need to register the custom serialization on all the workers, e.g. by using client.run?

@mrocklin

This comment has been minimized.

Show comment
Hide comment
@mrocklin

mrocklin Jul 4, 2018

Member

Given the class inheritance issue the dask solution here is probably to create a new serialization family (sorry for the two kinds of solutions here) that checks every object, fails quickly if it's not a torch object, but if it is calls the torch.save solution in that issue.

def serialize(obj):
    if type(obj).__module__.startswith('torch')
        header = {}
        frames = [torch_save_solution(obj)]
        return header, frames
    else:
        raise NotImplementedError()

def deserialize(header, frames):
    ...

register...('torch', serialize, deserialize)

See http://distributed.readthedocs.io/en/latest/serialization.html#extend for explicit details

Then we would need to include this in our list of serializers like

client = Client(..., serializers=['torch', 'dask', 'pickle'])

This code would need to be run on all of the workers. I would be happy to put it into the main dask codebase (assuming it doesn't need to import anything immediately on import). Alternatively you could distribute it yourself. If you do this I would recommend putting it in a worker preload script so that it runs whenever a worker starts (see the --preload flag to dask-worker).

This is all a workaround though and still somewhat fragile, it'd be very nice to see PyTorch handle things on its own.

Member

mrocklin commented Jul 4, 2018

Given the class inheritance issue the dask solution here is probably to create a new serialization family (sorry for the two kinds of solutions here) that checks every object, fails quickly if it's not a torch object, but if it is calls the torch.save solution in that issue.

def serialize(obj):
    if type(obj).__module__.startswith('torch')
        header = {}
        frames = [torch_save_solution(obj)]
        return header, frames
    else:
        raise NotImplementedError()

def deserialize(header, frames):
    ...

register...('torch', serialize, deserialize)

See http://distributed.readthedocs.io/en/latest/serialization.html#extend for explicit details

Then we would need to include this in our list of serializers like

client = Client(..., serializers=['torch', 'dask', 'pickle'])

This code would need to be run on all of the workers. I would be happy to put it into the main dask codebase (assuming it doesn't need to import anything immediately on import). Alternatively you could distribute it yourself. If you do this I would recommend putting it in a worker preload script so that it runs whenever a worker starts (see the --preload flag to dask-worker).

This is all a workaround though and still somewhat fragile, it'd be very nice to see PyTorch handle things on its own.

@mrocklin

This comment has been minimized.

Show comment
Hide comment
@mrocklin

mrocklin Jul 5, 2018

Member

@lesteve, if you're comfortable living with short-term patches, my current recommendation is just to patch your installation of PyTorch with the solution in pytorch/pytorch#9184. It's five lines and solves the problem. See pytorch/pytorch#9184 (comment)

Member

mrocklin commented Jul 5, 2018

@lesteve, if you're comfortable living with short-term patches, my current recommendation is just to patch your installation of PyTorch with the solution in pytorch/pytorch#9184. It's five lines and solves the problem. See pytorch/pytorch#9184 (comment)

@mrocklin

This comment has been minimized.

Show comment
Hide comment
@mrocklin

mrocklin Jul 6, 2018

Member

@lesteve I think that this has now been resolved upstream. With pytorch/pytorch#9184 I get 1 GB/s serialization bandwidths rather than 50MB/s (CPU). My guess is that this will also resolve the GPU serialization issue as well, but it would be good to check.

Member

mrocklin commented Jul 6, 2018

@lesteve I think that this has now been resolved upstream. With pytorch/pytorch#9184 I get 1 GB/s serialization bandwidths rather than 50MB/s (CPU). My guess is that this will also resolve the GPU serialization issue as well, but it would be good to check.

@lesteve

This comment has been minimized.

Show comment
Hide comment
@lesteve

lesteve Jul 6, 2018

Collaborator

Great, thanks a lot @mrocklin! I haven't got around trying out your suggestions yet but it's good to know that the slow pickling problem has been fixed in PyTorch.

Something I realised recently is that keras models are not be picklable in general (there is a 2-week old PR keras-team/keras#10483 to fix that) so the custom dask serialization may come in handy, e.g.:

import pickle
from keras.applications.resnet50 import ResNet50

resnet50 = ResNet50()
pickle.dumps(resnet50)
TypeError: can't pickle _thread.lock objects
Collaborator

lesteve commented Jul 6, 2018

Great, thanks a lot @mrocklin! I haven't got around trying out your suggestions yet but it's good to know that the slow pickling problem has been fixed in PyTorch.

Something I realised recently is that keras models are not be picklable in general (there is a 2-week old PR keras-team/keras#10483 to fix that) so the custom dask serialization may come in handy, e.g.:

import pickle
from keras.applications.resnet50 import ResNet50

resnet50 = ResNet50()
pickle.dumps(resnet50)
TypeError: can't pickle _thread.lock objects
@mrocklin

This comment has been minimized.

Show comment
Hide comment
@mrocklin

mrocklin Jul 6, 2018

Member
Member

mrocklin commented Jul 6, 2018

@lesteve

This comment has been minimized.

Show comment
Hide comment
@lesteve

lesteve Jul 10, 2018

Collaborator

Nice, great to know!

Collaborator

lesteve commented Jul 10, 2018

Nice, great to know!

@lesteve

This comment has been minimized.

Show comment
Hide comment
@lesteve

lesteve Jul 30, 2018

Collaborator

Closing this one, since most of the problems have been addressed.

Collaborator

lesteve commented Jul 30, 2018

Closing this one, since most of the problems have been addressed.

@lesteve lesteve closed this Jul 30, 2018

@mrocklin

This comment has been minimized.

Show comment
Hide comment
@mrocklin

mrocklin Jul 30, 2018

Member
Member

mrocklin commented Jul 30, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment