diff --git a/src/core/column.cc b/src/core/column.cc index 69c44f2101..e48e56bfe7 100644 --- a/src/core/column.cc +++ b/src/core/column.cc @@ -168,6 +168,10 @@ const dt::Type& Column::type() const noexcept { return impl_->type_; } +const dt::Type& Column::data_type() const noexcept { + return impl_->data_type(); +} + dt::SType Column::stype() const noexcept { return impl_->type_.stype(); } diff --git a/src/core/column.h b/src/core/column.h index 3ccead1309..dd2d11a551 100644 --- a/src/core/column.h +++ b/src/core/column.h @@ -119,6 +119,7 @@ class Column size_t nrows() const noexcept; size_t na_count() const; const dt::Type& type() const noexcept; + const dt::Type& data_type() const noexcept; dt::LType ltype() const noexcept; dt::SType stype() const noexcept; diff --git a/src/core/column/column_impl.cc b/src/core/column/column_impl.cc index 7f209d0e49..caf11bcc1e 100644 --- a/src/core/column/column_impl.cc +++ b/src/core/column/column_impl.cc @@ -236,11 +236,8 @@ size_t ColumnImpl::null_count() const { SType ColumnImpl::data_stype() const { - if (type_.is_categorical()) { - return n_children()? child(0).stype() - : SType::VOID; - } - return stype(); + return type_.is_categorical()? child(0).stype() + : stype(); } diff --git a/src/core/column/column_impl.h b/src/core/column/column_impl.h index 3ea967cc8e..9d75b94d96 100644 --- a/src/core/column/column_impl.h +++ b/src/core/column/column_impl.h @@ -96,6 +96,10 @@ class ColumnImpl SType stype() const { return type_.stype(); } SType data_stype() const; const Type& type() const { return type_; } + const Type& data_type() const { + return type_.is_categorical()? child(0).type() + : type_; + } virtual bool is_virtual() const noexcept = 0; virtual bool computationally_expensive() const { return false; } virtual size_t memory_footprint() const noexcept = 0; diff --git a/src/core/column/latent.h b/src/core/column/latent.h index b23461eb46..4f0c2ddcb4 100644 --- a/src/core/column/latent.h +++ b/src/core/column/latent.h @@ -94,7 +94,7 @@ class Latent_ColumnImpl : public Virtual_ColumnImpl { case SType::ARR32: case SType::ARR64: vivify(col); break; default: - throw RuntimeError() << "Unknown stype " << col.stype(); // LCOV_EXCL_LINE + throw RuntimeError() << "Unknown stype " << st; // LCOV_EXCL_LINE } } diff --git a/src/core/column/view.cc b/src/core/column/view.cc index 09ba373310..e00b038b3c 100644 --- a/src/core/column/view.cc +++ b/src/core/column/view.cc @@ -30,7 +30,7 @@ namespace dt { SliceView_ColumnImpl::SliceView_ColumnImpl( Column&& col, const RowIndex& ri) - : Virtual_ColumnImpl(ri.size(), col.type()), + : Virtual_ColumnImpl(ri.size(), col.data_type()), arg_(std::move(col)), start_(ri.slice_start()), step_(ri.slice_step()) @@ -41,7 +41,7 @@ SliceView_ColumnImpl::SliceView_ColumnImpl( SliceView_ColumnImpl::SliceView_ColumnImpl( Column&& col, size_t start, size_t count, size_t step) - : Virtual_ColumnImpl(count, col.type()), + : Virtual_ColumnImpl(count, col.data_type()), arg_(std::move(col)), start_(start), step_(step) @@ -97,7 +97,7 @@ template <> const int64_t* get_indices(const RowIndex& ri) { return ri.indices64 template ArrayView_ColumnImpl::ArrayView_ColumnImpl( Column&& col, const RowIndex& ri, size_t nrows) - : Virtual_ColumnImpl(nrows, col.stype()), + : Virtual_ColumnImpl(nrows, col.data_type()), arg(std::move(col)) { set_rowindex(ri); @@ -169,7 +169,7 @@ template class ArrayView_ColumnImpl; static Column _make_view(Column&& col, const RowIndex& ri) { // This covers the case when ri.size()==0, and when all elements are NAs if (ri.is_all_missing()) { - return Column::new_na_column(ri.size(), col.stype()); + return Column::new_na_column(ri.size(), col.data_type()); } switch (ri.type()) { case RowIndexType::SLICE: diff --git a/src/core/types/type_categorical.cc b/src/core/types/type_categorical.cc index a48fbbd50b..9454c1a01b 100644 --- a/src/core/types/type_categorical.cc +++ b/src/core/types/type_categorical.cc @@ -131,8 +131,9 @@ size_t Type_Cat::hash() const noexcept { Column Type_Cat::cast_column(Column&& col) const { switch (col.stype()) { case SType::VOID: - return Column::new_na_column(col.nrows(), make_type()); - + // We could immediately `return Column::new_na_column(col.nrows(), make_type());` + // here, however, it appears to be not good from + // the consistency point of view. case SType::BOOL: case SType::INT8: case SType::INT16: diff --git a/tests/types/test-categorical.py b/tests/types/test-categorical.py index 013e7cca2c..c17784a03a 100644 --- a/tests/types/test-categorical.py +++ b/tests/types/test-categorical.py @@ -27,6 +27,7 @@ from tests import assert_equals + #------------------------------------------------------------------------------- # Type object #------------------------------------------------------------------------------- @@ -148,8 +149,8 @@ def test_categorical_create_from_zero_rows(t): def test_categorical_create_from_void(t): src = [None] * 10 DT1 = dt.Frame(src) - DT2 = dt.Frame(src, types = [t(dt.Type.bool8)]) - assert DT2.type == t(dt.Type.bool8) + DT2 = dt.Frame(src, types = [t(dt.Type.void)]) + assert DT2.type == t(dt.Type.void) assert DT1.shape == DT2.shape assert DT1.names == DT2.names assert DT1.to_list() == DT2.to_list() @@ -594,6 +595,31 @@ def test_categorical_repr_numbers_in_terminal(t): ) +#------------------------------------------------------------------------------- +# [i, j] access +#------------------------------------------------------------------------------- + +@pytest.mark.parametrize('cat_type', [dt.Type.cat8, + dt.Type.cat16, + dt.Type.cat32]) +def test_categorical_element_access(cat_type): + src = ["cat", "dog", "hotdog", None, "cat", None] + DT = dt.Frame(src, type=cat_type(str)) + assert DT[0, 0] == "cat" + assert DT[3, 0] is None + + +@pytest.mark.parametrize('cat_type', [dt.Type.cat8, + dt.Type.cat16, + dt.Type.cat32]) +def test_categorical_slice(cat_type): + src = ["cat", "dog", "hotdog", None, "cat", None] + DT = dt.Frame(src, type=cat_type(str)) + assert_equals(DT[0, :], dt.Frame([src[0]])) + assert_equals(DT[0:, :], dt.Frame(src)) + assert_equals(DT[2:3, :], dt.Frame([src[2]])) + assert_equals(DT[[3, 4], :], dt.Frame([src[3], src[4]])) + #------------------------------------------------------------------------------- # Categories