diff --git a/cli/local/model_cache.go b/cli/local/model_cache.go index b4081e6b2d..165025f99f 100644 --- a/cli/local/model_cache.go +++ b/cli/local/model_cache.go @@ -41,100 +41,120 @@ func CacheModels(apiSpec *spec.API, awsClient *aws.Client) ([]*spec.LocalModelCa modelPaths[i] = modelResource.ModelPath } + uncachedModelCount := 0 + localModelCaches := make([]*spec.LocalModelCache, len(modelPaths)) for i, modelPath := range modelPaths { var err error - localModelCaches[i], err = CacheModel(modelPath, awsClient) + modelCacheID, err := modelCacheID(modelPath, awsClient) if err != nil { if apiSpec.Predictor.ModelPath != nil { return nil, errors.Wrap(err, apiSpec.Identify(), userconfig.PredictorKey, userconfig.ModelPathKey) } return nil, errors.Wrap(err, apiSpec.Identify(), userconfig.PredictorKey, userconfig.ModelsKey, apiSpec.Predictor.Models[i].Name, userconfig.ModelPathKey) } - localModelCaches[i].TargetPath = apiSpec.Predictor.Models[i].Name + + localModelCache := spec.LocalModelCache{ + ID: modelCacheID, + HostPath: filepath.Join(_modelCacheDir, modelCacheID), + TargetPath: apiSpec.Predictor.Models[i].Name, + } + + if !files.IsFile(filepath.Join(localModelCache.HostPath, "_SUCCESS")) { + err = cacheModel(modelPath, localModelCache, awsClient) + if err != nil { + if apiSpec.Predictor.ModelPath != nil { + return nil, errors.Wrap(err, apiSpec.Identify(), userconfig.PredictorKey, userconfig.ModelPathKey) + } + return nil, errors.Wrap(err, apiSpec.Identify(), userconfig.PredictorKey, userconfig.ModelsKey, apiSpec.Predictor.Models[i].Name, userconfig.ModelPathKey) + } + uncachedModelCount++ + } + + localModelCaches[i] = &localModelCache } - if len(localModelCaches) > 0 { + if uncachedModelCount > 0 { fmt.Println("") // Newline to group all of the model information } return localModelCaches, nil } -func CacheModel(modelPath string, awsClient *aws.Client) (*spec.LocalModelCache, error) { - localModelCache := spec.LocalModelCache{} - var awsClientForBucket *aws.Client - var err error - +func modelCacheID(modelPath string, awsClient *aws.Client) (string, error) { if strings.HasPrefix(modelPath, "s3://") { - awsClientForBucket, err = aws.NewFromClientS3Path(modelPath, awsClient) + awsClientForBucket, err := aws.NewFromClientS3Path(modelPath, awsClient) if err != nil { - return nil, err + return "", err } bucket, prefix, err := aws.SplitS3Path(modelPath) if err != nil { - return nil, err + return "", err } hash, err := awsClientForBucket.HashS3Dir(bucket, prefix, nil) if err != nil { - return nil, err - } - localModelCache.ID = hash - } else { - hash, err := localModelHash(modelPath) - if err != nil { - return nil, err + return "", err } - localModelCache.ID = hash + return hash, nil } - modelDir := filepath.Join(_modelCacheDir, localModelCache.ID) + hash, err := localModelHash(modelPath) + if err != nil { + return "", err + } + return hash, nil +} + +func cacheModel(modelPath string, localModelCache spec.LocalModelCache, awsClient *aws.Client) error { + modelDir := localModelCache.HostPath if files.IsFile(filepath.Join(modelDir, "_SUCCESS")) { - localModelCache.HostPath = modelDir - return &localModelCache, nil + return nil } - err = ResetModelCacheDir(modelDir) + err := ResetModelCacheDir(modelDir) if err != nil { - return nil, err + return err } if strings.HasPrefix(modelPath, "s3://") { - err := downloadModel(modelPath, modelDir, awsClientForBucket) + awsClientForBucket, err := aws.NewFromClientS3Path(modelPath, awsClient) + if err != nil { + return err + } + + err = downloadModel(modelPath, modelDir, awsClientForBucket) if err != nil { - return nil, err + return err } } else { if strings.HasSuffix(modelPath, ".zip") { err := unzipAndValidate(modelPath, modelPath, modelDir) if err != nil { - return nil, err + return err } } else if strings.HasSuffix(modelPath, ".onnx") { fmt.Println(fmt.Sprintf("○ caching model %s ...", modelPath)) err := files.CopyFileOverwrite(modelPath, filepath.Join(modelDir, filepath.Base(modelPath))) if err != nil { - return nil, err + return err } } else { fmt.Println(fmt.Sprintf("○ caching model %s ...", modelPath)) tfModelVersion := filepath.Base(modelPath) err := files.CopyDirOverwrite(strings.TrimSuffix(modelPath, "/"), s.EnsureSuffix(filepath.Join(modelDir, tfModelVersion), "/")) if err != nil { - return nil, err + return err } } } err = files.MakeEmptyFile(filepath.Join(modelDir, "_SUCCESS")) if err != nil { - return nil, err + return err } - localModelCache.HostPath = modelDir - - return &localModelCache, nil + return nil } func DeleteCachedModels(apiName string, modelsToDelete []string) error {