Skip to content

Commit

Permalink
Implement dt.categories() (#3367)
Browse files Browse the repository at this point in the history
Implement `dt.categories()` to get categories for categorical columns.

WIP for #1691
  • Loading branch information
oleksiyskononenko committed Oct 11, 2022
1 parent 11cf483 commit e610c90
Show file tree
Hide file tree
Showing 11 changed files with 201 additions and 9 deletions.
23 changes: 23 additions & 0 deletions docs/api/dt/categories.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

.. xfunction:: datatable.categories
:src: src/core/expr/fexpr_categories.cc pyfn_categories
:tests: tests/types/test-categorical.py
:cvar: doc_dt_categories
:signature: categories(cols)

.. x-version-added:: 1.1.0

For each column from `cols` get the underlying categories.

Parameters
----------
cols: FExpr
Input categorical data.

return: FExpr
f-expression that returns categories for each column
from `cols`.

except: TypeError
The exception is raised when one of the columns from `cols`
has a non-categorical type.
4 changes: 4 additions & 0 deletions docs/api/fexpr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@
* - :meth:`.as_type()`
- Same as :func:`dt.as_type()`.

* - :meth:`.categories()`
- Same as :func:`dt.categories()`.

* - :meth:`.count()`
- Same as :func:`dt.count()`.

Expand Down Expand Up @@ -303,6 +306,7 @@
.__xor__() <fexpr/__xor__>
.alias() <fexpr/alias>
.as_type() <fexpr/as_type>
.categories() <fexpr/categories>
.count() <fexpr/count>
.countna() <fexpr/countna>
.cummin() <fexpr/cummin>
Expand Down
7 changes: 7 additions & 0 deletions docs/api/fexpr/categories.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

.. xmethod:: datatable.FExpr.categories
:src: src/core/expr/fexpr.cc PyFExpr::categories
:cvar: doc_FExpr_categories
:signature: categories()

Equivalent to :func:`dt.categories(cols)`.
3 changes: 3 additions & 0 deletions docs/api/index-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ Functions
-
* - :func:`as_type()`
- Cast column into another type
* - :func:`categories()`
- Get categories for categorical columns
* - :func:`ifelse()`
- Ternary if operator
* - :func:`shift()`
Expand Down Expand Up @@ -241,6 +243,7 @@ Other
as_type() <dt/as_type>
build_info <dt/build_info>
by() <dt/by>
categories() <dt/categories>
cbind() <dt/cbind>
corr() <dt/corr>
count() <dt/count>
Expand Down
3 changes: 2 additions & 1 deletion src/core/documentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace dt {

extern const char* doc_dt_as_type;
extern const char* doc_dt_by;
extern const char* doc_dt_categories;
extern const char* doc_dt_cbind;
extern const char* doc_dt_corr;
extern const char* doc_dt_count;
Expand Down Expand Up @@ -284,6 +285,7 @@ extern const char* doc_Frame_view;
extern const char* doc_FExpr;
extern const char* doc_FExpr_alias;
extern const char* doc_FExpr_as_type;
extern const char* doc_FExpr_categories;
extern const char* doc_FExpr_count;
extern const char* doc_FExpr_countna;
extern const char* doc_FExpr_cummax;
Expand Down Expand Up @@ -345,6 +347,5 @@ extern const char* doc_Type_name;




} // namespace dt
#endif
10 changes: 10 additions & 0 deletions src/core/expr/fexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,16 @@ DECLARE_METHOD(&PyFExpr::as_type)
->n_required_args(1);


oobj PyFExpr::categories(const XArgs&) {
auto categoriesFn = oobj::import("datatable", "categories");
return categoriesFn.call({this});
}

DECLARE_METHOD(&PyFExpr::categories)
->name("categories")
->docs(dt::doc_FExpr_categories);


oobj PyFExpr::count(const XArgs&) {
auto countFn = oobj::import("datatable", "count");
return countFn.call({this});
Expand Down
1 change: 1 addition & 0 deletions src/core/expr/fexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class PyFExpr : public py::XObject<PyFExpr> {

py::oobj alias(const py::XArgs&);
py::oobj as_type(const py::XArgs&);
py::oobj categories(const py::XArgs&);
py::oobj count(const py::XArgs&);
py::oobj countna(const py::XArgs&);
py::oobj cummin(const py::XArgs&);
Expand Down
92 changes: 92 additions & 0 deletions src/core/expr/fexpr_categories.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//------------------------------------------------------------------------------
// Copyright 2022 H2O.ai
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
// IN THE SOFTWARE.
//------------------------------------------------------------------------------
#include "_dt.h"
#include "documentation.h"
#include "column/const.h"
#include "expr/eval_context.h"
#include "expr/fexpr_func.h"
#include "python/xargs.h"
namespace dt {
namespace expr {


//------------------------------------------------------------------------------
// FExpr_categories
//------------------------------------------------------------------------------

class FExpr_Categories : public FExpr_Func {
private:
ptrExpr arg_;

public:
FExpr_Categories(ptrExpr &&arg)
: arg_(std::move(arg)) {}


std::string repr() const override {
std::string out = "categories";
out += '(';
out += arg_->repr();
out += ')';
return out;
}


Workframe evaluate_n(EvalContext &ctx) const override {
Workframe wf = arg_->evaluate_n(ctx);

for (size_t i = 0; i < wf.ncols(); ++i) {
Column col = wf.retrieve_column(i);
if (!col.type().is_categorical()) {
throw TypeError() << "Invalid column of type `" << col.stype()
<< "` in " << repr();
}
Column categories = col.n_children()? col.child(0)
: Const_ColumnImpl::make_na_column(1);
wf.replace_column(i, std::move(categories));
}
return wf;
}

};



//------------------------------------------------------------------------------
// Python-facing `categories()` function
//------------------------------------------------------------------------------

static py::oobj pyfn_categories(const py::XArgs& args) {
auto cols = args[0].to_oobj();
return PyFExpr::make(new FExpr_Categories(as_fexpr(cols)));
}


DECLARE_PYFN(&pyfn_categories)
->name("categories")
->docs(doc_dt_categories)
->arg_names({"cols"})
->n_positional_args(1)
->n_required_args(1);


}} // dt::expr
7 changes: 4 additions & 3 deletions src/core/types/type_categorical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,14 @@ Column Type_Cat::cast_column(Column&& col) const {
case SType::CAT8: cast_non_compound<uint8_t>(col); break;
case SType::CAT16: cast_non_compound<uint16_t>(col); break;
case SType::CAT32: cast_non_compound<uint32_t>(col); break;
default: throw RuntimeError() << "Unknown categorical type";
default: throw RuntimeError()
<< "Unknown categorical type: " << stype();
}
break;

default:
throw NotImplError() << "Unable to cast a column of type `" << col.type()
<< "` into `" << to_string() << "`";
throw NotImplError() << "Unable to cast a column of type `"
<< col.type() << "` into `" << to_string() << "`";
}

return std::move(col);
Expand Down
4 changes: 3 additions & 1 deletion src/datatable/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
#-------------------------------------------------------------------------------
# Copyright 2018-2021 H2O.ai
# Copyright 2018-2022 H2O.ai
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
Expand All @@ -26,6 +26,7 @@
from .lib._datatable import (
as_type,
by,
categories,
cbind,
cumcount,
cummax,
Expand Down Expand Up @@ -87,6 +88,7 @@
"as_type",
"bool8",
"by",
"categories",
"cbind",
"corr",
"count",
Expand Down
56 changes: 52 additions & 4 deletions tests/types/test-categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_str_to_cat(cat_type, type):
dt.Type.cat32])
@pytest.mark.parametrize('type', [dt.Type.date32,
dt.Type.time64])
def test_str_to_cat(cat_type, type):
def test_datetime_to_cat(cat_type, type):
from datetime import datetime as d
src = [d(1991, 1, 9, 10, 2, 3, 500),
d(2014, 7, 2, 1, 3, 5, 100),
Expand All @@ -340,11 +340,8 @@ def test_str_to_cat(cat_type, type):
d(2020, 2, 22, 1, 3, 5, 129),
d(1950, 8, 2, 4, 3, 5, 10)]
DT = dt.Frame(src, type=type)
print(DT)
DT[0] = cat_type(type)
print(DT)
DT_ref = dt.Frame(src, type=cat_type(type))
print(DT_ref)
assert_equals(DT, DT_ref)


Expand Down Expand Up @@ -476,3 +473,54 @@ def test_repr_numbers_in_terminal(t):
" 4 | 7 \n"
"[5 rows x 1 column]\n"
)



#-------------------------------------------------------------------------------
# Getting categories
#-------------------------------------------------------------------------------

def test_categories_wrong_type():
DT = dt.Frame(range(10))
msg = r"Invalid column of type int32 in categories\(f\.C0\)"
with pytest.raises(TypeError, match=msg):
DT[:, dt.categories(f.C0)]


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_void(cat_type):
src = [None] * 11
DT = dt.Frame(src, type=cat_type(dt.Type.void))
DT_cats = DT[:, dt.categories(f[:])]
DT_ref = dt.Frame([None])
assert_equals(DT_cats, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_simple(cat_type):
src = ["cat", "dog", "mouse", "cat"]
DT = dt.Frame([src], type=cat_type(dt.Type.str32))
DT_cats = DT[:, dt.categories(f.C0)]
DT_ref = dt.Frame(["cat", "dog", "mouse"])
assert_equals(DT_cats, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_multicolumn(cat_type):
N = 123
src_int = [None, 100, 500, None, 100, 100500, 100, 500] * N
src_str = [None, "dog", "mouse", None, "dog", "cat", "dog", "mouse"] * N
DT = dt.Frame([src_int, src_str],
types=[cat_type(dt.Type.int32), cat_type(dt.Type.str32)])
DT_cats = DT[:, dt.categories(f[:])]
DT_ref = dt.Frame([[None, 100, 500, 100500],
[None, "cat", "dog", "mouse"]])
assert_equals(DT_cats, DT_ref)


0 comments on commit e610c90

Please sign in to comment.