Skip to content

Commit

Permalink
update create handler to use model.Name
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed Apr 19, 2024
1 parent b49dea6 commit a94c401
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 66 deletions.
25 changes: 6 additions & 19 deletions server/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ func realpath(rel, from string) string {
return abspath
}

func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) (err error) {
func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
Expand Down Expand Up @@ -491,31 +491,18 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
}
}

unref := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range manifest.Layers {
if !slices.Contains(digests, layer.Digest) {
unref[layer.Digest] = struct{}{}
}
}

if manifest.Config.Digest != layer.Digest {
unref[manifest.Config.Digest] = struct{}{}
if os.Getenv("OLLAMA_NOPRUNE") == "" {
fn(api.ProgressResponse{Status: "removing unused layers"})
if manifest, err := ParseNamedManifest(name); err == nil {
_ = manifest.Remove()
}
}

fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, layer, layers); err != nil {
if _, err := NewManifest(name, layer, layers); err != nil {
return err
}

if os.Getenv("OLLAMA_NOPRUNE") == "" && len(unref) > 0 {
fn(api.ProgressResponse{Status: "removing unused layers"})
if err := deleteUnusedLayers(nil, unref, false); err != nil {
return err
}
}

fn(api.ProgressResponse{Status: "success"})
return nil
}
Expand Down
46 changes: 30 additions & 16 deletions server/manifest.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package server

import (
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -80,30 +79,45 @@ func ParseNamedManifest(name model.Name) (*Manifest, error) {
}, nil
}

func WriteManifest(name string, config *Layer, layers []*Layer) error {
manifest := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,
Layers: layers,
func NewManifest(name model.Name, config *Layer, layers []*Layer) (*Manifest, error) {
manifests, err := GetManifestPath()
if err != nil {
return nil, err
}

var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(manifest); err != nil {
return err
manifestpath := filepath.Join(manifests, name.FilepathNoBuild())
if err := os.MkdirAll(filepath.Dir(manifestpath), 0o755); err != nil {
return nil, err
}

modelpath := ParseModelPath(name)
manifestPath, err := modelpath.GetManifestPath()
manifestfile, err := os.Create(manifestpath)
if err != nil {
return err
return nil, err
}
defer manifestfile.Close()

manifest := Manifest{
ManifestV2: ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,
Layers: layers,
},
}

if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
return err
sha256sum := sha256.New()
if err := json.NewEncoder(io.MultiWriter(manifestfile, sha256sum)).Encode(manifest); err != nil {
return nil, err
}

manifest.filepath = manifestpath
manifest.stat, err = manifestfile.Stat()
if err != nil {
return nil, err
}

return os.WriteFile(manifestPath, b.Bytes(), 0o644)
manifest.digest = fmt.Sprintf("%x", sha256sum.Sum(nil))
return &manifest, nil
}

type iter_Seq2[A, B any] func(func(A, B) bool)
Expand Down
5 changes: 2 additions & 3 deletions server/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ import (
)

func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (*ordered.Map[*Layer, *llm.GGML], error) {
modelpath := ParseModelPath(name.DisplayLongest())
manifest, _, err := GetManifest(modelpath)
manifest, err := ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.DisplayLongest(), &registryOptions{}, fn); err != nil {
Expand All @@ -34,7 +33,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe

layers := ordered.NewMap[*Layer, *llm.GGML]()
for _, layer := range manifest.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest(""))
if err != nil {
return nil, err
}
Expand Down
38 changes: 11 additions & 27 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,40 +594,24 @@ func PushModelHandler(c *gin.Context) {
}

func CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.CreateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}

if err := ParseModelPath(model).Validate(); err != nil {
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
name := model.ParseName(cmp.Or(r.Model, r.Name), "")
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}

var modelfile io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" {
mf, err := os.Open(req.Path)
var modelfile io.Reader = strings.NewReader(r.Modelfile)
if r.Path != "" && r.Modelfile == "" {
mf, err := os.Open(r.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return
Expand All @@ -653,12 +637,12 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()

if err := CreateModel(ctx, model, filepath.Dir(req.Path), strings.ToUpper(req.Quantization), commands, fn); err != nil {
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(r.Quantization), commands, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()

if req.Stream != nil && !*req.Stream {
if r.Stream != nil && !*r.Stream {
waitForStream(c, ch)
return
}
Expand Down
3 changes: 2 additions & 1 deletion server/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"github.com/ollama/ollama/api"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)

Expand Down Expand Up @@ -61,7 +62,7 @@ func Test_Routes(t *testing.T) {
fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status)
}
err = CreateModel(context.TODO(), name, "", "", commands, fn)
err = CreateModel(context.TODO(), model.ParseName(name, ""), "", "", commands, fn)
assert.Nil(t, err)
}

Expand Down

0 comments on commit a94c401

Please sign in to comment.