Skip to content

Commit

Permalink
Fix casting void columns to categoricals (#3362)
Browse files Browse the repository at this point in the history
Fix casting `void` columns to categoricals and add corresponding tests.

WIP for #1691
  • Loading branch information
oleksiyskononenko committed Sep 27, 2022
1 parent 79fcaff commit 0e79950
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 20 deletions.
15 changes: 11 additions & 4 deletions src/core/column.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,15 @@ dt::SType Column::stype() const noexcept {
return impl_->type_.stype();
}

dt::SType Column::data_stype() const noexcept {
if (impl_->type_.is_categorical()) {
if (n_children()) return child(0).stype();
else return dt::SType::VOID;
} else {
return stype();
}
}

dt::LType Column::ltype() const noexcept {
return dt::stype_to_ltype(impl_->stype());
}
Expand Down Expand Up @@ -273,8 +282,7 @@ static inline py::oobj getelem(const Column& col, size_t i) {
}

py::oobj Column::get_element_as_pyobject(size_t i) const {
dt::SType st = type().is_categorical()? child(0).stype()
: stype();
dt::SType st = data_stype();

switch (st) {
case dt::SType::VOID: return py::None();
Expand Down Expand Up @@ -328,8 +336,7 @@ py::oobj Column::get_element_as_pyobject(size_t i) const {
}

bool Column::get_element_isvalid(size_t i) const {
dt::SType st = type().is_categorical()? child(0).stype()
: stype();
dt::SType st = data_stype();

switch (st) {
case dt::SType::VOID: return false;
Expand Down
8 changes: 7 additions & 1 deletion src/core/column.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,14 @@ class Column
size_t nrows() const noexcept;
size_t na_count() const;
const dt::Type& type() const noexcept;
dt::SType stype() const noexcept;
dt::LType ltype() const noexcept;
dt::SType stype() const noexcept;

// For categorical columns this method will return the stype of the data,
// the column is backed up with. For all the other column types,
// this method is equivalent to `stype()`.
dt::SType data_stype() const noexcept;

size_t elemsize() const noexcept;
bool is_fixedwidth() const noexcept;
bool is_virtual() const noexcept;
Expand Down
3 changes: 1 addition & 2 deletions src/core/column/latent.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ class Latent_ColumnImpl : public Virtual_ColumnImpl {
}

static void vivify(const Column& col) {
dt::SType st = col.type().is_categorical()? col.child(0).stype()
: col.stype();
dt::SType st = col.data_stype();
switch (st) {
case SType::VOID:
case SType::BOOL:
Expand Down
3 changes: 1 addition & 2 deletions src/core/frame/repr/text_column.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,7 @@ tstring Data_TextColumn::_render_value_string(const Column& col, size_t i) const


tstring Data_TextColumn::_render_value(const Column& col, size_t i) const {
SType st = col.type().is_categorical()? col.child(0).stype()
: col.stype();
SType st = col.data_stype();

switch (st) {
case SType::VOID: return na_value_;
Expand Down
20 changes: 10 additions & 10 deletions src/core/types/type_categorical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ void Type_Cat::cast_obj_column_(Column& col) const {
Groupby gb = std::move(res.second);
auto offsets = gb.offsets_r();

Buffer buf = Buffer::mem(col.nrows() * sizeof(T));
Buffer buf_cat = Buffer::mem(gb.size() * sizeof(int32_t));
auto buf_ptr = static_cast<T*>(buf.xptr());
auto buf_cat_ptr = static_cast<int32_t*>(buf_cat.xptr());
Buffer buf_codes = Buffer::mem(col.nrows() * sizeof(T));
Buffer buf_cats = Buffer::mem(gb.size() * sizeof(int32_t));
auto buf_codes_ptr = static_cast<T*>(buf_codes.xptr());
auto buf_cats_ptr = static_cast<int32_t*>(buf_cats.xptr());

const size_t MAX_CATS = std::numeric_limits<T>::max() + size_t(1);

Expand All @@ -177,22 +177,22 @@ void Type_Cat::cast_obj_column_(Column& col) const {
}

// Fill out two buffers:
// - `buf_cat` with row indices of unique elements (one element per category)
// - `buf` with the codes of categories (group ids).
// - `buf_cats` with row indices of unique elements (one element per category)
// - `buf_codes` with the codes of categories (group ids).
dt::parallel_for_dynamic(gb.size(),
[&](size_t i) {
size_t jj;
ri.get_element(static_cast<size_t>(offsets[i]), &jj);
buf_cat_ptr[i] = static_cast<int32_t>(jj);
buf_cats_ptr[i] = static_cast<int32_t>(jj);

for (int32_t j = offsets[i]; j < offsets[i + 1]; ++j) {
ri.get_element(static_cast<size_t>(j), &jj);
buf_ptr[static_cast<size_t>(jj)] = static_cast<T>(i);
buf_codes_ptr[static_cast<size_t>(jj)] = static_cast<T>(i);
}
});

// Modify `col` in-place by only leaving one element per a category
const RowIndex ri_cat(std::move(buf_cat), RowIndex::ARR32);
const RowIndex ri_cat(std::move(buf_cats), RowIndex::ARR32);
col.apply_rowindex(ri_cat);
col.materialize();

Expand All @@ -205,7 +205,7 @@ void Type_Cat::cast_obj_column_(Column& col) const {
col = Column(new Categorical_ColumnImpl<T>(
nrows,
std::move(val),
std::move(buf),
std::move(buf_codes),
std::move(col)
));
}
Expand Down
30 changes: 29 additions & 1 deletion tests/types/test-categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,36 @@ def test_create_multicolumn(t):
assert_equals(DT2, DT2[:, :])



#-------------------------------------------------------------------------------
# Casting to other types
#-------------------------------------------------------------------------------

@pytest.mark.parametrize('t', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_void_to_cat(t):
src = [None] * 11
DT = dt.Frame(src)
DT[0] = t(dt.Type.str32)
DT_ref = dt.Frame(src, type=t(dt.Type.str32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('t', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_obj_to_cat(t):
src = [None, "cat", "cat", "dog", "mouse", None, "panda", "dog"]
DT = dt.Frame(A=src, type=object)
DT['A'] = t(dt.Type.str32)
DT_ref = dt.Frame(A=src, type=t(dt.Type.str32))
assert_equals(DT, DT_ref)



#-------------------------------------------------------------------------------
# Conversion
# Conversion to other formats
#-------------------------------------------------------------------------------

@pytest.mark.parametrize('t', [dt.Type.cat8,
Expand Down

0 comments on commit 0e79950

Please sign in to comment.