Skip to content

Commit

Permalink
fix!: Replace GetAllConsumerChains with lightweight version (#1946)
Browse files Browse the repository at this point in the history
* add GetAllConsumerChainIDs

* replace GetAllConsumerChains with GetAllRegisteredConsumerChainIDs

* add changelog entry

* move HasToValidate to grpc_query.go as it's used only there

* apply review suggestions
  • Loading branch information
mpoke authored and insumity committed Jun 13, 2024
1 parent 4f0308f commit c01de00
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 281 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Replace `GetAllConsumerChains` with lightweight version
(`GetAllRegisteredConsumerChainIDs`) that doesn't call into the staking module
([\#1946](https://github.com/cosmos/interchain-security/pull/1946))
31 changes: 13 additions & 18 deletions tests/mbt/driver/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
consumerkeeper "github.com/cosmos/interchain-security/v4/x/ccv/consumer/keeper"
consumertypes "github.com/cosmos/interchain-security/v4/x/ccv/consumer/types"
providerkeeper "github.com/cosmos/interchain-security/v4/x/ccv/provider/keeper"
providertypes "github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
"github.com/cosmos/interchain-security/v4/x/ccv/types"
)

Expand Down Expand Up @@ -219,11 +218,7 @@ func (s *Driver) getStateString() string {
state.WriteString("\n")

state.WriteString("Consumers Chains:\n")
consumerChains := s.providerKeeper().GetAllConsumerChains(s.providerCtx())
chainIds := make([]string, len(consumerChains))
for i, consumerChain := range consumerChains {
chainIds[i] = consumerChain.ChainId
}
chainIds := s.providerKeeper().GetAllRegisteredConsumerChainIDs(s.providerCtx())
state.WriteString(strings.Join(chainIds, ", "))
state.WriteString("\n\n")

Expand Down Expand Up @@ -261,11 +256,11 @@ func (s *Driver) getChainStateString(chain ChainId) string {
if !s.isProviderChain(chain) {
// Check whether the chain is in the consumer chains on the provider

consumerChains := s.providerKeeper().GetAllConsumerChains(s.providerCtx())
consumerChainIDs := s.providerKeeper().GetAllRegisteredConsumerChainIDs(s.providerCtx())

found := false
for _, consumerChain := range consumerChains {
if consumerChain.ChainId == string(chain) {
for _, consumerChainID := range consumerChainIDs {
if consumerChainID == string(chain) {
found = true
}
}
Expand Down Expand Up @@ -369,16 +364,16 @@ func (s *Driver) endAndBeginBlock(chain ChainId, timeAdvancement time.Duration)
return header
}

func (s *Driver) runningConsumers() []providertypes.Chain {
consumersOnProvider := s.providerKeeper().GetAllConsumerChains(s.providerCtx())
func (s *Driver) runningConsumerChainIDs() []ChainId {
consumerIDsOnProvider := s.providerKeeper().GetAllRegisteredConsumerChainIDs(s.providerCtx())

consumersWithIntactChannel := make([]providertypes.Chain, 0)
for _, consumer := range consumersOnProvider {
if s.path(ChainId(consumer.ChainId)).Path.EndpointA.GetChannel().State == channeltypes.CLOSED ||
s.path(ChainId(consumer.ChainId)).Path.EndpointB.GetChannel().State == channeltypes.CLOSED {
consumersWithIntactChannel := make([]ChainId, 0)
for _, consumerChainID := range consumerIDsOnProvider {
if s.path(ChainId(consumerChainID)).Path.EndpointA.GetChannel().State == channeltypes.CLOSED ||
s.path(ChainId(consumerChainID)).Path.EndpointB.GetChannel().State == channeltypes.CLOSED {
continue
}
consumersWithIntactChannel = append(consumersWithIntactChannel, consumer)
consumersWithIntactChannel = append(consumersWithIntactChannel, ChainId(consumerChainID))
}
return consumersWithIntactChannel
}
Expand Down Expand Up @@ -447,8 +442,8 @@ func (s *Driver) RequestSlash(
// DeliverAcks delivers, for each path,
// all possible acks (up to math.MaxInt many per path).
func (s *Driver) DeliverAcks() {
for _, chain := range s.runningConsumers() {
path := s.path(ChainId(chain.ChainId))
for _, chainID := range s.runningConsumerChainIDs() {
path := s.path(chainID)
path.DeliverAcks(path.Path.EndpointA.Chain.ChainID, math.MaxInt)
path.DeliverAcks(path.Path.EndpointB.Chain.ChainID, math.MaxInt)
}
Expand Down
72 changes: 36 additions & 36 deletions tests/mbt/driver/mbt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,21 +304,21 @@ func RunItfTrace(t *testing.T, path string) {
// needs a header of height H+1 to accept the packet
// so, we do two blocks, one with a very small increment,
// and then another to increment the rest of the time
runningConsumersBefore := driver.runningConsumers()
runningConsumerChainIDsBefore := driver.runningConsumerChainIDs()

driver.endAndBeginBlock("provider", 1*time.Nanosecond)
for _, consumer := range driver.runningConsumers() {
UpdateProviderClientOnConsumer(t, driver, consumer.ChainId)
for _, consumerChainID := range driver.runningConsumerChainIDs() {
UpdateProviderClientOnConsumer(t, driver, string(consumerChainID))
}
driver.endAndBeginBlock("provider", time.Duration(timeAdvancement)*time.Second-1*time.Nanosecond)

runningConsumersAfter := driver.runningConsumers()
runningConsumerChainIDsAfter := driver.runningConsumerChainIDs()

// the consumers that were running before but not after must have timed out
for _, consumer := range runningConsumersBefore {
for _, consumerChainID := range runningConsumerChainIDsBefore {
found := false
for _, consumerAfter := range runningConsumersAfter {
if consumerAfter.ChainId == consumer.ChainId {
for _, consumerChainIDAfter := range runningConsumerChainIDsAfter {
if consumerChainIDAfter == consumerChainID {
found = true
break
}
Expand All @@ -332,8 +332,8 @@ func RunItfTrace(t *testing.T, path string) {
// because setting up chains will modify timestamps
// when the coordinator is starting chains
lastTimestamps := make(map[ChainId]time.Time, len(consumers))
for _, consumer := range driver.runningConsumers() {
lastTimestamps[ChainId(consumer.ChainId)] = driver.runningTime(ChainId(consumer.ChainId))
for _, consumerChainID := range driver.runningConsumerChainIDs() {
lastTimestamps[consumerChainID] = driver.runningTime(consumerChainID)
}

driver.coordinator.CurrentTime = driver.runningTime("provider")
Expand Down Expand Up @@ -364,12 +364,12 @@ func RunItfTrace(t *testing.T, path string) {
// for all connected consumers, update the clients...
// unless it was the last consumer to be started, in which case it already has the header
// as we called driver.setupConsumer
for _, consumer := range driver.runningConsumers() {
if len(consumersToStart) > 0 && consumer.ChainId == consumersToStart[len(consumersToStart)-1].Value.(string) {
for _, consumerChainID := range driver.runningConsumerChainIDs() {
if len(consumersToStart) > 0 && string(consumerChainID) == consumersToStart[len(consumersToStart)-1].Value.(string) {
continue
}

UpdateProviderClientOnConsumer(t, driver, consumer.ChainId)
UpdateProviderClientOnConsumer(t, driver, string(consumerChainID))
}

case "EndAndBeginBlockForConsumer":
Expand Down Expand Up @@ -490,33 +490,33 @@ func RunItfTrace(t *testing.T, path string) {
t.Logf("Comparing model state to actual state...")

// compare the running consumers
modelRunningConsumers := RunningConsumers(currentModelState)
modelRunningConsumerChainIDs := RunningConsumers(currentModelState)

systemRunningConsumers := driver.runningConsumers()
actualRunningConsumers := make([]string, len(systemRunningConsumers))
for i, chain := range systemRunningConsumers {
actualRunningConsumers[i] = chain.ChainId
systemRunningConsumerChainIDs := driver.runningConsumerChainIDs()
actualRunningConsumerChainIDs := make([]string, len(systemRunningConsumerChainIDs))
for i, chainID := range systemRunningConsumerChainIDs {
actualRunningConsumerChainIDs[i] = string(chainID)
}

// sort the slices so that we can compare them
sort.Slice(modelRunningConsumers, func(i, j int) bool {
return modelRunningConsumers[i] < modelRunningConsumers[j]
sort.Slice(modelRunningConsumerChainIDs, func(i, j int) bool {
return modelRunningConsumerChainIDs[i] < modelRunningConsumerChainIDs[j]
})
sort.Slice(actualRunningConsumers, func(i, j int) bool {
return actualRunningConsumers[i] < actualRunningConsumers[j]
sort.Slice(actualRunningConsumerChainIDs, func(i, j int) bool {
return actualRunningConsumerChainIDs[i] < actualRunningConsumerChainIDs[j]
})

require.Equal(t, modelRunningConsumers, actualRunningConsumers, "Running consumers do not match")
require.Equal(t, modelRunningConsumerChainIDs, actualRunningConsumerChainIDs, "Running consumers do not match")

// check validator sets - provider current validator set should be the one from the staking keeper
CompareValidatorSets(t, driver, currentModelState, actualRunningConsumers, realAddrsToModelConsAddrs)
CompareValidatorSets(t, driver, currentModelState, actualRunningConsumerChainIDs, realAddrsToModelConsAddrs)

// check times - sanity check that the block times match the ones from the model
CompareTimes(driver, actualRunningConsumers, currentModelState, timeOffset)
CompareTimes(driver, actualRunningConsumerChainIDs, currentModelState, timeOffset)

// check sent packets: we check that the package queues in the model and the system have the same length.
for _, consumer := range actualRunningConsumers {
ComparePacketQueues(t, driver, currentModelState, consumer, timeOffset)
for _, consumerChainID := range actualRunningConsumerChainIDs {
ComparePacketQueues(t, driver, currentModelState, consumerChainID, timeOffset)
}
// compare that the sent packets on the proider match the model
CompareSentPacketsOnProvider(driver, currentModelState, timeOffset)
Expand All @@ -526,8 +526,8 @@ func RunItfTrace(t *testing.T, path string) {
CompareJailedValidators(driver, currentModelState, timeOffset, addressMap)

// for all newly sent vsc packets, figure out which vsc id in the model they correspond to
for _, consumer := range actualRunningConsumers {
actualPackets := driver.packetQueue(PROVIDER, ChainId(consumer))
for _, consumerChainID := range actualRunningConsumerChainIDs {
actualPackets := driver.packetQueue(PROVIDER, ChainId(consumerChainID))
actualNewPackets := make([]types.ValidatorSetChangePacketData, 0)
for _, packet := range actualPackets {

Expand All @@ -543,7 +543,7 @@ func RunItfTrace(t *testing.T, path string) {
actualNewPackets = append(actualNewPackets, packetData)
}

modelPackets := PacketQueue(currentModelState, PROVIDER, consumer)
modelPackets := PacketQueue(currentModelState, PROVIDER, consumerChainID)
newModelVscIds := make([]uint64, 0)
for _, packet := range modelPackets {
modelVscId := uint64(packet.Value.(itf.MapExprType)["value"].Value.(itf.MapExprType)["id"].Value.(int64))
Expand Down Expand Up @@ -781,15 +781,15 @@ func CompareValSet(modelValSet map[string]itf.Expr, systemValSet map[string]int6
}

func CompareSentPacketsOnProvider(driver *Driver, currentModelState map[string]itf.Expr, timeOffset time.Time) {
for _, consumer := range driver.runningConsumers() {
vscSendTimestamps := driver.providerKeeper().GetAllVscSendTimestamps(driver.providerCtx(), consumer.ChainId)
for _, consumerChainID := range driver.runningConsumerChainIDs() {
vscSendTimestamps := driver.providerKeeper().GetAllVscSendTimestamps(driver.providerCtx(), string(consumerChainID))

actualVscSendTimestamps := make([]time.Time, 0)
for _, vscSendTimestamp := range vscSendTimestamps {
actualVscSendTimestamps = append(actualVscSendTimestamps, vscSendTimestamp.Timestamp)
}

modelVscSendTimestamps := VscSendTimestamps(currentModelState, consumer.ChainId)
modelVscSendTimestamps := VscSendTimestamps(currentModelState, string(consumerChainID))

for i, modelVscSendTimestamp := range modelVscSendTimestamps {
actualTimeWithOffset := actualVscSendTimestamps[i].Unix() - timeOffset.Unix()
Expand All @@ -798,7 +798,7 @@ func CompareSentPacketsOnProvider(driver *Driver, currentModelState map[string]i
modelVscSendTimestamp,
actualTimeWithOffset,
"Vsc send timestamps do not match for consumer %v",
consumer.ChainId,
consumerChainID,
)
}
}
Expand Down Expand Up @@ -852,9 +852,9 @@ func (s *Stats) EnterStats(driver *Driver) {

// max number of in-flight packets
inFlightPackets := 0
for _, consumer := range driver.runningConsumers() {
inFlightPackets += len(driver.packetQueue(PROVIDER, ChainId(consumer.ChainId)))
inFlightPackets += len(driver.packetQueue(ChainId(consumer.ChainId), PROVIDER))
for _, consumerChainID := range driver.runningConsumerChainIDs() {
inFlightPackets += len(driver.packetQueue(PROVIDER, consumerChainID))
inFlightPackets += len(driver.packetQueue(consumerChainID, PROVIDER))
}
if inFlightPackets > s.maxNumInFlightPackets {
s.maxNumInFlightPackets = inFlightPackets
Expand Down
10 changes: 5 additions & 5 deletions x/ccv/provider/keeper/distribution.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ func (k Keeper) AllocateTokens(ctx sdk.Context) {
}

// Iterate over all registered consumer chains
for _, consumer := range k.GetAllConsumerChains(ctx) {
for _, consumerChainID := range k.GetAllRegisteredConsumerChainIDs(ctx) {
// transfer the consumer rewards to the distribution module account
// note that the rewards transferred are only consumer whitelisted denoms
rewardsCollected, err := k.TransferConsumerRewardsToDistributionModule(ctx, consumer.ChainId)
rewardsCollected, err := k.TransferConsumerRewardsToDistributionModule(ctx, consumerChainID)
if err != nil {
k.Logger(ctx).Error(
"fail to transfer rewards to distribution module for chain %s: %s",
consumer.ChainId,
consumerChainID,
err,
)
continue
Expand All @@ -101,7 +101,7 @@ func (k Keeper) AllocateTokens(ctx sdk.Context) {
// temporary workaround to keep CanWithdrawInvariant happy
// general discussions here: https://github.com/cosmos/cosmos-sdk/issues/2906#issuecomment-441867634
feePool := k.distributionKeeper.GetFeePool(ctx)
if k.ComputeConsumerTotalVotingPower(ctx, consumer.ChainId) == 0 {
if k.ComputeConsumerTotalVotingPower(ctx, consumerChainID) == 0 {
feePool.CommunityPool = feePool.CommunityPool.Add(rewardsCollectedDec...)
k.distributionKeeper.SetFeePool(ctx, feePool)
return
Expand All @@ -116,7 +116,7 @@ func (k Keeper) AllocateTokens(ctx sdk.Context) {
// allocate tokens to consumer validators
feeAllocated := k.AllocateTokensToConsumerValidators(
ctx,
consumer.ChainId,
consumerChainID,
feeMultiplier,
)
remaining = remaining.Sub(feeAllocated)
Expand Down
36 changes: 20 additions & 16 deletions x/ccv/provider/keeper/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,47 +108,51 @@ func (k Keeper) InitGenesis(ctx sdk.Context, genState *types.GenesisState) {
// ExportGenesis returns the CCV provider module's exported genesis
func (k Keeper) ExportGenesis(ctx sdk.Context) *types.GenesisState {
// get a list of all registered consumer chains
registeredChains := k.GetAllConsumerChains(ctx)
registeredChainIDs := k.GetAllRegisteredConsumerChainIDs(ctx)

var exportedVscSendTimestamps []types.ExportedVscSendTimestamp
// export states for each consumer chains
var consumerStates []types.ConsumerState
for _, chain := range registeredChains {
gen, found := k.GetConsumerGenesis(ctx, chain.ChainId)
for _, chainID := range registeredChainIDs {
// no need for the second return value of GetConsumerClientId
// as GetAllRegisteredConsumerChainIDs already iterated through
// the entire prefix range
clientID, _ := k.GetConsumerClientId(ctx, chainID)
gen, found := k.GetConsumerGenesis(ctx, chainID)
if !found {
panic(fmt.Errorf("cannot find genesis for consumer chain %s with client %s", chain.ChainId, chain.ClientId))
panic(fmt.Errorf("cannot find genesis for consumer chain %s with client %s", chainID, clientID))
}

// initial consumer chain states
cs := types.ConsumerState{
ChainId: chain.ChainId,
ClientId: chain.ClientId,
ChainId: chainID,
ClientId: clientID,
ConsumerGenesis: gen,
UnbondingOpsIndex: k.GetAllUnbondingOpIndexes(ctx, chain.ChainId),
UnbondingOpsIndex: k.GetAllUnbondingOpIndexes(ctx, chainID),
}

// try to find channel id for the current consumer chain
channelId, found := k.GetChainToChannel(ctx, chain.ChainId)
channelId, found := k.GetChainToChannel(ctx, chainID)
if found {
cs.ChannelId = channelId
cs.InitialHeight, found = k.GetInitChainHeight(ctx, chain.ChainId)
cs.InitialHeight, found = k.GetInitChainHeight(ctx, chainID)
if !found {
panic(fmt.Errorf("cannot find init height for consumer chain %s", chain.ChainId))
panic(fmt.Errorf("cannot find init height for consumer chain %s", chainID))
}
cs.SlashDowntimeAck = k.GetSlashAcks(ctx, chain.ChainId)
cs.SlashDowntimeAck = k.GetSlashAcks(ctx, chainID)
}

cs.PendingValsetChanges = k.GetPendingVSCPackets(ctx, chain.ChainId)
cs.PendingValsetChanges = k.GetPendingVSCPackets(ctx, chainID)
consumerStates = append(consumerStates, cs)

vscSendTimestamps := k.GetAllVscSendTimestamps(ctx, chain.ChainId)
exportedVscSendTimestamps = append(exportedVscSendTimestamps, types.ExportedVscSendTimestamp{ChainId: chain.ChainId, VscSendTimestamps: vscSendTimestamps})
vscSendTimestamps := k.GetAllVscSendTimestamps(ctx, chainID)
exportedVscSendTimestamps = append(exportedVscSendTimestamps, types.ExportedVscSendTimestamp{ChainId: chainID, VscSendTimestamps: vscSendTimestamps})
}

// ConsumerAddrsToPrune are added only for registered consumer chains
consumerAddrsToPrune := []types.ConsumerAddrsToPrune{}
for _, chain := range registeredChains {
consumerAddrsToPrune = append(consumerAddrsToPrune, k.GetAllConsumerAddrsToPrune(ctx, chain.ChainId)...)
for _, chainID := range registeredChainIDs {
consumerAddrsToPrune = append(consumerAddrsToPrune, k.GetAllConsumerAddrsToPrune(ctx, chainID)...)
}

params := k.GetParams(ctx)
Expand Down
Loading

0 comments on commit c01de00

Please sign in to comment.