Skip to content

Commit

Permalink
Gallery repository (#663)
Browse files Browse the repository at this point in the history
Signed-off-by: mudler <mudler@localai.io>
  • Loading branch information
mudler committed Jun 24, 2023
1 parent 2a45a99 commit 60db595
Show file tree
Hide file tree
Showing 15 changed files with 640 additions and 190 deletions.
5 changes: 4 additions & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ func App(opts ...AppOption) (*fiber.App, error) {
// LocalAI API endpoints
applier := newGalleryApplier(options.loader.ModelPath)
applier.start(options.context, cm)
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C))

app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries))
app.Get("/models/list", listModelFromGallery(options.galleries, options.loader.ModelPath))
app.Get("/models/jobs/:uuid", getOpStatus(applier))

// openAI compatible API endpoint
Expand All @@ -120,6 +122,7 @@ func App(opts ...AppOption) (*fiber.App, error) {
// completion
app.Post("/v1/completions", completionEndpoint(cm, options))
app.Post("/completions", completionEndpoint(cm, options))
app.Post("/v1/engines/:model/completions", completionEndpoint(cm, options))

// embeddings
app.Post("/v1/embeddings", embeddingsEndpoint(cm, options))
Expand Down
106 changes: 105 additions & 1 deletion api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"runtime"

. "github.com/go-skynet/LocalAI/api"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
. "github.com/onsi/ginkgo/v2"
Expand All @@ -24,6 +25,7 @@ import (
)

type modelApplyRequest struct {
ID string `json:"id"`
URL string `json:"url"`
Name string `json:"name"`
Overrides map[string]string `json:"overrides"`
Expand Down Expand Up @@ -52,6 +54,35 @@ func getModelStatus(url string) (response map[string]interface{}) {
}
return
}

func getModels(url string) (response []gallery.GalleryModel) {

//url := "http://localhost:AI/models/apply"

// Create the request payload

// Create the HTTP request
resp, err := http.Get(url)
if err != nil {
return nil
}
defer resp.Body.Close()

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
return
}

// Unmarshal the response into a map[string]interface{}
err = json.Unmarshal(body, &response)
if err != nil {
fmt.Println("Error unmarshaling JSON response:", err)
return
}
return
}

func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {

//url := "http://localhost:AI/models/apply"
Expand Down Expand Up @@ -118,7 +149,33 @@ var _ = Describe("API test", func() {
modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background())

app, err = App(WithContext(c), WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir))
g := []gallery.GalleryModel{
{
Name: "bert",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
},
{
Name: "bert2",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Overrides: map[string]interface{}{"foo": "bar"},
AdditionalFiles: []gallery.File{gallery.File{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}},
},
}
out, err := yaml.Marshal(g)
Expect(err).ToNot(HaveOccurred())
err = ioutil.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644)
Expect(err).ToNot(HaveOccurred())

galleries := []gallery.Gallery{
{
Name: "test",
URL: "file://" + filepath.Join(tmpdir, "gallery_simple.yaml"),
},
}

app, err = App(WithContext(c),
WithGalleries(galleries),
WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir))
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")

Expand All @@ -143,6 +200,53 @@ var _ = Describe("API test", func() {
})

Context("Applying models", func() {
It("applies models from a gallery", func() {

models := getModels("http://127.0.0.1:9090/models/list")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))

response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "test@bert2",
})

Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))

uuid := response["uuid"].(string)
resp := map[string]interface{}{}
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
resp = response
return response["processed"].(bool)
}, "360s").Should(Equal(true))
Expect(resp["message"]).ToNot(ContainSubstring("error"))

dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml"))
Expect(err).ToNot(HaveOccurred())

_, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml"))
Expect(err).ToNot(HaveOccurred())

content := map[string]interface{}{}
err = yaml.Unmarshal(dat, &content)
Expect(err).ToNot(HaveOccurred())
Expect(content["backend"]).To(Equal("bert-embeddings"))
Expect(content["foo"]).To(Equal("bar"))

models = getModels("http://127.0.0.1:9090/models/list")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2")))
for _, m := range models {
if m.Name == "bert2" {
Expect(m.Installed).To(BeTrue())
} else {
Expect(m.Installed).To(BeFalse())
}
}
})
It("overrides models", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Expand Down
36 changes: 24 additions & 12 deletions api/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func ReadConfig(file string) (*Config, error) {
return c, nil
}

func (cm ConfigMerger) LoadConfigFile(file string) error {
func (cm *ConfigMerger) LoadConfigFile(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfigFile(file)
Expand All @@ -111,7 +111,7 @@ func (cm ConfigMerger) LoadConfigFile(file string) error {
return nil
}

func (cm ConfigMerger) LoadConfig(file string) error {
func (cm *ConfigMerger) LoadConfig(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfig(file)
Expand All @@ -123,14 +123,14 @@ func (cm ConfigMerger) LoadConfig(file string) error {
return nil
}

func (cm ConfigMerger) GetConfig(m string) (Config, bool) {
func (cm *ConfigMerger) GetConfig(m string) (Config, bool) {
cm.Lock()
defer cm.Unlock()
v, exists := cm.configs[m]
return v, exists
}

func (cm ConfigMerger) ListConfigs() []string {
func (cm *ConfigMerger) ListConfigs() []string {
cm.Lock()
defer cm.Unlock()
var res []string
Expand All @@ -140,7 +140,7 @@ func (cm ConfigMerger) ListConfigs() []string {
return res
}

func (cm ConfigMerger) LoadConfigs(path string) error {
func (cm *ConfigMerger) LoadConfigs(path string) error {
cm.Lock()
defer cm.Unlock()
entries, err := os.ReadDir(path)
Expand Down Expand Up @@ -316,20 +316,32 @@ func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (strin
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := cm.LoadConfig(modelConfig); err != nil {
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
}

var config *Config
cfg, exists := cm.GetConfig(modelFile)
if !exists {

defaults := func() {
config = defaultConfig(modelFile)
config.ContextSize = ctx
config.Threads = threads
config.F16 = f16
config.Debug = debug
}

cfg, exists := cm.GetConfig(modelFile)
if !exists {
if _, err := os.Stat(modelConfig); err == nil {
if err := cm.LoadConfig(modelConfig); err != nil {
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfg, exists = cm.GetConfig(modelFile)
if exists {
config = &cfg
} else {
defaults()
}
} else {
defaults()
}
} else {
config = &cfg
}
Expand Down

0 comments on commit 60db595

Please sign in to comment.