Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic support for Grouping::GtoFEW #3370

Merged
merged 7 commits into from
Oct 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/core/expr/declarations.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ namespace expr {
// Each group is mapped to exactly `groupsize` rows. This is the
// most common grouping mode. Any simple column, or a function of
// a simple column will be using this mode. Few groupby functions
// may use this mode too
// may use this mode too.
//
// GtoANY
// Groups may be mapped to any number of rows, including having
Expand Down
4 changes: 1 addition & 3 deletions src/core/expr/fexpr_alias.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@ Workframe FExpr_Alias::evaluate_n(EvalContext& ctx) const {
auto gmode = wf.get_grouping_mode();

for (size_t i = 0; i < wf.ncols(); ++i) {
Workframe arg_out(ctx);
Column col = wf.retrieve_column(i);
arg_out.add_column(std::move(col), std::string(names_[i]), gmode);
wf_out.cbind( std::move(arg_out) );
wf_out.add_column(std::move(col), std::string(names_[i]), gmode);
}

return wf_out;
Expand Down
21 changes: 17 additions & 4 deletions src/core/expr/fexpr_categories.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,31 @@ class FExpr_Categories : public FExpr_Func {

Workframe evaluate_n(EvalContext &ctx) const override {
Workframe wf = arg_->evaluate_n(ctx);
Workframe wf_out(ctx);

for (size_t i = 0; i < wf.ncols(); ++i) {
Column col = wf.retrieve_column(i);
if (!col.type().is_categorical()) {
throw TypeError() << "Invalid column of type `" << col.stype()
<< "` in " << repr();
}
Column categories = col.n_children()? col.child(0)
: Const_ColumnImpl::make_na_column(1);
wf.replace_column(i, std::move(categories));

Column col_cats;
if (col.n_children()) {
// categorical column is backed by `Categorical_ColumnImpl`
xassert(col.n_children() == 1);
col_cats = col.child(0);
} else {
// categorical column is backed by `ConstNa_ColumnImpl`
bool ncats = bool(col.nrows());
col_cats = Const_ColumnImpl::make_na_column(ncats);
}

wf_out.add_column(std::move(col_cats), wf.retrieve_name(i), Grouping::GtoFEW);
}
return wf;

wf_out.sync_gtofew_columns();
return wf_out;
}

};
Expand Down
2 changes: 2 additions & 0 deletions src/core/expr/fexpr_dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ Workframe FExpr_Dict::evaluate_n(EvalContext& ctx) const {
arg_out.rename(names_[i]);
outputs.cbind( std::move(arg_out) );
}

outputs.sync_gtofew_columns();
return outputs;
}

Expand Down
2 changes: 2 additions & 0 deletions src/core/expr/fexpr_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ Workframe FExpr_List::evaluate_j(EvalContext& ctx) const {
for (const auto& arg : args_) {
outputs.cbind( arg->evaluate_j(ctx) );
}

outputs.sync_gtofew_columns();
return outputs;
}

Expand Down
4 changes: 2 additions & 2 deletions src/core/expr/fexpr_slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ Workframe FExpr_Slice::evaluate_n(EvalContext& ctx) const {
wfs.push_back(stop_->evaluate_n(ctx));
wfs.push_back(step_->evaluate_n(ctx));
if (wfs[0].ncols() != 1) {
throw TypeError() << "Slice cannot be applied to multi-column expressions";
throw TypeError() << "Slice can only be applied to one-column expressions";
}
if (wfs[1].ncols() != 1 || wfs[2].ncols() != 1 || wfs[3].ncols() != 1) {
throw TypeError() << "Cannot use multi-column expressions inside a slice";
throw TypeError() << "Can only use one-column expressions inside a slice";
}
auto gmode = Workframe::sync_grouping_mode(wfs);

Expand Down
20 changes: 18 additions & 2 deletions src/core/expr/workframe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,19 @@ std::unique_ptr<DataTable> Workframe::convert_to_datatable() && {
// Grouping mode manipulation
//------------------------------------------------------------------------------

void Workframe::sync_gtofew_columns() {
if (ncols() <= 1 || grouping_mode_ != Grouping::GtoFEW) return;

for (auto& item : entries_) {
if (!item.column) continue; // placeholder column
if (nrows() != item.column.nrows()) {
increase_grouping_mode(Grouping::GtoALL);
break;
}
}
}


void Workframe::sync_grouping_mode(Workframe& other) {
if (grouping_mode_ != other.grouping_mode_) {
size_t g1 = static_cast<size_t>(grouping_mode_);
Expand All @@ -328,7 +341,7 @@ void Workframe::sync_grouping_mode(Column& col, Grouping gmode) {
if (g1 < g2) increase_grouping_mode(gmode);
else column_increase_grouping_mode(col, gmode, grouping_mode_);
}
xassert(ncols() == 0 || nrows() == col.nrows());
xassert(ncols() == 0 || nrows() == col.nrows() || gmode == Grouping::GtoFEW);
}


Expand Down Expand Up @@ -359,7 +372,7 @@ void Workframe::increase_grouping_mode(Grouping gmode) {
void Workframe::column_increase_grouping_mode(
Column& col, Grouping gfrom, Grouping gto)
{
xassert(gfrom != Grouping::GtoFEW && gfrom != Grouping::GtoANY);
xassert(gfrom != Grouping::GtoANY);
xassert(gto != Grouping::GtoFEW && gto != Grouping::GtoANY);
xassert(static_cast<int>(gfrom) < static_cast<int>(gto));
if (gfrom == Grouping::SCALAR && gto == Grouping::GtoONE) {
Expand All @@ -377,6 +390,9 @@ void Workframe::column_increase_grouping_mode(
}
xassert(col.nrows() == ctx_.nrows());
}
else if (gfrom == Grouping::GtoFEW && gto == Grouping::GtoALL) {
col.resize(ctx_.nrows());
}
else {
throw RuntimeError() << "Unexpected Grouping mode"; // LCOV_EXCL_LINE
}
Expand Down
4 changes: 4 additions & 0 deletions src/core/expr/workframe.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class Workframe {

std::unique_ptr<DataTable> convert_to_datatable() &&;

// This method promotes `GtoFEW` workframe to `GtoALL`,
// in the case columns have different number of rows.
void sync_gtofew_columns();

// This method ensures that two `Workframe` objects have the
// same grouping mode. It can either modify itself, or
// `other` object.
Expand Down
78 changes: 71 additions & 7 deletions tests/types/test-categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ def test_create_too_many_cats_error(t):
assert DT1.to_list() == DT2.to_list()


@pytest.mark.parametrize('t', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_create_from_zero_rows(t):
src = [[]]
DT1 = dt.Frame(src)
DT2 = dt.Frame(src, types = [t(dt.Type.bool8)])
assert DT2.type == t(dt.Type.bool8)
assert DT1.shape == DT2.shape
assert DT1.names == DT2.names
assert DT1.to_list() == DT2.to_list()
assert_equals(DT2, DT2[:, :])


@pytest.mark.parametrize('t', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
Expand Down Expand Up @@ -487,6 +501,30 @@ def test_categories_wrong_type():
DT[:, dt.categories(f.C0)]


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_zero_rows(cat_type):
src = [[]]
data_type = dt.Type.int32
DT = dt.Frame(src, type=cat_type(data_type))
DT_cats = DT[:, dt.categories(f[:])]
DT_ref = dt.Frame(src, type=data_type)
assert_equals(DT_cats, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_zero_rows_casted(cat_type):
src = [[]]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.void)
DT_cats = DT[:, dt.categories(f[:])]
DT_ref = dt.Frame(src)
assert_equals(DT_cats, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
Expand All @@ -501,21 +539,33 @@ def test_categories_void(cat_type):
@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_simple(cat_type):
src = ["cat", "dog", "mouse", "cat"]
def test_categories_one_column(cat_type):
src = [None, "cat", "dog", None, "mouse", "cat"] * 7
DT = dt.Frame([src], type=cat_type(dt.Type.str32))
DT_cats = DT[:, dt.categories(f.C0)]
DT_ref = dt.Frame(["cat", "dog", "mouse"])
DT_ref = dt.Frame([None, "cat", "dog", "mouse"])
assert_equals(DT_cats, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_multicolumn(cat_type):
N = 123
src_int = [None, 100, 500, None, 100, 100500, 100, 500] * N
src_str = [None, "dog", "mouse", None, "dog", "cat", "dog", "mouse"] * N
def test_categories_one_column_plus_orig(cat_type):
src = [None, "cat", "dog", None, "mouse", "cat"]
DT = dt.Frame([src], type=cat_type(dt.Type.str32))
DT_cats = DT[:, [f.C0, dt.categories(f.C0)]]
DT_ref = dt.Frame([src,
[None, "cat", "dog", "mouse"] + [None] * 2],
types=[cat_type(dt.Type.str32), dt.Type.str32])
assert_equals(DT_cats, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_multicolumn_even_ncats(cat_type):
src_int = [None, 100, 500, None, 100, 100500, 100, 500]
src_str = [None, "dog", "mouse", None, "dog", "cat", "dog", "mouse"]
DT = dt.Frame([src_int, src_str],
types=[cat_type(dt.Type.int32), cat_type(dt.Type.str32)])
DT_cats = DT[:, dt.categories(f[:])]
Expand All @@ -524,3 +574,17 @@ def test_categories_multicolumn(cat_type):
assert_equals(DT_cats, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_categories_multicolumn_uneven_ncats(cat_type):
src_int = [None, 100, 500, None, 100, 100500, 100, 500]
src_str = [None, "dog", None, None, "dog", "cat", "dog", None]
DT = dt.Frame([src_int, src_str],
types=[cat_type(dt.Type.int32), cat_type(dt.Type.str32)])
DT_cats = DT[:, dt.categories(f[:])]
DT_ref = dt.Frame([[None, 100, 500, 100500] + [None] * 4,
[None, "cat", "dog"] + [None] * 5])
assert_equals(DT_cats, DT_ref)