From a94c4014c29f6968d44fe7bd159616cb3b7bf3ae Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 18 Apr 2024 15:23:54 -0700 Subject: [PATCH] update create handler to use model.Name --- server/images.go | 25 ++++++----------------- server/manifest.go | 46 ++++++++++++++++++++++++++++--------------- server/model.go | 5 ++--- server/routes.go | 38 +++++++++++------------------------ server/routes_test.go | 3 ++- 5 files changed, 51 insertions(+), 66 deletions(-) diff --git a/server/images.go b/server/images.go index 06792e5dc8..3b84a7bf5c 100644 --- a/server/images.go +++ b/server/images.go @@ -243,7 +243,7 @@ func realpath(rel, from string) string { return abspath } -func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) (err error) { +func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) (err error) { config := ConfigV2{ OS: "linux", Architecture: "amd64", @@ -491,31 +491,18 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c } } - unref := make(map[string]struct{}) - if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { - for _, layer := range manifest.Layers { - if !slices.Contains(digests, layer.Digest) { - unref[layer.Digest] = struct{}{} - } - } - - if manifest.Config.Digest != layer.Digest { - unref[manifest.Config.Digest] = struct{}{} + if os.Getenv("OLLAMA_NOPRUNE") == "" { + fn(api.ProgressResponse{Status: "removing unused layers"}) + if manifest, err := ParseNamedManifest(name); err == nil { + _ = manifest.Remove() } } fn(api.ProgressResponse{Status: "writing manifest"}) - if err := WriteManifest(name, layer, layers); err != nil { + if _, err := NewManifest(name, layer, layers); err != nil { return err } - if os.Getenv("OLLAMA_NOPRUNE") == "" && len(unref) > 0 { - fn(api.ProgressResponse{Status: "removing unused layers"}) - if err := deleteUnusedLayers(nil, unref, false); err != nil { - return err - } - } - fn(api.ProgressResponse{Status: "success"}) return nil } diff --git a/server/manifest.go b/server/manifest.go index 613adf61a9..ef41b42cdb 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -1,7 +1,6 @@ package server import ( - "bytes" "crypto/sha256" "encoding/json" "fmt" @@ -80,30 +79,45 @@ func ParseNamedManifest(name model.Name) (*Manifest, error) { }, nil } -func WriteManifest(name string, config *Layer, layers []*Layer) error { - manifest := ManifestV2{ - SchemaVersion: 2, - MediaType: "application/vnd.docker.distribution.manifest.v2+json", - Config: config, - Layers: layers, +func NewManifest(name model.Name, config *Layer, layers []*Layer) (*Manifest, error) { + manifests, err := GetManifestPath() + if err != nil { + return nil, err } - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(manifest); err != nil { - return err + manifestpath := filepath.Join(manifests, name.FilepathNoBuild()) + if err := os.MkdirAll(filepath.Dir(manifestpath), 0o755); err != nil { + return nil, err } - modelpath := ParseModelPath(name) - manifestPath, err := modelpath.GetManifestPath() + manifestfile, err := os.Create(manifestpath) if err != nil { - return err + return nil, err + } + defer manifestfile.Close() + + manifest := Manifest{ + ManifestV2: ManifestV2{ + SchemaVersion: 2, + MediaType: "application/vnd.docker.distribution.manifest.v2+json", + Config: config, + Layers: layers, + }, } - if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil { - return err + sha256sum := sha256.New() + if err := json.NewEncoder(io.MultiWriter(manifestfile, sha256sum)).Encode(manifest); err != nil { + return nil, err + } + + manifest.filepath = manifestpath + manifest.stat, err = manifestfile.Stat() + if err != nil { + return nil, err } - return os.WriteFile(manifestPath, b.Bytes(), 0o644) + manifest.digest = fmt.Sprintf("%x", sha256sum.Sum(nil)) + return &manifest, nil } type iter_Seq2[A, B any] func(func(A, B) bool) diff --git a/server/model.go b/server/model.go index 2d7797f0c5..9df923b130 100644 --- a/server/model.go +++ b/server/model.go @@ -19,8 +19,7 @@ import ( ) func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (*ordered.Map[*Layer, *llm.GGML], error) { - modelpath := ParseModelPath(name.DisplayLongest()) - manifest, _, err := GetManifest(modelpath) + manifest, err := ParseNamedManifest(name) switch { case errors.Is(err, os.ErrNotExist): if err := PullModel(ctx, name.DisplayLongest(), ®istryOptions{}, fn); err != nil { @@ -34,7 +33,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe layers := ordered.NewMap[*Layer, *llm.GGML]() for _, layer := range manifest.Layers { - layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) + layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest("")) if err != nil { return nil, err } diff --git a/server/routes.go b/server/routes.go index 2270843c3b..8ce63bde6c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -594,40 +594,24 @@ func PushModelHandler(c *gin.Context) { } func CreateModelHandler(c *gin.Context) { - var req api.CreateRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + var r api.CreateRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - var model string - if req.Model != "" { - model = req.Model - } else if req.Name != "" { - model = req.Name - } else { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - - if err := ParseModelPath(model).Validate(); err != nil { + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - if req.Path == "" && req.Modelfile == "" { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) + name := model.ParseName(cmp.Or(r.Model, r.Name), "") + if !name.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))}) return } - var modelfile io.Reader = strings.NewReader(req.Modelfile) - if req.Path != "" && req.Modelfile == "" { - mf, err := os.Open(req.Path) + var modelfile io.Reader = strings.NewReader(r.Modelfile) + if r.Path != "" && r.Modelfile == "" { + mf, err := os.Open(r.Path) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)}) return @@ -653,12 +637,12 @@ func CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := CreateModel(ctx, model, filepath.Dir(req.Path), strings.ToUpper(req.Quantization), commands, fn); err != nil { + if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(r.Quantization), commands, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() - if req.Stream != nil && !*req.Stream { + if r.Stream != nil && !*r.Stream { waitForStream(c, ch) return } diff --git a/server/routes_test.go b/server/routes_test.go index 1c32bb9ddb..6749754be6 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -18,6 +18,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -61,7 +62,7 @@ func Test_Routes(t *testing.T) { fn := func(resp api.ProgressResponse) { t.Logf("Status: %s", resp.Status) } - err = CreateModel(context.TODO(), name, "", "", commands, fn) + err = CreateModel(context.TODO(), model.ParseName(name, ""), "", "", commands, fn) assert.Nil(t, err) }