Skip to content

Commit

Permalink
Add basic support for Grouping::GtoFEW (#3370)
Browse files Browse the repository at this point in the history
It seems that `dt.categories()` is the first datatable function, that needs to produce `Grouping::GtoFEW` columns. That's because the number of categories in a categorical column could be anything between `0` and `nrows - 1`. Currently, datatable doesn't really support `Grouping::GtoFEW`, but may need it for the cases when `dt.categories()` is combined with other f-expressions, or when `dt.categories()` is applied to columns that have different number of underlying categories.

In this PR we
- add some basic support for `Grouping::GtoFEW` grouping mode;
- adjust `dt.categories()` to produce `Grouping::GtoFEW` columns, that in the case of uneven number of rows are  promoted to `Grouping::GtoALL`;
- do minor refactoring in `dt.alias()` function.

WIP for #1691
  • Loading branch information
oleksiyskononenko authored and samukweku committed Jan 3, 2023
1 parent fd310d0 commit a836fd6
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/core/expr/declarations.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ namespace expr {
// Each group is mapped to exactly `groupsize` rows. This is the
// most common grouping mode. Any simple column, or a function of
// a simple column will be using this mode. Few groupby functions
// may use this mode too
// may use this mode too.
//
// GtoANY
// Groups may be mapped to any number of rows, including having
Expand Down
4 changes: 1 addition & 3 deletions src/core/expr/fexpr_alias.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@ Workframe FExpr_Alias::evaluate_n(EvalContext& ctx) const {
auto gmode = wf.get_grouping_mode();

for (size_t i = 0; i < wf.ncols(); ++i) {
Workframe arg_out(ctx);
Column col = wf.retrieve_column(i);
arg_out.add_column(std::move(col), std::string(names_[i]), gmode);
wf_out.cbind( std::move(arg_out) );
wf_out.add_column(std::move(col), std::string(names_[i]), gmode);
}

return wf_out;
Expand Down
21 changes: 17 additions & 4 deletions src/core/expr/fexpr_categories.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,31 @@ class FExpr_Categories : public FExpr_Func {

Workframe evaluate_n(EvalContext &ctx) const override {
Workframe wf = arg_->evaluate_n(ctx);
Workframe wf_out(ctx);

for (size_t i = 0; i < wf.ncols(); ++i) {
Column col = wf.retrieve_column(i);
if (!col.type().is_categorical()) {
throw TypeError() << "Invalid column of type `" << col.stype()
<< "` in " << repr();
}
Column categories = col.n_children()? col.child(0)
: Const_ColumnImpl::make_na_column(1);
wf.replace_column(i, std::move(categories));

Column col_cats;
if (col.n_children()) {
// categorical column is backed by `Categorical_ColumnImpl`
xassert(col.n_children() == 1);
col_cats = col.child(0);
} else {
// categorical column is backed by `ConstNa_ColumnImpl`
bool ncats = bool(col.nrows());
col_cats = Const_ColumnImpl::make_na_column(ncats);
}

wf_out.add_column(std::move(col_cats), wf.retrieve_name(i), Grouping::GtoFEW);
}
return wf;

wf_out.sync_gtofew_columns();
return wf_out;
}

};
Expand Down
2 changes: 2 additions & 0 deletions src/core/expr/fexpr_dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ Workframe FExpr_Dict::evaluate_n(EvalContext& ctx) const {
arg_out.rename(names_[i]);
outputs.cbind( std::move(arg_out) );
}

outputs.sync_gtofew_columns();
return outputs;
}

Expand Down
2 changes: 2 additions & 0 deletions src/core/expr/fexpr_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ Workframe FExpr_List::evaluate_j(EvalContext& ctx) const {
for (const auto& arg : args_) {
outputs.cbind( arg->evaluate_j(ctx) );
}

outputs.sync_gtofew_columns();
return outputs;
}

Expand Down
4 changes: 2 additions & 2 deletions src/core/expr/fexpr_slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ Workframe FExpr_Slice::evaluate_n(EvalContext& ctx) const {
wfs.push_back(stop_->evaluate_n(ctx));
wfs.push_back(step_->evaluate_n(ctx));
if (wfs[0].ncols() != 1) {
throw TypeError() << "Slice cannot be applied to multi-column expressions";
throw TypeError() << "Slice can only be applied to one-column expressions";
}
if (wfs[1].ncols() != 1 || wfs[2].ncols() != 1 || wfs[3].ncols() != 1) {
throw TypeError() << "Cannot use multi-column expressions inside a slice";
throw TypeError() << "Can only use one-column expressions inside a slice";
}
auto gmode = Workframe::sync_grouping_mode(wfs);

Expand Down
20 changes: 18 additions & 2 deletions src/core/expr/workframe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,19 @@ std::unique_ptr<DataTable> Workframe::convert_to_datatable() && {
// Grouping mode manipulation
//------------------------------------------------------------------------------

void Workframe::sync_gtofew_columns() {
if (ncols() <= 1 || grouping_mode_ != Grouping::GtoFEW) return;

for (auto& item : entries_) {
if (!item.column) continue; // placeholder column
if (nrows() != item.column.nrows()) {
increase_grouping_mode(Grouping::GtoALL);
break;
}
}
}


void Workframe::sync_grouping_mode(Workframe& other) {
if (grouping_mode_ != other.grouping_mode_) {
size_t g1 = static_cast<size_t>(grouping_mode_);
Expand All @@ -328,7 +341,7 @@ void Workframe::sync_grouping_mode(Column& col, Grouping gmode) {
if (g1 < g2) increase_grouping_mode(gmode);
else column_increase_grouping_mode(col, gmode, grouping_mode_);
}
xassert(ncols() == 0 || nrows() == col.nrows());
xassert(ncols() == 0 || nrows() == col.nrows() || gmode == Grouping::GtoFEW);
}


Expand Down Expand Up @@ -359,7 +372,7 @@ void Workframe::increase_grouping_mode(Grouping gmode) {
void Workframe::column_increase_grouping_mode(
Column& col, Grouping gfrom, Grouping gto)
{
xassert(gfrom != Grouping::GtoFEW && gfrom != Grouping::GtoANY);
xassert(gfrom != Grouping::GtoANY);
xassert(gto != Grouping::GtoFEW && gto != Grouping::GtoANY);
xassert(static_cast<int>(gfrom) < static_cast<int>(gto));
if (gfrom == Grouping::SCALAR && gto == Grouping::GtoONE) {
Expand All @@ -377,6 +390,9 @@ void Workframe::column_increase_grouping_mode(
}
xassert(col.nrows() == ctx_.nrows());
}
else if (gfrom == Grouping::GtoFEW && gto == Grouping::GtoALL) {
col.resize(ctx_.nrows());
}
else {
throw RuntimeError() << "Unexpected Grouping mode"; // LCOV_EXCL_LINE
}
Expand Down
4 changes: 4 additions & 0 deletions src/core/expr/workframe.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class Workframe {

std::unique_ptr<DataTable> convert_to_datatable() &&;

// This method promotes `GtoFEW` workframe to `GtoALL`,
// in the case columns have different number of rows.
void sync_gtofew_columns();

// This method ensures that two `Workframe` objects have the
// same grouping mode. It can either modify itself, or
// `other` object.
Expand Down
78 changes: 71 additions & 7 deletions tests/types/test-categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ def test_create_too_many_cats_error(t):
assert DT1.to_list() == DT2.to_list()


@pytest.mark.parametrize('t', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_create_from_zero_rows(t):
src = [[]]
DT1 = dt.Frame(src)
DT2 = dt.Frame(src, types = [t(dt.Type.bool8)])
assert DT2.type == t(dt.Type.bool8)
assert DT1.shape == DT2.shape
assert DT1.names == DT2.names
assert DT1.to_list() == DT2.to_list()
assert_equals(DT2, DT2[:, :])


@pytest.mark.parametrize('t', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
Expand Down Expand Up @@ -487,6 +501,30 @@ def test_categories_wrong_type():
DT[:, dt.categories(f.C0)]


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_zero_rows(cat_type):
src = [[]]
data_type = dt.Type.int32
DT = dt.Frame(src, type=cat_type(data_type))
DT_cats = DT[:, dt.categories(f[:])]
DT_ref = dt.Frame(src, type=data_type)
assert_equals(DT_cats, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_zero_rows_casted(cat_type):
src = [[]]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.void)
DT_cats = DT[:, dt.categories(f[:])]
DT_ref = dt.Frame(src)
assert_equals(DT_cats, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
Expand All @@ -501,21 +539,33 @@ def test_categories_void(cat_type):
@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_simple(cat_type):
src = ["cat", "dog", "mouse", "cat"]
def test_categories_one_column(cat_type):
src = [None, "cat", "dog", None, "mouse", "cat"] * 7
DT = dt.Frame([src], type=cat_type(dt.Type.str32))
DT_cats = DT[:, dt.categories(f.C0)]
DT_ref = dt.Frame(["cat", "dog", "mouse"])
DT_ref = dt.Frame([None, "cat", "dog", "mouse"])
assert_equals(DT_cats, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_multicolumn(cat_type):
N = 123
src_int = [None, 100, 500, None, 100, 100500, 100, 500] * N
src_str = [None, "dog", "mouse", None, "dog", "cat", "dog", "mouse"] * N
def test_categories_one_column_plus_orig(cat_type):
src = [None, "cat", "dog", None, "mouse", "cat"]
DT = dt.Frame([src], type=cat_type(dt.Type.str32))
DT_cats = DT[:, [f.C0, dt.categories(f.C0)]]
DT_ref = dt.Frame([src,
[None, "cat", "dog", "mouse"] + [None] * 2],
types=[cat_type(dt.Type.str32), dt.Type.str32])
assert_equals(DT_cats, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_multicolumn_even_ncats(cat_type):
src_int = [None, 100, 500, None, 100, 100500, 100, 500]
src_str = [None, "dog", "mouse", None, "dog", "cat", "dog", "mouse"]
DT = dt.Frame([src_int, src_str],
types=[cat_type(dt.Type.int32), cat_type(dt.Type.str32)])
DT_cats = DT[:, dt.categories(f[:])]
Expand All @@ -524,3 +574,17 @@ def test_categories_multicolumn(cat_type):
assert_equals(DT_cats, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_multicolumn_uneven_ncats(cat_type):
src_int = [None, 100, 500, None, 100, 100500, 100, 500]
src_str = [None, "dog", None, None, "dog", "cat", "dog", None]
DT = dt.Frame([src_int, src_str],
types=[cat_type(dt.Type.int32), cat_type(dt.Type.str32)])
DT_cats = DT[:, dt.categories(f[:])]
DT_ref = dt.Frame([[None, 100, 500, 100500] + [None] * 4,
[None, "cat", "dog"] + [None] * 5])
assert_equals(DT_cats, DT_ref)


0 comments on commit a836fd6

Please sign in to comment.