Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

/workspace/xgboost/rabit/include/rabit/internal/utils.h:90: Allreduce failed #6551

Closed
pseudotensor opened this issue Dec 24, 2020 · 31 comments · Fixed by #7297
Closed

/workspace/xgboost/rabit/include/rabit/internal/utils.h:90: Allreduce failed #6551

pseudotensor opened this issue Dec 24, 2020 · 31 comments · Fixed by #7297
Assignees

Comments

@pseudotensor
Copy link
Contributor

@trivialfis

Turned on early stopping and even for just single node 2 GPU case I'm getting this error.

/workspace/xgboost/rabit/include/rabit/internal/utils.h:90: Allreduce failed  

It's related is #6272 but I get this allreduce error without the other error. So I wanted to post. However, if one ensures the eval_set has sufficient partitions across the dask workers one does not hit this problem. The error is a bit confusing by itself, but the worker logs show other errors like empty dataset.

Maybe the error provided by xgboost can be improved.

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Dec 25, 2020

Actually, @trivialfis , after ensuring the 2 workers have data, I still saw this hit once out of about 200 models.

Just before the failure I have num_workers: 2 from client and:

Xy_shape=(91457, 126)
Xy npartitions: 16 16
validXy_shape=(22864, 126)
valid Xy npartitions: 4 4

Yet, I hit on worker:


[09:30:09] WARNING: /workspace/xgboost/src/learner.cc:1219: Empty dataset at worker: 1
[09:30:09] WARNING: /workspace/xgboost/src/learner.cc:1219: Empty dataset at worker: 1
2020-12-24 09:30:11,002 - distributed.worker - WARNING -  Compute Failed
Function:  dispatched_train
args:      ('tcp://172.16.2.210:38283', [b'DMLC_NUM_WORKER=2', b'DMLC_TRACKER_URI=172.16.2.210', b'DMLC_TRACKER_PORT=9091', b'DMLC_TASK_ID=[xgboost.dask]:tcp://172.16.2.210:38283'], {'feature_names': None, 'feature_types': None, 'feature_weights': None, 'meta_names': ['labels'], 'missing': nan, 'parts': [(       100_v88   101_v89     102_v9  ...    97_v85    98_v86     99_v87
0     3.321300  0.095678   9.999999  ...  1.707317  0.866426   9.551836
1          NaN  2.678584        NaN  ...       NaN       NaN   9.848003
2     3.367346  0.111388  12.666667  ...  2.429906  1.071429   8.447465
3     1.408046  0.039051   8.965516  ...  1.587045  1.242817  10.747144
4          NaN       NaN        NaN  ...       NaN       NaN        NaN
...        ...       ...        ...  ...       ...       ...        ...
5711  1.257253  0.297503   7.586206  ...  1.600000  0.783365   6.193490
5712  1.269311  0.078320   8.954704  ...  3.215628  1.083507   8.527496
5713       NaN       NaN        NaN  ...       NaN      
kwargs:    {}
Exception: XGBoostError('[09:30:09] /workspace/xgboost/rabit/include/rabit/internal/utils.h:90: Allreduce failed',)

2020-12-24 09:30:11,034 - distributed.worker - WARNING -  Compute Failed
Function:  dispatched_train
args:      ('tcp://172.16.2.210:46607', [b'DMLC_NUM_WORKER=2', b'DMLC_TRACKER_URI=172.16.2.210', b'DMLC_TRACKER_PORT=9091', b'DMLC_TASK_ID=[xgboost.dask]:tcp://172.16.2.210:46607'], {'feature_names': None, 'feature_types': None, 'feature_weights': None, 'meta_names': ['labels'], 'missing': nan, 'parts': [(        100_v88   101_v89    102_v9  ...    97_v85        98_v86     99_v87
5716        NaN       NaN       NaN  ...       NaN           NaN        NaN
5717        NaN       NaN       NaN  ...       NaN           NaN        NaN
5718        NaN       NaN       NaN  ...       NaN           NaN        NaN
5719        NaN       NaN       NaN  ...       NaN           NaN        NaN
5720        NaN       NaN       NaN  ...       NaN           NaN        NaN
...         ...       ...       ...  ...       ...           ...        ...
11427  1.768530  0.171622  8.051949  ...  1.731602  1.014304e+00  13.394527
11428  3.500000  0.064629  8.888888  ...  3.414634 -5.495691e-07   3.851137
11429  1.299254  5.0
kwargs:    {}
Exception: XGBoostError('[09:30:09] /workspace/xgboost/src/metric/rank_metric.cc:242: Check failed: info.labels_.Size() != 0U (0 vs. 0) : label set cannot be empty\nStack trace:\n  [bt] (0) /home/jon/minicondadai/lib/python3.6/site-packages/xgboost/lib/libxgboost.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x54) [0x14a497b04f64]\n  [bt] (1) /home/jon/minicondadai/lib/python3.6/site-packages/xgboost/lib/libxgboost.so(xgboost::metric::EvalAuc::Eval(xgboost::HostDeviceVector<float> const&, xgboost::MetaInfo const&, bool)+0x9cb) [0x14a497c58a8b]\n  [bt] (2) /home/jon/minicondadai/lib/python3.6/site-packages/xgboost/lib/libxgboost.so(xgboost::LearnerImpl::EvalOneIter(int, std::vector<std::shared_ptr<xgboost::DMatrix>, std::allocator<std::shared_ptr<xgboost::DMatrix> > > const&, std::vector<std::string, std::allocator<std::string> > const&)+0x4f4) [0x14a497c2e964]\n  [bt] (3) /home/jon/minicondadai/lib/python3.6/site-packages/xgboost/lib/libxgboost.so(XGBoosterEvalOneIter+0x22d) [0x14a497b0cb6d]\n  [bt] (4) /home/jon/minicondadai/lib/python3.6/lib-dynload/../../libffi.so.6(ffi_call_unix64+0x4c) [0x14a57cb56630]\n  [bt] (5) /home/jon/minicondadai/lib/python3.6/lib-dynload/../../libffi.so.6(ffi_call+0x22d) [0x14a57cb55fed]\n  [bt] (6) /home/jon/minicondadai/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(_ctypes_callproc+0x2ce) [0x14a57cb6cf9e]\n  [bt] (7) /home/jon/minicondadai/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(+0x139d5) [0x14a57cb6d9d5]\n  [bt] (8) dask-worker [tcp://172.16.2.210:46607](_PyObject_FastCallDict+0x8b) [0x55966fa4e00b]\n\n',)

[09:31:35] task [xgboost.dask]:tcp://172.16.2.210:46607 got new rank 0
[09:31:35] task [xgboost.dask]:tcp://172.16.2.210:38283 got new rank 1

and I see in logs:

2020-12-24 09:30:11,104 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     func(X, y, **kwargs_dask)
2020-12-24 09:30:11,105 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/core.py", line 422, in inner_f
2020-12-24 09:30:11,106 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     return f(**kwargs)
2020-12-24 09:30:11,107 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/dask.py", line 1322, in fit
2020-12-24 09:30:11,108 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     callbacks=callbacks)
2020-12-24 09:30:11,109 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/distributed/client.py", line 832, in sync
2020-12-24 09:30:11,110 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
2020-12-24 09:30:11,111 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/distributed/utils.py", line 339, in sync
2020-12-24 09:30:11,112 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     raise exc.with_traceback(tb)
2020-12-24 09:30:11,113 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/distributed/utils.py", line 323, in f
2020-12-24 09:30:11,114 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     result[0] = yield future
2020-12-24 09:30:11,114 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/tornado/gen.py", line 735, in run
2020-12-24 09:30:11,115 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     value = future.result()
2020-12-24 09:30:11,116 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/dask.py", line 1290, in _fit_async
2020-12-24 09:30:11,117 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     callbacks=callbacks)
2020-12-24 09:30:11,118 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/dask.py", line 717, in _train_async
2020-12-24 09:30:11,119 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     results = await client.gather(futures)
2020-12-24 09:30:11,120 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/distributed/client.py", line 1841, in _gather
2020-12-24 09:30:11,121 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     raise exception.with_traceback(traceback)
2020-12-24 09:30:11,122 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/dask.py", line 690, in dispatched_train
2020-12-24 09:30:11,123 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     callbacks=callbacks)
2020-12-24 09:30:11,124 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/training.py", line 189, in train
2020-12-24 09:30:11,125 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     early_stopping_rounds=early_stopping_rounds)
2020-12-24 09:30:11,126 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/training.py", line 83, in _train_internal
2020-12-24 09:30:11,127 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     if callbacks.after_iteration(bst, i, dtrain, evals):
2020-12-24 09:30:11,128 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/callback.py", line 432, in after_iteration
2020-12-24 09:30:11,129 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     score = model.eval_set(evals, epoch, self.metric)
2020-12-24 09:30:11,130 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/core.py", line 1343, in eval_set
2020-12-24 09:30:11,131 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     ctypes.byref(msg)))
2020-12-24 09:30:11,132 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |   File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/core.py", line 189, in _check_call
2020-12-24 09:30:11,133 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   |     raise XGBoostError(py_str(_LIB.XGBGetLastError()))
2020-12-24 09:30:11,133 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   | xgboost.core.XGBoostError: [09:30:09] /workspace/xgboost/rabit/include/rabit/internal/utils.h:90: Allreduce failed
2020-12-24 09:30:11,134 C:  7% D:526.2GB M:97.8GB  NODE:LOCAL1      33707  DATA   | ].

I will try to repro and get the pickled state for when things fail next time. But any insight? It's just regular partitions of data, but for whatever reason xgboost still things one worker has empty dataset.

Is it possible that 1 worker happen to get all 16 or 4 partitions of Xy, valid_Xy?

@pseudotensor
Copy link
Contributor Author

@trivialfis Actually, I can reproduce the problem if I just run a given script in a loop, and actually it failed only on the 18th trial.

I'll try to get a clean repro for you that is free of our code base.

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Dec 26, 2020

I'm using (like have before) a 2 node cuda cluster with 2 GPUs in each node. I get this very frequently, but not always. Here's repro. When I ran this first time, it failed on very first iteration, another time did not fail on first iteration, but it always eventually fails within order 20 times at least.

from dask.distributed import Client
from dask import dataframe as dd
import xgboost as xgb

niter = 100
for iter in range(0, niter):
    print('iter: %d' % iter)
    import _pickle as pickle
    fil = "bad_allreduce.pkl"
    (model, X, y, kwargs) = pickle.load(open(fil, "rb"))
    target = "__TARGET__"
    X[target] = y
    print(X.shape)
    print(y.shape)

    with Client(scheduler_file="scheduler.json") as client:

        dask_df = dd.from_pandas(X, chunksize=5000).persist()
        # extract
        yd = dask_df[target]
        Xd = dask_df.drop(target, axis=1)
        print("Xy npartitions: %d %d" % (Xd.npartitions, yd.npartitions))

        # eval_set
        valid_X = kwargs['eval_set'][0][0]
        valid_y = kwargs['eval_set'][0][1]
        valid_X[target] = valid_y
        valid_dask_df = dd.from_pandas(valid_X, chunksize=5000).persist()
        # extract
        valid_yd = valid_dask_df[target]
        valid_Xd = valid_dask_df.drop(target, axis=1)
        print("valid Xy npartitions: %d %d" % (valid_Xd.npartitions, valid_yd.npartitions))
        kwargs['eval_set'] = [(valid_Xd, valid_yd)]

        print(model)
        print(model.get_params())
        model.fit(Xd, yd, **kwargs)
        preds = model.predict(Xd).compute()
        print(preds[0:10])
        valid_preds = model.predict(valid_Xd).compute()
        print(valid_preds[0:10])

(28MB, so github won't take)
https://0xdata-public.s3.amazonaws.com/jon/bad_allreduce.pkl.zip

scheduler.json.zip

/home/jon/minicondadai/bin/python /data/jon/h2oai.fullcondatest3/bad.dat.py
iter: 0
(91457, 128)
(91457,)
Xy npartitions: 19 19
valid Xy npartitions: 5 5
DaskXGBClassifier(accuracy=7, booster='gbtree', colsample_bytree=0.55,
                  debug_verbose=2, disable_gpus=False,
                  early_stopping_limit=None, early_stopping_rounds=20,
                  early_stopping_threshold=1e-05, encoder=None,
                  ensemble_level=3, eval_metric='aucpr',
                  experiment_description='3.cineweru', gamma=0.0, gpu_id=0,
                  grow_policy='lossguide', interpretability=1, labels=[0, 1],
                  learning_rate=0.15, lossguide=False, max_bin=256,
                  max_delta_step=0.0, max_depth=0, max_leaves=256,
                  min_child_weight=1, model_class_name='XGBoostGBMDaskModel',
                  model_origin='DefaultIndiv: '
                               'do_te:True,interp:11,depth:6,num_as_cat:False',
                  monotonicity_constraints=False, n_estimators=600, n_jobs=9,
                  ngenes=127, ...)
{'base_score': None, 'booster': 'gbtree', 'colsample_bylevel': None, 'colsample_bynode': None, 'colsample_bytree': 0.55, 'gamma': 0.0, 'gpu_id': 0, 'importance_type': 'gain', 'interaction_constraints': None, 'learning_rate': 0.15, 'max_delta_step': 0.0, 'max_depth': 0, 'min_child_weight': 1, 'missing': nan, 'monotone_constraints': None, 'n_estimators': 600, 'n_jobs': 9, 'num_parallel_tree': None, 'objective': 'binary:logistic', 'random_state': 278438169, 'reg_alpha': 0.0, 'reg_lambda': 2.0, 'scale_pos_weight': 1.0, 'subsample': 0.5, 'tree_method': 'gpu_hist', 'validate_parameters': None, 'verbosity': None, 'use_label_encoder': False, 'model_class_name': 'XGBoostGBMDaskModel', 'num_class': 1, 'labels': [0, 1], 'score_f_name': 'LOGLOSS', 'time_column': None, 'encoder': None, 'tgc': None, 'pred_gap': None, 'pred_periods': None, 'target': None, 'tsp': None, 'early_stopping_rounds': 20, 'max_bin': 256, 'grow_policy': 'lossguide', 'max_leaves': 256, 'eval_metric': 'aucpr', 'early_stopping_threshold': 1e-05, 'monotonicity_constraints': False, 'silent': 0, 'debug_verbose': 2, 'seed': 278438169, 'disable_gpus': False, 'lossguide': False, 'accuracy': 7, 'time_tolerance': 10, 'interpretability': 1, 'ensemble_level': 3, 'train_shape': (114321, 133), 'valid_shape': None, 'model_origin': 'DefaultIndiv: do_te:True,interp:11,depth:6,num_as_cat:False', 'resumed_experiment_id': 'bedd7566-45e6-11eb-bb81-0cc47adb058f', 'str_uuid': 'ret_ff6609f1-e952-4a09-af96-9388336d482c', 'experiment_description': '3.cineweru', 'train_dataset_name': 'train.csv.zip', 'valid_data_name': '[Valid]', 'test_data_name': '[Test]', 'ngenes': 127, 'ngenes_max': 133, 'uses_gpu': True, 'early_stopping_limit': None}
Traceback (most recent call last):
  File "/data/jon/h2oai.fullcondatest3/bad.dat.py", line 39, in <module>
    model.fit(Xd, yd, **kwargs)
  File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/core.py", line 422, in inner_f
    return f(**kwargs)
  File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/dask.py", line 1322, in fit
    callbacks=callbacks)
  File "/home/jon/minicondadai/lib/python3.6/site-packages/distributed/client.py", line 832, in sync
    self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
  File "/home/jon/minicondadai/lib/python3.6/site-packages/distributed/utils.py", line 339, in sync
    raise exc.with_traceback(tb)
  File "/home/jon/minicondadai/lib/python3.6/site-packages/distributed/utils.py", line 323, in f
    result[0] = yield future
  File "/home/jon/minicondadai/lib/python3.6/site-packages/tornado/gen.py", line 735, in run
    value = future.result()
  File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/dask.py", line 1290, in _fit_async
    callbacks=callbacks)
  File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/dask.py", line 717, in _train_async
    results = await client.gather(futures)
  File "/home/jon/minicondadai/lib/python3.6/site-packages/distributed/client.py", line 1841, in _gather
    raise exception.with_traceback(traceback)
  File "/home/jon/minicondadai/lib/python3.6/site-packages/xgboost/dask.py", line 690, in dispatched_train
    callbacks=callbacks)
  File "/home/jenkins/minicondadai/lib/python3.6/site-packages/xgboost/training.py", line 189, in train
  File "/home/jenkins/minicondadai/lib/python3.6/site-packages/xgboost/training.py", line 83, in _train_internal
  File "/home/jenkins/minicondadai/lib/python3.6/site-packages/xgboost/callback.py", line 432, in after_iteration
  File "/home/jenkins/minicondadai/lib/python3.6/site-packages/xgboost/core.py", line 1343, in eval_set
  File "/home/jenkins/minicondadai/lib/python3.6/site-packages/xgboost/core.py", line 189, in _check_call
xgboost.core.XGBoostError: [23:11:02] /workspace/xgboost/rabit/include/rabit/internal/utils.h:90: Allreduce failed

while when works, which it sometimes does, I get:

/home/jon/minicondadai/bin/python /data/jon/h2oai.fullcondatest3/bad.dat.py
iter: 0
(91457, 128)
(91457,)
Xy npartitions: 19 19
valid Xy npartitions: 5 5
DaskXGBClassifier(accuracy=7, booster='gbtree', colsample_bytree=0.55,
                  debug_verbose=2, disable_gpus=False,
                  early_stopping_limit=None, early_stopping_rounds=20,
                  early_stopping_threshold=1e-05, encoder=None,
                  ensemble_level=3, eval_metric='aucpr',
                  experiment_description='3.cineweru', gamma=0.0, gpu_id=0,
                  grow_policy='lossguide', interpretability=1, labels=[0, 1],
                  learning_rate=0.15, lossguide=False, max_bin=256,
                  max_delta_step=0.0, max_depth=0, max_leaves=256,
                  min_child_weight=1, model_class_name='XGBoostGBMDaskModel',
                  model_origin='DefaultIndiv: '
                               'do_te:True,interp:11,depth:6,num_as_cat:False',
                  monotonicity_constraints=False, n_estimators=600, n_jobs=9,
                  ngenes=127, ...)
{'base_score': None, 'booster': 'gbtree', 'colsample_bylevel': None, 'colsample_bynode': None, 'colsample_bytree': 0.55, 'gamma': 0.0, 'gpu_id': 0, 'importance_type': 'gain', 'interaction_constraints': None, 'learning_rate': 0.15, 'max_delta_step': 0.0, 'max_depth': 0, 'min_child_weight': 1, 'missing': nan, 'monotone_constraints': None, 'n_estimators': 600, 'n_jobs': 9, 'num_parallel_tree': None, 'objective': 'binary:logistic', 'random_state': 278438169, 'reg_alpha': 0.0, 'reg_lambda': 2.0, 'scale_pos_weight': 1.0, 'subsample': 0.5, 'tree_method': 'gpu_hist', 'validate_parameters': None, 'verbosity': None, 'use_label_encoder': False, 'model_class_name': 'XGBoostGBMDaskModel', 'num_class': 1, 'labels': [0, 1], 'score_f_name': 'LOGLOSS', 'time_column': None, 'encoder': None, 'tgc': None, 'pred_gap': None, 'pred_periods': None, 'target': None, 'tsp': None, 'early_stopping_rounds': 20, 'max_bin': 256, 'grow_policy': 'lossguide', 'max_leaves': 256, 'eval_metric': 'aucpr', 'early_stopping_threshold': 1e-05, 'monotonicity_constraints': False, 'silent': 0, 'debug_verbose': 2, 'seed': 278438169, 'disable_gpus': False, 'lossguide': False, 'accuracy': 7, 'time_tolerance': 10, 'interpretability': 1, 'ensemble_level': 3, 'train_shape': (114321, 133), 'valid_shape': None, 'model_origin': 'DefaultIndiv: do_te:True,interp:11,depth:6,num_as_cat:False', 'resumed_experiment_id': 'bedd7566-45e6-11eb-bb81-0cc47adb058f', 'str_uuid': 'ret_ff6609f1-e952-4a09-af96-9388336d482c', 'experiment_description': '3.cineweru', 'train_dataset_name': 'train.csv.zip', 'valid_data_name': '[Valid]', 'test_data_name': '[Test]', 'ngenes': 127, 'ngenes_max': 133, 'uses_gpu': True, 'early_stopping_limit': None}
[1 1 1 1 1 0 1 0 1 1]
[1 1 1 1 1 1 1 1 1 1]

Process finished with exit code 0

@trivialfis On this 2 node cluster this error happens about every other model fit!! So this is really bad.

FYI the predict is not required for failure, it's just way to see what is going on.

@pseudotensor
Copy link
Contributor Author

If it didn't fail only half the time, I would guess that xgboost has some extra conditions for the partitions that are not documented. That's why I asked about the requirements before. You said dmatrix has to be all same partition count, e.g. for X, y, sample_weight since that makes 1 dmatrix. The validation dmatrix can be for valid_X, valid_y, sample_weight_eval_set (for each eval set). That makes sense.

However, as a user I'd be worried that 5 partitions does not divide evenly into 4 cuda workers (2 nodes each 2 GPUs). If it failed every time, I would have guessed I need to make the partitions evenly divisible by the number of workers. Of course, that is not always possible to do.

@pseudotensor
Copy link
Contributor Author

And to be clear, the actual error vasillates between the allreduce error and the empty dataset -> label set 0 error from here: #6272 It's about 50%/50% in terms of which error pops up.

@pseudotensor
Copy link
Contributor Author

I also have reports from colleagues that it happens even without a cluster, just local cluster mode, which I'll check on.

@pseudotensor
Copy link
Contributor Author

The report seems to be true, but I so far have only been able to reproduce the issue when using a cluster. In that case, dask is highly unstable with repro above. @trivialfis

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Dec 28, 2020

Any progress? I added a hacky retry with (say) 5 retry attempts, but it only works about 30% of time to retry. That it works at all shows part (or whole of) problem is with xgboost.

But, I'm just guessing, it could also be that data is already, when doing persist, badly distributed or something. I will keep trying things.

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Dec 28, 2020

I also should note that nothing like this was seen in that earlier 1.2.x type build. Problems only occurred with that dask PR/commit we discussed, and so I don't think it is fully fixed in 1.3.0 with the IP fix alone.

So this is definitely a regression.

@pseudotensor
Copy link
Contributor Author

FYI, even when partitions are even for given workers I see same problem:

2020-12-28 14:21:27,315 C: NA D: NA M: NA NODE:SERVER 22966 PDEBUG | ('Xy npartitions: 32 32',)
2020-12-28 14:21:27,818 C: NA D: NA M: NA NODE:SERVER 22966 PDEBUG | ('valid Xy npartitions: 8 8',)

then fails same way.

I can't seem to see why it either hits Allreduce or "label set cannot be empty" even when data should be scattered. Do I have to scatter manually? I'm without a clue, but this definitely didn't happen in older xgboost.

@trivialfis
Copy link
Member

I don't have any progress yet. Just got back yesterday. Will go through the issues.

@pseudotensor
Copy link
Contributor Author

FYI if I remove the eval_set, then I don't seem to encounter either the Allreduce or label set empty issues. It's unclear how xgboost ends up with empty dataset, unless partitions are unevenly distributed. Trying to see on dask side how to determine which workers have which partitions, but maybe you know?

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Dec 29, 2020

FYI I also tried explicitly putting the data into dask_cudf by adding dask_df = dask_cudf.from_dask_dataframe(dask_df) and valid_dask_df = dask_cudf.from_dask_dataframe(valid_dask_df) to the above script after the relevant from_pandas, but still same problem. Maybe it happens less, but I can't be sure.

FYI the problem with using dask_cudf on local client is that uses GPU memory, and then I see the dask worker uses the same memory. So there is an extra copy for no good reason it seems.

@trivialfis
Copy link
Member

I think my plan will be improving the implementation of aucpr to handle empty datasets.

@pseudotensor
Copy link
Contributor Author

I think my plan will be improving the implementation of aucpr to handle empty datasets.

You mean all rank based metrics?

I guess that will fix one side but not the allreduce perhaps. But it still a concern that data is not evenly distributed. There's no reason, from what I see, that ever any worker should have empty data set.

@trivialfis
Copy link
Member

trivialfis commented Dec 29, 2020

You mean all rank based metrics?

Yeah if time allows.

I guess that will fix one side but not the allreduce perhaps.

Almost always (I haven't seen other case), the allreduce failure is caused by other failures in 1 or more workers. As for how to sort out a correct, consistent error reporting/handling by coordinating all workers, we don't know yet. Not exactly an issue can be solved by mutex ..

But it still a concern that data is not evenly distributed. There's no reason, from what I see, that ever any worker should have empty data set.

Did you see a Python warning looks like this:

worker {address} has an empty DMatrix

It's the first report for empty DMatrix. A trick for detecting empty DMatrix by yourself is:

def _get_client_workers(client: "Client") -> Dict[str, Dict]:  # Don't use this on GKE where 'workers' is empty
    workers = client.scheduler_info()['workers']
    return workers

Xy = xgb.dask.DaskDMatrix(X, y)
workers = _get_client_workers(client)
for addr in workers:
    parts = Xy.worker_map.get(addr)
    assert parts is not None  # Assert there's data partition on worker with address `addr`.

@pseudotensor
Copy link
Contributor Author

Unintentional indirect close

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Dec 31, 2020

Thanks, will see. FYI, while I mentioned doing X_dask = dask_cudf.from_dask_dataframe(X_dask) seems to avoid as many failures, the problem with that seems to be that it allocates the entire frame on 1 GPU.

I would have expected this command to not materialize any GPU memory on the client, however it does. E.g.

Wed Dec 30 16:46:56 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.80.02    Driver Version: 450.80.02    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  GeForce GTX 108...  On   | 00000000:02:00.0  On |                  N/A |
| 24%   42C    P2    58W / 250W |   1028MiB / 11178MiB |      8%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 108...  On   | 00000000:81:00.0 Off |                  N/A |
| 20%   36C    P8     9W / 250W |    139MiB / 11178MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     15019      C   ...ostGBMDaskModel_fit_model      295MiB |
|    0   N/A  N/A     15057      C   ...r [tcp://127.0.0.1:33315]      135MiB |
|    1   N/A  N/A     15064      C   ...r [tcp://127.0.0.1:45231]      135MiB |
+-----------------------------------------------------------------------------+

The first process 15019 is on client. The other 2 are workers, with just cuda context.

i.e. this is just before the fit, all frames are dask_cudf after passing the dask frame through that above function. Instead of pushing the data to workers, it just eats up GPU memory on the client side, which defeats the purpose of dask.

This may be a bug or problem with dask_cudf, or just a problem in rapids 0.14 version of it, but maybe you know more.

This is why I settled on passing xgboost the dask frames and letting xgboost convert them to GPU as needed, which seems to perform this per dask worker instead of on client.

Do you have an understanding of how to do this correctly? I can see the frames are dask_cudf with partitions etc. but for whatever reason the client uses GPU memory. And, I can see that once the fit actually starts the client holds onto that memory and each worker starts using GPU memory.

Perhaps dask_cudf team did not notice this because they usually test with localcudacluster, which would not be as easy to see this problem.

@pseudotensor
Copy link
Contributor Author

Hi @trivialfis , any progress on this? We keep seeing it and people who use our products also hit it.

@trivialfis
Copy link
Member

I don't know the exact cause of allreduce failure, and looking into the aucpr implementation for both cpu and gpu, hopefully if I can get it right we will be able to fix your error. It will take some time as I also want to resolve #4663

@pseudotensor
Copy link
Contributor Author

Hi @trivialfis any progress on this? We still see the issue, it makes dask unusable when using multiple nodes.

@trivialfis
Copy link
Member

Working on it now.

@trivialfis
Copy link
Member

I wanted to recommend using sklearn metric as a temporary workaround but it doesn't handle empty dataset. So will continue revising the implementation in xgboost.

@trivialfis
Copy link
Member

trivialfis commented Feb 26, 2021

@pseudotensor Could you please try to wrap sklearn.metrics.roc_auc_score with something like:

from sklearn.metrics import roc_auc_score

def _metric_decorator(func: Callable) -> Metric:
    # func is a metric like `sklearn.metrics.roc_auc_score`.
    def inner(y_score: np.ndarray, dmatrix: DMatrix) -> float:
        y_true = dmatrix.get_label()
        if y_true.size == 0:
            # return 0.5 as default for roc-auc, this various for different metrics.
            return func.__name__, 0.5
        return func.__name__, func(y_true, y_score)
    return inner

metric = _metric_decorator(roc_auc_score)
cls = xgb.dask.DaskXGBClassifier()
cls.fit(X, y, eval_metric=metric, eval_set=[(valid_X, valid_y)])

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Feb 26, 2021

I'm pretty sure in this case I don't have an empty dataset. I ensure the data is partitioned across all workers. When there were cases with empty dataset, it would show such an error.

The original repro works fine to repro for me. That is, it doesn't happen every time. So testing that nvidia does would not see except perhaps rarely.

Can one try to repro on a real multinode cluster?

This is the repro from above: #6551 (comment)

If you have a problem with pickle again, I can remove the "model" and just give params etc.

@trivialfis
Copy link
Member

Can one try to repro on a real multinode cluster?

I setup a ssh cluster.

So far the only way I can get the error is I shutdown one of the nodes in the middle of training.

@trivialfis
Copy link
Member

Judging from the stack trace you provided, I still believe it's the aucpr implementation somehow kills a worker and lead a allreduce failure on the other one. That's why I want to see if #6551 (comment) works for you.

@trivialfis
Copy link
Member

Okay, I got a repro but it's due to empty dataset. On worker:

kwargs:    {}
Exception: XGBoostError('[13:00:39] /gpfs/fs1/Jiaming/XGBoost/xgboost/src/metric/rank_metric.cc:634: Check failed: info.labels_.Size() != 0U (0 vs. 0) : label set cannot be empty\nStack trace:\n  [bt] (0) /gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/../../lib/libxgboost.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x67) [0x7f27784514b7]\n  [bt] (1) /gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/../../lib/libxgboost.so(xgboost::metric::EvalAucPR::Eval(xgboost::HostDeviceVector<float> const&, xgboost::MetaInfo const&, bool)+0xa5a) [0x7f277853e7ca]\n  [bt] (2) /gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/../../lib/libxgboost.so(xgboost::LearnerImpl::EvalOneIter(int, std::vector<std::shared_ptr<xgboost::DMatrix>, std::allocator<std::shared_ptr<xgboost::DMatrix> > > const&, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > const&)+0x32d) [0x7f2778518eed]\n  [bt] (3) /gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/../../lib/libxgboost.so(XGBoosterEvalOneIter+0x395) [0x7f277842c1f5]\n  [bt] (4) /gpfs/fs1/Jiaming/anaconda3/envs/XGBoost/lib/python3.7/lib-dynload/../../libffi.so.6(ffi_call_unix64+0x4c) [0x7f2abfdd8630]\n  [bt] (5) /gpfs/fs1/Jiaming/anaconda3/envs/XGBoost/lib/python3.7/lib-dynload/../../libffi.so.6(ffi_call+0x22d) [0x7f2abfdd7fed]\n  [bt] (6) /gpfs/fs1/Jiaming/anaconda3/envs/XGBoost/lib/python3.7/lib-dynload/_ctypes.cpython-37m-x86_64-linux-gnu.so(_ctypes_callproc+0x2e7) [0x7f2abfdee6d7]\n  [bt] (7) /gpfs/fs1/Jiaming/anaconda3/envs/XGBoost/lib/python3.7/lib-dynload/_ctypes.cpython-37m-x86_64-linux-gnu.so(+0x13144) [0x7f2abfdef144]\n  [bt] (8) /gpfs/fs1/Jiaming/anaconda3/envs/XGBoost/bin/python(_PyObject_FastCallKeywords+0x15c) [0x557c8b061a7c]\n\n')

And on client side:

  File "reproduce.py", line 54, in <module>
    main()
  File "reproduce.py", line 46, in main
    model.fit(Xd, yd, **kwargs)
  File "/gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/core.py", line 433, in inner_f
    return f(**kwargs)
  File "/gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/dask.py", line 1697, in fit
    return self.client.sync(self._fit_async, **args)
  File "/gpfs/fs1/Jiaming/anaconda3/envs/XGBoost/lib/python3.7/site-packages/distributed/client.py", line 833, in sync    self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
  File "/gpfs/fs1/Jiaming/anaconda3/envs/XGBoost/lib/python3.7/site-packages/distributed/utils.py", line 340, in sync
    raise exc.with_traceback(tb)
  File "/gpfs/fs1/Jiaming/anaconda3/envs/XGBoost/lib/python3.7/site-packages/distributed/utils.py", line 324, in f
    result[0] = yield future
  File "/gpfs/fs1/Jiaming/anaconda3/envs/XGBoost/lib/python3.7/site-packages/tornado/gen.py", line 762, in run
    value = future.result()
  File "/gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/dask.py", line 1665, in _fit_async
    xgb_model=model,
  File "/gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/dask.py", line 882, in _train_async
    results = await client.gather(futures)
  File "/gpfs/fs1/Jiaming/anaconda3/envs/XGBoost/lib/python3.7/site-packages/distributed/client.py", line 1851, in _gather
    raise exception.with_traceback(traceback)
  File "/gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/dask.py", line 850, in dispatched_train
    callbacks=callbacks)
  File "/gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/training.py", line 197, in train
    early_stopping_rounds=early_stopping_rounds)
  File "/gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/training.py", line 82, in _train_internal
    if callbacks.after_iteration(bst, i, dtrain, evals):
  File "/gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/callback.py", line 430, in after_iteration
    score = model.eval_set(evals, epoch, self.metric)
  File "/gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/core.py", line 1566, in eval_set
    ctypes.byref(msg)))
  File "/gpfs/fs1/Jiaming/XGBoost/xgboost/python-package/xgboost/core.py", line 210, in _check_call
    raise XGBoostError(py_str(_LIB.XGBGetLastError()))
xgboost.core.XGBoostError: [13:00:39] /gpfs/fs1/Jiaming/XGBoost/xgboost/rabit/include/rabit/internal/utils.h:90: Allreduce failed

@trivialfis
Copy link
Member

@pseudotensor The client side error seems to match the one you have provided.

@Hasna1994
Copy link

Hi, was this resolved? I'm getting this error.

@trivialfis
Copy link
Member

@Hasna1994 Could you please open a new issue with a reproducible example?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants