Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache and reuse intermediate blobs #4330

Merged
merged 1 commit into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions server/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,24 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err
}
} else if strings.HasPrefix(c.Args, "@") {
blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
digest := strings.TrimPrefix(c.Args, "@")
if ib, ok := intermediateBlobs.Load(digest); ok {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't ib and digest the same thing? Can't you just call this as _, ok := intermediateBlobs.Load(digest); ok { and use digest below?

p, err := GetBlobsPath(ib.(string))
if err != nil {
return err
}

if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
// pass
} else if err != nil {
return err
} else {
fn(api.ProgressResponse{Status: fmt.Sprintf("using cached layer %s", ib.(string))})
digest = ib.(string)
}
}

blobpath, err := GetBlobsPath(digest)
if err != nil {
return err
}
Expand All @@ -351,14 +368,14 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
}
defer blob.Close()

baseLayers, err = parseFromFile(ctx, blob, fn)
baseLayers, err = parseFromFile(ctx, blob, digest, fn)
if err != nil {
return err
}
} else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil {
defer file.Close()

baseLayers, err = parseFromFile(ctx, file, fn)
baseLayers, err = parseFromFile(ctx, file, "", fn)
if err != nil {
return err
}
Expand Down Expand Up @@ -398,10 +415,14 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err
}

f16digest := baseLayer.Layer.Digest

baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
if err != nil {
return err
}

intermediateBlobs.Store(f16digest, baseLayer.Layer.Digest)
}
}

Expand Down
2 changes: 1 addition & 1 deletion server/layer.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
}, nil
}

func (l *Layer) Open() (io.ReadCloser, error) {
func (l *Layer) Open() (io.ReadSeekCloser, error) {
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return nil, err
Expand Down
23 changes: 9 additions & 14 deletions server/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ import (
"net/http"
"os"
"path/filepath"
"sync"

"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/types/model"
)

var intermediateBlobs sync.Map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm, this is a map of blob digests to ggml model layer digests? Does intermediate mean f16 but not yet quantized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intermediate here means any blob that's not referenced in but used to produce something in the final manifest. here's a concrete example:

FROM /path/to/safetensors/dir creates a zip (intermediate) converts into a f16 (intermediate) quantizes into a q4_0 (final)

when you want to create another quantization, this map tracks the relationship between the zip and f16 so it's able to skip uploading the zip and reconverting the f16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this probably deserves a comment describing its operation. I think most people aren't going to know what intermediate means here when they're going through the code.

I was trying to think os some other names, like maybe "blobCache" but then I think people will wonder why only some things are in the blobCache. Although why not shove everything into the blobCache? Why just put the "intermediate" stuff there?

Or the inverse; why not just put the intermediate stuff on desk and just use the normal machanism for pulling blobs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also a little confused about why sync.Map here instead of map[string]bool. I don't think there should be any contention here w/ mutexes given you're just shoving the same digest into the map.


type layerWithGGML struct {
*Layer
*llm.GGML
Expand Down Expand Up @@ -76,7 +79,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
return layers, nil
}

func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
stat, err := file.Stat()
if err != nil {
return nil, err
Expand Down Expand Up @@ -169,12 +172,7 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp
return nil, fmt.Errorf("aaa: %w", err)
}

blobpath, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}

bin, err := os.Open(blobpath)
bin, err := layer.Open()
if err != nil {
return nil, err
}
Expand All @@ -185,16 +183,13 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp
return nil, err
}

layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "")
if err != nil {
return nil, err
}

layers = append(layers, &layerWithGGML{layer, ggml})

intermediateBlobs.Store(digest, layer.Digest)
return layers, nil
}

func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
sr := io.NewSectionReader(file, 0, 512)
contentType, err := detectContentType(sr)
if err != nil {
Expand All @@ -205,7 +200,7 @@ func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressRespo
case "gguf", "ggla":
// noop
case "application/zip":
return parseFromZipFile(ctx, file, fn)
return parseFromZipFile(ctx, file, digest, fn)
default:
return nil, fmt.Errorf("unsupported content type: %s", contentType)
}
Expand Down
19 changes: 19 additions & 0 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,25 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
}

func (s *Server) CreateBlobHandler(c *gin.Context) {
ib, ok := intermediateBlobs.Load(c.Param("digest"))
if ok {
p, err := GetBlobsPath(ib.(string))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should probably include a debug statement in here for the cache hit but something else removed the storage.

intermediateBlobs.Delete(c.Param("digest"))
} else if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
} else {
c.Status(http.StatusOK)
return
}
}

path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
Expand Down
Loading