Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dask Incremental with a slow fit will grow progressively slower due to spilling to disk #765

Open
2over12 opened this issue Dec 7, 2020 · 6 comments

Comments

@2over12
Copy link

2over12 commented Dec 7, 2020

When using Dask incremental to train on datasets that are larger than memory if fit is substantially slower than the read tasks then most of the array will end up spilled back to disk. When fit hits those spilled tasks and has to load them this makes fit even slower and allows the read tasks that are reading data from disk to get further ahead making each fit operation spend time loading spilled data from disk.

Here is some example code that demonstrates this issue:

from distributed.client import Client
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from scikeras.wrappers import KerasClassifier
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np
from typing import Tuple
import dask_ml
from dask import array
import dask
import sys
from dask_ml import datasets

if __name__ == "__main__":
    X,y = datasets.make_classification_df(n_samples = int(sys.argv[1]),n_features=784,n_classes=2,chunks=(3000,784))
    
    X['ys'] = y
    X.to_parquet(sys.argv[2])

I used this script to generate data 4x the amount of RAM that I have.

Then I trained on it using incremental:

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from scikeras.wrappers import KerasClassifier
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np
from typing import Tuple
import dask_ml
from dask import array
from dask import dataframe
from dask.distributed import LocalCluster
import dask
import sys


def build_model(lr=0.01, momentum=0.9):
    
    layers = [Dense(512, input_shape=(784,), activation="relu"),
            Dense(10, input_shape=(512,), activation="softmax")]
    model = Sequential(layers)

    opt = tf.keras.optimizers.SGD(
        learning_rate=lr, momentum=momentum, nesterov=True,
    )
    model.compile(loss="binary_crossentropy", optimizer=opt, metrics=['accuracy'])
    return model

if __name__ == "__main__":
    cluster = LocalCluster(n_workers=4,memory_target_fraction=.4, memory_spill_fraction=.4)
    dask.distributed.Client(cluster)
    
    path = sys.argv[1]

    model = KerasClassifier(build_fn=build_model, lr=0.1, momentum=0.9,batch_size=1)
    inc_mod = dask_ml.wrappers.Incremental(model)
    
    df = dataframe.read_parquet(path)


    parition_sizes = list(df.map_partitions(lambda x: x.shape[0]).compute())
    print(type(parition_sizes))

    Xs = df.loc[:,df.columns != 'ys'].to_dask_array(lengths=parition_sizes)

    ys = df.loc[:,'ys'].to_dask_array(lengths=parition_sizes)

    inc_mod.fit(Xs,ys)

I dropped the batch size to 1 to intentionally make fit slower and dropped the spill fractions to make the issue occur faster.

This results in this computation graph:
big_graph

A more zoomed in version:
small_graph

As you can see the 3 early tasks which are reads and transformations are all dependencies of each fit task while the fit tasks are each dependent on the previous fit task in order to guarantee training the model in serial with respect to dask.

The trouble is dask will continue to execute the dependencies of each fit ad infinium
2020-12-06-201720_954x281_scrot

When fit is slow these reads will get far ahead of the fit task. Since incremental is meant for larger than memory datasets eventually the tasks will start spilling to disk. Fit will then have to load the dependency back from disk when it reaches the spilled block. These loads can be quite slow if blocks are large. In the worst case this ends up as essentially serial execution where each fit is preceded by a block read. All of these reads and writes can make training extremely slow as spilling, fit loading spilled tasks, and further read tasks are all competing for disk time.

Ideally what would occur is read tasks are only executed until memory is full and then workers wait until fit is done with some tasks freeing up space to read in more training data. That way data never has to be spilled to disk and read back.

One way I thought about acheiving this is by modifying the graph to add synchronization points at regular intervals.

For example if we have the task graph [read0->fit0, read1->fit1,read2->fit2,read3->fit3,read4->fit4, read5->fit5,read6->fit6,read7->fit7, read8->fit8, fit0->fit1,fit1->fit2, fit2->fit3,fit3->fit4, fit4->fit5,fit5->fit6, fit6->fit7,fit7->fit8]

And we know we can fit 3 blocks in memory then we can add edges {fit2->read3, fit5->read6}. These artificial dependencies force the reading tasks to wait until fit has released the 3 blocks before starting to read the next set of blocks. This behavior is not ideal because it will wait to start reading the next 3 chunks until all 3 previous chunks have been read. Ideally as soon as the first chunk is fit and released we can start loading the fourth chunk. That being said I do not believe such behavior can easily be accomplished with normal computation graphs and would likely require using raw futures.

I think this is an important issue for Incremental since its purpose is larger than memory datasets and it is not super uncommon for model fitting to take longer than reading and transforming the data.

@TomAugspurger
Copy link
Member

So, (over-)simplifying a bit, is it fair to say that

  • Dask continues to schedule early tasks (the reads, transforms) when workers are running low on memory
  • This causes spilling to disk, slowing down the fitting even more
  • We'd prefer to only schedule early tasks after a later task (fit) has finished and the memory has been freed.

If that's accurate, this sounds a bit like dask/distributed#2602. So I'd have a follow-up question: What workarounds are appropriate to put in place in Dask-ML, while we wait for a more general solution to this scheduling problem (which is hard).

@2over12
Copy link
Author

2over12 commented Dec 7, 2020

Yeah your summary is accurate. Like I said the only idea i have currently is to add artificial dependency edges from fit tasks to read tasks at regular intervals. So the idea would be a parameter gets passed to Incremental which is how many chunks we want to have in memory at a given time say 100. That parameter reaches dask_ml/_partial.py:fit. Since fit already creates a custom graph we could add edges of the form fit task -> read task which would prevent the read tasks from getting more than 100 tasks ahead of fit. The main problem with this approach is that if read gets ahead of fit by 100 tasks then it will wait until fit completely catches up to start reading again, instead of executing new read tasks as each individual fit task completes.

The only other option I can see is implementing that fit function with futures instead. I think it would be fairly easy with futures to ensure that fit tasks are executed in serial and only a certain number of read task futures are ever in the queue. I guess the approach would be to have loop that first adds say 100 read futures to the scheduler and then when a future is done add a future for the fit task. Every time a fit task finishes we add a new fit task with one of the completed read tasks. If we now have less than 100 pending read tasks we add another read future.

These are just the ideas I had please let me know if either of these make sense or if you have another idea. Or maybe it's best to wait on some more granular scheduling possibilities from dask 2602

@stsievert
Copy link
Member

@TomAugspurger could Dask's resources be used for this problem? I think it'd require some modification, but it seems relevant because the documentation claims they can be used for "tasks that require a large amount of memory at runtime."

@TomAugspurger
Copy link
Member

Resources could perhaps be used at the application level (ensure that at-most N preprocessing tasks are live at a time), but I don't think they're the appropriate tool for a library like dask-ml to work around an issue in dask.

@2over12
Copy link
Author

2over12 commented Dec 15, 2020

Is the general feeling here to wait until dask comes up with a scheduling improvement to handle cases like this?

@TomAugspurger
Copy link
Member

TomAugspurger commented Dec 19, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants