-
Notifications
You must be signed in to change notification settings - Fork 0
/
invariants.go
138 lines (112 loc) · 3.84 KB
/
invariants.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package keeper
import (
"github.com/furya-official/mage/x/swap/types"
sdk "github.com/cosmos/cosmos-sdk/types"
)
// RegisterInvariants registers the swap module invariants
func RegisterInvariants(ir sdk.InvariantRegistry, k Keeper) {
ir.RegisterRoute(types.ModuleName, "pool-records", PoolRecordsInvariant(k))
ir.RegisterRoute(types.ModuleName, "share-records", ShareRecordsInvariant(k))
ir.RegisterRoute(types.ModuleName, "pool-reserves", PoolReservesInvariant(k))
ir.RegisterRoute(types.ModuleName, "pool-shares", PoolSharesInvariant(k))
}
// AllInvariants runs all invariants of the swap module
func AllInvariants(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) (string, bool) {
if res, stop := PoolRecordsInvariant(k)(ctx); stop {
return res, stop
}
if res, stop := ShareRecordsInvariant(k)(ctx); stop {
return res, stop
}
if res, stop := PoolReservesInvariant(k)(ctx); stop {
return res, stop
}
res, stop := PoolSharesInvariant(k)(ctx)
return res, stop
}
}
// PoolRecordsInvariant iterates all pool records and asserts that they are valid
func PoolRecordsInvariant(k Keeper) sdk.Invariant {
broken := false
message := sdk.FormatInvariant(types.ModuleName, "validate pool records broken", "pool record invalid")
return func(ctx sdk.Context) (string, bool) {
k.IteratePools(ctx, func(record types.PoolRecord) bool {
if err := record.Validate(); err != nil {
broken = true
return true
}
return false
})
return message, broken
}
}
// ShareRecordsInvariant iterates all share records and asserts that they are valid
func ShareRecordsInvariant(k Keeper) sdk.Invariant {
broken := false
message := sdk.FormatInvariant(types.ModuleName, "validate share records broken", "share record invalid")
return func(ctx sdk.Context) (string, bool) {
k.IterateDepositorShares(ctx, func(record types.ShareRecord) bool {
if err := record.Validate(); err != nil {
broken = true
return true
}
return false
})
return message, broken
}
}
// PoolReservesInvariant iterates all pools and ensures the total reserves matches the module account coins
func PoolReservesInvariant(k Keeper) sdk.Invariant {
message := sdk.FormatInvariant(types.ModuleName, "pool reserves broken", "pool reserves do not match module account")
return func(ctx sdk.Context) (string, bool) {
balance := k.bankKeeper.GetAllBalances(ctx, k.GetSwapModuleAccount(ctx).GetAddress())
reserves := sdk.Coins{}
k.IteratePools(ctx, func(record types.PoolRecord) bool {
for _, coin := range record.Reserves() {
reserves = reserves.Add(coin)
}
return false
})
broken := !reserves.IsEqual(balance)
return message, broken
}
}
type poolShares struct {
totalShares sdk.Int
totalSharesOwned sdk.Int
}
// PoolSharesInvariant iterates all pools and shares and ensures the total pool shares match the sum of depositor shares
func PoolSharesInvariant(k Keeper) sdk.Invariant {
broken := false
message := sdk.FormatInvariant(types.ModuleName, "pool shares broken", "pool shares do not match depositor shares")
return func(ctx sdk.Context) (string, bool) {
totalShares := make(map[string]poolShares)
k.IteratePools(ctx, func(pr types.PoolRecord) bool {
totalShares[pr.PoolID] = poolShares{
totalShares: pr.TotalShares,
totalSharesOwned: sdk.ZeroInt(),
}
return false
})
k.IterateDepositorShares(ctx, func(sr types.ShareRecord) bool {
if shares, found := totalShares[sr.PoolID]; found {
shares.totalSharesOwned = shares.totalSharesOwned.Add(sr.SharesOwned)
totalShares[sr.PoolID] = shares
} else {
totalShares[sr.PoolID] = poolShares{
totalShares: sdk.ZeroInt(),
totalSharesOwned: sr.SharesOwned,
}
}
return false
})
for _, ps := range totalShares {
if !ps.totalShares.Equal(ps.totalSharesOwned) {
broken = true
break
}
}
return message, broken
}
}