Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Use client when persisting collections #6712

Closed
jose-moralez opened this issue Feb 17, 2021 · 2 comments
Closed

[dask] Use client when persisting collections #6712

jose-moralez opened this issue Feb 17, 2021 · 2 comments

Comments

@jose-moralez
Copy link
Contributor

Hi. I'm trying to fit xgboost models concurrently using asynchronous dask clients, however using the collection's persist method like here triggers the computation on whatever is the default client at the moment, so code like this:

async def train(params):
  async with Client(asynchronous=True) as client:
      # create X_train, X_valid, y_train, y_valid as dask collections
      # ...
      dtrain = await xgb.dask.DaskDMatrix(client, X_train, y_train)
      dvalid = await xgb.dask.DaskDMatrix(client, X_valid, y_valid)

won't neccesarily create the collections in its corresponding client and will trigger errors like these:

  File "./api_xgb.py", line 71, in train                                                                                                        
    dvalid = await xgb.dask.DaskDMatrix(client, X_valid, y_valid)                                                                                 
  File "/home/jose_morales/miniconda3/envs/fastapi/lib/python3.8/site-packages/xgboost/dask.py", line 257, in map_local_data                      
    data = data.persist()                                                                                                                         
  File "/home/jose_morales/miniconda3/envs/fastapi/lib/python3.8/site-packages/dask/base.py", line 254, in persist
    (result,) = persist(self, traverse=False, **kwargs)
  File "/home/jose_morales/miniconda3/envs/fastapi/lib/python3.8/site-packages/dask/base.py", line 755, in persist
    results = client.persist(
  File "/home/jose_morales/miniconda3/envs/fastapi/lib/python3.8/site-packages/distributed/client.py", line 2944, in persist
    futures = self._graph_to_futures(
  File "/home/jose_morales/miniconda3/envs/fastapi/lib/python3.8/site-packages/distributed/client.py", line 2543, in _graph_to_futures
    dsk = highlevelgraph_pack(dsk, self, keyset)
  File "/home/jose_morales/miniconda3/envs/fastapi/lib/python3.8/site-packages/distributed/protocol/highlevelgraph.py", line 115, in highlevelgrap
h_pack
    "state": _materialized_layer_pack(
    "state": _materialized_layer_pack(
  File "/home/jose_morales/miniconda3/envs/fastapi/lib/python3.8/site-packages/distributed/protocol/highlevelgraph.py", line 40, in _materialized_
layer_pack
    raise ValueError(
ValueError: Inputs contain futures that were created by another client.

The solution to this problem would be replacing data = data.persist() with data = client.persist(data) and client is already being passed to these functions so this should be fairly straightforward. I'd be happy to give it a go.

@trivialfis
Copy link
Member

Thanks for raising the issue. Feel free to open a PR and ping me on github. Please note that training multiple models is not yet supported, I need to try out dask/distributed#4503 later.

@jameslamb
Copy link
Contributor

Good find @jose-moralez ! I've been experiencing that exact issue but hadn't been able to figure out the source.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants