Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 53 additions & 33 deletions cli/local/model_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,100 +41,120 @@ func CacheModels(apiSpec *spec.API, awsClient *aws.Client) ([]*spec.LocalModelCa
modelPaths[i] = modelResource.ModelPath
}

uncachedModelCount := 0

localModelCaches := make([]*spec.LocalModelCache, len(modelPaths))
for i, modelPath := range modelPaths {
var err error
localModelCaches[i], err = CacheModel(modelPath, awsClient)
modelCacheID, err := modelCacheID(modelPath, awsClient)
if err != nil {
if apiSpec.Predictor.ModelPath != nil {
return nil, errors.Wrap(err, apiSpec.Identify(), userconfig.PredictorKey, userconfig.ModelPathKey)
}
return nil, errors.Wrap(err, apiSpec.Identify(), userconfig.PredictorKey, userconfig.ModelsKey, apiSpec.Predictor.Models[i].Name, userconfig.ModelPathKey)
}
localModelCaches[i].TargetPath = apiSpec.Predictor.Models[i].Name

localModelCache := spec.LocalModelCache{
ID: modelCacheID,
HostPath: filepath.Join(_modelCacheDir, modelCacheID),
TargetPath: apiSpec.Predictor.Models[i].Name,
}

if !files.IsFile(filepath.Join(localModelCache.HostPath, "_SUCCESS")) {
err = cacheModel(modelPath, localModelCache, awsClient)
if err != nil {
if apiSpec.Predictor.ModelPath != nil {
return nil, errors.Wrap(err, apiSpec.Identify(), userconfig.PredictorKey, userconfig.ModelPathKey)
}
return nil, errors.Wrap(err, apiSpec.Identify(), userconfig.PredictorKey, userconfig.ModelsKey, apiSpec.Predictor.Models[i].Name, userconfig.ModelPathKey)
}
uncachedModelCount++
}

localModelCaches[i] = &localModelCache
}

if len(localModelCaches) > 0 {
if uncachedModelCount > 0 {
fmt.Println("") // Newline to group all of the model information
}

return localModelCaches, nil
}

func CacheModel(modelPath string, awsClient *aws.Client) (*spec.LocalModelCache, error) {
localModelCache := spec.LocalModelCache{}
var awsClientForBucket *aws.Client
var err error

func modelCacheID(modelPath string, awsClient *aws.Client) (string, error) {
if strings.HasPrefix(modelPath, "s3://") {
awsClientForBucket, err = aws.NewFromClientS3Path(modelPath, awsClient)
awsClientForBucket, err := aws.NewFromClientS3Path(modelPath, awsClient)
if err != nil {
return nil, err
return "", err
}
bucket, prefix, err := aws.SplitS3Path(modelPath)
if err != nil {
return nil, err
return "", err
}
hash, err := awsClientForBucket.HashS3Dir(bucket, prefix, nil)
if err != nil {
return nil, err
}
localModelCache.ID = hash
} else {
hash, err := localModelHash(modelPath)
if err != nil {
return nil, err
return "", err
}
localModelCache.ID = hash
return hash, nil
}

modelDir := filepath.Join(_modelCacheDir, localModelCache.ID)
hash, err := localModelHash(modelPath)
if err != nil {
return "", err
}
return hash, nil
}

func cacheModel(modelPath string, localModelCache spec.LocalModelCache, awsClient *aws.Client) error {
modelDir := localModelCache.HostPath

if files.IsFile(filepath.Join(modelDir, "_SUCCESS")) {
localModelCache.HostPath = modelDir
return &localModelCache, nil
return nil
}

err = ResetModelCacheDir(modelDir)
err := ResetModelCacheDir(modelDir)
if err != nil {
return nil, err
return err
}

if strings.HasPrefix(modelPath, "s3://") {
err := downloadModel(modelPath, modelDir, awsClientForBucket)
awsClientForBucket, err := aws.NewFromClientS3Path(modelPath, awsClient)
if err != nil {
return err
}

err = downloadModel(modelPath, modelDir, awsClientForBucket)
if err != nil {
return nil, err
return err
}
} else {
if strings.HasSuffix(modelPath, ".zip") {
err := unzipAndValidate(modelPath, modelPath, modelDir)
if err != nil {
return nil, err
return err
}
} else if strings.HasSuffix(modelPath, ".onnx") {
fmt.Println(fmt.Sprintf("○ caching model %s ...", modelPath))
err := files.CopyFileOverwrite(modelPath, filepath.Join(modelDir, filepath.Base(modelPath)))
if err != nil {
return nil, err
return err
}
} else {
fmt.Println(fmt.Sprintf("○ caching model %s ...", modelPath))
tfModelVersion := filepath.Base(modelPath)
err := files.CopyDirOverwrite(strings.TrimSuffix(modelPath, "/"), s.EnsureSuffix(filepath.Join(modelDir, tfModelVersion), "/"))
if err != nil {
return nil, err
return err
}
}
}

err = files.MakeEmptyFile(filepath.Join(modelDir, "_SUCCESS"))
if err != nil {
return nil, err
return err
}

localModelCache.HostPath = modelDir

return &localModelCache, nil
return nil
}

func DeleteCachedModels(apiName string, modelsToDelete []string) error {
Expand Down