From e13c2d9c95422593bf0edd8dd717ecb992928506 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 10 May 2024 15:48:41 -0700 Subject: [PATCH] tmp --- server/images.go | 27 ++++++++++++++++++++++++--- server/layer.go | 2 +- server/model.go | 23 +++++++++-------------- server/routes.go | 19 +++++++++++++++++++ 4 files changed, 53 insertions(+), 18 deletions(-) diff --git a/server/images.go b/server/images.go index c74479353d7..0caf44580a7 100644 --- a/server/images.go +++ b/server/images.go @@ -307,7 +307,24 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio return err } } else if strings.HasPrefix(c.Args, "@") { - blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) + digest := strings.TrimPrefix(c.Args, "@") + slog.Info("original", "digest", digest) + if ib, ok := intermediateBlobs.Load(digest); ok { + p, err := GetBlobsPath(ib.(string)) + if err != nil { + return err + } + + if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) { + } else if err != nil { + return err + } else { + fn(api.ProgressResponse{Status: fmt.Sprintf("using cached layer %s", ib.(string))}) + digest = ib.(string) + } + } + + blobpath, err := GetBlobsPath(digest) if err != nil { return err } @@ -318,14 +335,14 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio } defer blob.Close() - baseLayers, err = parseFromFile(ctx, blob, fn) + baseLayers, err = parseFromFile(ctx, blob, digest, fn) if err != nil { return err } } else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil { defer file.Close() - baseLayers, err = parseFromFile(ctx, file, fn) + baseLayers, err = parseFromFile(ctx, file, "", fn) if err != nil { return err } @@ -365,10 +382,14 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio return err } + f16digest := baseLayer.Layer.Digest + baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType) if err != nil { return err } + + intermediateBlobs.Store(f16digest, baseLayer.Layer.Digest) } } diff --git a/server/layer.go b/server/layer.go index 1fdc02c42a6..bffccd5a392 100644 --- a/server/layer.go +++ b/server/layer.go @@ -82,7 +82,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) { }, nil } -func (l *Layer) Open() (io.ReadCloser, error) { +func (l *Layer) Open() (io.ReadSeekCloser, error) { blob, err := GetBlobsPath(l.Digest) if err != nil { return nil, err diff --git a/server/model.go b/server/model.go index 3e8f86ae690..f0885d3e886 100644 --- a/server/model.go +++ b/server/model.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "path/filepath" + "sync" "github.com/ollama/ollama/api" "github.com/ollama/ollama/convert" @@ -17,6 +18,8 @@ import ( "github.com/ollama/ollama/types/model" ) +var intermediateBlobs sync.Map + type layerWithGGML struct { *Layer *llm.GGML @@ -74,7 +77,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe return layers, nil } -func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { +func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { stat, err := file.Stat() if err != nil { return nil, err @@ -167,12 +170,7 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp return nil, fmt.Errorf("aaa: %w", err) } - blobpath, err := GetBlobsPath(layer.Digest) - if err != nil { - return nil, err - } - - bin, err := os.Open(blobpath) + bin, err := layer.Open() if err != nil { return nil, err } @@ -183,16 +181,13 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp return nil, err } - layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "") - if err != nil { - return nil, err - } - layers = append(layers, &layerWithGGML{layer, ggml}) + + intermediateBlobs.Store(digest, layer.Digest) return layers, nil } -func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { +func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { sr := io.NewSectionReader(file, 0, 512) contentType, err := detectContentType(sr) if err != nil { @@ -203,7 +198,7 @@ func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressRespo case "gguf", "ggla": // noop case "application/zip": - return parseFromZipFile(ctx, file, fn) + return parseFromZipFile(ctx, file, digest, fn) default: return nil, fmt.Errorf("unsupported content type: %s", contentType) } diff --git a/server/routes.go b/server/routes.go index e438e224a35..5ed05e83b25 100644 --- a/server/routes.go +++ b/server/routes.go @@ -762,6 +762,25 @@ func (s *Server) HeadBlobHandler(c *gin.Context) { } func (s *Server) CreateBlobHandler(c *gin.Context) { + ib, ok := intermediateBlobs.Load(c.Param("digest")) + if ok { + p, err := GetBlobsPath(ib.(string)) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) { + intermediateBlobs.Delete(c.Param("digest")) + } else if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } else { + c.Status(http.StatusOK) + return + } + } + path, err := GetBlobsPath(c.Param("digest")) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})