Skip to content

Commit

Permalink
Fix DMatrix slice with feature types. (#6689)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 9, 2021
1 parent 218a5fb commit 5d48d40
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
out.feature_weigths.Copy(this->feature_weigths);

out.feature_names = this->feature_names;
out.feature_types.Resize(this->feature_types.Size());
out.feature_types.Copy(this->feature_types);
out.feature_type_names = this->feature_type_names;

Expand Down
14 changes: 12 additions & 2 deletions tests/python/test_with_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@


class TestPandas:

def test_pandas(self):

df = pd.DataFrame([[1, 2., True], [2, 3., False]],
columns=['a', 'b', 'c'])
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
Expand Down Expand Up @@ -110,6 +108,18 @@ def test_pandas(self):
assert dm.num_row() == 2
assert dm.num_col() == 6

def test_slice(self):
rng = np.random.RandomState(1994)
rows = 100
X = rng.randint(3, 7, size=rows)
X = pd.DataFrame({'f0': X})
y = rng.randn(rows)
ridxs = [1, 2, 3, 4, 5, 6]
m = xgb.DMatrix(X, y)
sliced = m.slice(ridxs)

assert m.feature_types == sliced.feature_types

def test_pandas_categorical(self):
rng = np.random.RandomState(1994)
rows = 100
Expand Down

0 comments on commit 5d48d40

Please sign in to comment.