-
Notifications
You must be signed in to change notification settings - Fork 111
/
expectations.go
139 lines (119 loc) · 5.62 KB
/
expectations.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package keeper
import (
time "time"
sdk "github.com/cosmos/cosmos-sdk/types"
capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
clienttypes "github.com/cosmos/ibc-go/v4/modules/core/02-client/types"
conntypes "github.com/cosmos/ibc-go/v4/modules/core/03-connection/types"
channeltypes "github.com/cosmos/ibc-go/v4/modules/core/04-channel/types"
ibctmtypes "github.com/cosmos/ibc-go/v4/modules/light-clients/07-tendermint/types"
providertypes "github.com/cosmos/interchain-security/v2/x/ccv/provider/types"
"github.com/golang/mock/gomock"
host "github.com/cosmos/ibc-go/v4/modules/core/24-host"
ccv "github.com/cosmos/interchain-security/v2/x/ccv/types"
extra "github.com/oxyno-zeta/gomock-extra-matcher"
)
//
// A file containing groups of commonly used mock expectations.
// Note: Each group of mock expectations is associated with a single method
// that may be called during unit tests.
//
// GetMocksForCreateConsumerClient returns mock expectations needed to call CreateConsumerClient().
func GetMocksForCreateConsumerClient(ctx sdk.Context, mocks *MockedKeepers,
expectedChainID string, expectedLatestHeight clienttypes.Height,
) []*gomock.Call {
// append MakeConsumerGenesis and CreateClient expectations
expectations := GetMocksForMakeConsumerGenesis(ctx, mocks, time.Hour)
createClientExp := mocks.MockClientKeeper.EXPECT().CreateClient(
gomock.Any(),
// Allows us to expect a match by field. These are the only two client state values
// that are dependant on parameters passed to CreateConsumerClient.
extra.StructMatcher().Field(
"ChainId", expectedChainID).Field(
"LatestHeight", expectedLatestHeight,
),
gomock.Any(),
).Return("clientID", nil).Times(1)
expectations = append(expectations, createClientExp)
return expectations
}
// GetMocksForMakeConsumerGenesis returns mock expectations needed to call MakeConsumerGenesis().
func GetMocksForMakeConsumerGenesis(ctx sdk.Context, mocks *MockedKeepers,
unbondingTimeToInject time.Duration,
) []*gomock.Call {
return []*gomock.Call{
mocks.MockStakingKeeper.EXPECT().UnbondingTime(gomock.Any()).Return(unbondingTimeToInject).Times(1),
mocks.MockClientKeeper.EXPECT().GetSelfConsensusState(gomock.Any(),
clienttypes.GetSelfHeight(ctx)).Return(&ibctmtypes.ConsensusState{}, nil).Times(1),
mocks.MockStakingKeeper.EXPECT().IterateLastValidatorPowers(gomock.Any(), gomock.Any()).Times(1),
}
}
// GetMocksForSetConsumerChain returns mock expectations needed to call SetConsumerChain().
func GetMocksForSetConsumerChain(ctx sdk.Context, mocks *MockedKeepers,
chainIDToInject string,
) []*gomock.Call {
return []*gomock.Call{
mocks.MockChannelKeeper.EXPECT().GetChannel(ctx, ccv.ProviderPortID, gomock.Any()).Return(
channeltypes.Channel{
State: channeltypes.OPEN,
ConnectionHops: []string{"connectionID"},
},
true,
).Times(1),
mocks.MockConnectionKeeper.EXPECT().GetConnection(ctx, "connectionID").Return(
conntypes.ConnectionEnd{ClientId: "clientID"}, true,
).Times(1),
mocks.MockClientKeeper.EXPECT().GetClientState(ctx, "clientID").Return(
&ibctmtypes.ClientState{ChainId: chainIDToInject}, true,
).Times(1),
}
}
// GetMocksForStopConsumerChain returns mock expectations needed to call StopConsumerChain().
func GetMocksForStopConsumerChain(ctx sdk.Context, mocks *MockedKeepers) []*gomock.Call {
dummyCap := &capabilitytypes.Capability{}
return []*gomock.Call{
mocks.MockChannelKeeper.EXPECT().GetChannel(gomock.Any(), ccv.ProviderPortID, "channelID").Return(
channeltypes.Channel{State: channeltypes.OPEN}, true,
).Times(1),
mocks.MockScopedKeeper.EXPECT().GetCapability(gomock.Any(), gomock.Any()).Return(dummyCap, true).Times(1),
mocks.MockChannelKeeper.EXPECT().ChanCloseInit(gomock.Any(), ccv.ProviderPortID, "channelID", dummyCap).Times(1),
}
}
func GetMocksForHandleSlashPacket(ctx sdk.Context, mocks MockedKeepers,
expectedProviderValConsAddr providertypes.ProviderConsAddress,
valToReturn stakingtypes.Validator, expectJailing bool,
) []*gomock.Call {
// These first two calls are always made.
calls := []*gomock.Call{
mocks.MockStakingKeeper.EXPECT().GetValidatorByConsAddr(
ctx, expectedProviderValConsAddr.ToSdkConsAddr()).Return(
valToReturn, true,
).Times(1),
mocks.MockSlashingKeeper.EXPECT().IsTombstoned(ctx,
expectedProviderValConsAddr.ToSdkConsAddr()).Return(false).Times(1),
}
if expectJailing {
calls = append(calls, mocks.MockStakingKeeper.EXPECT().Jail(
gomock.Eq(ctx),
gomock.Eq(expectedProviderValConsAddr.ToSdkConsAddr()),
).Return())
// JailUntil is set in this code path.
calls = append(calls, mocks.MockSlashingKeeper.EXPECT().DowntimeJailDuration(ctx).Return(time.Hour).Times(1))
calls = append(calls, mocks.MockSlashingKeeper.EXPECT().JailUntil(ctx,
expectedProviderValConsAddr.ToSdkConsAddr(), gomock.Any()).Times(1))
}
return calls
}
func ExpectLatestConsensusStateMock(ctx sdk.Context, mocks MockedKeepers, clientID string, consState *ibctmtypes.ConsensusState) *gomock.Call {
return mocks.MockClientKeeper.EXPECT().
GetLatestClientConsensusState(ctx, clientID).Return(consState, true).Times(1)
}
func ExpectCreateClientMock(ctx sdk.Context, mocks MockedKeepers, clientID string, clientState *ibctmtypes.ClientState, consState *ibctmtypes.ConsensusState) *gomock.Call {
return mocks.MockClientKeeper.EXPECT().CreateClient(ctx, clientState, consState).Return(clientID, nil).Times(1)
}
func ExpectGetCapabilityMock(ctx sdk.Context, mocks MockedKeepers, times int) *gomock.Call {
return mocks.MockScopedKeeper.EXPECT().GetCapability(
ctx, host.PortPath(ccv.ConsumerPortID),
).Return(nil, true).Times(times)
}