200 changes: 174 additions & 26 deletions source/cppfront.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2188,40 +2188,188 @@ class cppfront
auto emit(binary_expression_node<Name,Term> const& n) -> void
{
assert(n.expr);
assert(n.terms.empty() || n.terms.front().op);

// If this is relational comparison
if (!n.terms.empty() &&
(
n.terms.front().op->type() == lexeme::Less ||
n.terms.front().op->type() == lexeme::LessEq ||
n.terms.front().op->type() == lexeme::Greater ||
n.terms.front().op->type() == lexeme::GreaterEq ||
n.terms.front().op->type() == lexeme::EqualComparison ||
n.terms.front().op->type() == lexeme::NotEqualComparison
)
)
{
auto const& op = *n.terms.front().op;

// If this is relational comparison, and safe comparison was requested
if (flag_safe_comparisons && !n.terms.empty()) {
assert(n.terms.front().op);
token const& op = *n.terms.front().op;
if (op.type() == lexeme::Less ||
op.type() == lexeme::LessEq ||
op.type() == lexeme::Greater ||
op.type() == lexeme::GreaterEq)
// If this is one (non-chained) comparison, just emit it directly
if (std::ssize(n.terms) < 2)
{
if (std::ssize(n.terms) > 1) {
assert (std::ssize(n.terms) == 1);

// emit < <= >= > as cmp_*(a,b) calls (if selected)
if (flag_safe_comparisons) {
switch (op.type()) {
break;case lexeme::Less:
printer.print_cpp2( "cpp2::cmp_less(", n.position());
break;case lexeme::LessEq:
printer.print_cpp2( "cpp2::cmp_less_eq(", n.position());
break;case lexeme::Greater:
printer.print_cpp2( "cpp2::cmp_greater(", n.position());
break;case lexeme::GreaterEq:
printer.print_cpp2( "cpp2::cmp_greater_eq(", n.position());
break;default:
;
}
}

emit(*n.expr);

// emit == and != as infix a @ b operators (since we don't have
// any checking/instrumentation we want to do for those)
if (flag_safe_comparisons) {
switch (op.type()) {
break;case lexeme::EqualComparison:
case lexeme::NotEqualComparison:
emit(op);
break;default:
printer.print_cpp2( ",", n.position() );
}
}
else {
emit(op);
}

emit(*n.terms.front().expr);

if (flag_safe_comparisons) {
switch (op.type()) {
break;case lexeme::Less:
case lexeme::LessEq:
case lexeme::Greater:
case lexeme::GreaterEq:
printer.print_cpp2( ")", n.position() );
break;default:
;
}
}

return;
}

// Else if this is a chained comparison, emit it as a lambda,
// to get single evaluation via the lambda capture
else
{
// To check for the valid chains: all </<=, all >/>=, or all ==
auto found_lt = 0; // < and <=
auto found_gt = 0; // > and >=
auto found_eq = 0; // ==
auto count = 0;

auto const* lhs = n.expr.get();
auto lhs_name = "_" + std::to_string(count);

auto lambda_capture = lhs_name + " = " + print_to_string(*lhs);
auto lambda_body = std::string{};

for (auto const& term : n.terms)
{
assert(term.op && term.expr);
++count;
auto rhs_name = "_" + std::to_string(count);

// Not the first expression? Insert a "&&"
if (found_lt + found_gt + found_eq > 0) {
lambda_body += " && ";
}

// Remember what we've seen
switch (term.op->type()) {
break;case lexeme::Less:
case lexeme::LessEq:
found_lt = 1;
break;case lexeme::Greater:
case lexeme::GreaterEq:
found_gt = 1;
break;case lexeme::EqualComparison:
found_eq = 1;
break;default:
;
}

// emit < <= >= > as cmp_*(a,b) calls (if selected)
if (flag_safe_comparisons) {
switch (term.op->type()) {
break;case lexeme::Less:
lambda_body += "cpp2::cmp_less(";
break;case lexeme::LessEq:
lambda_body += "cpp2::cmp_less_eq(";
break;case lexeme::Greater:
lambda_body += "cpp2::cmp_greater(";
break;case lexeme::GreaterEq:
lambda_body += "cpp2::cmp_greater_eq(";
break;default:
;
}
}

auto rhs_expr = print_to_string(*term.expr);

lambda_body += lhs_name;

// emit == and != as infix a @ b operators (since we don't have
// any checking/instrumentation we want to do for those)
if (flag_safe_comparisons) {
switch (term.op->type()) {
break;case lexeme::EqualComparison:
lambda_body += *term.op;
break;case lexeme::NotEqualComparison:
errors.emplace_back(
n.position(),
"!= comparisons cannot appear in a comparison chain (see https://wg21.link/p0893)"
);
return;
break;default:
lambda_body += ",";
}
}
else {
lambda_body += *term.op;
}

lambda_capture += ", " + rhs_name + " = " + rhs_expr;
lambda_body += rhs_name;

lhs = term.expr.get();
lhs_name = rhs_name;

if (flag_safe_comparisons) {
switch (term.op->type()) {
break;case lexeme::Less:
case lexeme::LessEq:
case lexeme::Greater:
case lexeme::GreaterEq:
lambda_body += ")";
break;default:
;
}
}
}

assert(found_lt + found_gt + found_eq > 0);
if (found_lt + found_gt + found_eq != 1) {
errors.emplace_back(
n.position(),
"comparisons cannot be chained - a future update to cppfront will make expressions like 'a < b < c' meaningful and safe, but note this is a mistake and a pitfall in C and C++ today (see https://wg21.link/p0893)"
"a comparison chain must be all < and <=, all > and >=, or all == (see https://wg21.link/p0893)"
);
return;
}
switch (op.type()) {
break;case lexeme::Less:
printer.print_cpp2( "cpp2::cmp_less(", n.position());
break;case lexeme::LessEq:
printer.print_cpp2( "cpp2::cmp_less_eq(", n.position());
break;case lexeme::Greater:
printer.print_cpp2( "cpp2::cmp_greater(", n.position());
break;case lexeme::GreaterEq:
printer.print_cpp2( "cpp2::cmp_greater_eq(", n.position());
break;default:
assert(!"ICE: switch is not exhaustive");
}

emit(*n.expr);
printer.print_cpp2( ",", n.position() );
emit(*n.terms.front().expr);
printer.print_cpp2( ")", n.position() );
printer.print_cpp2( "[" + lambda_capture + "]{ return " + lambda_body + "; }()", n.position());

return;
}
}
Expand Down