diff --git a/cli/local/model_cache.go b/cli/local/model_cache.go index 2f582ef8e8..8de2139ab7 100644 --- a/cli/local/model_cache.go +++ b/cli/local/model_cache.go @@ -64,7 +64,7 @@ func CacheLocalModels(apiSpec *spec.API, models []spec.CuratedModelResource) err if wasAlreadyCached { modelsThatWereCachedAlready++ } - if len(model.Versions) == 0 { + if model.IsFilePath || len(model.Versions) == 0 { localModelCache.TargetPath = filepath.Join(model.Name, "1") } else { localModelCache.TargetPath = model.Name @@ -98,7 +98,7 @@ func cacheLocalModel(model spec.CuratedModelResource) (*spec.LocalModelCache, bo destModelDir := filepath.Join(_modelCacheDir, localModelCache.ID) if files.IsDir(destModelDir) { - if len(model.Versions) == 0 { + if model.IsFilePath || len(model.Versions) == 0 { localModelCache.HostPath = filepath.Join(destModelDir, "1") } else { localModelCache.HostPath = destModelDir @@ -110,7 +110,7 @@ func cacheLocalModel(model spec.CuratedModelResource) (*spec.LocalModelCache, bo if err != nil { return nil, false, err } - if len(model.Versions) == 0 { + if model.IsFilePath || len(model.Versions) == 0 { if _, err := files.CreateDirIfMissing(filepath.Join(destModelDir, "1")); err != nil { return nil, false, err } @@ -137,10 +137,16 @@ func cacheLocalModel(model spec.CuratedModelResource) (*spec.LocalModelCache, bo } } - if len(model.Versions) == 0 { + if model.IsFilePath || len(model.Versions) == 0 { destModelDir = filepath.Join(destModelDir, "1") } - if err := files.CopyDirOverwrite(strings.TrimSuffix(model.Path, "/"), s.EnsureSuffix(destModelDir, "/")); err != nil { + + if model.IsFilePath { + err = files.CopyFileOverwrite(model.Path, filepath.Join(destModelDir, filepath.Base(model.Path))) + } else { + err = files.CopyDirOverwrite(strings.TrimSuffix(model.Path, "/"), s.EnsureSuffix(destModelDir, "/")) + } + if err != nil { return nil, false, err } diff --git a/pkg/cortex/downloader/download.py b/pkg/cortex/downloader/download.py index 2fbb6fd426..75b7ee3907 100644 --- a/pkg/cortex/downloader/download.py +++ b/pkg/cortex/downloader/download.py @@ -31,7 +31,7 @@ def start(args): if from_path.startswith("s3://"): bucket_name, prefix = S3.deconstruct_s3_path(from_path) - client = S3(bucket_name, client_config={}) + client = S3(bucket_name) elif from_path.startswith("gs://"): bucket_name, prefix = GCS.deconstruct_gcs_path(from_path) client = GCS(bucket_name) diff --git a/pkg/cortex/serve/cortex_internal/lib/model/type.py b/pkg/cortex/serve/cortex_internal/lib/model/type.py index d2164767ea..0df0cacdbe 100644 --- a/pkg/cortex/serve/cortex_internal/lib/model/type.py +++ b/pkg/cortex/serve/cortex_internal/lib/model/type.py @@ -188,17 +188,34 @@ def get_models_from_api_spec( for model in models: model_resource = {} model_resource["name"] = model["name"] - model_resource["s3_path"] = model["path"].startswith("s3://") - model_resource["gcs_path"] = model["path"].startswith("gs://") - model_resource["local_path"] = ( - not model_resource["s3_path"] and not model_resource["gcs_path"] - ) if not model["signature_key"]: model_resource["signature_key"] = models_spec["signature_key"] else: model_resource["signature_key"] = model["signature_key"] + ends_as_file_path = model["path"].endswith(".onnx") + if ends_as_file_path and os.path.exists( + os.path.join(model_dir, model_resource["name"], "1", os.path.basename(model["path"])) + ): + model_resource["is_file_path"] = True + model_resource["s3_path"] = False + model_resource["gcs_path"] = False + model_resource["local_path"] = True + model_resource["versions"] = [] + model_resource["path"] = os.path.join( + model_dir, model_resource["name"], "1", os.path.basename(model["path"]) + ) + model_resources.append(model_resource) + continue + model_resource["is_file_path"] = False + + model_resource["s3_path"] = model["path"].startswith("s3://") + model_resource["gcs_path"] = model["path"].startswith("gs://") + model_resource["local_path"] = ( + not model_resource["s3_path"] and not model_resource["gcs_path"] + ) + if model_resource["s3_path"] or model_resource["gcs_path"]: model_resource["path"] = model["path"] _, versions, _, _, _, _, _ = find_all_cloud_models( diff --git a/pkg/operator/operator/k8s.go b/pkg/operator/operator/k8s.go index 6d3e4774dc..bf464779c6 100644 --- a/pkg/operator/operator/k8s.go +++ b/pkg/operator/operator/k8s.go @@ -48,6 +48,7 @@ const ( const ( _specCacheDir = "/mnt/spec" + _modelDir = "/mnt/model" _emptyDirMountPath = "/mnt" _emptyDirVolumeName = "mnt" _tfServingContainerName = "serve" @@ -570,20 +571,43 @@ func pythonDownloadArgs(api *spec.API) string { } func onnxDownloadArgs(api *spec.API) string { - downloadConfig := downloadContainerConfig{ - LastLog: fmt.Sprintf(_downloaderLastLog, "onnx"), - DownloadArgs: []downloadContainerArg{ - { - From: config.BucketPath(api.ProjectKey), - To: path.Join(_emptyDirMountPath, "project"), - Unzip: true, - ItemName: "the project code", - HideFromLog: true, - HideUnzippingLog: true, - }, + downloadContainerArs := []downloadContainerArg{ + { + From: config.BucketPath(api.ProjectKey), + To: path.Join(_emptyDirMountPath, "project"), + Unzip: true, + ItemName: "the project code", + HideFromLog: true, + HideUnzippingLog: true, }, } + if api.Predictor.Models.Path != nil && strings.HasSuffix(*api.Predictor.Models.Path, ".onnx") { + downloadContainerArs = append(downloadContainerArs, downloadContainerArg{ + From: *api.Predictor.Models.Path, + To: path.Join(_modelDir, consts.SingleModelName, "1"), + ItemName: "the onnx model", + }) + } + + for _, model := range api.Predictor.Models.Paths { + if model == nil { + continue + } + if strings.HasSuffix(model.Path, ".onnx") { + downloadContainerArs = append(downloadContainerArs, downloadContainerArg{ + From: model.Path, + To: path.Join(_modelDir, model.Name, "1"), + ItemName: fmt.Sprintf("%s onnx model", model.Name), + }) + } + } + + downloadConfig := downloadContainerConfig{ + LastLog: fmt.Sprintf(_downloaderLastLog, "onnx"), + DownloadArgs: downloadContainerArs, + } + downloadArgsBytes, _ := json.Marshal(downloadConfig) return base64.URLEncoding.EncodeToString(downloadArgsBytes) } diff --git a/pkg/types/spec/api.go b/pkg/types/spec/api.go index 0622c5fc08..f34da5ccc4 100644 --- a/pkg/types/spec/api.go +++ b/pkg/types/spec/api.go @@ -57,10 +57,11 @@ type LocalModelCache struct { type CuratedModelResource struct { *userconfig.ModelResource - S3Path bool `json:"s3_path"` - GCSPath bool `json:"gcs_path"` - LocalPath bool `json:"local_path"` - Versions []int64 `json:"versions"` + S3Path bool `json:"s3_path"` + GCSPath bool `json:"gcs_path"` + LocalPath bool `json:"local_path"` + IsFilePath bool `json:"file_path"` + Versions []int64 `json:"versions"` } /* diff --git a/pkg/types/spec/errors.go b/pkg/types/spec/errors.go index 94caec0966..f7130ff791 100644 --- a/pkg/types/spec/errors.go +++ b/pkg/types/spec/errors.go @@ -65,6 +65,7 @@ const ( ErrInvalidPythonModelPath = "spec.invalid_python_model_path" ErrInvalidTensorFlowModelPath = "spec.invalid_tensorflow_model_path" ErrInvalidONNXModelPath = "spec.invalid_onnx_model_path" + ErrInvalidONNXModelFilePath = "spec.invalid_onnx_model_file_path" ErrDuplicateModelNames = "spec.duplicate_model_names" ErrReservedModelName = "spec.reserved_model_name" @@ -438,6 +439,17 @@ func ErrorInvalidONNXModelPath(modelPath string, modelSubPaths []string) error { }) } +func ErrorInvalidONNXModelFilePath(filePath string) error { + message := fmt.Sprintf("%s: invalid %s model file path; specify an ONNX file path or provide a directory with one of the following structures:\n", filePath, userconfig.ONNXPredictorType.CasedString()) + templateModelPath := "path/to/model/directory/" + message += fmt.Sprintf(_onnxVersionedExpectedStructMessage, templateModelPath, templateModelPath) + + return errors.WithStack(&errors.Error{ + Kind: ErrInvalidONNXModelFilePath, + Message: message, + }) +} + func ErrorDuplicateModelNames(duplicateModel string) error { return errors.WithStack(&errors.Error{ Kind: ErrDuplicateModelNames, diff --git a/pkg/types/spec/validations.go b/pkg/types/spec/validations.go index 9e803e6e54..e166fe3525 100644 --- a/pkg/types/spec/validations.go +++ b/pkg/types/spec/validations.go @@ -1197,18 +1197,26 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource, var modelWrapError func(error) error var modelResources []userconfig.ModelResource + var modelFileResources []userconfig.ModelResource if hasSingleModel { modelWrapError = func(err error) error { - return errors.Wrap(err, userconfig.ModelsPathKey) + return errors.Wrap(err, userconfig.ModelsKey, userconfig.ModelsPathKey) } - modelResources = []userconfig.ModelResource{ - { - Name: consts.SingleModelName, - Path: *predictor.Models.Path, - }, + modelResource := userconfig.ModelResource{ + Name: consts.SingleModelName, + Path: *predictor.Models.Path, + } + + if strings.HasSuffix(*predictor.Models.Path, ".onnx") && provider != types.LocalProviderType { + if err := validateONNXModelFilePath(*predictor.Models.Path, projectFiles.ProjectDir(), awsClient, gcpClient); err != nil { + return modelWrapError(err) + } + modelFileResources = append(modelFileResources, modelResource) + } else { + modelResources = append(modelResources, modelResource) + *predictor.Models.Path = s.EnsureSuffix(*predictor.Models.Path, "/") } - *predictor.Models.Path = s.EnsureSuffix(*predictor.Models.Path, "/") } if hasMultiModels { if len(predictor.Models.Paths) > 0 { @@ -1225,8 +1233,15 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource, path.Name, ) } - (*path).Path = s.EnsureSuffix((*path).Path, "/") - modelResources = append(modelResources, *path) + if strings.HasSuffix((*path).Path, ".onnx") && provider != types.LocalProviderType { + if err := validateONNXModelFilePath((*path).Path, projectFiles.ProjectDir(), awsClient, gcpClient); err != nil { + return errors.Wrap(modelWrapError(err), path.Name) + } + modelFileResources = append(modelFileResources, *path) + } else { + (*path).Path = s.EnsureSuffix((*path).Path, "/") + modelResources = append(modelResources, *path) + } } } @@ -1249,6 +1264,23 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource, return modelWrapError(err) } + for _, modelFileResource := range modelFileResources { + s3Path := strings.HasPrefix(modelFileResource.Path, "s3://") + gcsPath := strings.HasPrefix(modelFileResource.Path, "gs://") + localPath := !s3Path && !gcsPath + + *models = append(*models, CuratedModelResource{ + ModelResource: &userconfig.ModelResource{ + Name: modelFileResource.Name, + Path: modelFileResource.Path, + }, + S3Path: s3Path, + GCSPath: gcsPath, + LocalPath: localPath, + IsFilePath: true, + }) + } + if hasMultiModels { for _, model := range *models { if model.Name == consts.SingleModelName { @@ -1264,6 +1296,58 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource, return nil } +func validateONNXModelFilePath(modelPath string, projectDir string, awsClient *aws.Client, gcpClient *gcp.Client) error { + s3Path := strings.HasPrefix(modelPath, "s3://") + gcsPath := strings.HasPrefix(modelPath, "gs://") + localPath := !s3Path && !gcsPath + + if s3Path { + awsClientForBucket, err := aws.NewFromClientS3Path(modelPath, awsClient) + if err != nil { + return err + } + + bucket, modelPrefix, err := aws.SplitS3Path(modelPath) + if err != nil { + return err + } + + isS3File, err := awsClientForBucket.IsS3File(bucket, modelPrefix) + if err != nil { + return err + } + + if !isS3File { + return ErrorInvalidONNXModelFilePath(modelPrefix) + } + } + + if gcsPath { + bucket, modelPrefix, err := gcp.SplitGCSPath(modelPath) + if err != nil { + return err + } + + isGCSFile, err := gcpClient.IsGCSFile(bucket, modelPrefix) + if err != nil { + return err + } + + if !isGCSFile { + return ErrorInvalidONNXModelFilePath(modelPrefix) + } + } + + if localPath { + expandedLocalPath := files.RelToAbsPath(modelPath, projectDir) + if err := files.CheckFile(expandedLocalPath); err != nil { + return err + } + } + + return nil +} + func validatePythonPath(predictor *userconfig.Predictor, projectFiles ProjectFiles) error { if !projectFiles.HasDir(*predictor.PythonPath) { return ErrorPythonPathNotFound(*predictor.PythonPath)