Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove claimRewards double call #228

Merged
merged 7 commits into from
Dec 19, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
128 changes: 64 additions & 64 deletions src/aave-v3/SupplyVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {FixedPointMathLib} from "@rari-capital/solmate/src/utils/FixedPointMathL
import {SafeCastLib} from "@rari-capital/solmate/src/utils/SafeCastLib.sol";

import {SupplyVaultBase} from "./SupplyVaultBase.sol";
import "@forge-std/console.sol";
Rubilmax marked this conversation as resolved.
Show resolved Hide resolved

/// @title SupplyVault.
/// @author Morpho Labs.
Expand All @@ -24,20 +25,20 @@ contract SupplyVault is ISupplyVault, SupplyVaultBase {
/// @notice Emitted when rewards of an asset are accrued on behalf of a user.
/// @param rewardToken The address of the reward token.
/// @param user The address of the user that rewards are accrued on behalf of.
/// @param rewardsIndex The index of the asset distribution on behalf of the user.
/// @param accruedRewards The amount of rewards accrued.
/// @param index The index of the asset distribution on behalf of the user.
/// @param unclaimed The new unclaimed amount of the user.
event Accrued(
address indexed rewardToken,
address indexed user,
uint256 rewardsIndex,
uint256 accruedRewards
uint256 index,
uint256 unclaimed
);

/// @notice Emitted when rewards of an asset are claimed on behalf of a user.
/// @param rewardToken The address of the reward token.
/// @param user The address of the user that rewards are claimed on behalf of.
/// @param claimedRewards The amount of rewards claimed.
event Claimed(address indexed rewardToken, address indexed user, uint256 claimedRewards);
/// @param claimed The amount of rewards claimed.
event Claimed(address indexed rewardToken, address indexed user, uint256 claimed);

/// STRUCTS ///

Expand Down Expand Up @@ -91,25 +92,18 @@ contract SupplyVault is ISupplyVault, SupplyVaultBase {
external
returns (address[] memory rewardTokens, uint256[] memory claimedAmounts)
{
_accrueUnclaimedRewards(_user);

rewardTokens = morpho.rewardsController().getRewardsByAsset(poolToken);

claimedAmounts = new uint256[](rewardTokens.length);
(rewardTokens, claimedAmounts) = _accrueUnclaimedRewards(_user);

for (uint256 i; i < rewardTokens.length; ++i) {
address rewardToken = rewardTokens[i];
UserRewardsData storage userRewardsData = userRewards[rewardToken][_user];
uint256 claimedAmount = claimedAmounts[i];
if (claimedAmount == 0) continue;

uint256 unclaimedAmount = userRewardsData.unclaimed;
if (unclaimedAmount > 0) {
claimedAmounts[i] = unclaimedAmount;
userRewardsData.unclaimed = 0;
address rewardToken = rewardTokens[i];
userRewards[rewardToken][_user].unclaimed = 0;

ERC20(rewardToken).safeTransfer(_user, unclaimedAmount);
ERC20(rewardToken).safeTransfer(_user, claimedAmount);

emit Claimed(rewardToken, _user, unclaimedAmount);
}
emit Claimed(rewardToken, _user, claimedAmount);
}
}

Expand Down Expand Up @@ -168,81 +162,87 @@ contract SupplyVault is ISupplyVault, SupplyVaultBase {

/// INTERNAL ///

function _deposit(
address _caller,
address _receiver,
uint256 _assets,
uint256 _shares
) internal override {
_accrueUnclaimedRewards(_receiver);
super._deposit(_caller, _receiver, _assets, _shares);
}

function _withdraw(
address _caller,
address _receiver,
address _owner,
uint256 _assets,
uint256 _shares
) internal override {
_accrueUnclaimedRewards(_owner);
super._withdraw(_caller, _receiver, _owner, _assets, _shares);
}

function _beforeTokenTransfer(
address _from,
address _to,
uint256 _amount
) internal virtual override {
(address[] memory rewardTokens, uint256[] memory claimedAmounts) = _claimRewards();
_accrueUnclaimedRewardsFromClaimedRewards(_from, rewardTokens, claimedAmounts);
_accrueUnclaimedRewardsFromClaimedRewards(_to, rewardTokens, claimedAmounts);
(address[] memory rewardTokens, uint256[] memory newRewardsIndexes) = _claimVaultRewards();
_accrueUnclaimedRewardsFromRewardsIndexes(_from, rewardTokens, newRewardsIndexes);
_accrueUnclaimedRewardsFromRewardsIndexes(_to, rewardTokens, newRewardsIndexes);

super._beforeTokenTransfer(_from, _to, _amount);
}

function _claimRewards()
function _claimVaultRewards()
internal
returns (address[] memory rewardTokens, uint256[] memory claimedAmounts)
returns (address[] memory rewardTokens, uint256[] memory newRewardsIndexes)
{
address[] memory poolTokens = new address[](1);
poolTokens[0] = poolToken;

uint256[] memory claimedAmounts;
(rewardTokens, claimedAmounts) = morpho.claimRewards(poolTokens, false);

newRewardsIndexes = new uint256[](rewardTokens.length);

uint256 supply = totalSupply();
for (uint256 i; i < rewardTokens.length; ++i) {
address rewardToken = rewardTokens[i];
uint256 newRewardIndex = rewardsIndex[rewardToken] +
_getUnaccruedRewardIndex(claimedAmounts[i], supply);

newRewardsIndexes[i] = newRewardIndex;
rewardsIndex[rewardToken] = newRewardIndex.safeCastTo128();
}
}

function _accrueUnclaimedRewards(address _user) internal {
(address[] memory rewardTokens, uint256[] memory claimedAmounts) = _claimRewards();
function _accrueUnclaimedRewards(address _user)
internal
returns (address[] memory rewardTokens, uint256[] memory unclaimedAmounts)
{
uint256[] memory newRewardsIndexes;
(rewardTokens, newRewardsIndexes) = _claimVaultRewards();

_accrueUnclaimedRewardsFromClaimedRewards(_user, rewardTokens, claimedAmounts);
unclaimedAmounts = _accrueUnclaimedRewardsFromRewardsIndexes(
_user,
rewardTokens,
newRewardsIndexes
);
}

function _accrueUnclaimedRewardsFromClaimedRewards(
function _accrueUnclaimedRewardsFromRewardsIndexes(
address _user,
address[] memory _rewardTokens,
uint256[] memory _claimedAmounts
) internal {
uint256[] memory _newRewardsIndexes
Rubilmax marked this conversation as resolved.
Show resolved Hide resolved
) internal returns (uint256[] memory unclaimedAmounts) {
unclaimedAmounts = new uint256[](_rewardTokens.length);

for (uint256 i; i < _rewardTokens.length; ++i) {
address rewardToken = _rewardTokens[i];
uint256 claimedAmount = _claimedAmounts[i];
uint256 rewardsIndexMem = rewardsIndex[rewardToken];
uint256 newRewardsIndex = _newRewardsIndexes[i];

UserRewardsData storage userRewardsData = userRewards[rewardToken][_user];

if (claimedAmount > 0) {
rewardsIndexMem += _getUnaccruedRewardIndex(claimedAmount, totalSupply());
rewardsIndex[rewardToken] = rewardsIndexMem.safeCastTo128();
// Safe because we always have `rewardsIndex` >= `userRewardsData.index`.
uint256 rewardsIndexDiff;
unchecked {
rewardsIndexDiff = newRewardsIndex - userRewardsData.index;
}

UserRewardsData storage userRewardsData = userRewards[rewardToken][_user];
if (rewardsIndexMem > userRewardsData.index) {
uint256 accruedReward = _getUnaccruedRewardsFromRewardsIndexAccrual(
uint256 unclaimedAmount = userRewardsData.unclaimed;
if (rewardsIndexDiff > 0) {
unclaimedAmount += _getUnaccruedRewardsFromRewardsIndexAccrual(
_user,
rewardsIndexMem - userRewardsData.index
rewardsIndexDiff
);
userRewardsData.unclaimed += accruedReward.safeCastTo128();
userRewardsData.index = rewardsIndexMem.safeCastTo128();
userRewardsData.unclaimed = unclaimedAmount.safeCastTo128();
userRewardsData.index = newRewardsIndex.safeCastTo128();

emit Accrued(rewardToken, _user, rewardsIndexMem, accruedReward);
emit Accrued(rewardToken, _user, newRewardsIndex, unclaimedAmount);
}

unclaimedAmounts[i] = unclaimedAmount;
}
}

Expand Down