Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python exception in Pool() when pairs_weight is a numpy ndarray #1913

Closed
LachlanStuart opened this issue Nov 11, 2021 · 0 comments
Closed

Python exception in Pool() when pairs_weight is a numpy ndarray #1913

LachlanStuart opened this issue Nov 11, 2021 · 0 comments
Labels

Comments

@LachlanStuart
Copy link

Problem:

Minimal reproduction:

pool = Pool(
    np.ones((5,1)), 
    np.ones(5), 
    pairs=np.array([[0,1],[2,3]]), 
    pairs_weight=np.array([0.5, 0.5]),
)

Exception:

Traceback (most recent call last):
  File "/home/lachlan/miniconda3/envs/sm38/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3417, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-49-2c10cac71c1a>", line 1, in <module>
    Pool(np.ones((5,1)), np.ones(5), pairs=np.array([[0,1],[2,3]]), pairs_weight=np.array([0.5, 0.5]))
  File "/home/lachlan/miniconda3/envs/sm38/lib/python3.8/site-packages/catboost/core.py", line 628, in __init__
    self._init(data, label, cat_features, text_features, embedding_features, pairs, weight, group_id, group_weight, subgroup_id, pairs_weight, baseline, timestamp, feature_names, thread_count)
  File "/home/lachlan/miniconda3/envs/sm38/lib/python3.8/site-packages/catboost/core.py", line 1171, in _init
    self._init_pool(data, label, cat_features, text_features, embedding_features, pairs, weight, group_id, group_weight, subgroup_id, pairs_weight, baseline, timestamp, feature_names, thread_count)
  File "_catboost.pyx", line 3755, in _catboost._PoolBase._init_pool
  File "_catboost.pyx", line 3803, in _catboost._PoolBase._init_pool
  File "_catboost.pyx", line 3676, in _catboost._PoolBase._init_features_order_layout_pool
  File "_catboost.pyx", line 3311, in _catboost._set_pairs
  File "_catboost.pyx", line 3292, in _catboost._make_pairs_vector
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

It runs correctly if a regular Python list is used for pairs_weight. The bug is clear in the code - it should be if pairs_weight is not None:. I'd make a PR but I have no experience with .pyx code and don't know how to test it after making changes.

catboost version: 1.0.3
Operating System: Ubuntu 20.04
CPU: Intel Core i7
GPU: N/A

@andrey-khropov andrey-khropov changed the title Python exception in Pool() when pairs_weight is a numpy array Python exception in Pool() when pairs_weight is a numpy ndarray Jan 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants