/
trainedmodel_webhook.go
117 lines (98 loc) · 4.31 KB
/
trainedmodel_webhook.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/*
Copyright 2021 The KServe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package v1alpha1
import (
"fmt"
"regexp"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
"strings"
"github.com/kserve/kserve/pkg/agent/storage"
"github.com/kserve/kserve/pkg/utils"
"k8s.io/apimachinery/pkg/runtime"
logf "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/webhook"
)
// regular expressions for validation of isvc name
const (
CommaSpaceSeparator = ", "
TmNameFmt string = "[a-zA-Z0-9_-]+"
InvalidTmNameFormatError = "the Trained Model \"%s\" is invalid: a Trained Model name must consist of alphanumeric characters, '_', or '-'. (e.g. \"my-Name\" or \"abc_123\", regex used for validation is '%s')"
InvalidStorageUriFormatError = "the Trained Model \"%s\" storageUri field is invalid. The storage uri must start with one of the prefixes: %s. (the storage uri given is \"%s\")"
InvalidTmMemoryModification = "the Trained Model \"%s\" memory field is immutable. The memory was \"%s\" but it is updated to \"%s\""
)
var (
// log is for logging in this package.
tmLogger = logf.Log.WithName("trainedmodel-alpha1-validator")
// regular expressions for validation of tm name
TmRegexp = regexp.MustCompile("^" + TmNameFmt + "$")
// protocols that are accepted by storage uri
StorageUriProtocols = strings.Join(storage.GetAllProtocol(), CommaSpaceSeparator)
)
// +kubebuilder:webhook:verbs=create;update,path=/validate-trainedmodel,mutating=false,failurePolicy=fail,groups=serving.kserve.io,resources=trainedmodels,versions=v1alpha1,name=trainedmodel.kserve-webhook-server.validator
var _ webhook.Validator = &TrainedModel{}
// ValidateCreate implements webhook.Validator so a webhook will be registered for the type
func (tm *TrainedModel) ValidateCreate() (admission.Warnings, error) {
tmLogger.Info("validate create", "name", tm.Name)
return nil, utils.FirstNonNilError([]error{
tm.validateTrainedModel(),
})
}
// ValidateUpdate implements webhook.Validator so a webhook will be registered for the type
func (tm *TrainedModel) ValidateUpdate(old runtime.Object) (admission.Warnings, error) {
tmLogger.Info("validate update", "name", tm.Name)
oldTm := convertToTrainedModel(old)
return nil, utils.FirstNonNilError([]error{
tm.validateTrainedModel(),
tm.validateMemorySpecNotModified(oldTm),
})
}
// ValidateDelete implements webhook.Validator so a webhook will be registered for the type
func (tm *TrainedModel) ValidateDelete() (admission.Warnings, error) {
tmLogger.Info("validate delete", "name", tm.Name)
return nil, nil
}
// Validates ModelSpec memory is not modified from previous TrainedModel state
func (tm *TrainedModel) validateMemorySpecNotModified(oldTm *TrainedModel) error {
newTmMemory := tm.Spec.Model.Memory
oldTmMemory := oldTm.Spec.Model.Memory
if !newTmMemory.Equal(oldTmMemory) {
return fmt.Errorf(InvalidTmMemoryModification, tm.Name, oldTmMemory.String(), newTmMemory.String())
}
return nil
}
// Validates format of TrainedModel's fields
func (tm *TrainedModel) validateTrainedModel() error {
return utils.FirstNonNilError([]error{
tm.validateTrainedModelName(),
tm.validateStorageURI(),
})
}
// Convert runtime.Object into TrainedModel
func convertToTrainedModel(old runtime.Object) *TrainedModel {
tm := old.(*TrainedModel)
return tm
}
// Validates format for TrainedModel's name
func (tm *TrainedModel) validateTrainedModelName() error {
if !TmRegexp.MatchString(tm.Name) {
return fmt.Errorf(InvalidTmNameFormatError, tm.Name, TmRegexp)
}
return nil
}
// Validates TrainModel's storageURI
func (tm *TrainedModel) validateStorageURI() error {
if !utils.IsPrefixSupported(tm.Spec.Model.StorageURI, storage.GetAllProtocol()) {
return fmt.Errorf(InvalidStorageUriFormatError, tm.Name, StorageUriProtocols, tm.Spec.Model.StorageURI)
}
return nil
}