Skip to content

Commit

Permalink
Implement CategoricalColumn_Impl and support for basic operations (#…
Browse files Browse the repository at this point in the history
…3158)

- implemented `CategoricalColumn_Impl`;
- added support for categorical columns in `dt.Frame()`;
- added support for categorical columns in a terminal, also allowing the element access through `[i, j]` selector;
- added and modified some relevant tests.

WIP for #1691
  • Loading branch information
oleksiyskononenko committed Sep 8, 2021
1 parent 1af837e commit 95662d0
Show file tree
Hide file tree
Showing 15 changed files with 543 additions and 15 deletions.
12 changes: 9 additions & 3 deletions src/core/column.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,10 @@ static inline py::oobj getelem(const Column& col, size_t i) {
}

py::oobj Column::get_element_as_pyobject(size_t i) const {
switch (stype()) {
dt::SType st = type().is_categorical()? child(0).stype()
: stype();

switch (st) {
case dt::SType::VOID: return py::None();
case dt::SType::BOOL: {
int8_t x;
Expand Down Expand Up @@ -325,7 +328,10 @@ py::oobj Column::get_element_as_pyobject(size_t i) const {
}

bool Column::get_element_isvalid(size_t i) const {
switch (stype()) {
dt::SType st = type().is_categorical()? child(0).stype()
: stype();

switch (st) {
case dt::SType::VOID: return false;
case dt::SType::BOOL:
case dt::SType::INT8: {
Expand Down Expand Up @@ -361,7 +367,7 @@ bool Column::get_element_isvalid(size_t i) const {
}
default:
throw NotImplError() << "Unable to check validity of the element "
<< "for stype: `" << stype() << "`";
<< "for type: `" << type() << "`";
}
}

Expand Down
158 changes: 158 additions & 0 deletions src/core/column/categorical.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
//------------------------------------------------------------------------------
// Copyright 2021 H2O.ai
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
// IN THE SOFTWARE.
//------------------------------------------------------------------------------
#include "column/categorical.h"

namespace dt {


template <typename T>
Type _type_from_cattype(const Type& tcat) {
Type t;
switch (sizeof(T)) {
case 1: t = Type::cat8(tcat); break;
case 2: t = Type::cat16(tcat); break;
case 4: t = Type::cat32(tcat); break;
default : throw RuntimeError() << "Type is not supported";
}
return t;
}


template <typename T>
Categorical_ColumnImpl<T>::Categorical_ColumnImpl(
size_t nrows, Buffer&& codes, Column&& categories
) : Virtual_ColumnImpl(nrows, _type_from_cattype<T>(categories.type())),
codes_(std::move(codes)),
categories_(std::move(categories))
{
xassert(codes_.size() >= sizeof(T) * nrows);
}


template <typename T>
void Categorical_ColumnImpl<T>::materialize(Column&, bool) {
categories_.materialize();
}


template <typename T>
ColumnImpl* Categorical_ColumnImpl<T>::clone() const {
return
new Categorical_ColumnImpl(nrows_, Buffer(codes_), Column(categories_));
}


template <typename T>
size_t Categorical_ColumnImpl<T>::n_children() const noexcept {
return 1;
}


template <typename T>
const Column& Categorical_ColumnImpl<T>::child(size_t i) const {
xassert(i == 0);
return categories_;
}


template <typename T>
size_t Categorical_ColumnImpl<T>::num_buffers() const noexcept {
return 1;
}


template <typename T>
Buffer Categorical_ColumnImpl<T>::get_buffer() const noexcept {
return codes_;
}


template <typename T>
template <typename U>
bool Categorical_ColumnImpl<T>::get_element_(size_t i, U* out) const {
xassert(i < nrows_);
size_t ii = static_cast<size_t>(codes_.get_element<T>(i));
bool valid = categories_.get_element(ii, out);
return valid;
}


template <typename T>
bool Categorical_ColumnImpl<T>::get_element(size_t i, int8_t* out) const {
return get_element_(i, out);
}


template <typename T>
bool Categorical_ColumnImpl<T>::get_element(size_t i, int16_t* out) const {
return get_element_(i, out);
}


template <typename T>
bool Categorical_ColumnImpl<T>::get_element(size_t i, int32_t* out) const {
return get_element_(i, out);
}


template <typename T>
bool Categorical_ColumnImpl<T>::get_element(size_t i, int64_t* out) const {
return get_element_(i, out);
}


template <typename T>
bool Categorical_ColumnImpl<T>::get_element(size_t i, float* out) const {
return get_element_(i, out);
}


template <typename T>
bool Categorical_ColumnImpl<T>::get_element(size_t i, double* out) const {
return get_element_(i, out);
}


template <typename T>
bool Categorical_ColumnImpl<T>::get_element(size_t i, CString* out) const {
return get_element_(i, out);
}


template <typename T>
bool Categorical_ColumnImpl<T>::get_element(size_t i, py::oobj* out) const {
return get_element_(i, out);
}


template <typename T>
bool Categorical_ColumnImpl<T>::get_element(size_t i, Column* out) const {
return get_element_(i, out);
}


template class Categorical_ColumnImpl<uint8_t>;
template class Categorical_ColumnImpl<uint16_t>;
template class Categorical_ColumnImpl<uint32_t>;


} // namespace dt
70 changes: 70 additions & 0 deletions src/core/column/categorical.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//------------------------------------------------------------------------------
// Copyright 2021 H2O.ai
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
// IN THE SOFTWARE.
//------------------------------------------------------------------------------
#ifndef dt_COLUMN_CATEGORICAL_h
#define dt_COLUMN_CATEGORICAL_h

#include "column/virtual.h"


namespace dt {

template <typename T>
class Categorical_ColumnImpl : public Virtual_ColumnImpl {
private:
Buffer codes_;
Column categories_;

public:
Categorical_ColumnImpl(size_t nrows, Buffer&& codes, Column&& categories);

ColumnImpl* clone() const override;
size_t n_children() const noexcept override;
const Column& child(size_t i) const override;
size_t num_buffers() const noexcept;
Buffer get_buffer() const noexcept;

template <class U>
bool get_element_(size_t, U*) const;

bool get_element(size_t, int8_t*) const override;
bool get_element(size_t, int16_t*) const override;
bool get_element(size_t, int32_t*) const override;
bool get_element(size_t, int64_t*) const override;
bool get_element(size_t, float*) const override;
bool get_element(size_t, double*) const override;
bool get_element(size_t, CString*) const override;
bool get_element(size_t, py::oobj*) const override;
bool get_element(size_t, Column*) const override;

void materialize(Column&, bool) override;
};


extern template class Categorical_ColumnImpl<uint8_t>;
extern template class Categorical_ColumnImpl<uint16_t>;
extern template class Categorical_ColumnImpl<uint32_t>;


} // namespace dt


#endif
4 changes: 2 additions & 2 deletions src/core/column/column_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class ColumnImpl
// Data buffers
//------------------------------------
public:
virtual NaStorage get_na_storage_method() const noexcept = 0;
virtual size_t get_num_data_buffers() const noexcept = 0;
virtual NaStorage get_na_storage_method() const noexcept = 0;
virtual size_t get_num_data_buffers() const noexcept = 0;
virtual bool is_data_editable(size_t k) const = 0;
virtual size_t get_data_size(size_t k) const = 0;
virtual const void* get_data_readonly(size_t k) const = 0;
Expand Down
4 changes: 2 additions & 2 deletions src/core/column/range.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ namespace dt {
* column is created when a `range` object is passed into Frame
* constructor.
*
* By default, this column will take stype INT32. However, if the
* range is sufficiently large, the stype will become INT64. However,
* By default, this column will take stype INT32. If the range is
* sufficiently large, the stype will become INT64. However,
* we do not support further promotion into FLOAT64 stype for even
* larger integers.
*/
Expand Down
5 changes: 4 additions & 1 deletion src/core/frame/repr/text_column.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,10 @@ tstring Data_TextColumn::_render_value_string(const Column& col, size_t i) const


tstring Data_TextColumn::_render_value(const Column& col, size_t i) const {
switch (col.stype()) {
SType st = col.type().is_categorical()? col.child(0).stype()
: col.stype();

switch (st) {
case SType::VOID: return na_value_;
case SType::BOOL: return _render_value_bool(col, i);
case SType::INT8: return _render_value_int<int8_t>(col, i);
Expand Down
6 changes: 2 additions & 4 deletions src/core/models/dt_ftrl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,13 @@ FtrlFitOutput Ftrl<T>::fit(T(*linkfn)(T),
init_helper_params();

// Define features, weight pointers, feature importances storage,
// as well as column hashers.
// column hashers, etc.
define_features();
init_weights();
if (dt_fi == nullptr) create_fi();
auto data_fi = static_cast<T*>(dt_fi->get_column(1).get_data_editable());
auto hashers = create_hashers(dt_X_train);

// Obtain rowindex and data pointers for the target column(s).
const Column& target_col0_train = dt_y_train->get_column(0);
auto data_fi = static_cast<T*>(dt_fi->get_column(1).get_data_editable());

// Since `nepochs` can be a float value
// - the model is trained `niterations - 1` times on
Expand Down
9 changes: 9 additions & 0 deletions src/core/stype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ const char* stype_name(SType stype) {
case SType::TIME64 : return "time64";
case SType::DATE32 : return "date32";
case SType::OBJ : return "obj64";
case SType::CAT8 : return "cat8";
case SType::CAT16 : return "cat16";
case SType::CAT32 : return "cat32";
case SType::AUTO : return "auto";
default : return "unknown";
}
Expand All @@ -104,6 +107,9 @@ size_t stype_elemsize(SType stype) {
case SType::TIME64 : return sizeof(element_t<SType::TIME64>);
case SType::DATE32 : return sizeof(element_t<SType::DATE32>);
case SType::OBJ : return sizeof(element_t<SType::OBJ>);
case SType::CAT8 : return sizeof(element_t<SType::CAT8>);
case SType::CAT16 : return sizeof(element_t<SType::CAT16>);
case SType::CAT32 : return sizeof(element_t<SType::CAT32>);
default : return 0;
}
}
Expand Down Expand Up @@ -156,6 +162,9 @@ void init_py_stype_objs(PyObject* stype_enum) {
_init_py_stype(SType::TIME64);
_init_py_stype(SType::DATE32);
_init_py_stype(SType::OBJ);
_init_py_stype(SType::CAT8);
_init_py_stype(SType::CAT16);
_init_py_stype(SType::CAT32);
}


Expand Down
3 changes: 3 additions & 0 deletions src/core/stype.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ template <> struct _elt<SType::STR64> { using t = uint64_t; };
template <> struct _elt<SType::ARR32> { using t = uint32_t; };
template <> struct _elt<SType::ARR64> { using t = uint64_t; };
template <> struct _elt<SType::OBJ> { using t = PyObject*; };
template <> struct _elt<SType::CAT8> { using t = uint8_t; };
template <> struct _elt<SType::CAT16> { using t = uint16_t; };
template <> struct _elt<SType::CAT32> { using t = uint32_t; };

template <SType s>
using element_t = typename _elt<s>::t;
Expand Down

0 comments on commit 95662d0

Please sign in to comment.