Skip to content

Commit

Permalink
Implement slicing for categorical columns (#3379)
Browse files Browse the repository at this point in the history
WIP for #1691
  • Loading branch information
oleksiyskononenko committed Apr 25, 2023
1 parent dbbc664 commit 887ad6b
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 14 deletions.
4 changes: 4 additions & 0 deletions src/core/column.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ const dt::Type& Column::type() const noexcept {
return impl_->type_;
}

const dt::Type& Column::data_type() const noexcept {
return impl_->data_type();
}

dt::SType Column::stype() const noexcept {
return impl_->type_.stype();
}
Expand Down
1 change: 1 addition & 0 deletions src/core/column.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class Column
size_t nrows() const noexcept;
size_t na_count() const;
const dt::Type& type() const noexcept;
const dt::Type& data_type() const noexcept;
dt::LType ltype() const noexcept;
dt::SType stype() const noexcept;

Expand Down
7 changes: 2 additions & 5 deletions src/core/column/column_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,8 @@ size_t ColumnImpl::null_count() const {


SType ColumnImpl::data_stype() const {
if (type_.is_categorical()) {
return n_children()? child(0).stype()
: SType::VOID;
}
return stype();
return type_.is_categorical()? child(0).stype()
: stype();
}


Expand Down
4 changes: 4 additions & 0 deletions src/core/column/column_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class ColumnImpl
SType stype() const { return type_.stype(); }
SType data_stype() const;
const Type& type() const { return type_; }
const Type& data_type() const {
return type_.is_categorical()? child(0).type()
: type_;
}
virtual bool is_virtual() const noexcept = 0;
virtual bool computationally_expensive() const { return false; }
virtual size_t memory_footprint() const noexcept = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/core/column/latent.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class Latent_ColumnImpl : public Virtual_ColumnImpl {
case SType::ARR32:
case SType::ARR64: vivify<Column>(col); break;
default:
throw RuntimeError() << "Unknown stype " << col.stype(); // LCOV_EXCL_LINE
throw RuntimeError() << "Unknown stype " << st; // LCOV_EXCL_LINE
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/core/column/view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace dt {

SliceView_ColumnImpl::SliceView_ColumnImpl(
Column&& col, const RowIndex& ri)
: Virtual_ColumnImpl(ri.size(), col.type()),
: Virtual_ColumnImpl(ri.size(), col.data_type()),
arg_(std::move(col)),
start_(ri.slice_start()),
step_(ri.slice_step())
Expand All @@ -41,7 +41,7 @@ SliceView_ColumnImpl::SliceView_ColumnImpl(

SliceView_ColumnImpl::SliceView_ColumnImpl(
Column&& col, size_t start, size_t count, size_t step)
: Virtual_ColumnImpl(count, col.type()),
: Virtual_ColumnImpl(count, col.data_type()),
arg_(std::move(col)),
start_(start),
step_(step)
Expand Down Expand Up @@ -97,7 +97,7 @@ template <> const int64_t* get_indices(const RowIndex& ri) { return ri.indices64
template <typename T>
ArrayView_ColumnImpl<T>::ArrayView_ColumnImpl(
Column&& col, const RowIndex& ri, size_t nrows)
: Virtual_ColumnImpl(nrows, col.stype()),
: Virtual_ColumnImpl(nrows, col.data_type()),
arg(std::move(col))
{
set_rowindex(ri);
Expand Down Expand Up @@ -169,7 +169,7 @@ template class ArrayView_ColumnImpl<int64_t>;
static Column _make_view(Column&& col, const RowIndex& ri) {
// This covers the case when ri.size()==0, and when all elements are NAs
if (ri.is_all_missing()) {
return Column::new_na_column(ri.size(), col.stype());
return Column::new_na_column(ri.size(), col.data_type());
}
switch (ri.type()) {
case RowIndexType::SLICE:
Expand Down
5 changes: 3 additions & 2 deletions src/core/types/type_categorical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ size_t Type_Cat::hash() const noexcept {
Column Type_Cat::cast_column(Column&& col) const {
switch (col.stype()) {
case SType::VOID:
return Column::new_na_column(col.nrows(), make_type());

// We could immediately `return Column::new_na_column(col.nrows(), make_type());`
// here, however, it appears to be not good from
// the consistency point of view.
case SType::BOOL:
case SType::INT8:
case SType::INT16:
Expand Down
30 changes: 28 additions & 2 deletions tests/types/test-categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tests import assert_equals



#-------------------------------------------------------------------------------
# Type object
#-------------------------------------------------------------------------------
Expand Down Expand Up @@ -148,8 +149,8 @@ def test_categorical_create_from_zero_rows(t):
def test_categorical_create_from_void(t):
src = [None] * 10
DT1 = dt.Frame(src)
DT2 = dt.Frame(src, types = [t(dt.Type.bool8)])
assert DT2.type == t(dt.Type.bool8)
DT2 = dt.Frame(src, types = [t(dt.Type.void)])
assert DT2.type == t(dt.Type.void)
assert DT1.shape == DT2.shape
assert DT1.names == DT2.names
assert DT1.to_list() == DT2.to_list()
Expand Down Expand Up @@ -594,6 +595,31 @@ def test_categorical_repr_numbers_in_terminal(t):
)


#-------------------------------------------------------------------------------
# [i, j] access
#-------------------------------------------------------------------------------

@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categorical_element_access(cat_type):
src = ["cat", "dog", "hotdog", None, "cat", None]
DT = dt.Frame(src, type=cat_type(str))
assert DT[0, 0] == "cat"
assert DT[3, 0] is None


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categorical_slice(cat_type):
src = ["cat", "dog", "hotdog", None, "cat", None]
DT = dt.Frame(src, type=cat_type(str))
assert_equals(DT[0, :], dt.Frame([src[0]]))
assert_equals(DT[0:, :], dt.Frame(src))
assert_equals(DT[2:3, :], dt.Frame([src[2]]))
assert_equals(DT[[3, 4], :], dt.Frame([src[3], src[4]]))


#-------------------------------------------------------------------------------
# Categories
Expand Down

0 comments on commit 887ad6b

Please sign in to comment.