From c5e892cb3ef21b4ba315389210205b65e46b62aa Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 13 May 2024 14:41:37 -0700 Subject: [PATCH] update tests --- server/manifest_test.go | 110 ++++++++++++++++++------ server/routes_create_test.go | 160 +++++++++++++++++++++++++++++++++++ server/routes_delete_test.go | 71 ++++++++++++++++ server/routes_list_test.go | 61 +++++++++++++ 4 files changed, 375 insertions(+), 27 deletions(-) create mode 100644 server/routes_create_test.go create mode 100644 server/routes_delete_test.go create mode 100644 server/routes_list_test.go diff --git a/server/manifest_test.go b/server/manifest_test.go index 4da8674548..b85976fd2b 100644 --- a/server/manifest_test.go +++ b/server/manifest_test.go @@ -30,35 +30,76 @@ func createManifest(t *testing.T, path, name string) { } func TestManifests(t *testing.T) { - cases := map[string][]string{ + cases := map[string]struct { + ps []string + wantValidCount int + wantInvalidCount int + }{ "empty": {}, "single": { - filepath.Join("host", "namespace", "model", "tag"), + ps: []string{ + filepath.Join("host", "namespace", "model", "tag"), + }, + wantValidCount: 1, }, "multiple": { - filepath.Join("registry.ollama.ai", "library", "llama3", "latest"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"), + ps: []string{ + filepath.Join("registry.ollama.ai", "library", "llama3", "latest"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"), + }, + wantValidCount: 15, }, "hidden": { - filepath.Join("host", "namespace", "model", "tag"), - filepath.Join("host", "namespace", "model", ".hidden"), + ps: []string{ + filepath.Join("host", "namespace", "model", "tag"), + filepath.Join("host", "namespace", "model", ".hidden"), + }, + wantValidCount: 1, + wantInvalidCount: 1, }, "subdir": { - filepath.Join("host", "namespace", "model", "tag", "one"), - filepath.Join("host", "namespace", "model", "tag", "another", "one"), + ps: []string{ + filepath.Join("host", "namespace", "model", "tag", "one"), + filepath.Join("host", "namespace", "model", "tag", "another", "one"), + }, + wantInvalidCount: 2, + }, + "upper tag": { + ps: []string{ + filepath.Join("host", "namespace", "model", "TAG"), + }, + wantValidCount: 1, + }, + "upper model": { + ps: []string{ + filepath.Join("host", "namespace", "MODEL", "tag"), + }, + wantValidCount: 1, + }, + "upper namespace": { + ps: []string{ + filepath.Join("host", "NAMESPACE", "model", "tag"), + }, + wantValidCount: 1, + }, + "upper host": { + ps: []string{ + filepath.Join("HOST", "namespace", "model", "tag"), + }, + wantValidCount: 1, }, } @@ -67,8 +108,8 @@ func TestManifests(t *testing.T) { d := t.TempDir() t.Setenv("OLLAMA_MODELS", d) - for _, want := range wants { - createManifest(t, d, want) + for _, p := range wants.ps { + createManifest(t, d, p) } ms, err := Manifests() @@ -81,14 +122,29 @@ func TestManifests(t *testing.T) { ns = append(ns, k) } - for _, want := range wants { - n := model.ParseNameFromFilepath(want) + var gotValidCount, gotInvalidCount int + for _, p := range wants.ps { + n := model.ParseNameFromFilepath(p) + if n.IsValid() { + gotValidCount++ + } else { + gotInvalidCount++ + } + if !n.IsValid() && slices.Contains(ns, n) { - t.Errorf("unexpected invalid name: %s", want) + t.Errorf("unexpected invalid name: %s", p) } else if n.IsValid() && !slices.Contains(ns, n) { - t.Errorf("missing valid name: %s", want) + t.Errorf("missing valid name: %s", p) } } + + if gotValidCount != wants.wantValidCount { + t.Errorf("got valid count %d, want %d", gotValidCount, wants.wantValidCount) + } + + if gotInvalidCount != wants.wantInvalidCount { + t.Errorf("got invalid count %d, want %d", gotInvalidCount, wants.wantInvalidCount) + } }) } } diff --git a/server/routes_create_test.go b/server/routes_create_test.go new file mode 100644 index 0000000000..e5af1ded2f --- /dev/null +++ b/server/routes_create_test.go @@ -0,0 +1,160 @@ +package server + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "slices" + "testing" + + "github.com/gin-gonic/gin" + "github.com/ollama/ollama/api" +) + +var stream bool = false + +func createBinFile(t *testing.T) string { + t.Helper() + + f, err := os.CreateTemp(t.TempDir(), "") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil { + t.Fatal(err) + } + + if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil { + t.Fatal(err) + } + + if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil { + t.Fatal(err) + } + + if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil { + t.Fatal(err) + } + + return f.Name() +} + +type responseRecorder struct { + *httptest.ResponseRecorder + http.CloseNotifier +} + +func NewRecorder() *responseRecorder { + return &responseRecorder{ + ResponseRecorder: httptest.NewRecorder(), + } +} + +func (t *responseRecorder) CloseNotify() <-chan bool { + return make(chan bool) +} + +func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.ResponseRecorder { + t.Helper() + + w := NewRecorder() + c, _ := gin.CreateTestContext(w) + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(body); err != nil { + t.Fatal(err) + } + + c.Request = &http.Request{ + Body: io.NopCloser(&b), + } + + fn(c) + return w.ResponseRecorder +} + +func checkFileExists(t *testing.T, p string, expect []string) { + t.Helper() + + actual, err := filepath.Glob(p) + if err != nil { + t.Fatal(err) + } + + if !slices.Equal(actual, expect) { + t.Fatalf("expected slices to be equal %v", actual) + } +} + +func TestCreateFromBin(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + + var s Server + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"), + }) +} + +func TestCreateFromModel(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test2", + Modelfile: "FROM test", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"), + }) +} diff --git a/server/routes_delete_test.go b/server/routes_delete_test.go new file mode 100644 index 0000000000..ea098d0577 --- /dev/null +++ b/server/routes_delete_test.go @@ -0,0 +1,71 @@ +package server + +import ( + "fmt" + "net/http" + "path/filepath" + "testing" + + "github.com/ollama/ollama/api" +) + +func TestDelete(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)), + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test2", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)), + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"), + filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"), + }) + + w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"}) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"), + }) + + w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test2"}) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{}) + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{}) +} diff --git a/server/routes_list_test.go b/server/routes_list_test.go new file mode 100644 index 0000000000..e92b4eab4e --- /dev/null +++ b/server/routes_list_test.go @@ -0,0 +1,61 @@ +package server + +import ( + "encoding/json" + "fmt" + "net/http" + "slices" + "testing" + + "github.com/ollama/ollama/api" +) + +func TestList(t *testing.T) { + t.Setenv("OLLAMA_MODELS", t.TempDir()) + + expectNames := []string{ + "mistral:7b-instruct-q4_0", + "zephyr:7b-beta-q5_K_M", + "apple/OpenELM:latest", + "boreas:2b-code-v1.5-q6_K", + "notus:7b-v1-IQ2_S", + // TODO: host:port currently fails on windows (#4107) + // "localhost:5000/library/eurus:700b-v0.5-iq3_XXS", + "mynamespace/apeliotes:latest", + "myhost/mynamespace/lips:code", + } + + var s Server + for _, n := range expectNames { + createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: n, + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)), + }) + } + + w := createRequest(t, s.ListModelsHandler, nil) + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + var resp api.ListResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + if len(resp.Models) != len(expectNames) { + t.Fatalf("expected %d models, actual %d", len(expectNames), len(resp.Models)) + } + + actualNames := make([]string, len(resp.Models)) + for i, m := range resp.Models { + actualNames[i] = m.Name + } + + slices.Sort(actualNames) + slices.Sort(expectNames) + + if !slices.Equal(actualNames, expectNames) { + t.Fatalf("expected slices to be equal %v", actualNames) + } +}