diff --git a/cli/cmd/lib_traffic_splitters.go b/cli/cmd/lib_traffic_splitters.go index fb62ea216a..63f8187f86 100644 --- a/cli/cmd/lib_traffic_splitters.go +++ b/cli/cmd/lib_traffic_splitters.go @@ -102,7 +102,7 @@ func trafficSplitterListTable(trafficSplitter []schema.TrafficSplitter, envNames lastUpdated := time.Unix(splitAPI.Spec.LastUpdated, 0) var apis []string for _, api := range splitAPI.Spec.APIs { - apis = append(apis, api.Name+":"+s.Int(api.Weight)) + apis = append(apis, api.Name+":"+s.Int32(api.Weight)) } apisStr := s.TruncateEllipses(strings.Join(apis, " "), 50) rows = append(rows, []interface{}{ diff --git a/pkg/lib/k8s/virtual_service.go b/pkg/lib/k8s/virtual_service.go index ba548468aa..91303725ab 100644 --- a/pkg/lib/k8s/virtual_service.go +++ b/pkg/lib/k8s/virtual_service.go @@ -17,6 +17,8 @@ limitations under the License. package k8s import ( + "reflect" + "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/sets/strset" "github.com/cortexlabs/cortex/pkg/lib/urls" @@ -246,3 +248,23 @@ func ExtractVirtualServiceEndpoints(virtualService *istioclientnetworking.Virtua } return endpoints } + +func VirtualServicesMatch(vs1, vs2 istionetworking.VirtualService) bool { + if !strset.New(vs1.Hosts...).IsEqual(strset.New(vs2.Hosts...)) { + return false + } + + if !strset.New(vs1.Gateways...).IsEqual(strset.New(vs2.Gateways...)) { + return false + } + + if !strset.New(vs1.ExportTo...).IsEqual(strset.New(vs2.ExportTo...)) { + return false + } + + if !reflect.DeepEqual(vs1.Http, vs2.Http) { + return false + } + + return true +} diff --git a/pkg/operator/resources/trafficsplitter/api.go b/pkg/operator/resources/trafficsplitter/api.go index f123c18871..b91ece3437 100644 --- a/pkg/operator/resources/trafficsplitter/api.go +++ b/pkg/operator/resources/trafficsplitter/api.go @@ -19,7 +19,6 @@ package trafficsplitter import ( "fmt" "path/filepath" - "reflect" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/k8s" @@ -33,7 +32,7 @@ import ( ) func UpdateAPI(apiConfig *userconfig.API, projectID string, force bool) (*spec.API, string, error) { - prevVirtualService, err := getK8sResources(apiConfig) + prevVirtualService, err := config.K8s.GetVirtualService(operator.K8sName(apiConfig.Name)) if err != nil { return nil, "", err } @@ -55,7 +54,7 @@ func UpdateAPI(apiConfig *userconfig.API, projectID string, force bool) (*spec.A return api, fmt.Sprintf("created %s", api.Resource.UserString()), nil } - if !areVirtualServiceEqual(prevVirtualService, virtualServiceSpec(api)) { + if !areAPIsEqual(prevVirtualService, virtualServiceSpec(api)) { if err := config.AWS.UploadMsgpackToS3(api, config.Cluster.Bucket, api.Key); err != nil { return nil, "", errors.Wrap(err, "upload api spec") } @@ -67,6 +66,7 @@ func UpdateAPI(apiConfig *userconfig.API, projectID string, force bool) (*spec.A } return api, fmt.Sprintf("updated %s", api.Resource.UserString()), nil } + return api, fmt.Sprintf("%s is up to date", api.Resource.UserString()), nil } @@ -105,17 +105,6 @@ func DeleteAPI(apiName string, keepCache bool) error { return nil } -func getK8sResources(apiConfig *userconfig.API) (*istioclientnetworking.VirtualService, error) { - var virtualService *istioclientnetworking.VirtualService - - virtualService, err := config.K8s.GetVirtualService(operator.K8sName(apiConfig.Name)) - if err != nil { - return nil, err - } - - return virtualService, err -} - func applyK8sVirtualService(trafficSplitter *spec.API, prevVirtualService *istioclientnetworking.VirtualService) error { newVirtualService := virtualServiceSpec(trafficSplitter) @@ -133,7 +122,7 @@ func getTrafficSplitterDestinations(trafficSplitter *spec.API) []k8s.Destination for i, api := range trafficSplitter.APIs { destinations[i] = k8s.Destination{ ServiceName: operator.K8sName(api.Name), - Weight: int32(api.Weight), + Weight: api.Weight, Port: uint32(_defaultPortInt32), } } @@ -201,11 +190,10 @@ func deleteS3Resources(apiName string) error { return config.AWS.DeleteS3Dir(config.Cluster.Bucket, prefix, true) } -func areVirtualServiceEqual(vs1, vs2 *istioclientnetworking.VirtualService) bool { - return vs1.ObjectMeta.Name == vs2.ObjectMeta.Name && - reflect.DeepEqual(vs1.ObjectMeta.Labels, vs2.ObjectMeta.Labels) && - reflect.DeepEqual(vs1.ObjectMeta.Annotations, vs2.ObjectMeta.Annotations) && - reflect.DeepEqual(vs1.Spec.Http, vs2.Spec.Http) && - reflect.DeepEqual(vs1.Spec.Gateways, vs2.Spec.Gateways) && - reflect.DeepEqual(vs1.Spec.Hosts, vs2.Spec.Hosts) +func areAPIsEqual(vs1, vs2 *istioclientnetworking.VirtualService) bool { + return vs1.Labels["apiName"] == vs2.Labels["apiName"] && + vs1.Labels["apiKind"] == vs2.Labels["apiKind"] && + vs1.Labels["apiID"] == vs2.Labels["apiID"] && + k8s.VirtualServicesMatch(vs1.Spec, vs2.Spec) && + operator.DoCortexAnnotationsMatch(vs1, vs2) } diff --git a/pkg/operator/resources/trafficsplitter/k8s_specs.go b/pkg/operator/resources/trafficsplitter/k8s_specs.go index b691135a4a..824633c341 100644 --- a/pkg/operator/resources/trafficsplitter/k8s_specs.go +++ b/pkg/operator/resources/trafficsplitter/k8s_specs.go @@ -21,7 +21,6 @@ import ( "github.com/cortexlabs/cortex/pkg/lib/pointer" "github.com/cortexlabs/cortex/pkg/operator/operator" "github.com/cortexlabs/cortex/pkg/types/spec" - "github.com/cortexlabs/cortex/pkg/types/userconfig" istioclientnetworking "istio.io/client-go/pkg/apis/networking/v1alpha3" ) @@ -36,9 +35,7 @@ func virtualServiceSpec(trafficSplitter *spec.API) *istioclientnetworking.Virtua Destinations: getTrafficSplitterDestinations(trafficSplitter), ExactPath: trafficSplitter.Networking.Endpoint, Rewrite: pointer.String("predict"), - Annotations: map[string]string{ - userconfig.EndpointAnnotationKey: *trafficSplitter.Networking.Endpoint, - userconfig.APIGatewayAnnotationKey: trafficSplitter.Networking.APIGateway.String()}, + Annotations: trafficSplitter.ToK8sAnnotations(), Labels: map[string]string{ "apiName": trafficSplitter.Name, "apiKind": trafficSplitter.Kind.String(), diff --git a/pkg/types/spec/errors.go b/pkg/types/spec/errors.go index cd67eed473..65a76290ed 100644 --- a/pkg/types/spec/errors.go +++ b/pkg/types/spec/errors.go @@ -397,7 +397,7 @@ func ErrorInsufficientBatchConcurrencyLevelInf(maxBatchSize int32, threadsPerPro }) } -func ErrorIncorrectTrafficSplitterWeightTotal(totalWeight int) error { +func ErrorIncorrectTrafficSplitterWeightTotal(totalWeight int32) error { return errors.WithStack(&errors.Error{ Kind: ErrIncorrectTrafficSplitterWeight, Message: fmt.Sprintf("expected weights to sum to 100 but found %d", totalWeight), diff --git a/pkg/types/spec/validations.go b/pkg/types/spec/validations.go index 526cfd1d66..1dd27515b2 100644 --- a/pkg/types/spec/validations.go +++ b/pkg/types/spec/validations.go @@ -114,10 +114,10 @@ func multiAPIsValidation() *cr.StructFieldValidation { }, { StructField: "Weight", - IntValidation: &cr.IntValidation{ + Int32Validation: &cr.Int32Validation{ Required: true, - GreaterThanOrEqualTo: pointer.Int(0), - LessThanOrEqualTo: pointer.Int(100), + GreaterThanOrEqualTo: pointer.Int32(0), + LessThanOrEqualTo: pointer.Int32(100), }, }, }, @@ -1267,7 +1267,7 @@ func validateDockerImagePath(image string, providerType types.ProviderType, awsC } func verifyTotalWeight(apis []*userconfig.TrafficSplit) error { - totalWeight := 0 + totalWeight := int32(0) for _, api := range apis { totalWeight += api.Weight } diff --git a/pkg/types/userconfig/api.go b/pkg/types/userconfig/api.go index 7dc5db3fda..39a93b9d2c 100644 --- a/pkg/types/userconfig/api.go +++ b/pkg/types/userconfig/api.go @@ -60,7 +60,7 @@ type Predictor struct { type TrafficSplit struct { Name string `json:"name" yaml:"name"` - Weight int `json:"weight" yaml:"weight "` + Weight int32 `json:"weight" yaml:"weight"` } type ModelResource struct { @@ -343,7 +343,7 @@ func (api *API) UserStr(provider types.ProviderType) string { func (trafficSplit *TrafficSplit) UserStr() string { var sb strings.Builder sb.WriteString(fmt.Sprintf("%s: %s\n", NameKey, trafficSplit.Name)) - sb.WriteString(fmt.Sprintf("%s: %s\n", WeightKey, s.Int(trafficSplit.Weight))) + sb.WriteString(fmt.Sprintf("%s: %s\n", WeightKey, s.Int32(trafficSplit.Weight))) return sb.String() }