diff --git a/x/ccv/provider/keeper/keeper_test.go b/x/ccv/provider/keeper/keeper_test.go index feb87490c6..90faec1c85 100644 --- a/x/ccv/provider/keeper/keeper_test.go +++ b/x/ccv/provider/keeper/keeper_test.go @@ -9,7 +9,6 @@ import ( ibctesting "github.com/cosmos/ibc-go/v7/testing" "github.com/stretchr/testify/require" - cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec" sdk "github.com/cosmos/cosmos-sdk/types" diff --git a/x/ccv/provider/keeper/relay.go b/x/ccv/provider/keeper/relay.go index 939f6d3995..1644ce263a 100644 --- a/x/ccv/provider/keeper/relay.go +++ b/x/ccv/provider/keeper/relay.go @@ -391,6 +391,14 @@ func (k Keeper) HandleSlashPacket(ctx sdk.Context, chainID string, data ccv.Slas "infractionType", data.Infraction, ) + // Check that the validator belongs to the consumer chain valset + if !k.IsConsumerValidator(ctx, chainID, providerConsAddr) { + k.Logger(ctx).Error("cannot jail validator %s that does not belong to consumer %s valset", + providerConsAddr.String(), chainID) + // drop packet + return + } + // Obtain validator from staking keeper validator, found := k.stakingKeeper.GetValidatorByConsAddr(ctx, providerConsAddr.ToSdkConsAddr()) diff --git a/x/ccv/provider/keeper/relay_test.go b/x/ccv/provider/keeper/relay_test.go index 2c5a24cab8..bb5a5227c6 100644 --- a/x/ccv/provider/keeper/relay_test.go +++ b/x/ccv/provider/keeper/relay_test.go @@ -22,6 +22,7 @@ import ( cryptotestutil "github.com/cosmos/interchain-security/v4/testutil/crypto" testkeeper "github.com/cosmos/interchain-security/v4/testutil/keeper" "github.com/cosmos/interchain-security/v4/x/ccv/provider/keeper" + "github.com/cosmos/interchain-security/v4/x/ccv/provider/types" providertypes "github.com/cosmos/interchain-security/v4/x/ccv/provider/types" ccv "github.com/cosmos/interchain-security/v4/x/ccv/types" ) @@ -136,6 +137,9 @@ func TestOnRecvDowntimeSlashPacket(t *testing.T) { // Now set slash meter to positive value and assert slash packet handled result is returned providerKeeper.SetSlashMeter(ctx, math.NewInt(5)) + // Set the consumer validator + providerKeeper.SetConsumerValidator(ctx, "chain-1", types.ConsumerValidator{ProviderConsAddr: packetData.Validator.Address}) + // Mock call to GetEffectiveValPower, so that it returns 2. providerAddr := providertypes.NewProviderConsAddress(packetData.Validator.Address) calls := []*gomock.Call{ @@ -289,8 +293,11 @@ func TestValidateSlashPacket(t *testing.T) { func TestHandleSlashPacket(t *testing.T) { chainId := "consumer-id" validVscID := uint64(234) + providerConsAddr := cryptotestutil.NewCryptoIdentityFromIntSeed(7842334).ProviderConsAddress() consumerConsAddr := cryptotestutil.NewCryptoIdentityFromIntSeed(784987634).ConsumerConsAddress() + // this "dummy" consensus address won't be stored on the provider states + dummyConsAddr := cryptotestutil.NewCryptoIdentityFromIntSeed(784987639).ConsumerConsAddress() testCases := []struct { name string @@ -299,6 +306,20 @@ func TestHandleSlashPacket(t *testing.T) { expectedCalls func(sdk.Context, testkeeper.MockedKeepers, ccv.SlashPacketData) []*gomock.Call expectedSlashAcksLen int }{ + { + "validator isn't a consumer validator", + ccv.SlashPacketData{ + Validator: abci.Validator{Address: dummyConsAddr.ToSdkConsAddr()}, + ValsetUpdateId: validVscID, + Infraction: stakingtypes.Infraction_INFRACTION_DOWNTIME, + }, + func(ctx sdk.Context, mocks testkeeper.MockedKeepers, + expectedPacketData ccv.SlashPacketData, + ) []*gomock.Call { + return []*gomock.Call{} + }, + 0, + }, { "unfound validator", ccv.SlashPacketData{ @@ -403,34 +424,36 @@ func TestHandleSlashPacket(t *testing.T) { } for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx( + t, testkeeper.NewInMemKeeperParams(t)) - providerKeeper, ctx, ctrl, mocks := testkeeper.GetProviderKeeperAndCtx( - t, testkeeper.NewInMemKeeperParams(t)) - - // Setup expected mock calls - gomock.InOrder(tc.expectedCalls(ctx, mocks, tc.packetData)...) + // Setup expected mock calls + gomock.InOrder(tc.expectedCalls(ctx, mocks, tc.packetData)...) - // Setup init chain height and a single valid valset update ID to block height mapping. - providerKeeper.SetInitChainHeight(ctx, chainId, 5) - providerKeeper.SetValsetUpdateBlockHeight(ctx, validVscID, 99) + // Setup init chain height and a single valid valset update ID to block height mapping. + providerKeeper.SetInitChainHeight(ctx, chainId, 5) + providerKeeper.SetValsetUpdateBlockHeight(ctx, validVscID, 99) - // Setup consumer address to provider address mapping. - require.NotEmpty(t, tc.packetData.Validator.Address) - providerKeeper.SetValidatorByConsumerAddr(ctx, chainId, consumerConsAddr, providerConsAddr) + // Setup consumer address to provider address mapping. + require.NotEmpty(t, tc.packetData.Validator.Address) + providerKeeper.SetValidatorByConsumerAddr(ctx, chainId, consumerConsAddr, providerConsAddr) + providerKeeper.SetConsumerValidator(ctx, chainId, types.ConsumerValidator{ProviderConsAddr: providerConsAddr.Address.Bytes()}) - // Execute method and assert expected mock calls. - providerKeeper.HandleSlashPacket(ctx, chainId, tc.packetData) + // Execute method and assert expected mock calls. + providerKeeper.HandleSlashPacket(ctx, chainId, tc.packetData) - require.Equal(t, tc.expectedSlashAcksLen, len(providerKeeper.GetSlashAcks(ctx, chainId))) + require.Equal(t, tc.expectedSlashAcksLen, len(providerKeeper.GetSlashAcks(ctx, chainId))) - if tc.expectedSlashAcksLen == 1 { - // must match the consumer address - require.Equal(t, consumerConsAddr.String(), providerKeeper.GetSlashAcks(ctx, chainId)[0]) - require.NotEqual(t, providerConsAddr.String(), providerKeeper.GetSlashAcks(ctx, chainId)[0]) - require.NotEqual(t, providerConsAddr.String(), consumerConsAddr.String()) - } + if tc.expectedSlashAcksLen == 1 { + // must match the consumer address + require.Equal(t, consumerConsAddr.String(), providerKeeper.GetSlashAcks(ctx, chainId)[0]) + require.NotEqual(t, providerConsAddr.String(), providerKeeper.GetSlashAcks(ctx, chainId)[0]) + require.NotEqual(t, providerConsAddr.String(), consumerConsAddr.String()) + } - ctrl.Finish() + ctrl.Finish() + }) } }