Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor .extend() and .remove() to use FExpr #3393

Merged
merged 10 commits into from
Dec 15, 2022
100 changes: 50 additions & 50 deletions src/core/column/sumprod.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,66 +29,66 @@ namespace dt {

template <typename T, bool SUM>
class SumProd_ColumnImpl : public Virtual_ColumnImpl {
private:
Column col_;
Groupby gby_;
bool is_grouped_;
size_t : 56;

public:
SumProd_ColumnImpl(Column &&col, const Groupby& gby, bool is_grouped)
: Virtual_ColumnImpl(gby.size(), col.stype()),
col_(std::move(col)),
gby_(gby),
is_grouped_(is_grouped)
{
xassert(col_.can_be_read_as<T>());
}


bool get_element(size_t i, T* out) const override {
T result = !SUM; // 0 for `sum()` and 1 for `prod()`
T value;
size_t i0, i1;
gby_.get_group(i, &i0, &i1);

if (is_grouped_){
size_t nrows = i1 - i0;
bool is_valid = col_.get_element(i, &value);
if (is_valid){
result = SUM? static_cast<T>(nrows) * value
: ipow(value, nrows);
}
} else {
for (size_t gi = i0; gi < i1; ++gi) {
bool is_valid = col_.get_element(gi, &value);
private:
Column col_;
Groupby gby_;
bool is_grouped_;
size_t : 56;

public:
SumProd_ColumnImpl(Column &&col, const Groupby& gby, bool is_grouped)
: Virtual_ColumnImpl(gby.size(), col.stype()),
col_(std::move(col)),
gby_(gby),
is_grouped_(is_grouped)
{
xassert(col_.can_be_read_as<T>());
}


bool get_element(size_t i, T* out) const override {
T result = !SUM; // 0 for `sum()` and 1 for `prod()`
T value;
size_t i0, i1;
gby_.get_group(i, &i0, &i1);

if (is_grouped_){
size_t nrows = i1 - i0;
bool is_valid = col_.get_element(i, &value);
if (is_valid){
result = SUM? result + value
: result * value;
result = SUM? static_cast<T>(nrows) * value
: ipow(value, nrows);
}
} else {
for (size_t gi = i0; gi < i1; ++gi) {
bool is_valid = col_.get_element(gi, &value);
if (is_valid){
result = SUM? result + value
: result * value;
}
}
}
}

*out = result;
return true; // the result is never a missing value
}
*out = result;
return true; // the result is never a missing value
}


ColumnImpl *clone() const override {
return new SumProd_ColumnImpl(Column(col_), Groupby(gby_), is_grouped_);
}
ColumnImpl *clone() const override {
return new SumProd_ColumnImpl(Column(col_), Groupby(gby_), is_grouped_);
}


size_t n_children() const noexcept override {
return 1;
}
size_t n_children() const noexcept override {
return 1;
}


const Column &child(size_t i) const override {
xassert(i == 0);
(void)i;
return col_;
}
const Column &child(size_t i) const override {
xassert(i == 0);
(void)i;
return col_;
}
};
oleksiyskononenko marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
8 changes: 4 additions & 4 deletions src/core/expr/fexpr_extend_remove.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ std::string FExpr_Extend_Remove<EXTEND>::repr() const {
std::string out = EXTEND? "extend" : "remove";
out += '(';
out += arg_->repr();
out += ", arg=";
out += ", ";
oleksiyskononenko marked this conversation as resolved.
Show resolved Hide resolved
out += other_->repr();
out += ')';
return out;
Expand All @@ -47,11 +47,11 @@ std::string FExpr_Extend_Remove<EXTEND>::repr() const {
template <bool EXTEND>
Workframe FExpr_Extend_Remove<EXTEND>::evaluate_n(EvalContext& ctx) const {
Workframe wf = arg_->evaluate_n(ctx);
Workframe out = other_->evaluate_n(ctx);
Workframe wf_other = other_->evaluate_n(ctx);
if (EXTEND){
wf.cbind(std::move(out));
wf.cbind(std::move(wf_other));
} else {
wf.remove(std::move(out));
wf.remove(std::move(wf_other));
}

return wf;
Expand Down
18 changes: 7 additions & 11 deletions tests/test-f.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,15 @@ def test_f_columnset_str():


def test_f_columnset_extend():
assert str(f[:].extend(f.A)) == \
"FExpr<extend(f[:], arg=f.A)>"
assert str(f[int].extend(f[str])) == \
"FExpr<extend(f[int], arg=f[str])>"
assert str(f.A.extend(f['B','C'])) == \
"FExpr<extend(f.A, arg=f[['B', 'C']])>"
assert str(f[:].extend(f.A)) == "FExpr<extend(f[:], f.A)>"
assert str(f[int].extend(f[str])) == "FExpr<extend(f[int], f[str])>"
assert str(f.A.extend(f['B','C'])) == "FExpr<extend(f.A, f[['B', 'C']])>"


def test_f_columnset_remove():
assert str(f[:].remove(f.A)) == "FExpr<remove(f[:], arg=f.A)>"
assert str(f[int].remove(f[0])) == "FExpr<remove(f[int], arg=f[0])>"
assert str(f.A.remove(f['B','C'])) == \
"FExpr<remove(f.A, arg=f[['B', 'C']])>"
assert str(f[:].remove(f.A)) == "FExpr<remove(f[:], f.A)>"
assert str(f[int].remove(f[0])) == "FExpr<remove(f[int], f[0])>"
assert str(f.A.remove(f['B','C'])) == "FExpr<remove(f.A, f[['B', 'C']])>"



Expand Down Expand Up @@ -209,7 +205,7 @@ def test_f_columnset_ltypes(DT):
def test_columnset_sum(DT):
assert_equals(DT[:, f[int].extend(f[float])], DT[:, [int, float]])
assert_equals(DT[:, f[:3].extend(f[-3:])], DT[:, [0, 1, 2, -3, -2, -1]])
assert_equals( DT[:, f['A','B','C'].extend(f['E','F', 'G'])], DT[:, [0, 1, 2, -3, -2, -1]])
assert_equals(DT[:, f['A','B','C'].extend(f['E','F', 'G'])], DT[:, [0, 1, 2, -3, -2, -1]])
assert_equals(DT[:, f.A.extend(f.B)], DT[:, ['A', 'B']])
assert_equals(DT[:, f[:].extend({"extra": f.A + f.C})],
dt.cbind(DT, DT[:, {"extra": f.A + f.C}]))
Expand Down