Skip to content

Commit

Permalink
Implement checked exponentiation.
Browse files Browse the repository at this point in the history
  • Loading branch information
chriseth committed Aug 17, 2020
1 parent 660ef79 commit e97d00c
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 2 deletions.
131 changes: 131 additions & 0 deletions libsolidity/codegen/YulUtilFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,137 @@ string YulUtilFunctions::overflowCheckedIntSubFunction(IntegerType const& _type)
});
}

string YulUtilFunctions::overflowCheckedIntExpFunction(
IntegerType const& _type,
IntegerType const& _exponentType
)
{
solAssert(!_exponentType.isSigned(), "");

string functionName = "checked_exp_" + _type.identifier() + "_" + _exponentType.identifier();
return m_functionCollector.createFunction(functionName, [&]() {
return
Whiskers(R"(
function <functionName>(base, exponent) -> power {
base := <baseCleanupFunction>(base)
exponent := <exponentCleanupFunction>(exponent)
<?signed>
power := <exp>(base, exponent, <minValue>, <maxValue>)
<!signed>
power := <exp>(base, exponent, <maxValue>)
</signed>
}
)")
("functionName", functionName)
("signed", _type.isSigned())
("exp", _type.isSigned() ? overflowCheckedSignedExpFunction() : overflowCheckedUnsignedExpFunction())
("maxValue", toCompactHexWithPrefix(u256(_type.maxValue())))
("minValue", toCompactHexWithPrefix(u256(_type.minValue())))
("baseCleanupFunction", cleanupFunction(_type))
("exponentCleanupFunction", cleanupFunction(_exponentType))
.render();
});
}

string YulUtilFunctions::overflowCheckedUnsignedExpFunction()
{
string functionName = "checked_exp";
return m_functionCollector.createFunction(functionName, [&]() {
return
Whiskers(R"(
function <functionName>(base, exponent, max) -> power {
// This function currently cannot be inlined because of the
// "leave" statements. We have to improve the optimizer.
// Note that 0**0 == 1
if iszero(exponent) { power := 1 leave }
if iszero(base) { power := 0 leave }
power := 1
for { } gt(exponent, 1) {}
{
// overflow check for base * base
if gt(base, div(max, base)) { revert(0, 0) }
if and(exponent, 1)
{
// no check needed here because base >= power
power := mul(power, base)
}
base := mul(base, base)
exponent := <shr_1>(exponent)
}
if gt(power, div(max, base)) { revert(0, 0) }
power := mul(power, base)
}
)")
("functionName", functionName)
("shr_1", shiftRightFunction(1))
.render();
});
}

string YulUtilFunctions::overflowCheckedSignedExpFunction()
{
string functionName = "checked_exp";
return m_functionCollector.createFunction(functionName, [&]() {
return
Whiskers(R"(
function <functionName>(base, exponent, min, max) -> power {
// Currently, `leave` avoids this function being inlined.
// We have to improve the optimizer.
// Note that 0**0 == 1
switch exponent
case 0 { power := 1 leave }
case 1 { power := base leave }
if iszero(base) { power := 0 leave }
power := 1
// We pull out the first iteration because it is the only one in which
// base can be negative.
// Exponent is at least 2 here.
// overflow check for base * base
switch sgt(base, 0)
case 1 { if gt(base, div(max, base)) { revert(0, 0) } }
case 0 { if slt(base, sdiv(max, base)) { revert(0, 0) } }
if and(exponent, 1)
{
power := base
}
base := mul(base, base)
exponent := <shr_1>(exponent)
// Below this point, base is always positive.
for { } gt(exponent, 1) {}
{
// overflow check for base * base
if gt(base, div(max, base)) { revert(0, 0) }
if and(exponent, 1)
{
// no check needed for positive power, because base >= power
if and(slt(power, 0), slt(power, sdiv(min, base))) { revert(0, 0) }
power := mul(power, base)
}
base := mul(base, base)
exponent := <shr_1>(exponent)
}
if and(sgt(power, 0), gt(power, div(max, base))) { revert(0, 0) }
if and(slt(power, 0), slt(power, sdiv(min, base))) { revert(0, 0) }
power := mul(power, base)
}
)")
("functionName", functionName)
("shr_1", shiftRightFunction(1))
.render();
});
}

string YulUtilFunctions::extractByteArrayLengthFunction()
{
string functionName = "extract_byte_array_length";
Expand Down
14 changes: 14 additions & 0 deletions libsolidity/codegen/YulUtilFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ class YulUtilFunctions
/// signature: (x, y) -> diff
std::string overflowCheckedIntSubFunction(IntegerType const& _type);

/// @returns the name of the exponentiation function.
/// signature: (base, exponent) -> power
std::string overflowCheckedIntExpFunction(IntegerType const& _type, IntegerType const& _exponentType);

/// Generic unsigned checked exponentiation function.
/// Reverts if the result is larger than max.
/// signature: (base, exponent, max) -> power
std::string overflowCheckedUnsignedExpFunction();

/// Generic signed checked exponentiation function.
/// Reverts if the result is smaller than min or larger than max.
/// signature: (base, exponent, min, max) -> power
std::string overflowCheckedSignedExpFunction();

/// @returns the name of a function that fetches the length of the given
/// array
/// signature: (array) -> length
Expand Down
11 changes: 9 additions & 2 deletions libsolidity/codegen/ir/IRGeneratorForStatements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ bool IRGeneratorForStatements::visit(Assignment const& _assignment)
solAssert(type(_assignment.leftHandSide()).isValueType(), "Compound operators only available for value types.");
solAssert(rightIntermediateType->isValueType(), "Compound operators only available for value types.");
IRVariable leftIntermediate = readFromLValue(*m_currentLValue);
solAssert(binaryOperator != Token::Exp, "");
if (TokenTraits::isShiftOp(binaryOperator))
{
solAssert(type(_assignment) == leftIntermediate.type(), "");
Expand Down Expand Up @@ -593,11 +594,17 @@ bool IRGeneratorForStatements::visit(BinaryOperation const& _binOp)
solAssert(false, "Unknown comparison operator.");
define(_binOp) << expr << "\n";
}
else if (TokenTraits::isShiftOp(op))
else if (TokenTraits::isShiftOp(op) || op == Token::Exp)
{
IRVariable left = convert(_binOp.leftExpression(), *commonType);
IRVariable right = convert(_binOp.rightExpression(), *type(_binOp.rightExpression()).mobileType());
define(_binOp) << shiftOperation(_binOp.getOperator(), left, right) << "\n";
if (op == Token::Exp)
define(_binOp) << m_utils.overflowCheckedIntExpFunction(
dynamic_cast<IntegerType const&>(left.type()),
dynamic_cast<IntegerType const&>(right.type())
) << "(" << left.name() << ", " << right.name() << ")\n";
else
define(_binOp) << shiftOperation(_binOp.getOperator(), left, right) << "\n";
}
else
{
Expand Down
2 changes: 2 additions & 0 deletions test/libsolidity/semanticTests/exponentiation/signed_base.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ contract test {
return (x**y1, x**y2);
}
}
// ====
// compileViaYul: also
// ----
// f() -> 9, -27
19 changes: 19 additions & 0 deletions test/libsolidity/semanticTests/viaYul/exp.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
contract C {
function f(uint x, uint y) public returns (uint) {
return x**y;
}
}
// ====
// compileViaYul: also
// ----
// f(uint256,uint256): 0, 0 -> 1
// f(uint256,uint256): 0, 1 -> 0x00
// f(uint256,uint256): 0, 2 -> 0x00
// f(uint256,uint256): 1, 0 -> 1
// f(uint256,uint256): 1, 1 -> 1
// f(uint256,uint256): 1, 2 -> 1
// f(uint256,uint256): 2, 0 -> 1
// f(uint256,uint256): 2, 1 -> 2
// f(uint256,uint256): 2, 2 -> 4
// f(uint256,uint256): 7, 63 -> 174251498233690814305510551794710260107945042018748343
// f(uint256,uint256): 128, 2 -> 0x4000
27 changes: 27 additions & 0 deletions test/libsolidity/semanticTests/viaYul/exp_neg.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
contract C {
function f(int x, uint y) public returns (int) {
return x**y;
}
}
// ====
// compileViaYul: also
// ----
// f(int256,uint256): 0, 0 -> 1
// f(int256,uint256): 0, 1 -> 0x00
// f(int256,uint256): 0, 2 -> 0x00
// f(int256,uint256): 1, 0 -> 1
// f(int256,uint256): 1, 1 -> 1
// f(int256,uint256): 1, 2 -> 1
// f(int256,uint256): 2, 0 -> 1
// f(int256,uint256): 2, 1 -> 2
// f(int256,uint256): 2, 2 -> 4
// f(int256,uint256): 7, 63 -> 174251498233690814305510551794710260107945042018748343
// f(int256,uint256): 128, 2 -> 0x4000
// f(int256,uint256): -1, 0 -> 1
// f(int256,uint256): -1, 1 -> -1
// f(int256,uint256): -1, 2 -> 1
// f(int256,uint256): -2, 0 -> 1
// f(int256,uint256): -2, 1 -> -2
// f(int256,uint256): -2, 2 -> 4
// f(int256,uint256): -7, 63 -> -174251498233690814305510551794710260107945042018748343
// f(int256,uint256): -128, 2 -> 0x4000
38 changes: 38 additions & 0 deletions test/libsolidity/semanticTests/viaYul/exp_neg_overflow.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
contract C {
function f(int8 x, uint y) public returns (int) {
return x**y;
}
function g(int256 x, uint y) public returns (int) {
return x**y;
}
}
// ====
// compileViaYul: true
// ----
// f(int8,uint256): 2, 6 -> 64
// f(int8,uint256): 2, 7 -> FAILURE
// f(int8,uint256): 2, 8 -> FAILURE
// f(int8,uint256): -2, 6 -> 64
// f(int8,uint256): -2, 7 -> -128
// f(int8,uint256): -2, 8 -> FAILURE
// f(int8,uint256): 6, 3 -> FAILURE
// f(int8,uint256): 7, 2 -> 0x31
// f(int8,uint256): 7, 3 -> FAILURE
// f(int8,uint256): -7, 2 -> 0x31
// f(int8,uint256): -7, 3 -> FAILURE
// f(int8,uint256): -7, 4 -> FAILURE
// f(int8,uint256): 127, 31 -> FAILURE
// f(int8,uint256): 127, 131 -> FAILURE
// f(int8,uint256): -128, 0 -> 1
// f(int8,uint256): -128, 1 -> -128
// f(int8,uint256): -128, 31 -> FAILURE
// f(int8,uint256): -128, 131 -> FAILURE
// f(int8,uint256): -11, 2 -> 121
// f(int8,uint256): -12, 2 -> FAILURE
// f(int8,uint256): 12, 2 -> FAILURE
// f(int8,uint256): -5, 3 -> -125
// f(int8,uint256): -6, 3 -> FAILURE
// g(int256,uint256): -7, 90 -> 11450477594321044359340126713545146077054004823284978858214566372120240027249
// g(int256,uint256): -7, 91 -> FAILURE
// g(int256,uint256): -63, 42 -> 3735107253208426854890677539053540390278853997836851167913009474475553834369
// g(int256,uint256): -63, 43 -> FAILURE
31 changes: 31 additions & 0 deletions test/libsolidity/semanticTests/viaYul/exp_overflow.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
contract C {
function f(uint8 x, uint8 y) public returns (uint) {
return x**y;
}
function g(uint x, uint y) public returns (uint) {
return x**y;
}
}
// ====
// compileViaYul: true
// ----
// f(uint8,uint8): 2, 7 -> 0x80
// f(uint8,uint8): 2, 8 -> FAILURE
// f(uint8,uint8): 15, 2 -> 225
// f(uint8,uint8): 6, 3 -> 0xd8
// f(uint8,uint8): 7, 2 -> 0x31
// f(uint8,uint8): 7, 3 -> FAILURE
// f(uint8,uint8): 7, 4 -> FAILURE
// f(uint8,uint8): 255, 31 -> FAILURE
// f(uint8,uint8): 255, 131 -> FAILURE
// g(uint256,uint256): 0x200000000000000000000000000000000, 1 -> 0x0200000000000000000000000000000000
// g(uint256,uint256): 0x100000000000000000000000000000010, 2 -> FAILURE
// g(uint256,uint256): 0x200000000000000000000000000000000, 2 -> FAILURE
// g(uint256,uint256): 0x200000000000000000000000000000000, 3 -> FAILURE
// g(uint256,uint256): 255, 31 -> 400631961586894742455537928461950192806830589109049416147172451019287109375
// g(uint256,uint256): 255, 32 -> -13630939032658036097408813250890608687528184442832962921928608997994916749311
// g(uint256,uint256): 255, 33 -> FAILURE
// g(uint256,uint256): 255, 131 -> FAILURE
// g(uint256,uint256): 258, 31 -> 575719427506838823084316385994930914701079543089399988096291424922125729792
// g(uint256,uint256): 258, 37 -> FAILURE
// g(uint256,uint256): 258, 131 -> FAILURE
50 changes: 50 additions & 0 deletions test/libsolidity/semanticTests/viaYul/exp_various.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
contract C {
function f(uint8 x, uint8 y) public returns (uint) {
return x**y;
}
function g(uint x, uint y) public returns (uint) {
return x**y;
}
}
// ====
// compileViaYul: also
// ----
// f(uint8,uint8): 0, 0 -> 1
// f(uint8,uint8): 0, 1 -> 0x00
// f(uint8,uint8): 0, 2 -> 0x00
// f(uint8,uint8): 0, 3 -> 0x00
// f(uint8,uint8): 1, 0 -> 1
// f(uint8,uint8): 1, 1 -> 1
// f(uint8,uint8): 1, 2 -> 1
// f(uint8,uint8): 1, 3 -> 1
// f(uint8,uint8): 2, 0 -> 1
// f(uint8,uint8): 2, 1 -> 2
// f(uint8,uint8): 2, 2 -> 4
// f(uint8,uint8): 2, 3 -> 8
// f(uint8,uint8): 3, 0 -> 1
// f(uint8,uint8): 3, 1 -> 3
// f(uint8,uint8): 3, 2 -> 9
// f(uint8,uint8): 3, 3 -> 0x1b
// f(uint8,uint8): 10, 0 -> 1
// f(uint8,uint8): 10, 1 -> 0x0a
// f(uint8,uint8): 10, 2 -> 100
// g(uint256,uint256): 0, 0 -> 1
// g(uint256,uint256): 0, 1 -> 0x00
// g(uint256,uint256): 0, 2 -> 0x00
// g(uint256,uint256): 0, 3 -> 0x00
// g(uint256,uint256): 1, 0 -> 1
// g(uint256,uint256): 1, 1 -> 1
// g(uint256,uint256): 1, 2 -> 1
// g(uint256,uint256): 1, 3 -> 1
// g(uint256,uint256): 2, 0 -> 1
// g(uint256,uint256): 2, 1 -> 2
// g(uint256,uint256): 2, 2 -> 4
// g(uint256,uint256): 2, 3 -> 8
// g(uint256,uint256): 3, 0 -> 1
// g(uint256,uint256): 3, 1 -> 3
// g(uint256,uint256): 3, 2 -> 9
// g(uint256,uint256): 3, 3 -> 0x1b
// g(uint256,uint256): 10, 10 -> 10000000000
// g(uint256,uint256): 10, 77 -> -15792089237316195423570985008687907853269984665640564039457584007913129639936
// g(uint256,uint256): 256, 2 -> 0x010000
// g(uint256,uint256): 256, 31 -> 0x0100000000000000000000000000000000000000000000000000000000000000

0 comments on commit e97d00c

Please sign in to comment.