Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove claimRewards double call #228

Merged
merged 7 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
139 changes: 76 additions & 63 deletions src/aave-v3/SupplyVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@ contract SupplyVault is ISupplyVault, SupplyVaultBase {
/// @notice Emitted when rewards of an asset are accrued on behalf of a user.
/// @param rewardToken The address of the reward token.
/// @param user The address of the user that rewards are accrued on behalf of.
/// @param rewardsIndex The index of the asset distribution on behalf of the user.
/// @param accruedRewards The amount of rewards accrued.
/// @param index The index of the asset distribution on behalf of the user.
/// @param unclaimed The new unclaimed amount of the user.
event Accrued(
address indexed rewardToken,
address indexed user,
uint256 rewardsIndex,
uint256 accruedRewards
uint256 index,
uint256 unclaimed
);

/// @notice Emitted when rewards of an asset are claimed on behalf of a user.
/// @param rewardToken The address of the reward token.
/// @param user The address of the user that rewards are claimed on behalf of.
/// @param claimedRewards The amount of rewards claimed.
event Claimed(address indexed rewardToken, address indexed user, uint256 claimedRewards);
/// @param claimed The amount of rewards claimed.
event Claimed(address indexed rewardToken, address indexed user, uint256 claimed);

/// STRUCTS ///

Expand Down Expand Up @@ -91,25 +91,18 @@ contract SupplyVault is ISupplyVault, SupplyVaultBase {
external
returns (address[] memory rewardTokens, uint256[] memory claimedAmounts)
{
_accrueUnclaimedRewards(_user);

rewardTokens = morpho.rewardsController().getRewardsByAsset(poolToken);

claimedAmounts = new uint256[](rewardTokens.length);
(rewardTokens, claimedAmounts) = _accrueUnclaimedRewards(_user);

for (uint256 i; i < rewardTokens.length; ++i) {
address rewardToken = rewardTokens[i];
UserRewardsData storage userRewardsData = userRewards[rewardToken][_user];
uint256 claimedAmount = claimedAmounts[i];
if (claimedAmount == 0) continue;

uint256 unclaimedAmount = userRewardsData.unclaimed;
if (unclaimedAmount > 0) {
claimedAmounts[i] = unclaimedAmount;
userRewardsData.unclaimed = 0;
address rewardToken = rewardTokens[i];
userRewards[rewardToken][_user].unclaimed = 0;

ERC20(rewardToken).safeTransfer(_user, unclaimedAmount);
ERC20(rewardToken).safeTransfer(_user, claimedAmount);

emit Claimed(rewardToken, _user, unclaimedAmount);
}
emit Claimed(rewardToken, _user, claimedAmount);
}
}

Expand Down Expand Up @@ -168,69 +161,89 @@ contract SupplyVault is ISupplyVault, SupplyVaultBase {

/// INTERNAL ///

function _deposit(
address _caller,
address _receiver,
uint256 _assets,
uint256 _shares
) internal override {
_accrueUnclaimedRewards(_receiver);
super._deposit(_caller, _receiver, _assets, _shares);
}

function _withdraw(
address _caller,
address _receiver,
address _owner,
uint256 _assets,
uint256 _shares
) internal override {
_accrueUnclaimedRewards(_owner);
super._withdraw(_caller, _receiver, _owner, _assets, _shares);
}

function _beforeTokenTransfer(
address _from,
address _to,
uint256 _amount
) internal virtual override {
_accrueUnclaimedRewards(_from);
_accrueUnclaimedRewards(_to);
(address[] memory rewardTokens, uint256[] memory rewardsIndexes) = _claimVaultRewards();
_accrueUnclaimedRewardsFromRewardIndexes(_from, rewardTokens, rewardsIndexes);
_accrueUnclaimedRewardsFromRewardIndexes(_to, rewardTokens, rewardsIndexes);

super._beforeTokenTransfer(_from, _to, _amount);
}

function _accrueUnclaimedRewards(address _user) internal {
address[] memory rewardTokens;
uint256[] memory claimedAmounts;
function _claimVaultRewards()
internal
returns (address[] memory rewardTokens, uint256[] memory rewardsIndexes)
{
address[] memory poolTokens = new address[](1);
poolTokens[0] = poolToken;

{
address[] memory poolTokens = new address[](1);
poolTokens[0] = poolToken;
uint256[] memory claimedAmounts;
(rewardTokens, claimedAmounts) = morpho.claimRewards(poolTokens, false);

(rewardTokens, claimedAmounts) = morpho.claimRewards(poolTokens, false);
}
rewardsIndexes = new uint256[](rewardTokens.length);

uint256 supply = totalSupply();
for (uint256 i; i < rewardTokens.length; ++i) {
address rewardToken = rewardTokens[i];
uint256 claimedAmount = claimedAmounts[i];
uint256 rewardsIndexMem = rewardsIndex[rewardToken];
uint256 newRewardIndex = rewardsIndex[rewardToken] +
_getUnaccruedRewardIndex(claimedAmounts[i], supply);

if (claimedAmount > 0) {
rewardsIndexMem += _getUnaccruedRewardIndex(claimedAmount, totalSupply());
rewardsIndex[rewardToken] = rewardsIndexMem.safeCastTo128();
}
rewardsIndexes[i] = newRewardIndex;
rewardsIndex[rewardToken] = newRewardIndex.safeCastTo128();
}
}

function _accrueUnclaimedRewards(address _user)
internal
returns (address[] memory rewardTokens, uint256[] memory unclaimedAmounts)
{
uint256[] memory rewardsIndexes;
(rewardTokens, rewardsIndexes) = _claimVaultRewards();

unclaimedAmounts = _accrueUnclaimedRewardsFromRewardIndexes(
_user,
rewardTokens,
rewardsIndexes
);
}

function _accrueUnclaimedRewardsFromRewardIndexes(
address _user,
address[] memory _rewardTokens,
uint256[] memory _rewardIndexes
) internal returns (uint256[] memory unclaimedAmounts) {
if (_user == address(0)) return unclaimedAmounts;

unclaimedAmounts = new uint256[](_rewardTokens.length);

for (uint256 i; i < _rewardTokens.length; ++i) {
address rewardToken = _rewardTokens[i];
uint256 rewardIndex = _rewardIndexes[i];

UserRewardsData storage userRewardsData = userRewards[rewardToken][_user];
if (rewardsIndexMem > userRewardsData.index) {
uint256 accruedReward = _getUnaccruedRewardsFromRewardsIndexAccrual(

// Safe because we always have `rewardsIndex` >= `userRewardsData.index`.
uint256 rewardsIndexDiff;
unchecked {
rewardsIndexDiff = rewardIndex - userRewardsData.index;
}

uint256 unclaimedAmount = userRewardsData.unclaimed;
if (rewardsIndexDiff > 0) {
unclaimedAmount += _getUnaccruedRewardsFromRewardsIndexAccrual(
_user,
rewardsIndexMem - userRewardsData.index
rewardsIndexDiff
);
userRewardsData.unclaimed += accruedReward.safeCastTo128();
userRewardsData.index = rewardsIndexMem.safeCastTo128();
userRewardsData.unclaimed = unclaimedAmount.safeCastTo128();
userRewardsData.index = rewardIndex.safeCastTo128();

emit Accrued(rewardToken, _user, rewardsIndexMem, accruedReward);
emit Accrued(rewardToken, _user, rewardIndex, unclaimedAmount);
}

unclaimedAmounts[i] = unclaimedAmount;
}
}

Expand Down
65 changes: 28 additions & 37 deletions src/compound/SupplyVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -78,76 +78,67 @@ contract SupplyVault is ISupplyVault, SupplyVaultBase {
/// @return rewardsAmount The amount of rewards claimed.
function claimRewards(address _user) external returns (uint256 rewardsAmount) {
rewardsAmount = _accrueUnclaimedRewards(_user);
if (rewardsAmount == 0) return rewardsAmount;

if (rewardsAmount > 0) {
userRewards[_user].unclaimed = 0;
userRewards[_user].unclaimed = 0;

comp.safeTransfer(_user, rewardsAmount);
}
comp.safeTransfer(_user, rewardsAmount);

emit Claimed(_user, rewardsAmount);
}

/// INTERNAL ///

function _deposit(
address _caller,
address _receiver,
uint256 _assets,
uint256 _shares
) internal override {
_accrueUnclaimedRewards(_receiver);
super._deposit(_caller, _receiver, _assets, _shares);
}

function _withdraw(
address _caller,
address _receiver,
address _owner,
uint256 _assets,
uint256 _shares
) internal override {
_accrueUnclaimedRewards(_owner);
super._withdraw(_caller, _receiver, _owner, _assets, _shares);
}

function _beforeTokenTransfer(
address from,
address to,
uint256 amount
) internal override {
_accrueUnclaimedRewards(from);
_accrueUnclaimedRewards(to);
uint256 newRewardsIndex = _claimVaultRewards();
_accrueUnclaimedRewardsFromRewardsIndex(from, newRewardsIndex);
_accrueUnclaimedRewardsFromRewardsIndex(to, newRewardsIndex);

super._beforeTokenTransfer(from, to, amount);
}

function _accrueUnclaimedRewards(address _user) internal returns (uint256 unclaimed) {
uint256 supply = totalSupply();
uint256 rewardsIndexMem = rewardsIndex;
function _claimVaultRewards() internal returns (uint256 newRewardsIndex) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change the name? IMO it's clear enough

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there's the external function claimRewards and it may suggest that this internal function is the implementation to claimRewards

newRewardsIndex = rewardsIndex;

uint256 supply = totalSupply();
if (supply > 0) {
address[] memory poolTokens = new address[](1);
poolTokens[0] = poolToken;
rewardsIndexMem += morpho.claimRewards(poolTokens, false).divWadDown(supply);
rewardsIndex = rewardsIndexMem;

newRewardsIndex += morpho.claimRewards(poolTokens, false).divWadDown(supply);
rewardsIndex = newRewardsIndex;
}
}

function _accrueUnclaimedRewards(address _user) internal returns (uint256 unclaimed) {
return _accrueUnclaimedRewardsFromRewardsIndex(_user, _claimVaultRewards());
}

function _accrueUnclaimedRewardsFromRewardsIndex(address _user, uint256 _newRewardsIndex)
Rubilmax marked this conversation as resolved.
Show resolved Hide resolved
internal
returns (uint256 unclaimed)
{
if (_user == address(0)) return unclaimed;

UserRewardsData storage userRewardsData = userRewards[_user];
uint256 rewardsIndexDiff;

// Safe because we always have `rewardsIndex` >= `userRewardsData.index`.
unchecked {
rewardsIndexDiff = rewardsIndexMem - userRewardsData.index;
rewardsIndexDiff = _newRewardsIndex - userRewardsData.index;
}

unclaimed = userRewardsData.unclaimed;
if (rewardsIndexDiff > 0) {
unclaimed += balanceOf(_user).mulWadDown(rewardsIndexDiff);
userRewardsData.unclaimed = unclaimed.safeCastTo128();
}
userRewardsData.index = _newRewardsIndex.safeCastTo128();

userRewardsData.index = rewardsIndexMem.safeCastTo128();

emit Accrued(_user, rewardsIndexMem, unclaimed);
emit Accrued(_user, _newRewardsIndex, unclaimed);
}
}
}
15 changes: 0 additions & 15 deletions test/aave-v2/TestSupplyVault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -168,21 +168,6 @@ contract TestSupplyVault is TestSetupVaults {
vaultSupplier1.redeemVault(daiSupplyVault, shares + 1);
}

// TODO: fix this test by using updated indexes in previewMint
// function testShouldMintCorrectAmountWhenMorphoPoolIndexesOutdated() public {
// uint256 amount = 10_000 ether;

// vaultSupplier1.depositVault(daiSupplyVault, amount);

// vm.roll(block.number + 100_000);
// vm.warp(block.timestamp + 1_000_000);

// uint256 assets = vaultSupplier2.mintVault(daiSupplyVault, amount);
// uint256 shares = vaultSupplier2.withdrawVault(daiSupplyVault, assets);

// assertEq(shares, amount, "unexpected redeemed shares");
// }

function testShouldDepositCorrectAmountWhenMorphoPoolIndexesOutdated() public {
uint256 amount = 10_000 ether;

Expand Down
15 changes: 0 additions & 15 deletions test/aave-v3/TestSupplyVault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -798,21 +798,6 @@ contract TestSupplyVault is TestSetupVaults {
// In the end, vaultSupplier1 got 2 * X rewards while vaultSupplier2 got 3 * X
}

// TODO: fix this test by using updated indexes in previewMint
// function testShouldMintCorrectAmountWhenMorphoPoolIndexesOutdated() public {
// uint256 amount = 10_000 ether;

// vaultSupplier1.depositVault(daiSupplyVault, amount);

// vm.roll(block.number + 100_000);
// vm.warp(block.timestamp + 1_000_000);

// uint256 assets = vaultSupplier2.mintVault(daiSupplyVault, amount);
// uint256 shares = vaultSupplier2.withdrawVault(daiSupplyVault, assets);

// assertEq(shares, amount, "unexpected redeemed shares");
// }

function testShouldDepositCorrectAmountWhenMorphoPoolIndexesOutdated() public {
uint256 amount = 10_000 ether;

Expand Down
15 changes: 0 additions & 15 deletions test/compound/TestSupplyVault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -469,21 +469,6 @@ contract TestSupplyVault is TestSetupVaults {
assertApproxEqAbs(uint256(userReward1_1), userReward1_2, 1);
}

// TODO: fix this test by using updated indexes in previewMint
// function testShouldMintCorrectAmountWhenMorphoPoolIndexesOutdated() public {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you re-try this test? Perhaps redundant to new tests though..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is redundant yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it is passing, except on aave-v3 on which the lens was not added

// uint256 amount = 10_000 ether;

// vaultSupplier1.depositVault(daiSupplyVault, amount);

// vm.roll(block.number + 100_000);
// vm.warp(block.timestamp + 1_000_000);

// uint256 assets = vaultSupplier2.mintVault(daiSupplyVault, amount);
// uint256 shares = vaultSupplier2.withdrawVault(daiSupplyVault, assets);

// assertEq(shares, amount, "unexpected redeemed shares");
// }

function testShouldDepositCorrectAmountWhenMorphoPoolIndexesOutdated() public {
uint256 amount = 10_000 ether;

Expand Down