-
Notifications
You must be signed in to change notification settings - Fork 3.6k
/
filtered_fee.go
132 lines (108 loc) · 3.58 KB
/
filtered_fee.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package feegrant
import (
"time"
"github.com/gogo/protobuf/proto"
"github.com/cosmos/cosmos-sdk/codec/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
)
// TODO: Revisit this once we have propoer gas fee framework.
// Tracking issues https://github.com/cosmos/cosmos-sdk/issues/9054, https://github.com/cosmos/cosmos-sdk/discussions/9072
const (
gasCostPerIteration = uint64(10)
)
var _ FeeAllowanceI = (*AllowedMsgAllowance)(nil)
var _ types.UnpackInterfacesMessage = (*AllowedMsgAllowance)(nil)
// UnpackInterfaces implements UnpackInterfacesMessage.UnpackInterfaces
func (a *AllowedMsgAllowance) UnpackInterfaces(unpacker types.AnyUnpacker) error {
var allowance FeeAllowanceI
return unpacker.UnpackAny(a.Allowance, &allowance)
}
// NewAllowedMsgFeeAllowance creates new filtered fee allowance.
func NewAllowedMsgAllowance(allowance FeeAllowanceI, allowedMsgs []string) (*AllowedMsgAllowance, error) {
msg, ok := allowance.(proto.Message)
if !ok {
return nil, sdkerrors.Wrapf(sdkerrors.ErrPackAny, "cannot proto marshal %T", msg)
}
any, err := types.NewAnyWithValue(msg)
if err != nil {
return nil, err
}
return &AllowedMsgAllowance{
Allowance: any,
AllowedMessages: allowedMsgs,
}, nil
}
// GetAllowance returns allowed fee allowance.
func (a *AllowedMsgAllowance) GetAllowance() (FeeAllowanceI, error) {
allowance, ok := a.Allowance.GetCachedValue().(FeeAllowanceI)
if !ok {
return nil, sdkerrors.Wrap(ErrNoAllowance, "failed to get allowance")
}
return allowance, nil
}
// SetAllowance sets allowed fee allowance.
func (a *AllowedMsgAllowance) SetAllowance(allowance FeeAllowanceI) error {
var err error
a.Allowance, err = types.NewAnyWithValue(allowance.(proto.Message))
if err != nil {
return sdkerrors.Wrapf(sdkerrors.ErrPackAny, "cannot proto marshal %T", allowance)
}
return nil
}
// Accept method checks for the filtered messages has valid expiry
func (a *AllowedMsgAllowance) Accept(ctx sdk.Context, fee sdk.Coins, msgs []sdk.Msg) (bool, error) {
if !a.allMsgTypesAllowed(ctx, msgs) {
return false, sdkerrors.Wrap(ErrMessageNotAllowed, "message does not exist in allowed messages")
}
allowance, err := a.GetAllowance()
if err != nil {
return false, err
}
remove, err := allowance.Accept(ctx, fee, msgs)
if err == nil && !remove {
if err = a.SetAllowance(allowance); err != nil {
return false, err
}
}
return remove, err
}
func (a *AllowedMsgAllowance) allowedMsgsToMap(ctx sdk.Context) map[string]bool {
msgsMap := make(map[string]bool, len(a.AllowedMessages))
for _, msg := range a.AllowedMessages {
ctx.GasMeter().ConsumeGas(gasCostPerIteration, "check msg")
msgsMap[msg] = true
}
return msgsMap
}
func (a *AllowedMsgAllowance) allMsgTypesAllowed(ctx sdk.Context, msgs []sdk.Msg) bool {
msgsMap := a.allowedMsgsToMap(ctx)
for _, msg := range msgs {
ctx.GasMeter().ConsumeGas(gasCostPerIteration, "check msg")
if !msgsMap[sdk.MsgTypeURL(msg)] {
return false
}
}
return true
}
// ValidateBasic implements FeeAllowance and enforces basic sanity checks
func (a *AllowedMsgAllowance) ValidateBasic() error {
if a.Allowance == nil {
return sdkerrors.Wrap(ErrNoAllowance, "allowance should not be empty")
}
if len(a.AllowedMessages) == 0 {
return sdkerrors.Wrap(ErrNoMessages, "allowed messages shouldn't be empty")
}
allowance, err := a.GetAllowance()
if err != nil {
return err
}
return allowance.ValidateBasic()
}
func (a *AllowedMsgAllowance) ExpiresAt() (*time.Time, error) {
allowance, err := a.GetAllowance()
if err != nil {
return nil, err
}
return allowance.ExpiresAt()
}