diff --git a/cmd/cmd.go b/cmd/cmd.go index b0f9681cd5..a26ec5a0af 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -11,6 +11,7 @@ import ( "io" "log" "net" + "net/http" "os" "os/exec" "os/signal" @@ -98,22 +99,21 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } - models, err := client.List(context.Background()) + name := args[0] + // check if the model exists on the server + _, err = client.Show(context.Background(), &api.ShowRequest{Name: name}) if err != nil { - return err - } - - canonicalModelPath := server.ParseModelPath(args[0]) - for _, model := range models.Models { - if model.Name == canonicalModelPath.GetShortTagname() { - return RunGenerate(cmd, args) + var statusError api.StatusError + switch { + case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound: + if err := PullHandler(cmd, args); err != nil { + return err + } + case err != nil: + return err } } - if err := PullHandler(cmd, args); err != nil { - return err - } - return RunGenerate(cmd, args) } @@ -731,21 +731,6 @@ func RunServer(cmd *cobra.Command, _ []string) error { origins = strings.Split(o, ",") } - if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { - if err := server.PruneLayers(); err != nil { - return err - } - - manifestsPath, err := server.GetManifestPath() - if err != nil { - return err - } - - if err := server.PruneDirectory(manifestsPath); err != nil { - return err - } - } - return server.Serve(ln, origins) } diff --git a/server/routes.go b/server/routes.go index 5c52dbfefa..34aa41461c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -614,6 +614,22 @@ var defaultAllowOrigins = []string{ } func Serve(ln net.Listener, allowOrigins []string) error { + if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { + // clean up unused layers and manifests + if err := PruneLayers(); err != nil { + return err + } + + manifestsPath, err := GetManifestPath() + if err != nil { + return err + } + + if err := PruneDirectory(manifestsPath); err != nil { + return err + } + } + config := cors.DefaultConfig() config.AllowWildcard = true