/
policy.go
329 lines (287 loc) · 10.2 KB
/
policy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
package types
import (
"bytes"
"encoding/json"
"fmt"
"math"
"reflect"
"strings"
sdkerrors "cosmossdk.io/errors"
sdk "github.com/cosmos/cosmos-sdk/types"
legacyerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/lavanet/lava/utils/decoder"
"github.com/lavanet/lava/utils/lavaslices"
epochstoragetypes "github.com/lavanet/lava/x/epochstorage/types"
"github.com/mitchellh/mapstructure"
)
const WILDCARD_CHAIN_POLICY = "*" // wildcard allows you to define only part of the chains and allow all others
// init policy default values (for fields that their natural zero value is not good)
// the values were chosen in a way that they will not influence the strictest policy calculation
var policyDefaultValues = map[string]interface{}{
"geolocation_profile": int32(Geolocation_GL),
"max_providers_to_pair": uint64(math.MaxUint64),
}
func (policy *Policy) ContainsChainID(chainID string) bool {
if policy == nil {
return false
}
if len(policy.ChainPolicies) == 0 {
// empty chainPolicies -> support all chains
return true
}
for _, chain := range policy.ChainPolicies {
if chain.ChainId == chainID {
return true
}
}
return false
}
// gets the chainPolicy if exists, null safe
func (policy *Policy) ChainPolicy(chainID string) (chainPolicy ChainPolicy, allowed bool) {
// empty policy | chainPolicies -> support all chains
if policy == nil || len(policy.ChainPolicies) == 0 {
return ChainPolicy{ChainId: chainID}, true
}
wildcard := false
for _, chain := range policy.ChainPolicies {
if chain.ChainId == chainID {
return chain, true
}
if chain.ChainId == WILDCARD_CHAIN_POLICY {
wildcard = true
}
}
if wildcard {
return ChainPolicy{ChainId: chainID}, true
}
return ChainPolicy{}, false
}
func (policy *Policy) GetSupportedAddons(specID string) (addons []string, err error) {
chainPolicy, allowed := policy.ChainPolicy(specID)
if !allowed {
return nil, fmt.Errorf("specID %s not allowed by current policy", specID)
}
addons = []string{""} // always allow an empty addon
for _, requirement := range chainPolicy.Requirements {
addons = append(addons, requirement.Collection.AddOn)
}
return addons, nil
}
func (policy *Policy) GetSupportedExtensions(specID string) (extensions []epochstoragetypes.EndpointService, err error) {
chainPolicy, allowed := policy.ChainPolicy(specID)
if !allowed {
return nil, fmt.Errorf("specID %s not allowed by current policy", specID)
}
extensions = []epochstoragetypes.EndpointService{}
for _, requirement := range chainPolicy.Requirements {
// always allow an empty extension
emptyExtension := epochstoragetypes.EndpointService{
ApiInterface: requirement.Collection.ApiInterface,
Addon: requirement.Collection.AddOn,
Extension: "",
}
extensions = append(extensions, emptyExtension)
for _, extension := range requirement.Extensions {
extensionServiceToAdd := epochstoragetypes.EndpointService{
ApiInterface: requirement.Collection.ApiInterface,
Addon: requirement.Collection.AddOn,
Extension: extension,
}
extensions = append(extensions, extensionServiceToAdd)
}
}
return extensions, nil
}
func (policy Policy) ValidateBasicPolicy(isPlanPolicy bool) error {
// plan policy checks
if isPlanPolicy {
if policy.EpochCuLimit == 0 || policy.TotalCuLimit == 0 {
return sdkerrors.Wrapf(ErrInvalidPolicyCuFields, `plan's compute units fields can't be zero
(EpochCuLimit = %v, TotalCuLimit = %v)`, policy.EpochCuLimit, policy.TotalCuLimit)
}
if policy.SelectedProvidersMode == SELECTED_PROVIDERS_MODE_DISABLED && len(policy.SelectedProviders) != 0 {
return sdkerrors.Wrap(ErrPolicyInvalidSelectedProvidersConfig, `cannot configure mode = 3 (selected
providers feature is disabled) and non-empty list of selected providers`)
}
// non-plan policy checks
} else if policy.SelectedProvidersMode == SELECTED_PROVIDERS_MODE_DISABLED {
return sdkerrors.Wrap(ErrPolicyInvalidSelectedProvidersConfig, `cannot configure mode = 3 (selected
providers feature is disabled) for a policy that is not plan policy`)
}
// general policy checks
if policy.EpochCuLimit > policy.TotalCuLimit {
return sdkerrors.Wrapf(ErrInvalidPolicyCuFields, "EpochCuLimit can't be larger than TotalCuLimit (EpochCuLimit = %v, TotalCuLimit = %v)", policy.EpochCuLimit, policy.TotalCuLimit)
}
if policy.MaxProvidersToPair <= 1 {
return sdkerrors.Wrapf(ErrInvalidPolicyMaxProvidersToPair, "invalid policy's MaxProvidersToPair fields (MaxProvidersToPair = %v)", policy.MaxProvidersToPair)
}
if policy.SelectedProvidersMode == SELECTED_PROVIDERS_MODE_ALLOWED && len(policy.SelectedProviders) != 0 {
return sdkerrors.Wrap(ErrPolicyInvalidSelectedProvidersConfig, `cannot configure mode = 0 (no
providers restrictions) and non-empty list of selected providers`)
}
if policy.GeolocationProfile == int32(Geolocation_GLS) && !isPlanPolicy {
return sdkerrors.Wrap(ErrPolicyGeolocation, `cannot configure geolocation = GLS (0)`)
}
if !IsValidGeoEnum(policy.GeolocationProfile) {
return sdkerrors.Wrap(ErrPolicyGeolocation, `invalid geolocation enum`)
}
seen := map[string]bool{}
for _, addr := range policy.SelectedProviders {
_, err := sdk.AccAddressFromBech32(addr)
if err != nil {
return sdkerrors.Wrapf(legacyerrors.ErrInvalidAddress, "invalid selected provider address (%s)", err)
}
if seen[addr] {
return sdkerrors.Wrapf(ErrPolicyInvalidSelectedProvidersConfig, "found duplicate provider address %s", addr)
}
seen[addr] = true
}
for _, chainPolicy := range policy.ChainPolicies {
for _, requirement := range chainPolicy.GetRequirements() {
if requirement.Collection.ApiInterface == "" {
return sdkerrors.Wrapf(legacyerrors.ErrInvalidRequest, "invalid requirement definition requirement must define collection with an apiInterface (%+v)", chainPolicy)
}
}
}
return nil
}
func GetStrictestChainPolicyForSpec(chainID string, policies []*Policy) (chainPolicyRet ChainPolicy, allowed bool) {
requirements := []ChainRequirement{}
for _, policy := range policies {
chainPolicy, allowdChain := policy.ChainPolicy(chainID)
if !allowdChain {
return ChainPolicy{}, false
}
// get the strictest collection specification, while empty is allowed
chainPolicyRequirements := chainPolicy.Requirements
// if no collection data is specified in the policy previous allowed is stricter and no update is necessary
if len(chainPolicyRequirements) == 0 {
continue
}
// this policy is limiting collection data so overwrite what is allowed
if len(requirements) == 0 {
requirements = chainPolicyRequirements
continue
}
// previous policies and current policy change collection data, we need the union of both
requirements = lavaslices.UnionByFunc(chainPolicyRequirements, requirements)
}
return ChainPolicy{ChainId: chainID, Requirements: requirements}, true
}
func VerifyTotalCuUsage(effectiveTotalCu uint64, cuUsage uint64) bool {
return cuUsage < effectiveTotalCu
}
// allows unmarshaling parser func
func (s SELECTED_PROVIDERS_MODE) MarshalJSON() ([]byte, error) {
buffer := bytes.NewBufferString(`"`)
buffer.WriteString(SELECTED_PROVIDERS_MODE_name[int32(s)])
buffer.WriteString(`"`)
return buffer.Bytes(), nil
}
// UnmarshalJSON unmashals a quoted json string to the enum value
func (s *SELECTED_PROVIDERS_MODE) UnmarshalJSON(b []byte) error {
var j string
err := json.Unmarshal(b, &j)
if err != nil {
return err
}
// Note that if the string cannot be found then the zero value is used ('Created' in this case)
*s = SELECTED_PROVIDERS_MODE(SELECTED_PROVIDERS_MODE_value[j])
return nil
}
func ParsePolicyFromYamlString(input string) (*Policy, error) {
return parsePolicyFromYaml(input, false)
}
func ParsePolicyFromYamlPath(path string) (*Policy, error) {
return parsePolicyFromYaml(path, true)
}
func parsePolicyFromYaml(from string, isPath bool) (*Policy, error) {
var policy Policy
enumHooks := []mapstructure.DecodeHookFunc{
PolicyEnumDecodeHookFunc,
}
var (
unused []string
unset []string
err error
)
if isPath {
err = decoder.DecodeFile(from, "Policy", &policy, enumHooks, &unset, &unused)
} else {
err = decoder.Decode(from, "Policy", &policy, enumHooks, &unset, &unused)
}
if err != nil {
return &policy, err
}
if len(unused) != 0 {
return &policy, fmt.Errorf("invalid policy: unknown field(s): %v", unused)
}
if len(unset) != 0 {
err = policy.HandleUnsetPolicyFields(unset)
if err != nil {
return &policy, err
}
}
return &policy, nil
}
// handleMissingPolicyFields sets default values to missing fields
func (p *Policy) HandleUnsetPolicyFields(unset []string) error {
defaultValues := make(map[string]interface{})
for _, field := range unset {
// fields without explicit default values use their natural default value
if defValue, ok := policyDefaultValues[field]; ok {
defaultValues[field] = defValue
}
}
return decoder.SetDefaultValues(defaultValues, p)
}
func DecodeSelectedProvidersMode(dataStr string) (interface{}, error) {
mode, found := SELECTED_PROVIDERS_MODE_value[dataStr]
if found {
return SELECTED_PROVIDERS_MODE(mode), nil
} else {
return 0, fmt.Errorf("invalid selected providers mode: %s", dataStr)
}
}
func (cr ChainRequirement) Differentiator() string {
if cr.Collection.ApiInterface == "" {
return ""
}
diff := cr.Collection.String() + strings.Join(cr.Extensions, ",")
if cr.Mixed {
diff = "mixed-" + diff
}
return diff
}
func PolicyEnumDecodeHookFunc(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
if t == reflect.TypeOf(Policy{}) {
policyMap, ok := data.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected data type for policy field")
}
// geolocation enum handling
geo, ok := policyMap["geolocation_profile"]
if ok {
if geoStr, ok := geo.(string); ok {
geoUint, err := ParseGeoEnum(geoStr)
if err != nil {
return nil, err
}
policyMap["geolocation_profile"] = geoUint
}
}
// selected providers mode enum handling
mode, ok := policyMap["selected_providers_mode"]
if ok {
if modeStr, ok := mode.(string); ok {
modeEnum, err := DecodeSelectedProvidersMode(modeStr)
if err != nil {
return nil, err
}
policyMap["selected_providers_mode"] = modeEnum
}
}
data = policyMap
}
return data, nil
}