diff --git a/cli/local/model_cache.go b/cli/local/model_cache.go index a9109dd215..522f43a7c7 100644 --- a/cli/local/model_cache.go +++ b/cli/local/model_cache.go @@ -77,6 +77,12 @@ func CacheModel(modelPath string, awsClient *aws.Client) (*spec.LocalModelCache, if err != nil { return nil, err } + } else if strings.HasSuffix(modelPath, ".onnx") { + fmt.Println(fmt.Sprintf("caching model %s ...", modelPath)) + err := files.CopyFileOverwrite(modelPath, filepath.Join(modelDir, filepath.Base(modelPath))) + if err != nil { + return nil, err + } } else { fmt.Println(fmt.Sprintf("caching model %s ...", modelPath)) err := files.CopyDirOverwrite(strings.TrimSuffix(modelPath, "/"), s.EnsureSuffix(modelDir, "/"))