Skip to content

Commit

Permalink
reference license, template, system as files
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed May 8, 2024
1 parent 2687f02 commit 4c0b25f
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 49 deletions.
108 changes: 65 additions & 43 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ import (
)

func CreateHandler(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file")
filename, err := filepath.Abs(filename)
n, _ := cmd.Flags().GetString("file")
n, err := filepath.Abs(n)
if err != nil {
return err
}
Expand All @@ -56,7 +56,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()

f, err := os.Open(filename)
f, err := os.Open(n)
if err != nil {
return err
}
Expand All @@ -77,39 +77,61 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p.Add(status, spinner)

for i := range modelfile.Commands {
switch modelfile.Commands[i].Name {
case "model", "adapter":
path := modelfile.Commands[i].Args
if path == "~" {
path = home
} else if strings.HasPrefix(path, "~/") {
path = filepath.Join(home, path[2:])
if slices.Contains([]string{"model", "adapter", "license", "template", "system"}, modelfile.Commands[i].Name) {
p := modelfile.Commands[i].Args
if p == "" {
continue
} else if p == "~" {
p = home
} else if strings.HasPrefix(p, "~/") {
p = filepath.Join(home, p[2:])
}

if !filepath.IsAbs(path) {
path = filepath.Join(filepath.Dir(filename), path)
if !filepath.IsAbs(p) {
p = filepath.Join(filepath.Dir(n), p)
}

fi, err := os.Stat(path)
if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
continue
} else if err != nil {
return err
}
switch modelfile.Commands[i].Name {
case "model", "adapter":
fi, err := os.Stat(p)
if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
continue
} else if err != nil {
return err
}

if fi.IsDir() {
// this is likely a safetensors or pytorch directory
// TODO make this work w/ adapters
tempfile, err := tempZipFiles(p)
if err != nil {
return err
}
defer os.RemoveAll(tempfile)

p = tempfile
}
case "license", "template", "system":
ct, err := detectContentType(p)
if errors.Is(err, os.ErrNotExist) {
continue
} else if err != nil {
return err
} else if ct != "text/plain" {
return fmt.Errorf("invalid content type: expected text/plain for %s", p)
}

if fi.IsDir() {
// this is likely a safetensors or pytorch directory
// TODO make this work w/ adapters
tempfile, err := tempZipFiles(path)
fi, err := os.Stat(p)
if err != nil {
return err
}
defer os.RemoveAll(tempfile)

path = tempfile
if fi.Size() > 1<<16 {
return fmt.Errorf("file is too large: %s", p)
}
}

digest, err := createBlob(cmd, client, path)
digest, err := createBlob(cmd, client, p)
if err != nil {
return err
}
Expand Down Expand Up @@ -152,6 +174,24 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil
}

func detectContentType(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()

var b bytes.Buffer
b.Grow(512)

if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
return "", err
}

contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
return contentType, nil
}

func tempZipFiles(path string) (string, error) {
tempfile, err := os.CreateTemp("", "ollama-tf")
if err != nil {
Expand All @@ -162,24 +202,6 @@ func tempZipFiles(path string) (string, error) {
zipfile := zip.NewWriter(tempfile)
defer zipfile.Close()

detectContentType := func(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()

var b bytes.Buffer
b.Grow(512)

if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
return "", err
}

contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
return contentType, nil
}

glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern)
if err != nil {
Expand Down
27 changes: 21 additions & 6 deletions server/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,18 +339,18 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err
}
} else if strings.HasPrefix(c.Args, "@") {
blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
p, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}

blob, err := os.Open(blobpath)
b, err := os.Open(p)
if err != nil {
return err
}
defer blob.Close()
defer b.Close()

baseLayers, err = parseFromFile(ctx, blob, fn)
baseLayers, err = parseFromFile(ctx, b, fn)
if err != nil {
return err
}
Expand Down Expand Up @@ -415,8 +415,23 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
layers = append(layers, baseLayer.Layer)
}
case "license", "template", "system":
blob := strings.NewReader(c.Args)
layer, err := NewLayer(blob, mediatype)
var r io.Reader = strings.NewReader(c.Args)
if strings.HasPrefix(c.Args, "@") {
p, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}

b, err := os.Open(p)
if err != nil {
return err
}
defer b.Close()

r = b
}

layer, err := NewLayer(r, mediatype)
if err != nil {
return err
}
Expand Down

0 comments on commit 4c0b25f

Please sign in to comment.