Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Amxx committed Mar 21, 2024
1 parent 432df7f commit 162a193
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 40 deletions.
19 changes: 8 additions & 11 deletions libsolidity/ast/Types.cpp
Expand Up @@ -320,7 +320,7 @@ Type const* Type::fullEncodingType(bool _inLibraryCall, bool _encoderV2, bool) c
// Structs are fine in the following circumstances:
// - ABIv2 or,
// - storage struct for a library
if (_inLibraryCall && encodingType && (encodingType->dataStoredIn(DataLocation::Storage) || encodingType->dataStoredIn(DataLocation::TransientStorage)))
if (_inLibraryCall && encodingType && (encodingType->dataStoredInAnyOf({ DataLocation::Storage, DataLocation::TransientStorage })))
return encodingType;
Type const* baseType = encodingType;
while (auto const* arrayType = dynamic_cast<ArrayType const*>(baseType))
Expand Down Expand Up @@ -1293,10 +1293,7 @@ BoolResult StringLiteralType::isImplicitlyConvertibleTo(Type const& _convertTo)
return
arrayType->location() != DataLocation::CallData &&
arrayType->isByteArrayOrString() &&
(!arrayType->isPointer() || (
!arrayType->dataStoredIn(DataLocation::Storage) &&
!arrayType->dataStoredIn(DataLocation::TransientStorage)
));
!(arrayType->dataStoredInAnyOf({ DataLocation::Storage, DataLocation::TransientStorage }) && arrayType->isPointer());
}
else
return false;
Expand Down Expand Up @@ -1633,13 +1630,13 @@ BoolResult ArrayType::isImplicitlyConvertibleTo(Type const& _convertTo) const
if (convertTo.isByteArray() != isByteArray() || convertTo.isString() != isString())
return false;
// memory/calldata to storage can be converted, but only to a direct storage reference
if (convertTo.location() == DataLocation::Storage && location() != DataLocation::Storage && convertTo.isPointer())
if (convertTo.dataStoredIn(DataLocation::Storage) && location() != convertTo.location() && convertTo.isPointer())
return false;
if (convertTo.location() == DataLocation::TransientStorage && location() != DataLocation::TransientStorage && convertTo.isPointer())
if (convertTo.dataStoredIn(DataLocation::TransientStorage) && location() != convertTo.location() && convertTo.isPointer())
return false;
if (convertTo.location() == DataLocation::CallData && location() != convertTo.location())
if (convertTo.dataStoredIn(DataLocation::CallData) && location() != convertTo.location())
return false;
if ((convertTo.location() == DataLocation::Storage || convertTo.location() == DataLocation::TransientStorage) && !convertTo.isPointer())
if ((convertTo.dataStoredIn(DataLocation::Storage) || convertTo.location() == DataLocation::TransientStorage) && !convertTo.isPointer())
{
// Less restrictive conversion, since we need to copy anyway.
if (!baseType()->isImplicitlyConvertibleTo(*convertTo.baseType()))
Expand Down Expand Up @@ -2208,10 +2205,10 @@ Type const* StructType::encodingType() const
{
case DataLocation::CallData:
case DataLocation::Memory:
return TypeProvider::uint256();
return this;
case DataLocation::Storage:
case DataLocation::TransientStorage:
return this;
return TypeProvider::uint256();
}
solAssert(false, "");
}
Expand Down
2 changes: 1 addition & 1 deletion libsolidity/ast/Types.h
Expand Up @@ -1031,7 +1031,7 @@ class ContractType: public Type
class StructType: public ReferenceType
{
public:
explicit StructType(StructDefinition const& _struct, DataLocation _location = DataLocation::Storage):
explicit StructType(StructDefinition const& _struct, DataLocation _location = DataLocation::Storage): // [Amxx] TODO: Transient?
ReferenceType(_location), m_struct(_struct) {}

Category category() const override { return Category::Struct; }
Expand Down
58 changes: 33 additions & 25 deletions libsolidity/codegen/ABIFunctions.cpp
Expand Up @@ -24,6 +24,7 @@
#include <libsolidity/codegen/ABIFunctions.h>

#include <libsolidity/codegen/CompilerUtils.h>
#include <libsolidity/codegen/InstructionsUtils.h>
#include <libsolutil/CommonData.h>
#include <libsolutil/Whiskers.h>
#include <libsolutil/StringUtils.h>
Expand Down Expand Up @@ -324,6 +325,7 @@ std::string ABIFunctions::abiEncodingFunction(
else
return abiEncodingFunctionSimpleArray(*fromArray, *toArray, _options);
case DataLocation::Storage:
case DataLocation::TransientStorage:
if (fromArray->baseType()->storageBytes() <= 16)
return abiEncodingFunctionCompactStorageArray(*fromArray, *toArray, _options);
else
Expand Down Expand Up @@ -364,7 +366,7 @@ std::string ABIFunctions::abiEncodingFunction(
)");
templ("functionName", functionName);

if (_from.dataStoredIn(DataLocation::Storage))
if (_from.dataStoredInAnyOf({ DataLocation::Storage, DataLocation::TransientStorage }))
{
// special case: convert storage reference type to value type - this is only
// possible for library calls where we just forward the storage reference
Expand Down Expand Up @@ -537,7 +539,7 @@ std::string ABIFunctions::abiEncodingFunctionSimpleArray(
solAssert(_from.isDynamicallySized() == _to.isDynamicallySized(), "");
solAssert(_from.length() == _to.length(), "");
solAssert(!_from.isByteArrayOrString(), "");
if (_from.dataStoredIn(DataLocation::Storage))
if (_from.dataStoredInAnyOf({ DataLocation::Storage, DataLocation::TransientStorage }))
solAssert(_from.baseType()->storageBytes() > 16, "");

return createFunction(functionName, [&]() {
Expand Down Expand Up @@ -611,18 +613,19 @@ std::string ABIFunctions::abiEncodingFunctionSimpleArray(
templ("encodeToMemoryFun", abiEncodeAndReturnUpdatedPosFunction(*_from.baseType(), *_to.baseType(), subOptions));
switch (_from.location())
{
case DataLocation::CallData:
templ("arrayElementAccess", calldataAccessFunction(*_from.baseType()) + "(baseRef, srcPtr)");
break;
case DataLocation::Memory:
templ("arrayElementAccess", "mload(srcPtr)");
break;
case DataLocation::Storage:
case DataLocation::TransientStorage:
if (_from.baseType()->isValueType())
templ("arrayElementAccess", m_utils.readFromStorage(*_from.baseType(), 0, false) + "(srcPtr)");
else
templ("arrayElementAccess", "srcPtr");
break;
case DataLocation::CallData:
templ("arrayElementAccess", calldataAccessFunction(*_from.baseType()) + "(baseRef, srcPtr)");
break;
default:
solAssert(false, "");
}
Expand Down Expand Up @@ -683,7 +686,7 @@ std::string ABIFunctions::abiEncodingFunctionCompactStorageArray(

solAssert(_from.isDynamicallySized() == _to.isDynamicallySized(), "");
solAssert(_from.length() == _to.length(), "");
solAssert(_from.dataStoredIn(DataLocation::Storage), "");
solAssert(_from.dataStoredInAnyOf({ DataLocation::Storage, DataLocation::TransientStorage }), "");

return createFunction(functionName, [&]() {
if (_from.isByteArrayOrString())
Expand All @@ -692,7 +695,7 @@ std::string ABIFunctions::abiEncodingFunctionCompactStorageArray(
Whiskers templ(R"(
// <readableTypeNameFrom> -> <readableTypeNameTo>
function <functionName>(value, pos) -> ret {
let slotValue := sload(value)
let slotValue := <load>(value)
let length := <byteArrayLengthFunction>(slotValue)
pos := <storeLength>(pos, length)
switch and(slotValue, 1)
Expand All @@ -706,7 +709,7 @@ std::string ABIFunctions::abiEncodingFunctionCompactStorageArray(
let dataPos := <arrayDataSlot>(value)
let i := 0
for { } lt(i, length) { i := add(i, 0x20) } {
mstore(add(pos, i), sload(dataPos))
mstore(add(pos, i), <load>(dataPos))
dataPos := add(dataPos, 1)
}
ret := add(pos, <lengthPaddedLong>)
Expand All @@ -721,6 +724,7 @@ std::string ABIFunctions::abiEncodingFunctionCompactStorageArray(
templ("lengthPaddedShort", _options.padded ? "0x20" : "length");
templ("lengthPaddedLong", _options.padded ? "i" : "length");
templ("arrayDataSlot", m_utils.arrayDataAreaFunction(_from));
templ("load", LoadCode(_from));
return templ.render();
}
else
Expand Down Expand Up @@ -750,7 +754,7 @@ std::string ABIFunctions::abiEncodingFunctionCompactStorageArray(
for { } lt(add(itemCounter, sub(<itemsPerSlot>, 1)), length)
{ itemCounter := add(itemCounter, <itemsPerSlot>) }
{
let data := sload(srcPtr)
let data := <load>(srcPtr)
<#items>
<encodeToMemoryFun>(<extractFromSlot>(data), pos)
pos := add(pos, <stride>)
Expand All @@ -760,7 +764,7 @@ std::string ABIFunctions::abiEncodingFunctionCompactStorageArray(
}
// Handle the last (not necessarily full) slot specially
if <useSpill> {
let data := sload(srcPtr)
let data := <load>(srcPtr)
<#items>
if <inRange> {
<encodeToMemoryFun>(<extractFromSlot>(data), pos)
Expand All @@ -781,6 +785,7 @@ std::string ABIFunctions::abiEncodingFunctionCompactStorageArray(
templ("lengthFun", m_utils.arrayLengthFunction(_from));
templ("storeLength", arrayStoreLengthForEncodingFunction(_to, _options));
templ("dataArea", m_utils.arrayDataAreaFunction(_from));
templ("load", LoadCode(_from));
// We skip the loop for arrays that fit a single slot.
if (_from.isDynamicallySized() || _from.length() >= itemsPerSlot)
templ("useLoop", "1");
Expand Down Expand Up @@ -863,7 +868,7 @@ std::string ABIFunctions::abiEncodingFunctionStruct(
else
templ("assignEnd", "");
// to avoid multiple loads from the same slot for subsequent members
templ("init", _from.dataStoredIn(DataLocation::Storage) ? "let slotValue := 0" : "");
templ("init", _from.dataStoredInAnyOf({ DataLocation::Storage, DataLocation::TransientStorage }) ? "let slotValue := 0" : "");
u256 previousSlotOffset(-1);
u256 encodingOffset = 0;
std::vector<std::map<std::string, std::string>> members;
Expand All @@ -884,7 +889,20 @@ std::string ABIFunctions::abiEncodingFunctionStruct(

switch (_from.location())
{
case DataLocation::CallData:
{
std::string sourceOffset = toCompactHexWithPrefix(_from.calldataOffsetOfMember(member.name));
members.back()["retrieveValue"] = calldataAccessFunction(*memberTypeFrom) + "(value, add(value, " + sourceOffset + "))";
break;
}
case DataLocation::Memory:
{
std::string sourceOffset = toCompactHexWithPrefix(_from.memoryOffsetOfMember(member.name));
members.back()["retrieveValue"] = "mload(add(value, " + sourceOffset + "))";
break;
}
case DataLocation::Storage:
case DataLocation::TransientStorage:
{
solAssert(memberTypeFrom->isValueType() == memberTypeTo->isValueType(), "");
u256 storageSlotOffset;
Expand All @@ -894,31 +912,21 @@ std::string ABIFunctions::abiEncodingFunctionStruct(
{
if (storageSlotOffset != previousSlotOffset)
{
members.back()["preprocess"] = "slotValue := sload(add(value, " + toCompactHexWithPrefix(storageSlotOffset) + "))";
members.back()["preprocess"] = _from.dataStoredIn(DataLocation::TransientStorage)
? ("slotValue := sload(add(value, " + toCompactHexWithPrefix(storageSlotOffset) + "))")
: ("slotValue := tload(add(value, " + toCompactHexWithPrefix(storageSlotOffset) + "))");
previousSlotOffset = storageSlotOffset;
}
members.back()["retrieveValue"] = m_utils.extractFromStorageValue(*memberTypeFrom, intraSlotOffset) + "(slotValue)";
}
else
{
solAssert(memberTypeFrom->dataStoredIn(DataLocation::Storage), "");
solAssert(memberTypeFrom->dataStoredInAnyOf({ DataLocation::Storage, DataLocation::TransientStorage }), "");
solAssert(intraSlotOffset == 0, "");
members.back()["retrieveValue"] = "add(value, " + toCompactHexWithPrefix(storageSlotOffset) + ")";
}
break;
}
case DataLocation::Memory:
{
std::string sourceOffset = toCompactHexWithPrefix(_from.memoryOffsetOfMember(member.name));
members.back()["retrieveValue"] = "mload(add(value, " + sourceOffset + "))";
break;
}
case DataLocation::CallData:
{
std::string sourceOffset = toCompactHexWithPrefix(_from.calldataOffsetOfMember(member.name));
members.back()["retrieveValue"] = calldataAccessFunction(*memberTypeFrom) + "(value, add(value, " + sourceOffset + "))";
break;
}
default:
solAssert(false, "");
}
Expand Down
1 change: 0 additions & 1 deletion libsolidity/codegen/CompilerUtils.cpp
Expand Up @@ -1190,7 +1190,6 @@ void CompilerUtils::convertType(
std::pair<u256, unsigned> const& offsets = typeOnStack->storageOffsetsOfMember(member.name);
_context << offsets.first << Instruction::DUP3 << Instruction::ADD;
_context << u256(offsets.second);
std::cout << __FILE__ << "@" << __LINE__ << std::endl;
StorageItem(_context, *member.type, typeOnStack->location() == DataLocation::TransientStorage).retrieveValue(SourceLocation(), true);
Type const* targetMemberType = targetType->memberType(member.name);
solAssert(!!targetMemberType, "Member not found in target type.");
Expand Down
4 changes: 2 additions & 2 deletions libsolidity/codegen/YulUtilFunctions.cpp
Expand Up @@ -3288,12 +3288,12 @@ std::string YulUtilFunctions::conversionFunction(Type const& _from, Type const&
solAssert(_from == _to || _to == dynamic_cast<UserDefinedValueType const&>(_from).underlyingType(), "");
return conversionFunction(dynamic_cast<UserDefinedValueType const&>(_from).underlyingType(), _to);
}
if (_to.category() == Type::Category::UserDefinedValueType)
else if (_to.category() == Type::Category::UserDefinedValueType)
{
solAssert(_from == _to || _from.isImplicitlyConvertibleTo(dynamic_cast<UserDefinedValueType const&>(_to).underlyingType()), "");
return conversionFunction(_from, dynamic_cast<UserDefinedValueType const&>(_to).underlyingType());
}
if (_from.category() == Type::Category::Function)
else if (_from.category() == Type::Category::Function)
{
solAssert(_to.category() == Type::Category::Function, "");
FunctionType const& fromType = dynamic_cast<FunctionType const&>(_from);
Expand Down

0 comments on commit 162a193

Please sign in to comment.